Пример #1
0
    def _strip_traced_tensors(self, args: Tuple,
                              kwargs: Dict) -> Tuple[Tuple, Dict]:
        """
            Required to guard against new forward calls on tensors that have already passed
            through NNCF's forward once and got turned into TracedTensors by reference access.
        """
        is_traced_tensor_predicate = lambda x: isinstance(x, TracedTensor)

        def strip_fn(tensor: TracedTensor) -> torch.Tensor:
            if hasattr(torch.Tensor, 'as_subclass'):
                return torch.Tensor.as_subclass(tensor, torch.Tensor)
            # Torch < 1.7.0 fallback
            return torch.tensor(tensor,
                                device=tensor.device,
                                requires_grad=tensor.requires_grad)

        args = objwalk(args, is_traced_tensor_predicate, strip_fn)
        kwargs = objwalk(kwargs, is_traced_tensor_predicate, strip_fn)
        return args, kwargs
Пример #2
0
def test_objwalk(objwalk_objects):
    start_obj = objwalk_objects[0]
    ref_obj = objwalk_objects[1]

    def is_target_class(obj):
        return isinstance(obj, ObjwalkTestClass)

    fn_to_apply = partial(ObjwalkTestClass.member_fn, val=OBJWALK_REF_VAL)

    test_obj = objwalk(start_obj, is_target_class, fn_to_apply)

    assert test_obj == ref_obj
Пример #3
0
def test_objwalk_retains_named_tuple():
    named_tuple = NamedTuple(field1=ObjwalkTestClass(OBJWALK_INIT_VAL),
                             field2=NamedTuple(field1=ObjwalkTestClass(OBJWALK_INIT_VAL),
                                               field2=-8))

    def is_target_class(obj):
        return isinstance(obj, ObjwalkTestClass)

    fn_to_apply = partial(ObjwalkTestClass.member_fn, val=OBJWALK_REF_VAL)
    test_obj = objwalk(named_tuple, is_target_class, fn_to_apply)
    assert_named_tuples_are_equal(named_tuple, test_obj)
    assert_named_tuples_are_equal(named_tuple.field2, test_obj.field2)
Пример #4
0
def replicate_same_tensors(obj: Any) -> Any:
    """
    Required to handle the situation when multiple references to one and the
    same tensor are present in the input. If tensor replication is not done, then
    at runtime one and the same tensor could be wrapped by input/output wrappers twice,
    which will disrupt the traced graph structure and possibly hook calls.
    """
    observed_tensor_object_ids = set()  # type: Set[int]

    def replicate_fn(tensor: torch.Tensor) -> torch.Tensor:
        tensor_object_id = id(tensor)
        if tensor_object_id in observed_tensor_object_ids:
            with forward_nncf_trace():
                return tensor.clone()
        observed_tensor_object_ids.add(tensor_object_id)
        return tensor

    obj = objwalk(obj, is_tensor, replicate_fn)
    return obj
Пример #5
0
    def __next__(self):
        if self.num_iter >= self._num_data_iter:
            raise StopIteration
        self.num_iter += 1
        dataloader_output = next(self.data_loader_iter)

        device = next(self._model.parameters()).device
        to_device_fn = partial(torch.Tensor.to, device=device)
        dataloader_output = objwalk(dataloader_output, is_tensor, to_device_fn)
        args, kwargs = self._data_loader.get_inputs(dataloader_output)

        self._model.zero_grad()

        target = self._data_loader.get_target(dataloader_output)
        outputs = self._model(*args, **kwargs)
        loss = self._criterion_fn(outputs, target, self._criterion)

        loss.backward(create_graph=True)
        grads = self._parameter_handler.get_gradients()
        self._model.zero_grad()
        return grads
Пример #6
0
def wrap_nncf_model_outputs_with_objwalk(model_outputs):
    model_outputs = objwalk(model_outputs, is_traced_tensor, nncf_model_output)
    return model_outputs
Пример #7
0
def wrap_nncf_model_inputs_with_objwalk(model_args, model_kwargs):
    model_args = objwalk(model_args, is_tensor, nncf_model_input)
    model_kwargs = objwalk(model_kwargs, is_tensor, nncf_model_input)
    return model_args, model_kwargs
Пример #8
0
 def _infer_batch(self, args_kwargs_tuple, device):
     to_device_fn = partial(torch.Tensor.to, device=device)
     args, kwargs = objwalk(args_kwargs_tuple, is_tensor, to_device_fn)
     self.model(*args, **kwargs)