Source code for pinspect.traverse

"""
The main function is `find()`.
"""


import contextlib
import inspect
import logging
import re
import uuid
from pprint import pformat

import networkx as nx
from tqdm import tqdm

from pinspect.logger import init_logger
from pinspect.utils import get_module_root, IgnoreFunc, REGEX_NEVER_MATCH, NON_EXECUTABLE, to_pyvis, to_string, \
    check_edge

init_logger()


[docs]class DiGraphAcyclic(nx.DiGraph): """ Directed Acyclic Graph. """
[docs] def add_edge(self, u, v_obj, label=None, **attr): """ Adds `id(v_obj)` node in the graph, if not present, and then adds an edge from `u` to `id(v_obj)`. Parameters ---------- u : int or str A node from. v_obj : object A node object to. The node of the object is `id(v_obj)`. label : str Edge label. Returns ------- bool If the edge has been successfully added or not. If adding the edge from `u` to `v_obj` closes a cycle, returns False. Otherwise, returns True. """ v_of_edge = self.add_node(v_obj, level=self.nodes[u]['level'] + 1) if nx.has_path(self, v_of_edge, u): # makes cycle return False if label.endswith('()'): color = 'red' elif self.nodes[v_of_edge]['level'] < self.nodes[u]['level']: # level up color = 'magenta' else: color = None super().add_edge(u, v_of_edge, label=label, color=color, **attr) return True
[docs] def add_node(self, obj, **attr): """ Adds `obj` in the graph, if not present. Parameters ---------- obj : object An object to add in the graph. Returns ------- obj_id : int Node id. Notes ----- Due to the fact that two objects with non-overlapping lifetime might have the same identifier (address in memory), adding a node might overwrite the node with the same ID. """ if isinstance(obj, Exception): obj_id = uuid.uuid4().hex color = 'red' else: obj_id = id(obj) color = None if obj_id in self.nodes: return obj_id label = obj.__class__.__name__ if isinstance(obj, (set, list, tuple, dict)): label = f"{label} of size {len(obj)}" title = pformat(obj, depth=1, compact=True) title = title.splitlines() title_short = title[0] if len(title) > 1: title_short = f"{title_short} ..., {title[-1]}" title_short = title_short.strip('<>') super().add_node(obj_id, label=label, title=title_short, color=color, **attr) return obj_id
[docs]class GraphBuilder: def __init__(self, obj, key, ignore_key='', ignore_class=(), max_depth=10): """ Parameters ---------- obj : object An object to inspect for `key`. key : str A key to look for. ignore_key : str or list, optional A string or a list of strings to ignore `obj` attributes and methods from being accessed and executed. Apart from user-provided strings, all methods that contain one of the following key-words will be ignored: 'save', 'write', 'remove', 'delete', 'duplicate' For the total list of ignored key-words, see `NON_EXECUTABLE` in `utils.py`. ignore_class : list, optional A list of class types to ignore. Apart from user-provided class types, all numpy functions will not be executed. max_depth : int, optional The max recursion depth. Default is 10. Raises ------ ValueError If the `key` is a part of `ignore_key`. """ if key == '': key = REGEX_NEVER_MATCH self.obj = obj self.obj_saved = [] # prevent being collected by GC self.key = re.compile(key, flags=re.IGNORECASE) self.graph = DiGraphAcyclic() self.module = get_module_root(obj) if not isinstance(ignore_key, str): ignore_key = '|'.join(ignore_key) ignore_key = f"{NON_EXECUTABLE}|{ignore_key}".rstrip('|') if re.search(ignore_key, key): raise ValueError(f"The key='{key}' cannot be a part of ignore_key='{ignore_key}'") self.ignore_attribute = IgnoreFunc(key=ignore_key, obj_class=ignore_class) self.tried_functions = set() self.tried_classes = set() self.max_depth = max_depth self.graph.add_node(obj, level=0)
[docs] def traverse(self, obj, parent_edge=None, level=0): if level >= self.max_depth: return if parent_edge is not None: parent, edge_name = parent_edge if not self.graph.add_edge(id(parent), obj, label=edge_name): # makes a cycle return if isinstance(obj, (bool, int, str, float, type)): # ignore builtin types return if isinstance(obj, dict): for key, value in obj.items(): self.traverse(value, parent_edge=(obj, f"['{key}']"), level=level + 1) return if isinstance(obj, (set, list, tuple)): if len(obj) > 0: element = next(iter(obj)) if parent_edge is not None: parent, edge_name = parent_edge self.graph.remove_node(id(obj)) self.traverse(element, parent_edge=(parent, f"{edge_name}[0]"), level=level + 1) else: self.traverse(element, parent_edge=(obj, "[0]"), level=level + 1) return if get_module_root(obj) != self.module: # we're interested only in functions of the given module return if obj.__class__ in self.tried_classes: return self.tried_classes.add(obj.__class__) logging.debug(f"{' ' * level}Inspecting {obj.__class__.__name__} (level={level}): {obj}") for attr_name in tqdm(dir(obj), desc=f"Inspecting '{obj.__class__.__name__}'", disable=level > 0): if attr_name.startswith('__'): continue if self.ignore_attribute(obj, attr_name): continue try: attr = getattr(obj, attr_name) except ValueError: continue full_name = f"{obj.__class__.__name__}.{attr_name}" if callable(attr) and full_name not in self.tried_functions: self.tried_functions.add(full_name) try: logging.debug(f"{' ' * (level + 1)}Executing {obj.__class__.__name__}.{attr_name}()") with contextlib.redirect_stdout(None): res = attr() except Exception as err: # create a new exception to make sure the id is unique err = err.__class__(str(err)) self.graph.add_edge(id(obj), err, label=f"{attr_name}()") else: self.obj_saved.append(res) self.traverse(res, parent_edge=(obj, f"{attr_name}()"), level=level + 1) elif not inspect.ismethod(attr): self.traverse(attr, parent_edge=(obj, attr_name), level=level + 1)
[docs] def strip(self, with_methods=True): graph = self.graph.reverse(copy=False) graph_stripped = nx.DiGraph() for node_from in nx.topological_sort(graph): if self.key.search(graph.nodes[node_from].get('label', '')): graph.nodes[node_from]['color'] = 'green' include_parents = True else: include_parents = graph.nodes[node_from].get('include', False) for node_to, edge_attr in graph.adj[node_from].items(): method_hit = with_methods and self.key.search(edge_attr['label']) if include_parents or method_hit: graph.nodes[node_to]['include'] = True graph_stripped.add_node(node_to, **graph.nodes[node_to]) graph_stripped.add_node(node_from, **graph.nodes[node_from]) graph_stripped.add_edge(node_to, node_from, **edge_attr) graph_stripped.nodes[id(self.obj)]['color'] = 'blue' return graph_stripped
[docs]def find(obj, key, ignore_key='', ignore_class=(), verbose=True, visualize=True): """ Traverse the object `obj` and find methods and attributes that match the `key`. Parameters ---------- obj : object An object to inspect for `key`. key : str A key to look for. ignore_key : str or list, optional A string or a list of strings to ignore `obj` attributes and methods from being accessed and executed. Apart from user-provided strings, all methods that contain one of the following key-words will be ignored: 'save', 'write', 'remove', 'delete', 'duplicate' For the total list of ignored key-words, see `NON_EXECUTABLE` in `utils.py`. ignore_class : list, optional A list of class types to ignore. Apart from user-provided class types, all numpy functions will not be executed. verbose : bool, optional If set to True, prints found matches in console. Default is True. visualize : bool, optional If set to True, renders a graph in a web browser, using `pyvis` package. Returns ------- graph : nx.DiGraph Stripped graph with edges and nodes that match the `key`. Raises ------ ValueError If the `key` is a part of `ignore_key`. """ builder = GraphBuilder(obj, key=key, ignore_key=ignore_key, ignore_class=ignore_class) builder.traverse(obj) builder.obj_saved.clear() graph = builder.strip(with_methods=True) logging.info(f"Stripped graph length: {len(builder.graph)} -> {len(graph)}") if verbose: if len(graph) == 0: print("No match") else: matches = to_string(graph, source=id(obj), prefix=obj.__class__.__name__) print('\n'.join(matches)) # to_pyvis(builder.graph, layout=False).show('full.html') if visualize and len(graph) > 0: network_pyvis = to_pyvis(graph) network_pyvis.show(name=f"{obj.__class__.__name__}.html") return graph