Source code for liberator.core

"""
Extracts relevant parts of the source code

Note:
    If the source code changes while the run is executing then this may not
    work correctly.

TODO:

- [x] Maintain a parse tree instead of raw lines

- [x] Keep a mapping from "definition names" to the top-level nodes in the parse tree that define them.

- [X] For each extracted node in the parse tree keep track of

    - [X] where it came from

    - [ ] what modifications were made to it

- [ ] Handle expanding imports nested within functions

- [ ] Maintain docstring formatting after using the node transformer


Issues:

    - [ ] We currently (0.0.1) get a KeyError in the case where, a module is imported like `import mod.submod` and all usage is of the form `mod.submod.attr`, then
"""
import ast
import astunparse
import copy
import inspect
import io
import sys
import ubelt as ub
import warnings
from os.path import isdir
from os.path import join
from os.path import basename
from os.path import abspath
from collections import OrderedDict

__all__ = ['Liberator', 'Closer']


__todo__ = """

# FIXME:
# The following case has misformatted docstrings and needs to handle
# the duplicate function name issue.

import ubelt as ub
import liberator
lib = liberator.Liberator()

import ubelt
import xdoctest
lib.add_dynamic(ubelt.util_import.modpath_to_modname)
lib.add_dynamic(ubelt.util_import.modname_to_modpath)
lib.add_dynamic(xdoctest.static_analysis.package_modpaths)
lib.expand(['ubelt', 'xdoctest'])
text = lib.current_sourcecode()
print(text)


"""


class LocalLogger:
    """
    A non-global logger used for specific code paths or class instances.
    """
    def __init__(self, tag='', verbose=0):
        self.verbose = verbose
        self.logs = []
        self.tag = tag
        self.indent = ''
        if verbose >= 2:
            self.debug('init new logger')

    def warn(self, msg):
        line = '[WARN.{}] '.format(self.tag) + self.indent + msg
        self.logs.append(line)
        if self.verbose >= 0:
            print(line)

    def error(self, msg):
        line = '[ERROR.{}] '.format(self.tag) + self.indent + msg
        self.logs.append(line)
        if self.verbose >= 0:
            print(line)

    def info(self, msg):
        line = '[INFO.{}] '.format(self.tag) + self.indent + msg
        self.logs.append(line)
        if self.verbose >= 1:
            print(line)

    def debug(self, msg):
        line = '[DEBUG.{}] '.format(self.tag) + self.indent + msg
        self.logs.append(line)
        if self.verbose >= 2:
            print(line)

    def _print_logs(self):
        print('\n'.join(self.logs))

    @classmethod
    def coerce(cls, item, tag='', verbose=0):
        """
        Create logger from another logger
        """
        if isinstance(item, int):
            verbose = item

        self = cls(tag=tag, verbose=verbose)

        if isinstance(item, cls):
            # Make a sublogger, TODO: be more eloquent
            self.logs = item.logs
            self.verbose = item.verbose

        return self


[docs]class Liberator(ub.NiceRepr): r""" Maintains the current state of the source code There are 3 major steps: (a) extract the code to that defines a function or class from a module, (b) go back to the module and extract extra code required to define any names that were undefined in the extracted code, and (c) replace import statements to specified "expand" modules with the actual code used to define the variables accessed via the imports. This results in a standalone file that has absolutely no dependency on the original module or the specified "expand" modules (the expand module is usually the module that is doing the training for a network. This means that you can deploy a model independant of the training framework). Note: This is not designed to work for cases where the code depends on logic executed in a global scope (e.g. dynamically registering properties) . I think its actually impossible to statically account for this case in general. Args: tag (str): logging tag logger (Callable): logging function verbose (int): verbosity, 0 is nothing, 1 is info, 2 is debug, etc.. Example: >>> import ubelt as ub >>> from liberator.core import Liberator >>> lib = Liberator(logger=print) >>> lib.add_dynamic(ub.find_exe, eager=False) >>> lib.expand(['ubelt']) >>> print(lib.current_sourcecode()) >>> lib = Liberator() >>> lib.add_dynamic(ub.find_exe, eager=True) >>> print(lib.current_sourcecode()) >>> lib = Liberator(logger=3, tag='mytest') >>> lib.add_dynamic(ub.Cacher, eager=True) >>> visitor = ub.peek(lib.visitors.values()) >>> print('visitor.definitions = {}'.format(ub.urepr(ub.map_keys(str, visitor.definitions), nl=1))) >>> print('visitor.nested_definitions = {}'.format(ub.urepr(ub.map_keys(str, visitor.nested_definitions), nl=1))) >>> lib._print_logs() >>> lib.expand(['ubelt']) Ignore: from ubelt import _win32_links lib = Liberator() lib.add_dynamic(_win32_links._win32_symlink, eager=True) print(lib.current_sourcecode()) definitions = list(visitor.definitions.values()) import_defs = [d for d in definitions if 'Import' in d.type] print('import_defs = {}'.format(ub.urepr(import_defs, nl=1))) Example: >>> # xdoctest: +REQUIRES(module:fastai) >>> from liberator.core import * >>> import fastai.vision >>> obj = fastai.vision.models.WideResNet >>> expand_names = ['fastai'] >>> lib = Liberator() >>> lib.add_dynamic(obj) >>> lib.expand(expand_names) >>> #print(ub.urepr(lib.body_defs, si=1)) >>> print(lib.current_sourcecode()) Example: >>> # xdoctest: +REQUIRES(module:fastai) >>> from liberator.core import Liberator >>> from fastai.vision.models import unet >>> lib = Liberator() >>> lib.add_dynamic(unet.DynamicUnet) >>> lib.expand(['fastai']) >>> print(lib.current_sourcecode()) Example: >>> # xdoctest: +REQUIRES(module:netharn) >>> from liberator.core import * >>> import netharn as nh >>> from netharn.models.yolo2 import yolo2 >>> obj = yolo2.Yolo2 >>> expand_names = ['netharn'] >>> lib = Liberator() >>> lib.add_static(obj.__name__, sys.modules[obj.__module__].__file__) >>> lib.expand(expand_names) >>> #print(ub.urepr(lib.body_defs, si=1)) >>> print(lib.current_sourcecode()) """ def __init__(lib, tag='root', logger=None, verbose=0): lib.header_defs = ub.odict() lib.body_defs = ub.odict() lib.visitors = {} lib.logger = LocalLogger.coerce(logger, tag=tag, verbose=verbose) lib._lazy_visitors = []
[docs] def error(lib, msg): lib.logger.error(msg)
[docs] def info(lib, msg): lib.logger.info(msg)
[docs] def debug(lib, msg): lib.logger.debug(msg)
[docs] def warn(lib, msg): lib.logger.warn(msg)
[docs] def _print_logs(lib): lib.logger._print_logs()
def __nice__(self): return self.logger.tag
[docs] def _add_definition(lib, d): lib.debug('_add_definition = {!r}'.format(d)) d = copy.deepcopy(d) # print('ADD DEFINITION d = {!r}'.format(d)) if 'Import' in d.type: if d.absname in lib.header_defs: del lib.header_defs[d.absname] lib.header_defs[d.absname] = d else: if d.absname in lib.body_defs: del lib.body_defs[d.absname] lib.body_defs[d.absname] = d
[docs] def current_sourcecode(self): header_lines = [d.code for d in self.header_defs.values()] body_lines = [d.code for d in self.body_defs.values()][::-1] current_sourcecode = '\n'.join(header_lines) current_sourcecode += '\n\n\n' current_sourcecode += '\n\n\n'.join(body_lines) # current_sourcecode += '\n'.join(body_lines) return current_sourcecode
[docs] def _ensure_visitor(lib, modpath=None, module=None): """ Return an existing visitor for a module or create one if it doesnt exist """ if modpath is None and module is not None: modpath = module.__file__ if modpath not in lib.visitors: visitor = DefinitionVisitor.parse( module=module, modpath=modpath, logger=lib.logger) lib.visitors[modpath] = visitor visitor = lib.visitors[modpath] return visitor
[docs] def add_dynamic(lib, obj, eager=True): """ Add the source to define a live python object Args: obj (object): a reference to a class or function eager (bool): experimental Example: >>> from liberator import core >>> import liberator >>> obj = core.unparse >>> eager = True >>> lib = liberator.Liberator() >>> lib.add_dynamic(obj, eager=eager) >>> print(lib.current_sourcecode()) """ lib.info('\n\n') lib.info('====\n\n') lib.info('lib.add_dynamic(obj={!r})'.format(obj)) name = obj.__name__ modname = obj.__module__ module = sys.modules[modname] visitor = lib._ensure_visitor(module=module) d = visitor.extract_definition(name) lib._add_definition(d) if eager: lib.close(visitor) else: # Experimental lib._lazy_visitors.append(visitor)
[docs] def add_static(lib, name, modpath): """ Statically extract a definition from a module file Args: name (str): the name of the member of the module to define modpath (PathLike): The path to the module Example: >>> from liberator import core >>> import liberator >>> modpath = core.__file__ >>> name = core.unparse.__name__ >>> lib = liberator.Liberator() >>> lib.add_static(name, modpath) >>> print(lib.current_sourcecode()) """ # print('ADD_STATIC name = {} from {}'.format(name, modpath)) lib.info('lib.add_static(name={!r}, modpath={!r})'.format(name, modpath)) visitor = lib._ensure_visitor(modpath=modpath) d = visitor.extract_definition(name) lib._add_definition(d) lib.close(visitor)
[docs] def _lazy_close(lib): # Experimental lib.close2(ub.oset(lib._lazy_visitors)) lib._lazy_visitors = []
[docs] def close2(lib, visitors): """ Experimental Populate all undefined names using the context from a module """ # Parse the parent module to find only the relevant global varaibles and # include those in the extracted source code. lib.debug('closing') # Loop until all undefined names are defined names = True while names: # Determine if there are any variables needed from the parent scope current_sourcecode = lib.current_sourcecode() # Make sure we process names in the same order for hashability prev_names = names names = sorted(undefined_names(current_sourcecode)) lib.debug(' * undefined_names = {}'.format(names)) if names == prev_names: for visitor in visitors: lib.debug('visitor.definitions = {}'.format(ub.urepr( ub.map_keys(str, visitor.definitions), si=1, nl=1))) if 0: warnings.warn('We were unable do do anything about undefined names') return else: # current_sourcecode = lib.current_sourcecode() lib.error('--- <ERROR[4]> ---') lib.error('Unable to define names') lib.error(' * names = {!r}'.format(names)) #lib.error('<<< CURRENT_SOURCE >>>\n{}\n<<<>>>'.format(ub.highlight_code(current_sourcecode))) lib.error('--- </ERROR[4]> ---') raise AssertionError('unable to define names: {}'.format(names)) for name in names: try: # Greedilly choose the visitor that has the name we are # looking for. for visitor in visitors: if name in visitor.definitions: break try: lib.debug(' * try visitor.extract_definition({})'.format(name)) d = visitor.extract_definition(name) except KeyError as ex: lib.warn(' * encountered issue: {!r}'.format(ex)) # There is a corner case where we have the definition, # we just need to move it to the top. flag = False for d_ in lib.body_defs.values(): if name == d_.name: lib.warn(' * corner case: move definition to top') lib._add_definition(d_) flag = True break if not flag: raise else: lib.debug(' * add extracted def {}'.format(name)) lib._add_definition(d) # type_, text = visitor.extract_definition(name) except Exception as ex: lib.warn(' * unable to extracted def {} due to {!r}'.format(name, ex)) # current_sourcecode = lib.current_sourcecode() lib.error('--- <ERROR[3]> ---') lib.error('Error computing source code extract_definition') lib.error(' * failed to close name = {!r}'.format(name)) # lib.error('<<< CURRENT_SOURCE >>>\n{}\n<<<>>>'.format(ub.highlight_code(current_sourcecode))) lib.error('--- </ERROR[3]> ---')
[docs] def close(lib, visitor): """ Populate all undefined names using the context from a module """ # Parse the parent module to find only the relevant global varaibles and # include those in the extracted source code. lib.info('closing - i.e. populating, crawling') # Loop until all undefined names are defined names = True while names: # Determine if there are any variables needed from the parent scope current_sourcecode = lib.current_sourcecode() # Make sure we process names in the same order for hashability prev_names = names names = sorted(undefined_names(current_sourcecode)) lib.debug(' * undefined_names = {}'.format(names)) if names == prev_names: lib.debug('visitor.definitions = {}'.format(ub.urepr( ub.map_keys(str, visitor.definitions), si=1, nl=1))) if 0: warnings.warn('We were unable do do anything about undefined names') return else: # current_sourcecode = lib.current_sourcecode() lib.error('--- <ERROR[1]> ---') lib.error('Unable to define names') lib.error(' * names = {!r}'.format(names)) #lib.error('<<< CURRENT_SOURCE >>>\n{}\n<<<>>>'.format(ub.highlight_code(current_sourcecode))) lib.error('--- </ERROR[1]> ---') raise AssertionError('unable to define names: {}'.format(names)) for name in names: try: try: # pass lib.debug(' * try visitor.extract_definition({})'.format(name)) d = visitor.extract_definition(name) except KeyError as ex: lib.debug(' * encountered issue: {!r}'.format(ex)) # There is a corner case where we have the definition, # we just need to move it to the top. flag = False for d_ in lib.body_defs.values(): if name == d_.name: lib.debug(' * corner case: move definition to top') lib._add_definition(d_) flag = True break # There is another corner case where we only have a # prefix of the definition. Note, we could be more # clever and look at the attribute usage in the current # sourcefile instead of blindly taking everything with # the given prefix. if visitor.definitions.has_subtrie(name): flag = True for k, d in visitor.definitions.items(name): lib.debug(' * add extracted prefix def {} for {}'.format(k, name)) lib._add_definition(d) if not flag: raise else: lib.debug(' * add extracted def {}'.format(name)) lib._add_definition(d) # type_, text = visitor.extract_definition(name) except Exception as ex: lib.warn(' * unable to extracted def {} due to {!r}'.format(name, ex)) # current_sourcecode = lib.current_sourcecode() lib.error('--- <ERROR[2]> ---') lib.error('Error computing source code extract_definition') lib.error(' * failed to close name = {!r}'.format(name)) # lib.error('<<< CURRENT_SOURCE >>>\n{}\n<<<>>>'.format(ub.highlight_code(current_sourcecode))) lib.error('--- </ERROR[2]> ---')
[docs] def expand(lib, expand_names): """ Remove all references to specific modules by directly copying in the referenced source code. If the code is referenced from a module, then the references will need to change as well. Args: expand_name (List[str]): list of module names. For each module we expand any reference to that module in the closed source code by directly copying the referenced code into that file. This doesn't work in all cases, but it usually does. Reasons why this wouldn't work include trying to expand import from C-extension modules and expanding modules with complicated global-level logic. TODO: - [ ] Add special unique (mangled) suffixes to all expanded names to avoid name conflicts. Example: >>> # Test a heavier duty class >>> # xdoctest: +REQUIRES(module:netharn) >>> from liberator.core import * >>> import netharn as nh >>> obj = nh.device.MountedModel >>> #obj = nh.layers.ConvNormNd >>> #obj = nh.data.CocoDataset >>> #expand_names = ['ubelt', 'progiter'] >>> expand_names = ['netharn'] >>> lib = Liberator() >>> lib.add_dynamic(obj) >>> lib.expand(expand_names) >>> #print('header_defs = ' + ub.urepr(lib.header_defs, si=1)) >>> #print('body_defs = ' + ub.urepr(lib.body_defs, si=1)) >>> print('SOURCE:') >>> text = lib.current_sourcecode() >>> print(text) """ lib.debug('\n\n') lib.debug('====\n\n') lib.debug("!!! EXPANDING") # Expand references to internal modules flag = True while flag: # Associate all top-level modules with any possible expand_name # that might trigger them to be expanded. Note this does not # account for nested imports. expandable_definitions = ub.ddict(list) for d in lib.header_defs.values(): parts = d.native_modname.split('.') for i in range(1, len(parts) + 1): root = '.'.join(parts[:i]) expandable_definitions[root].append(d) lib.debug('expandable_definitions = {!r}'.format( list(expandable_definitions.keys()))) flag = False # current_sourcecode = lib.current_sourcecode() # closed_visitor = DefinitionVisitor.parse(source=current_sourcecode) for root in expand_names: needs_expansion = expandable_definitions.get(root, []) lib.debug('root = {!r}'.format(root)) lib.debug('needs_expansion = {}'.format(ub.urepr(needs_expansion, nl=1))) d: Definition for d in needs_expansion: if d._expanded: continue flag = True # if d.absname == d.native_modname: if ub.modname_to_modpath(d.absname): lib.info('TODO: NEED TO CLOSE module = {}'.format(d)) # import warnings # warnings.warn('Closing module {} may not be implemented'.format(d)) # definition is a module, need to expand its attributes lib.expand_module_attributes(d) d._expanded = True else: lib.info('TODO: NEED TO CLOSE attribute varname = {}'.format(d)) # warnings.warn('Closing attribute {} may not be implemented'.format(d)) # definition is a non-module, directly copy in its code # We can directly replace this import statement by # copy-pasting the relevant code from the other module # (ASSUMING THERE ARE NO NAME CONFLICTS) assert d.type == 'ImportFrom' try: native_modpath = ub.modname_to_modpath(d.native_modname) if native_modpath is None: raise Exception('Cannot find the module path for modname={!r}. ' 'Are you missing an __init__.py?'.format(d.native_modname)) sub_lib = Liberator(lib.logger.tag + '.sub.' + d.name, logger=lib.logger) sub_lib.add_static(d.name, native_modpath) print(f'native_modpath={native_modpath}') # sub_visitor = sub_lib.visitors[d.native_modname] sub_lib.expand(expand_names) # sub_lib.close(sub_visitor) except NotAPythonFile as ex: warnings.warn('CANNOT EXPAND d = {!r}, REASON: {}'.format(d, repr(ex))) d._expanded = True raise continue except Exception as ex: warnings.warn('CANNOT EXPAND d = {!r}, REASON: {}'.format(d, repr(ex))) d._expanded = True raise continue else: # Hack: remove the imported definition and add the # explicit definition # TODO: FIXME: more robust modification and replacement d._code = '# ' + d.code d._expanded = True for d_ in sub_lib.header_defs.values(): lib._add_definition(d_) for d_ in sub_lib.body_defs.values(): lib._add_definition(d_) # print('sub_visitor = {!r}'.format(sub_visitor)) # lib.close(sub_visitor) lib.debug('CLOSED attribute d = {}'.format(d))
[docs] def expand_module_attributes(lib, d): """ Args: d (Definition): the definition to expand """ # current_sourcecode = lib.current_sourcecode() # closed_visitor = DefinitionVisitor.parse(source=current_sourcecode) assert 'Import' in d.type varname = d.name varmodpath = ub.modname_to_modpath(d.absname) modname = d.absname def _exhaust(varname, modname, modpath): lib.debug('REWRITE ACCESSOR varname={!r}, modname={}, modpath={}'.format(varname, modname, modpath)) # Modify the current node definitions and recompute code # TODO: make more robust rewriter = RewriteModuleAccess(varname) for d_ in lib.body_defs.values(): rewriter.visit(d_.node) d_._code = unparse(d_.node) lib.debug('rewriter.accessed_attrs = {!r}'.format(rewriter.accessed_attrs)) # For each modified attribute, copy in the appropriate source. for subname in rewriter.accessed_attrs: submodname = modname + '.' + subname submodpath = ub.modname_to_modpath(submodname) if submodpath is not None: # if the accessor is to another module, exhaust until # we reach a non-module lib.debug('EXAUSTING: {}, {}, {}'.format(subname, submodname, submodpath)) _exhaust(subname, submodname, submodpath) else: # Otherwise we can directly add the referenced attribute lib.debug('FINALIZE: {} from {}'.format(subname, modpath)) lib.add_static(subname, modpath) _exhaust(varname, modname, varmodpath) d._code = '# ' + d.code
class UnparserVariant(astunparse.Unparser): """ wraps astunparse to fix 2/3 compatibility minor issues Notes: x = np.random.rand(3, 3) # In python3 this works, but it fails in python2 x[(..., 2)] # However, this works in both x[(Ellipsis, 2)] # Interestingly, this also works, but is not how astunparse generates code x[..., 2] """ def _Ellipsis(self, t): # be compatible with python2 if possible self.write("Ellipsis") def _Constant(self, node): """ Args: node (ast.Constant): a constant node, if it is a triplequote string we will try to make it look nice. Example: >>> from liberator.core import * # NOQA >>> from liberator.core import UnparserVariant >>> tq = '"' * 3 >>> code = ub.codeblock( >>> fr''' >>> def foobar(): >>> {tq} >>> A docstring >>> {tq} >>> ''') >>> import ast >>> tree = ast.parse(code) >>> node = tree.body[0].body[0].value >>> v = io.StringIO() >>> self = UnparserVariant(node, file=v) >>> print(v.getvalue()) """ # Better support for multiline strings if not isinstance(node.value, str): return super()._Constant(node) if node.lineno != node.end_lineno: # heuristic for tripple quote strings nl = '\n' tsq = "'''" tdq = '"""' # indent = ' ' * node.col_offset candidates = [ # '"""\n' + ub.indent(node.s + '\n"""', ' ' * node.col_offset), # 'r"""\n' + ub.indent(node.s + '\n"""', ' ' * node.col_offset), # "'''\n" + ub.indent(node.s + "\n'''", ' ' * node.col_offset), # "r'''\n" + ub.indent(node.s + "\n'''", ' ' * node.col_offset), # 'r' + tdq + node.s + tdq, # 'r' + tsq + node.s + tsq, tdq + node.s + tdq, 'r' + tdq + node.s + tdq, tsq + node.s + nl + tsq, 'r' + tsq + node.s + tsq, ] if 0: for cand in candidates: print(cand) found = None for cand in candidates: try: reparsed = ast.literal_eval(cand) if node.s != reparsed: raise ValueError('not equivalent') except ValueError: pass else: found = cand break if found: self.write(found) else: self.write(repr(node.s)) else: self.write(repr(node.s)) def unparse(tree): r""" wraps astunparse to fix 2/3 compatibility minor issues Args: tree (ast.AST): abstract syntax tree to unparse FIXME: [ ] This needs to format docstrings better Example: >>> from liberator.core import * # NOQA >>> tq = '"' * 3 >>> code = ub.codeblock( >>> fr''' >>> def foobar(): >>> {tq} >>> A docstring >>> {tq} >>> ''') >>> import ast >>> tree = ast.parse(code) >>> print(unparse(tree)) """ v = io.StringIO() # astunparse.Unparser(tree, file=v) UnparserVariant(tree, file=v) return v.getvalue() def source_closure(obj, expand_names=[]): """ Pulls the minimum amount of code needed to define `obj`. Uses a combination of dynamic and static introspection. Args: obj (type): the class whose definition will be exported. expand_names (List[str]): EXPERIMENTAL. List of modules that should be expanded into raw source code. Returns: str: closed_sourcecode: text defining a new python module. CommandLine: xdoctest -m liberator.core source_closure Example: >>> # xdoctest: +REQUIRES(module:torchvision) >>> import torchvision >>> from torchvision import models >>> got = {} >>> model_class = models.AlexNet >>> text = source_closure(model_class) >>> assert not undefined_names(text) >>> got['alexnet'] = ub.hash_data(text) >>> model_class = models.DenseNet >>> text = source_closure(model_class) >>> assert not undefined_names(text) >>> got['densenet'] = ub.hash_data(text) >>> model_class = models.resnet50 >>> text = source_closure(model_class) >>> assert not undefined_names(text) >>> got['resnet50'] = ub.hash_data(text) >>> model_class = models.Inception3 >>> text = source_closure(model_class) >>> assert not undefined_names(text) >>> got['inception3'] = ub.hash_data(text) >>> # The hashes will depend on torchvision itself >>> if torchvision.__version__ == '0.2.1': >>> # Note: the hashes may change if the exporter changes formats >>> want = { >>> 'alexnet': '4b2ab9c8e27b34602bdff99cbc', >>> 'densenet': 'fef4788586d2b93587ec52dd9', >>> 'resnet50': '343e6a73e754557fcce3fdb6', >>> 'inception3': '2e43a58133d0817753383', >>> } >>> failed = [] >>> for k in want: >>> if not got[k].startswith(want[k]): >>> item = (k, got[k], want[k]) >>> print('failed item = {!r}'.format(item)) >>> failed.append(item) >>> assert not failed, str(failed) >>> else: >>> warnings.warn('Unsupported version of torchvision') Example: >>> # Test a heavier duty class >>> # xdoctest: +REQUIRES(module:netharn) >>> from liberator.core import * >>> import netharn as nh >>> obj = nh.layers.ConvNormNd >>> expand_names = ['netharn'] >>> text = source_closure(obj, expand_names) >>> print(text) """ lib = Liberator() # First try to add statically (which tends to be slightly nicer) try: try: name = obj.__name__ modpath = sys.modules[obj.__module__].__file__ except Exception: # Otherwise add dynamically lib.add_dynamic(obj) else: lib.add_static(name, modpath) if expand_names: lib.expand(expand_names) closed_sourcecode = lib.current_sourcecode() except Exception: print('ERROR IN CLOSING') print('[[[ START CLOSE LOGS ]]]') print('lib.logs =\n{}'.format('\n'.join(lib.logger.logs))) print('[[[ END CLOSE LOGS ]]]') raise return closed_sourcecode def _parse_static_node_value(node): """ Extract a constant value from a node if possible """ if isinstance(node, ast.Num): value = node.n elif isinstance(node, ast.Str): value = node.s elif isinstance(node, ast.List): value = list(map(_parse_static_node_value, node.elts)) elif isinstance(node, ast.Tuple): value = tuple(map(_parse_static_node_value, node.elts)) elif isinstance(node, (ast.Dict)): keys = map(_parse_static_node_value, node.keys) values = map(_parse_static_node_value, node.values) value = OrderedDict(zip(keys, values)) # value = dict(zip(keys, values)) elif isinstance(node, (ast.NameConstant)): value = node.value else: msg = ('Cannot parse a static value from non-static node ' 'of type: {!r}'.format(type(node))) # print('node.__dict__ = {!r}'.format(node.__dict__)) # print('msg = {!r}'.format(msg)) raise TypeError(msg) return value def undefined_names(sourcecode): """ Parses source code for undefined names Args: sourcecode (str): code to check for unused names Returns: Set[str]: the unused variable names Example: >>> # xdoctest: +REQUIRES(module:pyflakes) >>> print(ub.urepr(undefined_names('x = y'), nl=0)) {'y'} """ import pyflakes.api import pyflakes.reporter class CaptureReporter(pyflakes.reporter.Reporter): def __init__(reporter, warningStream, errorStream): reporter.syntax_errors = [] reporter.messages = [] reporter.unexpected = [] def unexpectedError(reporter, filename, msg): reporter.unexpected.append(msg) def syntaxError(reporter, filename, msg, lineno, offset, text): reporter.syntax_errors.append(msg) def flake(reporter, message): reporter.messages.append(message) names = set() reporter = CaptureReporter(None, None) pyflakes.api.check(sourcecode, '_.py', reporter) for msg in reporter.messages: if msg.__class__.__name__.endswith('UndefinedName'): assert len(msg.message_args) == 1 names.add(msg.message_args[0]) return names class RewriteModuleAccess(ast.NodeTransformer): """ Refactors attribute accesses into top-level references. In other words, instances of <varname>.<attr> change to <attr>. Any attributes that were modified are stored in `accessed_attrs`. Example: >>> from liberator.core import * >>> source = ub.codeblock( ... ''' ... foo.bar = 3 ... foo.baz.bar = 3 ... biz.foo.baz.bar = 3 ... ''') >>> pt = ast.parse(source) >>> visitor = RewriteModuleAccess('foo') >>> orig = unparse(pt) >>> print(orig) foo.bar = 3 foo.baz.bar = 3 biz.foo.baz.bar = 3 >>> visitor.visit(pt) >>> modified = unparse(pt) >>> print(modified) bar = 3 baz.bar = 3 biz.foo.baz.bar = 3 >>> visitor.accessed_attrs ['bar', 'baz'] """ def __init__(self, modname): self.modname = modname self.level = 0 self.accessed_attrs = [] def visit_Import(self, node): # if self.level == 0: # return None return node def visit_ImportFrom(self, node): # if self.level == 0: # return None return node def visit_FunctionDef(self, node): self.level += 1 self.generic_visit(node) self.level -= 1 return node def visit_ClassDef(self, node): self.level += 1 self.generic_visit(node) self.level -= 1 return node def visit_Attribute(self, node): # print('VISIT ATTR: node = {!r}'.format(node.__dict__)) self.generic_visit(node) if isinstance(node.value, ast.Name): if node.value.id == self.modname: self.accessed_attrs.append(node.attr) new_node = ast.Name(node.attr, node.ctx) old_node = node return ast.copy_location(new_node, old_node) return node class Definition(ub.NiceRepr): def __init__(self, name, node, type=None, code=None, absname=None, modpath=None, modname=None, native_modname=None): self.name = name self.node = node self.type = type self._code = code self.absname = absname self.modpath = modpath self.modname = modname self.native_modname = native_modname self._expanded = False @property def code(self): if self._code is None: # NOTE: the unparse variant captures decorators whereas the dynamic # inspect variant does not seem to do that. # # In general the inspect.getsource seems to return the same # formatting as the original module, but the unparse # is more accurate. try: if self._expanded or self.type == 'Assign': # always use astunparse if we have expanded raise Exception # Attempt to dynamically extract the source code because it # keeps formatting better. module = ub.import_module_from_name(self.modname) obj = getattr(module, self.name) self._code = inspect.getsource(obj).strip('\n') except Exception: # Fallback on static sourcecode extraction # (NOTE: it should be possible to keep formatting with a bit of # work) self._code = unparse(self.node).strip('\n') return self._code.strip() def __nice__(self): parts = [] parts.append('name={}'.format(self.name)) parts.append('type={}'.format(self.type)) if self.absname is not None: parts.append('absname={}'.format(self.absname)) if self.native_modname is not None: parts.append('native_modname={}'.format(self.native_modname)) return ', '.join(parts) class NotAPythonFile(ValueError): pass class AttributeAccessVisitor(ast.NodeVisitor): """ Constructs a list of all fully-specified attributes names accessed in a parse tree TODO: could use this to parse out all used attributes in current sourcecode Ignore: from liberator.core import AttributeAccessVisitor # NOQA fpath = ub.expandpath('~/code/dvc/dvc/lock.py') sourcecode = ub.readfrom(fpath) pt = ast.parse(sourcecode) self = AttributeAccessVisitor() self.visit(pt) self.dotted_trie WIP: TRY TO FIX ISSUE WITH IMPORTINING SUBPACKAGES EXPLICITLY >>> sourcecode = ub.codeblock( ''' class MyClass(foo.bar.baz): pass class MyClass3(foo.bar.baz): pass def blah(): return foo.bar.BAZ() ''') >>> print(ub.urepr(undefined_names(sourcecode), nl=0)) from liberator.core import DefinitionVisitor # NOQA pt = ast.parse(sourcecode) node1 = pt.body[0].bases[0] visitor = AttributeAccessVisitor() visitor.visit(pt) visitor.dotted_names """ def __init__(self): import pygtrie self.dotted_trie = pygtrie.StringTrie(separator='.') def visit_Attribute(self, node): curr = node attr_chain = [] while isinstance(curr, ast.Attribute): attr_chain.append(curr.attr) curr = curr.value if isinstance(curr, ast.Name): attr_chain.append(curr.id) dotted_name = '.'.join(attr_chain[::-1]) self.dotted_trie.setdefault(dotted_name, 0) self.dotted_trie[dotted_name] += 1 # self.generic_visit(node) class DefinitionVisitor(ast.NodeVisitor, ub.NiceRepr): """ Used to search for dependencies in the original module References: https://greentreesnakes.readthedocs.io/en/latest/nodes.html Example: >>> from liberator.core import * >>> from liberator.core import DefinitionVisitor >>> from liberator import core >>> modpath = core.__file__ >>> sourcecode = ub.codeblock( ... ''' ... from ubelt.util_const import * ... import a ... import b ... import c.d ... import e.f as g ... from . import h ... from .i import j ... from . import k, l, m ... from n import o, p, q ... r = 3 ... ''') >>> visitor = DefinitionVisitor.parse(source=sourcecode, modpath=modpath) >>> print(ub.urepr(visitor.definitions, si=1)) Example: >>> from liberator.core import * >>> from liberator import core >>> modpath = core.__file__ >>> sourcecode = ub.codeblock( ''' def decor(func): return func @decor def foo(): return 'bar' ... ''') >>> visitor = DefinitionVisitor.parse(source=sourcecode, modpath=modpath) >>> print(ub.urepr(visitor.definitions, si=1)) Example: >>> from liberator.core import * >>> from liberator.core import DefinitionVisitor >>> from liberator import core >>> modpath = core.__file__ >>> sourcecode = ub.codeblock( ''' import kwarray def global_import(func): kwarray.ensure_rng(1) def nested_import(): import ubelt as ub return ub.Cacher ... ''') >>> visitor = DefinitionVisitor.parse(source=sourcecode, modpath=modpath) >>> print(ub.urepr(list(visitor.definitions), si=1)) >>> print(ub.urepr(list(visitor.nested_definitions), si=1)) Example: >>> # xdoctest: +REQUIRES(module:mmdet) >>> import mmdet >>> import mmdet.models >>> import liberator >>> lib = liberator.core.Liberator() >>> lib.add_dynamic(mmdet.models.backbones.HRNet) >>> print(lib.current_sourcecode()) >>> visitor = ub.peek(lib.visitors.values()) >>> print(ub.urepr(visitor.definitions, si=1)) >>> d = visitor.definitions['HRNet'] >>> print(d.code[0:1000]) Example: >>> from liberator.core import * >>> from liberator.core import DefinitionVisitor >>> from liberator import core >>> modpath = core.__file__ >>> # Test case global variables with type annots >>> sourcecode = ub.codeblock( ''' GLOBAL_VAR: list = [] ... ''') >>> visitor = DefinitionVisitor.parse(source=sourcecode, modpath=modpath) >>> print(ub.urepr(list(visitor.definitions), si=1)) >>> print(ub.urepr(list(visitor.nested_definitions), si=1)) """ def __init__(visitor, modpath=None, modname=None, module=None, pt=None, logger=None): super(DefinitionVisitor, visitor).__init__() visitor.pt = pt visitor.modpath = modpath visitor.modname = modname visitor.module = module visitor.logger = logger import pygtrie visitor.definitions = pygtrie.StringTrie(separator='.') visitor.nested_definitions = pygtrie.StringTrie(separator='.') visitor.level = 0 def __nice__(self): if self.modname is not None: return self.modname else: return "<sourcecode>" @classmethod def parse(DefinitionVisitor, source=None, modpath=None, modname=None, module=None, logger=None): if module is not None: if source is None: source = inspect.getsource(module) if modpath is None: modname = module.__file__ if modname is None: modname = module.__name__ if modpath is not None: if modpath.endswith('.pyc'): modpath = modpath.replace('.pyc', '.py') # python 2 hack if isdir(modpath): modpath = join(modpath, '__init__.py') if modname is None: modname = ub.modpath_to_modname(modpath) if modpath is not None: if source is None: if not modpath.endswith(('.py', '>')): raise NotAPythonFile('can only parse python files, not {}'.format(modpath)) source = open(modpath, 'r').read() if source is None: raise ValueError('unable to derive source code') pt = ast.parse(source) visitor = DefinitionVisitor(modpath, modname, module, pt=pt, logger=logger) visitor.visit(pt) # Hack in attribute visiting # attr_visitor = AttributeAccessVisitor() # attr_visitor.visit(pt) # visitor.dotted_trie = attr_visitor.dotted_trie return visitor def extract_definition(visitor, name): """ Given the name of a variable / class / function / moodule, extract the relevant lines of source code that define that structure from the visited module. """ return visitor.definitions[name] def visit_Import(visitor, node): for d in visitor._import_definitions(node): if visitor.level == 0: visitor.definitions[d.name] = d else: visitor.nested_definitions[d.name] = d visitor.generic_visit(node) def visit_ImportFrom(visitor, node): for d in visitor._import_from_definition(node): if visitor.level == 0: visitor.definitions[d.name] = d else: visitor.nested_definitions[d.name] = d visitor.generic_visit(node) def _common_visit_assign(visitor, node, target): key = getattr(target, 'id', None) if key is not None: try: static_val = _parse_static_node_value(node.value) code = '{} = {}'.format(key, ub.urepr(static_val)) except TypeError: #code = unparse(node).strip('\n') code = None if visitor.logger: if key in visitor.definitions: # OVERLOADED visitor.logger.debug('OVERLOADED key = {!r}'.format(key)) definition = Definition( key, node, code=code, type='Assign', modpath=visitor.modpath, modname=visitor.modname, absname=visitor.modname + '.' + key, native_modname=visitor.modname, ) if visitor.level == 0: visitor.definitions[key] = definition # else: # visitor.nested_definitions[key] = definition def visit_AnnAssign(visitor, node): if visitor.level > 0: return visitor._common_visit_assign(node, node.target) def visit_Assign(visitor, node): if visitor.level > 0: return for target in node.targets: visitor._common_visit_assign(node, target) def visit_FunctionDef(visitor, node): defenition = Definition( node.name, node, type='FunctionDef', modpath=visitor.modpath, modname=visitor.modname, absname=visitor.modname + '.' + node.name, native_modname=visitor.modname, ) if visitor.level == 0: visitor.definitions[node.name] = defenition else: # visitor.nested_definitions[node.name] = defenition pass if visitor.level == 0: visitor.level += 1 visitor.generic_visit(node) visitor.level -= 1 else: visitor.generic_visit(node) # ast.NodeVisitor.generic_visit(visitor, node) def visit_ClassDef(visitor, node): defenition = Definition( node.name, node, type='ClassDef', modpath=visitor.modpath, modname=visitor.modname, absname=visitor.modname + '.' + node.name, native_modname=visitor.modname, ) if visitor.level == 0: visitor.definitions[node.name] = defenition else: # visitor.nested_definitions[node.name] = defenition pass if visitor.level == 0: visitor.level += 1 visitor.generic_visit(node) visitor.level -= 1 else: visitor.generic_visit(node) # # Ignore any non-top-level imports # if not visitor.level == 0: # # ast.NodeVisitor.generic_visit(visitor, node) def _import_definitions(visitor, node): for alias in node.names: varname = alias.asname or alias.name if alias.asname: line = 'import {} as {}'.format(alias.name, alias.asname) else: line = 'import {}'.format(alias.name) absname = alias.name yield Definition(varname, node, code=line, absname=absname, native_modname=absname, modpath=visitor.modpath, modname=visitor.modname, type='Import') def _import_from_definition(visitor, node): """ Ignore: from liberator.core import * visitor = DefinitionVisitor.parse(module=module) print('visitor.definitions = {}'.format(ub.urepr(visitor.definitions, sv=1))) """ if node.level: # Handle relative imports if visitor.modpath is not None: try: rel_modpath = ub.split_modpath(abspath(visitor.modpath))[1] except ValueError: warnings.warn('modpath={} does not exist'.format(visitor.modpath)) rel_modpath = basename(abspath(visitor.modpath)) modparts = rel_modpath.replace('\\', '/').split('/') parts = modparts[:-node.level] prefix = '.'.join(parts) if node.module: prefix = prefix + '.' else: warnings.warn('Unable to rectify absolute import') prefix = '.' * node.level else: prefix = '' if node.module is not None: abs_modname = prefix + node.module else: abs_modname = prefix for alias in node.names: varname = alias.asname or alias.name if alias.asname: line = 'from {} import {} as {}'.format(abs_modname, alias.name, alias.asname) else: line = 'from {} import {}'.format(abs_modname, alias.name) absname = abs_modname + '.' + alias.name if varname == '*': # HACK abs_modpath = ub.modname_to_modpath(abs_modname) star_visitor = DefinitionVisitor.parse( modpath=abs_modpath, logger=visitor.logger) for d in star_visitor.definitions.values(): if not d.name.startswith('_'): yield d else: yield Definition(varname, node, code=line, absname=absname, modpath=visitor.modpath, modname=visitor.modname, native_modname=abs_modname, type='ImportFrom') def _closefile(fpath, modnames): """ An api to remove dependencies from code by "closing" them. CommandLine: xdoctest -m ~/code/liberator/core.py _closefile xdoctest -m liberator.core _closefile --fpath=~/code/boltons/tests/test_cmdutils.py --modnames=ubelt, xdoctest -m liberator.core _closefile --fpath=~/code/dvc/dvc/updater.py --modnames=dvc, Example: >>> # SCRIPT >>> # ENTRYPOINT >>> import scriptconfig as scfg >>> config = scfg.quick_cli({ >>> 'fpath': scfg.Path(None), >>> 'modnames': scfg.Value([]), >>> }) >>> #fpath = config['fpath'] = ub.expandpath('~/code/boltons/tests/test_cmdutils.py') >>> #modnames = config['modnames'] = ['ubelt'] >>> _closefile(**config) """ from xdoctest import static_analysis as static modpath = fpath expand_names = modnames source = open(fpath, 'r').read() calldefs = static.parse_calldefs(source, fpath) calldefs.pop('__doc__', None) lib = Liberator() for key in calldefs.keys(): lib.add_static(key, modpath) lib.expand(expand_names) #print(ub.urepr(lib.body_defs, si=1)) print(lib.current_sourcecode())
[docs]class Closer(Liberator): """ Deprecated in favor of :class:`Liberator`. The original name of the Liberator class was called Closer. Exposing this for backwards compatibility. """ pass