def test_train_model_eval_model(self): class TestModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(1, 2) self.dropout = nn.Dropout() self.seq = nn.Sequential(nn.ReLU(), nn.Conv2d(1, 2, 3), nn.BatchNorm2d(1, 2)) test_model = TestModel() for train in [True, False]: test_model.train(train) # flip some of the modes test_model.dropout.train(not train) test_model.seq[1].train(not train) orig_model = copy.deepcopy(test_model) with util.train_model(test_model): self._check_model_train_mode(test_model, True) # the modes should be different inside the context manager self.assertFalse( self._compare_model_train_mode(orig_model, test_model)) self.assertTrue( self._compare_model_train_mode(orig_model, test_model)) with util.eval_model(test_model): self._check_model_train_mode(test_model, False) # the modes should be different inside the context manager self.assertFalse( self._compare_model_train_mode(orig_model, test_model)) self.assertTrue( self._compare_model_train_mode(orig_model, test_model))
def compute_complexity(model, compute_fn, input_shape, input_key=None): """ Compute the complexity of a forward pass. """ # assertions, input, and upvalue in which we will perform the count: assert isinstance(model, nn.Module) if not isinstance(input_shape, abc.Sequence) and not isinstance( input_shape, dict): return None else: input = get_model_dummy_input(model, input_shape, input_key) compute_list = [] # measure FLOPs: modify_forward(model, compute_list, compute_fn) try: # compute complexity in eval mode with eval_model(model), torch.no_grad(): model.forward(input) except NotImplementedError as err: raise err finally: restore_forward(model) return sum(compute_list)
def torchscript_using_trace(self, model): input_shape = model.input_shape if hasattr(model, "input_shape") else None if not input_shape: logging.warning("This model doesn't implement input_shape." "Cannot save torchscripted model.") return input_data = get_model_dummy_input( model, input_shape, input_key=model.input_key if hasattr(model, "input_key") else None, ) with eval_model(model) and torch.no_grad(): torchscript = torch.jit.trace(model, input_data) return torchscript
def compute_complexity( model, compute_fn, input_shape, input_key=None, patch_attr=None, compute_unique=False, ): """ Compute the complexity of a forward pass. Args: compute_unique: If True, the compexity for a given module is only calculated once. Otherwise, it is counted every time the module is called. TODO(@mannatsingh): We have some assumptions about only modules which are leaves or have patch_attr defined. This should be fixed and generalized if possible. """ # assertions, input, and upvalue in which we will perform the count: assert isinstance(model, nn.Module) if not isinstance(input_shape, abc.Sequence) and not isinstance( input_shape, dict): return None else: input = get_model_dummy_input(model, input_shape, input_key) complexity_computer = ComplexityComputer(compute_fn, compute_unique) # measure FLOPs: modify_forward(model, complexity_computer, patch_attr=patch_attr) try: # compute complexity in eval mode with eval_model(model), torch.no_grad(): model.forward(input) except NotImplementedError as err: raise err finally: restore_forward(model, patch_attr=patch_attr) return complexity_computer.count
def save_torchscript(self, task) -> None: model = task.base_model input_shape = (model.input_shape if hasattr(task.base_model, "input_shape") else None) if not input_shape: logging.warning("This model doesn't implement input_shape." "Cannot save torchscripted model.") return input_data = get_model_dummy_input( model, input_shape, input_key=model.input_key if hasattr(model, "input_key") else None, ) with eval_model(model) and torch.no_grad(): torchscript = torch.jit.trace(model, input_data) # save torchscript: logging.info("Saving torchscript to '{}'...".format( self.torchscript_folder)) torchscript_name = f"{self.torchscript_folder}/{TORCHSCRIPT_FILE}" with PathManager.open(torchscript_name, "wb") as f: torch.jit.save(torchscript, f)
def profile( model: nn.Module, batchsize_per_replica: int = 32, input_shape: Tuple[int] = (3, 224, 224), use_nvprof: bool = False, input_key: Optional[Union[str, List[str]]] = None, ): """ Performs CPU or GPU profiling of the specified model on the specified input. """ # assertions: if use_nvprof: raise ClassyProfilerError("Profiling not supported with nvprof") # FIXME (mannatsingh): in case of use_nvprof, exit() is called at the end # and we do not return a profile. assert is_on_gpu(model), "can only nvprof model that lives on GPU" logging.info("CUDA profiling: Make sure you are running under nvprof!") # input for model: input = get_model_dummy_input( model, input_shape, input_key, batchsize=batchsize_per_replica, non_blocking=False, ) # perform profiling in eval mode with eval_model(model), torch.no_grad(): model(input) # warm up CUDA memory allocator and profiler if use_nvprof: # nvprof profiling (TODO: Can we infer this?) cudart().cudaProfilerStart() model(input) cudart().cudaProfilerStop() exit() # exit gracefully else: # regular profiling with torch.autograd.profiler.profile(use_cuda=True) as profiler: model(input) return profiler
def torchscript_using_script(self, model): with eval_model(model) and torch.no_grad(): torchscript = torch.jit.script(model) return torchscript