예제 #1
0
 def __init__(self,
              lightning_module: LightningModule,
              trainer: Trainer,
              train_dataloaders: Optional[Any] = None,
              val_dataloaders: Optional[Any] = None,
              train_dataloader: Optional[Any] = None):
     assert isinstance(
         lightning_module, LightningModule
     ), f'Lightning module must be an instance of {__name__}.LightningModule.'
     if train_dataloader is not None:
         warnings.warn(
             '`train_dataloader` is deprecated and replaced with `train_dataloaders`.',
             DeprecationWarning)
         train_dataloaders = train_dataloader
     if cgo_import_failed:
         assert isinstance(trainer, pl.Trainer) and is_traceable(
             trainer), f'Trainer must be imported from {__name__}'
     else:
         # this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
         assert (isinstance(trainer, pl.Trainer) and is_traceable(trainer)) or isinstance(trainer, cgo_trainer.Trainer), \
             f'Trainer must be imported from {__name__} or nni.retiarii.evaluator.pytorch.cgo.trainer'
     if not _check_dataloader(train_dataloaders):
         warnings.warn(
             f'Please try to wrap PyTorch DataLoader with nni.trace or '
             f'import DataLoader from {__name__}: {train_dataloaders}',
             RuntimeWarning)
     if not _check_dataloader(val_dataloaders):
         warnings.warn(
             f'Please try to wrap PyTorch DataLoader with nni.trace or '
             f'import DataLoader from {__name__}: {val_dataloaders}',
             RuntimeWarning)
     self.module = lightning_module
     self.trainer = trainer
     self.train_dataloaders = train_dataloaders
     self.val_dataloaders = val_dataloaders
예제 #2
0
 def __init__(self,
              lightning_module: LightningModule,
              trainer: Trainer,
              train_dataloader: Optional[DataLoader] = None,
              val_dataloaders: Union[DataLoader, List[DataLoader],
                                     None] = None):
     assert isinstance(
         lightning_module, LightningModule
     ), f'Lightning module must be an instance of {__name__}.LightningModule.'
     if cgo_import_failed:
         assert isinstance(trainer, pl.Trainer) and is_traceable(
             trainer), f'Trainer must be imported from {__name__}'
     else:
         # this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
         assert (isinstance(trainer, pl.Trainer) and is_traceable(trainer)) or isinstance(trainer, cgo_trainer.Trainer), \
             f'Trainer must be imported from {__name__} or nni.retiarii.evaluator.pytorch.cgo.trainer'
     assert _check_dataloader(
         train_dataloader
     ), f'Wrong dataloader type. Try import DataLoader from {__name__}.'
     assert _check_dataloader(
         val_dataloaders
     ), f'Wrong dataloader type. Try import DataLoader from {__name__}.'
     self.module = lightning_module
     self.trainer = trainer
     self.train_dataloader = train_dataloader
     self.val_dataloaders = val_dataloaders
예제 #3
0
    def _mutate_traceable_object(
            self, obj: Any, value_choice_decisions: Dict[str, Any]) -> Any:
        if not _is_traceable_object(obj):
            return obj

        updates = {}

        # For each argument that is a composition of value choice
        # we find all the leaf-value-choice in the mutation
        # and compute the final updates
        for key, param in obj.trace_kwargs.items():
            if isinstance(param, ValueChoiceX):
                leaf_node_values = [
                    value_choice_decisions[choice.label]
                    for choice in param.inner_choices()
                ]
                updates[key] = param.evaluate(leaf_node_values)
            elif is_traceable(param):
                # Recursively
                sub_update = self._mutate_traceable_object(
                    param, value_choice_decisions)
                if sub_update is not param:  # if mutated
                    updates[key] = sub_update

        if updates:
            mutated_obj = obj.trace_copy()  # Make a copy
            mutated_obj.trace_kwargs.update(updates)  # Mutate
            mutated_obj = mutated_obj.get(
            )  # Instantiate the full mutated object

            return mutated_obj

        return obj
예제 #4
0
def _check_dataloader(dataloader):
    if dataloader is None:
        return True
    if isinstance(dataloader, list):
        return all([_check_dataloader(d) for d in dataloader])
    return isinstance(dataloader,
                      torch_data.DataLoader) and is_traceable(dataloader)
예제 #5
0
    def mutate(cls, module, name, memo, mutate_kwargs):
        """Find value choice in module's arguments and replace the whole module"""
        has_valuechoice = False
        if isinstance(module, cls.bound_type) and is_traceable(module):
            for arg in itertools.chain(
                    cast(list, module.trace_args),
                    cast(dict, module.trace_kwargs).values()):
                if isinstance(arg, ValueChoiceX):
                    has_valuechoice = True

        if has_valuechoice:
            if module.trace_args:
                raise ValueError(
                    'ValueChoice on class arguments cannot appear together with ``trace_args``. '
                    'Please enable ``kw_only`` on nni.trace.')

            # save type and kwargs
            mixed_op = cls(cast(dict, module.trace_kwargs))

            if 'mixed_op_sampling' not in mutate_kwargs:
                raise ValueError(
                    'Need to sampling policy of mixed op, but not found in `mutate_kwargs`.'
                )
            policy_cls: Type[MixedOperationSamplingPolicy] = mutate_kwargs[
                'mixed_op_sampling']
            # initialize policy class
            # this is put in mutate because we need to access memo
            mixed_op.sampling_policy = policy_cls(mixed_op, memo,
                                                  mutate_kwargs)

            return mixed_op
예제 #6
0
def _test_multiprocessing_dataset_worker(dataset):
    if sys.platform == 'linux':
        # on non-linux, the loaded object will become non-traceable
        # due to an implementation limitation
        assert is_traceable(dataset)
    else:
        from torch.utils.data import Dataset
        assert isinstance(dataset, Dataset)
예제 #7
0
 def from_trace(lr_scheduler_trace: Traceable):
     assert is_traceable(lr_scheduler_trace), \
         'Please use nni.trace to wrap the lr scheduler class before initialize the scheduler.'
     assert isinstance(lr_scheduler_trace, _LRScheduler), \
         'It is not an instance of torch.nn.lr_scheduler._LRScheduler.'
     return LRSchedulerConstructHelper(lr_scheduler_trace.trace_symbol,
                                       *lr_scheduler_trace.trace_args,
                                       **lr_scheduler_trace.trace_kwargs)
예제 #8
0
 def from_trace(model: Module, optimizer_trace: Traceable):
     assert is_traceable(optimizer_trace), \
         'Please use nni.trace to wrap the optimizer class before initialize the optimizer.'
     assert isinstance(optimizer_trace, Optimizer), \
         'It is not an instance of torch.nn.Optimizer.'
     return OptimizerConstructHelper(model, optimizer_trace.trace_symbol,
                                     *optimizer_trace.trace_args,
                                     **optimizer_trace.trace_kwargs)
예제 #9
0
def get_init_parameters_or_fail(obj: Any):
    if is_traceable(obj):
        return obj.trace_kwargs
    raise ValueError(f'Object {obj} needs to be serializable but `trace_kwargs` is not available. '
                     'If it is a built-in module (like Conv2d), please import it from retiarii.nn. '
                     'If it is a customized module, please to decorate it with @basic_unit. '
                     'For other complex objects (e.g., trainer, optimizer, dataset, dataloader), '
                     'try to use @nni.trace.')
예제 #10
0
def _check_dataloader(dataloader):
    # Check the type of dataloader recursively.
    if isinstance(dataloader, list):
        return all([_check_dataloader(d) for d in dataloader])
    if isinstance(dataloader, dict):
        return all([_check_dataloader(v) for v in dataloader.values()])
    if isinstance(dataloader, torch_data.DataLoader):
        return is_traceable(dataloader)
    return True
예제 #11
0
def test_function():
    t = nni.trace(math.sqrt, kw_only=False)(3)
    assert 1 < t < 2
    assert t.trace_symbol == math.sqrt
    assert t.trace_args == [3]
    t = nni.load(nni.dump(t))
    assert 1 < t < 2
    assert not is_traceable(t)  # trace not recovered, expected, limitation

    def simple_class_factory(bb=3.):
        return SimpleClass(1, bb)

    t = nni.trace(simple_class_factory)(4)
    ts = nni.dump(t)
    assert '__kwargs__' in ts
    t = nni.load(ts)
    assert t._a == 1
    assert is_traceable(t)
    t = t.trace_copy()
    assert is_traceable(t)
    assert t.trace_symbol(10)._b == 10
    assert t.trace_kwargs['bb'] == 4
    assert is_traceable(t.trace_copy())
예제 #12
0
def no_default_hook(module: nn.Module, name: str, memo: dict[str, Any],
                    mutate_kwargs: dict[str, Any]) -> bool:
    """Add this hook at the end of your hook list to raise error for unsupported mutation primitives."""

    # Forward IS NOT supernet
    primitive_list = (
        nas_nn.LayerChoice,
        nas_nn.InputChoice,
        nas_nn.Repeat,
        nas_nn.NasBench101Cell,
        # nas_nn.ValueChoice,       # could be false positive
        # nas_nn.Cell,              # later
        # nas_nn.NasBench201Cell,   # forward = supernet
    )

    if isinstance(module, primitive_list):
        raise TypeError(f'{type(module).__name__} is not supported')

    if isinstance(module, nas_nn.Cell) and module.merge_op != 'all':
        # need output_node_indices, which depends on super-net
        raise TypeError(
            f'Cell with merge_op `{module.merge_op}` is not supported')

    if is_traceable(module):
        # check whether there is a value-choice in its arguments
        has_valuechoice = False
        for arg in chain(cast(list, module.trace_args),
                         cast(dict, module.trace_kwargs).values()):
            if isinstance(arg, ValueChoiceX):
                has_valuechoice = True
                break

        if has_valuechoice:
            raise TypeError(
                f'`basic_unit` {type(module).__name__} with value choice in its arguments is not supported. '
                'Please try to remove `basic_unit` to see if that works, or support this type with value choice manually.'
            )

    return True  # suppress all other hooks
예제 #13
0
def _is_traceable_object(obj: Any) -> bool:
    # Is it a traceable "object" (not class)?
    return is_traceable(obj) and not is_wrapped_with_trace(obj)