def main(args, _): """Main method for ``catalyst-dl trace``.""" logdir: Path = args.logdir method_name: str = args.method checkpoint_name: str = args.checkpoint mode: str = args.mode requires_grad: bool = args.with_grad opt_level: str = args.opt_level if opt_level is not None: opt_level = opt_level device = "cuda" else: opt_level = None device = "cpu" traced = trace_model_from_checkpoint( logdir, method_name, checkpoint_name=checkpoint_name, stage=args.stage, loader=args.loader, mode=mode, requires_grad=requires_grad, opt_level=opt_level, device=device, ) if args.out_model is None: file_name = utils.get_trace_name( method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, additional_string=checkpoint_name, ) output: Path = args.out_dir if output is None: output: Path = logdir / "trace" output.mkdir(exist_ok=True, parents=True) out_model = str(output / file_name) else: out_model = str(args.out_model) torch.jit.save(traced, out_model)
def test_tracer_callback(): """ Tests a feature of `TracerCallback` for model tracing during training """ logdir = "./logs" dataset_root = "./dataset" loaders = _get_loaders(root=dataset_root, batch_size=4, num_workers=1) images, targets = next(iter(loaders["train"])) _, c, h, w = images.shape input_shape = (c, h, w) model = _TracedNet(input_shape) criterion = nn.CrossEntropyLoss() optimizer = Adam(model.parameters()) method_name = "forward" mode = "eval" requires_grad = False checkpoint_name = "best" opt_level = None trace_name = get_trace_name( method_name=method_name, mode=mode, requires_grad=requires_grad, additional_string=checkpoint_name, ) tracing_path = Path(logdir) / "trace" / trace_name criterion_callback = CriterionCallback() optimizer_callback = OptimizerCallback() tracer_callback = TracerCallback( metric="loss", minimize=False, trace_mode=mode, mode=checkpoint_name, do_once=True, method_name=method_name, requires_grad=requires_grad, opt_level=opt_level, ) test_callback = _OnStageEndCheckModelTracedCallback( path=tracing_path, inputs=images, ) callbacks = collections.OrderedDict( loss=criterion_callback, optimizer=optimizer_callback, tracer_callback=tracer_callback, test_callback=test_callback, ) runner = SupervisedRunner(input_key="x") runner.train( model=model, criterion=criterion, optimizer=optimizer, loaders=loaders, logdir=logdir, callbacks=callbacks, check=True, verbose=True, ) shutil.rmtree(logdir) shutil.rmtree(dataset_root)
def trace( self, *, model: Model = None, batch: Any = None, logdir: str = None, loader: DataLoader = None, method_name: str = "forward", mode: str = "eval", requires_grad: bool = False, fp16: Union[Dict, bool] = None, device: Device = "cpu", predict_params: dict = None, ) -> ScriptModule: """ Traces model using Torch Jit. Args: model (Model): model to trace batch: batch to forward through the model to trace logdir (str, optional): If specified, the result will be written to the directory loader (DataLoader, optional): if batch is not specified, the batch will be ``next(iter(loader))`` method_name (str): model's method name that will be traced mode (str): ``train`` or ``eval`` requires_grad (bool): flag to trace with gradients fp16 (Union[Dict, bool]): If not None, then sets tracing params to FP16 device (Device): Torch deivice or a string predict_params (dict): additional parameters for model forward """ if batch is None: if loader is None: raise ValueError( "If batch is not provided the loader must be specified") batch = next(iter(loader)) if model is not None: self.model = model assert self.model is not None if isinstance(fp16, bool) and fp16: opt_level = "O1" elif isinstance(fp16, bool) and not fp16: opt_level = None elif isinstance(fp16, dict): opt_level = fp16["opt_level"] else: opt_level = fp16 if opt_level is not None: device = "cuda" elif device is None: if self.device is None: self.device = utils.get_device() device = self.device result = utils.trace_model( model=self.model, runner=self, batch=batch, method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, device=device, predict_params=predict_params, ) if logdir is not None: filename = utils.get_trace_name( method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, ) logdir = Path(logdir) output: Path = logdir / "trace" output.mkdir(exist_ok=True, parents=True) out_model = str(output / filename) torch.jit.save(result, out_model) return result