def test_train_model_eval_model(self):
        class TestModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(1, 2)
                self.dropout = nn.Dropout()
                self.seq = nn.Sequential(nn.ReLU(), nn.Conv2d(1, 2, 3),
                                         nn.BatchNorm2d(1, 2))

        test_model = TestModel()
        for train in [True, False]:
            test_model.train(train)

            # flip some of the modes
            test_model.dropout.train(not train)
            test_model.seq[1].train(not train)

            orig_model = copy.deepcopy(test_model)

            with util.train_model(test_model):
                self._check_model_train_mode(test_model, True)
                # the modes should be different inside the context manager
                self.assertFalse(
                    self._compare_model_train_mode(orig_model, test_model))
            self.assertTrue(
                self._compare_model_train_mode(orig_model, test_model))

            with util.eval_model(test_model):
                self._check_model_train_mode(test_model, False)
                # the modes should be different inside the context manager
                self.assertFalse(
                    self._compare_model_train_mode(orig_model, test_model))
            self.assertTrue(
                self._compare_model_train_mode(orig_model, test_model))
예제 #2
0
def compute_complexity(model, compute_fn, input_shape, input_key=None):
    """
    Compute the complexity of a forward pass.
    """
    # assertions, input, and upvalue in which we will perform the count:
    assert isinstance(model, nn.Module)

    if not isinstance(input_shape, abc.Sequence) and not isinstance(
            input_shape, dict):
        return None
    else:
        input = get_model_dummy_input(model, input_shape, input_key)

    compute_list = []

    # measure FLOPs:
    modify_forward(model, compute_list, compute_fn)
    try:
        # compute complexity in eval mode
        with eval_model(model), torch.no_grad():
            model.forward(input)
    except NotImplementedError as err:
        raise err
    finally:
        restore_forward(model)

    return sum(compute_list)
예제 #3
0
 def torchscript_using_trace(self, model):
     input_shape = model.input_shape if hasattr(model,
                                                "input_shape") else None
     if not input_shape:
         logging.warning("This model doesn't implement input_shape."
                         "Cannot save torchscripted model.")
         return
     input_data = get_model_dummy_input(
         model,
         input_shape,
         input_key=model.input_key if hasattr(model, "input_key") else None,
     )
     with eval_model(model) and torch.no_grad():
         torchscript = torch.jit.trace(model, input_data)
     return torchscript
예제 #4
0
def compute_complexity(
    model,
    compute_fn,
    input_shape,
    input_key=None,
    patch_attr=None,
    compute_unique=False,
):
    """
    Compute the complexity of a forward pass.

    Args:
        compute_unique: If True, the compexity for a given module is only calculated
            once. Otherwise, it is counted every time the module is called.

    TODO(@mannatsingh): We have some assumptions about only modules which are leaves
        or have patch_attr defined. This should be fixed and generalized if possible.
    """
    # assertions, input, and upvalue in which we will perform the count:
    assert isinstance(model, nn.Module)

    if not isinstance(input_shape, abc.Sequence) and not isinstance(
            input_shape, dict):
        return None
    else:
        input = get_model_dummy_input(model, input_shape, input_key)

    complexity_computer = ComplexityComputer(compute_fn, compute_unique)

    # measure FLOPs:
    modify_forward(model, complexity_computer, patch_attr=patch_attr)
    try:
        # compute complexity in eval mode
        with eval_model(model), torch.no_grad():
            model.forward(input)
    except NotImplementedError as err:
        raise err
    finally:
        restore_forward(model, patch_attr=patch_attr)

    return complexity_computer.count
예제 #5
0
    def save_torchscript(self, task) -> None:
        model = task.base_model
        input_shape = (model.input_shape
                       if hasattr(task.base_model, "input_shape") else None)
        if not input_shape:
            logging.warning("This model doesn't implement input_shape."
                            "Cannot save torchscripted model.")
            return
        input_data = get_model_dummy_input(
            model,
            input_shape,
            input_key=model.input_key if hasattr(model, "input_key") else None,
        )
        with eval_model(model) and torch.no_grad():
            torchscript = torch.jit.trace(model, input_data)

        # save torchscript:
        logging.info("Saving torchscript to '{}'...".format(
            self.torchscript_folder))
        torchscript_name = f"{self.torchscript_folder}/{TORCHSCRIPT_FILE}"
        with PathManager.open(torchscript_name, "wb") as f:
            torch.jit.save(torchscript, f)
예제 #6
0
def profile(
    model: nn.Module,
    batchsize_per_replica: int = 32,
    input_shape: Tuple[int] = (3, 224, 224),
    use_nvprof: bool = False,
    input_key: Optional[Union[str, List[str]]] = None,
):
    """
    Performs CPU or GPU profiling of the specified model on the specified input.
    """
    # assertions:
    if use_nvprof:
        raise ClassyProfilerError("Profiling not supported with nvprof")
        # FIXME (mannatsingh): in case of use_nvprof, exit() is called at the end
        # and we do not return a profile.
        assert is_on_gpu(model), "can only nvprof model that lives on GPU"
        logging.info("CUDA profiling: Make sure you are running under nvprof!")

    # input for model:
    input = get_model_dummy_input(
        model,
        input_shape,
        input_key,
        batchsize=batchsize_per_replica,
        non_blocking=False,
    )
    # perform profiling in eval mode
    with eval_model(model), torch.no_grad():
        model(input)  # warm up CUDA memory allocator and profiler
        if use_nvprof:  # nvprof profiling (TODO: Can we infer this?)
            cudart().cudaProfilerStart()
            model(input)
            cudart().cudaProfilerStop()
            exit()  # exit gracefully
        else:  # regular profiling
            with torch.autograd.profiler.profile(use_cuda=True) as profiler:
                model(input)
                return profiler
예제 #7
0
 def torchscript_using_script(self, model):
     with eval_model(model) and torch.no_grad():
         torchscript = torch.jit.script(model)
     return torchscript