import inspect
import subprocess
import tempfile
import numpy as np
import re
"""
# DocTAPE #
------------
Documentation Testing and Automated Placement of Expressions
A collection of utility functions (and wrappers for Glue) that are useful
for automating the process of building and testing documentation to ensure that
documentation doesn't get stale.
expected_error is an execption that can be used in try/except blocks to allow desired errors to
pass while still raising unexpected errors.
gramatical_list combines the elements of a list into a string with proper punctuation
check_value is a simple function for comparing two values
check_contains confirms that all the elements of one iterable are contained in the other
check_args gets the signature of a function and compares it to the arguments you are expecting
run_command_no_file_error executes a CLI command but won't fail if a FileNotFoundError is raised
get_attribute_name gets the name of an object's attribute based on it's value
get_all_keys recursively get all of the keys from a dict of dicts
get_value recursively get a value from a dict of dicts
glue_variable Glue a variable for later use in markdown cells of notebooks (can auto format for code)
glue_keys recursively glue all of the keys from a dict of dicts
"""
[docs]
class expected_error(Exception):
...
[docs]
def gramatical_list(list_of_strings: list, cc='and', add_accents=False) -> str:
"""
Combines the elements of a list into a string with proper punctuation
Parameters
----------
list_of_strings : list
A list of strings (or elements with a string representation)
cc : str, optional
The coordinating conjunction to use with the list (default is `and`)
add_accents : bool, optional
Whether or not to wrap each element with ` characters (default is False)
Returns
-------
str
A string that combines the elements of the list into a string with proper punctuation
"""
list_of_strings = ['`'+str(s)+'`' if add_accents else str(s)
for s in list_of_strings]
if len(list_of_strings) == 1:
return list_of_strings[0]
elif len(list_of_strings) == 2:
return list_of_strings[0]+' '+cc+' '+list_of_strings[1]
else:
return ', '.join(list_of_strings[:-1]+[cc+' '+list_of_strings[-1]])
[docs]
def get_previous_line(n=1) -> str:
"""
returns the previous n line(s) of code as a string
Parameters
----------
n : int
The number of lines to return (default is 1)
Returns
-------
str
A string that contains the previous line of code or a
list that contains the previous n lines of code
"""
pframe = inspect.currentframe().f_back # get the previous frame that called this function
# get the lines of code as a list of strings
lines, first_line = inspect.getsourcelines(pframe)
# get the line number of the line that called this function
lineno = pframe.f_lineno - first_line if first_line else pframe.f_lineno - 1
# get the previous lines
return lines[lineno-n:lineno] if n > 1 else lines[lineno-1].strip()
[docs]
def get_variable_name(*variables) -> str:
"""
returns the name of the variable passed to the function as a string
# NOTE: You cannot call this function multiple times on one line
Parameters
----------
variables : any
The variable(s) of interest
Returns
-------
str
A string that contains the name of variable passed to this function
(or list of strings, if multiple arguments are passed)
"""
pframe = inspect.currentframe().f_back # get the previous frame that called this function
# get the lines of code as a list of strings
lines, first_line = inspect.getsourcelines(pframe)
# get the line number that called this function
lineno = pframe.f_lineno - first_line if first_line else pframe.f_lineno - 1
# extract the argument and remove all whitespace
arg = ''.join(lines[lineno].split()).split('get_variable_name(', 1)[1]
# Use regex to match balanced parentheses
match = re.match(r'([^()]*\([^()]*\))*[^()]*', arg)
if match:
arg = match.group(0)
# # Requires Python 3.11, but allows this to be called multiple times on one line
# positions = inspect.getframeinfo(pframe).positions
# calling_lines = lines[positions.lineno-1:positions.end_lineno]
# calling_lines[-1] = calling_lines[-1][:positions.end_col_offset-1]
# calling_lines[0] = calling_lines[0][positions.col_offset:].removeprefix('get_variable_name(')
# arg = ''.join([l.strip() for l in calling_lines])
if ',' in arg:
return arg.split(',')
else:
return arg
[docs]
def check_value(val1, val2, error_type=ValueError):
"""
Compares two values and raises a ValueError if they are not equal.
This method checks whether the provided values are equal. For primitive data types
such as strings, integers, floats, lists, tuples, dictionaries, and sets, it uses
the equality operator. For other types, it uses identity comparison.
Parameters
----------
val1 : any
The first value to be compared.
val2 : any
The second value to be compared.
error_type : Exception, optional
The exception to raise (default is ValueError)
Raises
------
ValueError
If the values are not equal (or not the same object for non-primitive types).
"""
if isinstance(val1, (str, int, float, list, tuple, dict, set, np.ndarray, type({}.keys()))):
if val1 != val2:
raise error_type(f"{val1} is not equal to {val2}")
else:
if val1 is not val2:
raise error_type(f"{val1} is not {val2}")
[docs]
def check_contains(expected_values, actual_values, error_string="{var} not in {actual_values}", error_type=RuntimeError):
"""
Checks that all of the expected_values exist in actual_values
(It does not check for missing values)
Parameters
----------
expected_values : any iterable
This can also be a single value, in which case it will be wrapped into a list
actual_values : any iterable
error_string : str, optional
The string to display as the error message,
kwarg substitutions will be made using .format() for "var" and "actual_values"
error_type : Exception, optional
The exception to raise (default is RuntimeError)
Raises
------
RuntimeError
If a value in expected_values is not present in actual_values
"""
# if a single expected item is provided, wrap it
if not hasattr(expected_values, '__class_getitem__'):
expected_values = [expected_values]
for var in expected_values:
if var not in actual_values:
raise error_type(error_string.format(var=var, actual_values=actual_values))
[docs]
def check_args(func, expected_args: tuple[list, dict, str], args_to_ignore: tuple[list, tuple] = ['self'], exact=True, error_type=ValueError):
"""
Checks that the expected arguments are valid for a given function.
This method verifies that the provided `expected_args` match the actual arguments
of the given function `func`. If `exact` is True, the method checks for an exact
match. If `exact` is False, it only checks that the provided `expected_args` are
included in the actual arguments (it won't fail if the function has additional arguments).
Parameters
----------
func : function
The function whose arguments are being checked.
expected_args : list, dict, or str
The expected arguments. If a dict, the values will be compared to the default values.
If a string, it will be treated as a single argument of interest. (exact will be set to False)
args_to_ignore : list or tuple, optional
Arguments to ignore during the check (default is ['self']).
exact : bool, optional
Whether to check for an exact match of arguments (default is True).
error_type : Exception, optional
The exception to raise (default is ValueError)
Raises
------
ValueError
If the expected arguments do not match the actual arguments of the function.
"""
if isinstance(expected_args, str):
expected_args = [expected_args]
exact = False
params = inspect.signature(func).parameters
available_args = {
arg: params[arg].default for arg in params if arg not in args_to_ignore}
if exact:
if isinstance(expected_args, dict):
check_value(available_args, expected_args)
else:
check_value(sorted(available_args), sorted(expected_args))
else:
for arg in expected_args:
if arg not in available_args:
raise error_type(f'{arg} is not a valid argument for {func.__name__}')
elif isinstance(expected_args, dict) and expected_args[arg] != available_args[arg]:
raise error_type(
f"the default value of {arg} is {available_args[arg]}, not {expected_args[arg]}")
[docs]
def run_command_no_file_error(command: str, verbose=False):
"""
Executes a CLI command and handles FileNotFoundError separately.
This method runs a given command in a temporary directory and captures the output.
If the command returns a non-zero exit code, it checks the error message. If the
error is a FileNotFoundError, it prints the error name. For other errors, it prints
the full error message.
Parameters
----------
command : str
The CLI command to be executed.
verbose : bool
Whether or not to include the error message if FileNotFoundError is raised
Raises
------
CalledProcessError
If the command returns a non-zero exit code (except for FileNotFoundError).
"""
with tempfile.TemporaryDirectory() as tempdir:
rc = subprocess.run(command.split(), cwd=tempdir, capture_output=True, text=True)
if rc.returncode:
err, info = rc.stderr.split('\n')[-2].split(':', 1)
if err == 'FileNotFoundError':
if verbose:
print(info)
print(
f"A file required by {command} couldn't be found, continuing anyway")
else:
print(rc.stderr)
rc.check_returncode()
[docs]
def get_attribute_name(object: object, attribute, error_type=AttributeError) -> str:
"""
Gets the name of an object's attribute based on it's value
This is intended for use with Enums and other objects that have unique values.
This method will return the name of the first attribute that has a value that
matches the value provided.
Parameters
----------
object : any
The object whose attributes will be searched
attribute : any
The value of interest
error_type : Exception, optional
The exception to raise (default is AttributeError)
Returns
-------
name : str
The name of the attribute
Raises
------
AttributeError
If the object has no attributes with the provided value.
"""
for name, val in object.__dict__.items():
if val == attribute:
return name
raise error_type(
f"`{object.__name__}` object has no attribute with a value of `{attribute}`")
[docs]
def get_all_keys(dict_of_dicts: dict, track_layers=False, all_keys=None) -> list:
"""
Recursively get all of the keys from a dict of dicts
This can also be used to recursively get all of the attributes from a complex object, like the Aircraft hierarchy
Note: this will not add duplicates of keys, but will
continue deeper even if a key is duplicated
Parameters
----------
dict_of_dicts : dict
The dictionary who's keys will be gathered
track_layers : Bool
Whether or not to track where keys inside the dict of dicts
came from. This will get every key, by ensuring that all keys
have a unique name by tracking the path it took to get there.
all_keys : list
A list of the keys that have been found so far
Returns
-------
all_keys : list
A list of all the keys in the dict_of_dicts
"""
if not isinstance(dict_of_dicts, dict):
dict_of_dicts = dict_of_dicts.__dict__
if all_keys is None:
all_keys = []
for key, val in dict_of_dicts.items():
if key.startswith('__') and key.endswith('__'):
continue
if track_layers is True:
current_layer = ''
elif track_layers:
current_layer = track_layers
if track_layers and current_layer:
key = current_layer+'.'+key
if key not in all_keys:
all_keys.append(key)
if isinstance(val, dict) or hasattr(val, '__dict__'):
if track_layers:
current_layer = key
else:
current_layer = False
all_keys = get_all_keys(val, track_layers=current_layer, all_keys=all_keys)
return all_keys
[docs]
def get_value(dict_of_dicts: dict, comlpete_key: str):
"""
Recursively get a value from a dict of dicts
Parameters
----------
dict_of_dicts : dict
complete_key : str
A string that contains the full path through the dict_of_dicts
(i.e. dictkey1.dictkey2.keyofinterest)
Returns
-------
val : any
The value found
"""
for key in comlpete_key.split('.'):
if not isinstance(dict_of_dicts, dict):
dict_of_dicts = dict_of_dicts.__dict__
dict_of_dicts = dict_of_dicts[key]
return dict_of_dicts
[docs]
def glue_variable(name: str, val=None, md_code=False, display=True):
"""
Glue a variable for later use in markdown cells of notebooks
Note:
glue_variable(get_variable_name(Aircraft.APU.MASS))
can be used to glue the name of the variable (Aircraft.APU.MASS)
not the value of the variable ('aircraft:apu:mass')
Parameters
----------
name : str
The name the value will be glued to
val : any
The value to be displayed in the markdown cell (default is the value of name)
md_code : Bool
Whether to wrap the value in markdown code formatting (e.g. `code`)
"""
# local import so myst isn't required unless glue is being used
from myst_nb import glue
from IPython.display import Markdown
from IPython.utils import io
if val is None:
val = name
if md_code:
val = Markdown('`'+val+'`')
else:
val = Markdown(val)
with io.capture_output() as captured:
glue(name, val, display)
# if display:
captured.show()
[docs]
def glue_keys(dict_of_dicts: dict, display=True) -> list:
"""
Recursively glue all of the keys from a dict of dicts
Parameters
----------
dict_of_dicts : dict
The dictionary who's keys will be glued
Returns
-------
all_keys : list
A list of all the keys that were glued
"""
if not isinstance(dict_of_dicts, dict):
track_layers = dict_of_dicts.__name__
else:
track_layers = False
all_keys = get_all_keys(dict_of_dicts, track_layers)
for key in all_keys:
glue_variable(key, md_code=True, display=display)
return all_keys