Ejemplo n.º 1
0
def re_compile_model(model, data_module):
    # is_typed_dataloader = data_module.test_loader_type == SAMPLING.DataLoader.value
    batchType = "TensorBatch"
    path2model = model.__class__.__module__.replace(".", "/")
    with open(join(ROOT_DIR, f"{path2model}.py")) as f:
        code = f.read()
    code = code.replace("batch)", f"batch:{batchType})")
    code = "from examples.core.typing import *\n" + code
    return class_from_module_repr(model.__class__.__name__, code)
Ejemplo n.º 2
0
def Sequential(
    input_args: str,
    modules: List[Union[Tuple[Callable, str], Callable]],
) -> torch.nn.Module:
    r"""An extension of the :class:`torch.nn.Sequential` container in order to
    define a sequential GNN model.
    Since GNN operators take in multiple input arguments,
    :class:`torch_geometric.nn.Sequential` expects both global input
    arguments, and function header definitions of individual operators.
    If omitted, an intermediate module will operate on the *output* of its
    preceding module:

    .. code-block:: python

        from torch.nn import Linear, ReLU
        from torch_geometric.nn import Sequential, GCNConv

        model = Sequential('x, edge_index', [
            (GCNConv(in_channels, 64), 'x, edge_index -> x'),
            ReLU(inplace=True),
            (GCNConv(64, 64), 'x, edge_index -> x'),
            ReLU(inplace=True),
            Linear(64, out_channels),
        ])

    where :obj:`'x, edge_index'` defines the input arguments of :obj:`model`,
    and :obj:`'x, edge_index -> x'` defines the function header, *i.e.* input
    arguments *and* return types, of :class:`~torch_geometric.nn.conv.GCNConv`.

    In particular, this also allows to create more sophisticated models,
    such as utilizing :class:`~torch_geometric.nn.models.JumpingKnowledge`:

    .. code-block:: python

        from torch.nn import Linear, ReLU, Dropout
        from torch_geometric.nn import Sequential, GCNConv, JumpingKnowledge
        from torch_geometric.nn import global_mean_pool

        model = Sequential('x, edge_index, batch', [
            (Dropout(p=0.5), 'x -> x'),
            (GCNConv(dataset.num_features, 64), 'x, edge_index -> x1'),
            ReLU(inplace=True),
            (GCNConv(64, 64), 'x1, edge_index -> x2'),
            ReLU(inplace=True),
            (lambda x1, x2: [x1, x2], 'x1, x2 -> xs'),
            (JumpingKnowledge("cat", 64, num_layers=2), 'xs -> x'),
            (global_mean_pool, 'x, batch -> x'),
            Linear(2 * 64, dataset.num_classes),
        ])

    Args:
        input_args (str): The input arguments of the model.
        modules ([(str, Callable) or Callable]): A list of modules (with
            optional function header definitions). Alternatively, an
            :obj:`OrderedDict` of modules (and function header definitions) can
            be passed.
    """
    try:
        from jinja2 import Template
    except ImportError:
        raise ModuleNotFoundError(
            "No module named 'jinja2' found on this machine. "
            "Run 'pip install jinja2' to install the library.")

    input_args = [x.strip() for x in input_args.split(',')]

    if not isinstance(modules, dict):
        modules = {f'module_{i}': module for i, module in enumerate(modules)}

    # We require the first entry of the input list to define arguments:
    assert len(modules) > 0
    first_module = list(modules.values())[0]
    assert isinstance(first_module, (tuple, list))

    # A list holding the callable function and the input and output names:
    calls: List[Tuple[str, Callable, List[str], List[str]]] = []

    for name, module in modules.items():
        if isinstance(module, (tuple, list)) and len(module) >= 2:
            module, desc = module[:2]
            in_desc, out_desc = parse_desc(desc)
        elif isinstance(module, (tuple, list)):
            module = module[0]
            in_desc = out_desc = calls[-1][-1]
        else:
            in_desc = out_desc = calls[-1][-1]

        calls.append((name, module, in_desc, out_desc))

    root = os.path.dirname(osp.realpath(__file__))
    with open(osp.join(root, 'sequential.jinja'), 'r') as f:
        template = Template(f.read())

    cls_name = f'Sequential_{uuid1().hex[:6]}'
    module_repr = template.render(
        cls_name=cls_name,
        input_args=input_args,
        calls=calls,
    )

    # Instantiate a class from the rendered module representation.
    module = class_from_module_repr(cls_name, module_repr)()
    module._names = list(modules.keys())
    for name, submodule, _, _ in calls:
        setattr(module, name, submodule)
    return module
Ejemplo n.º 3
0
    def jittable(self, typing: Optional[str] = None):
        r"""Analyzes the :class:`MessagePassing` instance and produces a new
        jittable module.

        Args:
            typing (string, optional): If given, will generate a concrete
                instance with :meth:`forward` types based on :obj:`typing`,
                *e.g.*: :obj:`"(Tensor, Optional[Tensor]) -> Tensor"`.
        """
        # Find and parse `propagate()` types to format `{arg1: type1, ...}`.
        if hasattr(self, 'propagate_type'):
            prop_types = {
                k: sanitize(str(v))
                for k, v in self.propagate_type.items()
            }
        else:
            source = inspect.getsource(self.__class__)
            match = re.search(r'#\s*propagate_type:\s*\((.*)\)', source)
            if match is None:
                raise TypeError(
                    'TorchScript support requires the definition of the types '
                    'passed to `propagate()`. Please specificy them via\n\n'
                    'propagate_type = {"arg1": type1, "arg2": type2, ... }\n\n'
                    'or via\n\n'
                    '# propagate_type: (arg1: type1, arg2: type2, ...)\n\n'
                    'inside the `MessagePassing` module.')
            prop_types = split_types_repr(match.group(1))
            prop_types = dict([re.split(r'\s*:\s*', t) for t in prop_types])

        # Parse `__collect__()` types to format `{arg:1, type1, ...}`.
        collect_types = self.inspector.types(
            ['message', 'aggregate', 'update'])

        # Collect `forward()` header, body and @overload types.
        forward_types = parse_types(self.forward)
        forward_types = [resolve_types(*types) for types in forward_types]
        forward_types = list(chain.from_iterable(forward_types))

        keep_annotation = len(forward_types) < 2
        forward_header = func_header_repr(self.forward, keep_annotation)
        forward_body = func_body_repr(self.forward, keep_annotation)

        if keep_annotation:
            forward_types = []
        elif typing is not None:
            forward_types = []
            forward_body = 8 * ' ' + f'# type: {typing}\n{forward_body}'

        root = os.path.dirname(osp.realpath(__file__))
        with open(osp.join(root, 'message_passing.jinja'), 'r') as f:
            template = Template(f.read())

        uid = uuid1().hex[:6]
        cls_name = f'{self.__class__.__name__}Jittable_{uid}'
        jit_module_repr = template.render(
            uid=uid,
            module=str(self.__class__.__module__),
            cls_name=cls_name,
            parent_cls_name=self.__class__.__name__,
            prop_types=prop_types,
            collect_types=collect_types,
            user_args=self.__user_args__,
            forward_header=forward_header,
            forward_types=forward_types,
            forward_body=forward_body,
            msg_args=self.inspector.keys(['message']),
            aggr_args=self.inspector.keys(['aggregate']),
            msg_and_aggr_args=self.inspector.keys(['message_and_aggregate']),
            update_args=self.inspector.keys(['update']),
            check_input=inspect.getsource(self.__check_input__)[:-1],
            lift=inspect.getsource(self.__lift__)[:-1],
        )

        # Instantiate a class from the rendered JIT module representation.
        cls = class_from_module_repr(cls_name, jit_module_repr)
        module = cls.__new__(cls)
        module.__dict__ = self.__dict__.copy()
        module.jittable = None

        return module