def main_worker(args, unknown_args): args, config = utils.parse_args_uargs(args, unknown_args) utils.set_global_seed(args.seed) utils.prepare_cudnn(args.deterministic, args.benchmark) config.setdefault("distributed_params", {})["apex"] = args.apex Experiment, Runner = utils.import_experiment_and_runner(Path(args.expdir)) runner_params = config.get("runner_params", {}) experiment = Experiment(config) runner = Runner(**runner_params) if experiment.logdir is not None and get_rank() <= 0: utils.dump_environment(config, experiment.logdir, args.configs) utils.dump_code(args.expdir, experiment.logdir) runner.run_experiment(experiment)
def main(args, unknown_args): """Run the ``catalyst-dl run`` script""" args, config = utils.parse_args_uargs(args, unknown_args) utils.set_global_seed(args.seed) utils.prepare_cudnn(args.deterministic, args.benchmark) Experiment, Runner = utils.import_experiment_and_runner(Path(args.expdir)) runner_params = config.pop("runner_params", {}) or {} experiment = Experiment(config) runner = Runner(**runner_params) if experiment.logdir is not None: utils.dump_environment(config, experiment.logdir, args.configs) utils.dump_code(args.expdir, experiment.logdir) check_run = safitty.get(config, "args", "check", default=False) runner.run_experiment(experiment, check=check_run)
def main_worker(args, unknown_args): """@TODO: Docs. Contribution is welcome.""" args, config = utils.parse_args_uargs(args, unknown_args) utils.set_global_seed(args.seed) utils.prepare_cudnn(args.deterministic, args.benchmark) config.setdefault("distributed_params", {})["apex"] = args.apex experiment_fn, runner_fn = utils.import_experiment_and_runner( Path(args.expdir)) if experiment_fn is None: experiment_params = config.get("experiment_params", {}) experiment = experiment_params.get("experiment", "Experiment") experiment_fn = EXPERIMENTS.get(experiment) runner_params = config.get("runner_params", {}) experiment = experiment_fn(config) runner = runner_fn(**runner_params) if experiment.logdir is not None and get_rank() <= 0: utils.dump_environment(config, experiment.logdir, args.configs) utils.dump_code(args.expdir, experiment.logdir) runner.run_experiment(experiment)
def trace_model_from_checkpoint( logdir: Path, method_name: str, checkpoint_name: str, stage: str = None, loader: Union[str, int] = None, mode: str = "eval", requires_grad: bool = False, opt_level: str = None, device: Device = "cpu", ): """Traces model using created experiment and runner. Args: logdir (Union[str, Path]): Path to Catalyst logdir with model checkpoint_name (str): Name of model checkpoint to use stage (str): experiment's stage name loader (Union[str, int]): experiment's loader name or its index method_name (str): Model's method name that will be used as entrypoint during tracing mode (str): Mode for model to trace (``train`` or ``eval``) requires_grad (bool): Flag to use grads opt_level (str): AMP FP16 init level device (str): Torch device Returns: the traced model """ config_path = logdir / "configs" / "_config.json" checkpoint_path = logdir / "checkpoints" / f"{checkpoint_name}.pth" print("Load config") config: Dict[str, dict] = utils.load_config(config_path) runner_params = config.get("runner_params", {}) or {} # Get expdir name config_expdir = Path(config["args"]["expdir"]) # We will use copy of expdir from logs for reproducibility expdir = Path(logdir) / "code" / config_expdir.name print("Import experiment and runner from logdir") ExperimentType, RunnerType = utils.import_experiment_and_runner(expdir) experiment: Experiment = ExperimentType(config) print(f"Load model state from checkpoints/{checkpoint_name}.pth") if stage is None: stage = list(experiment.stages)[0] model = experiment.get_model(stage) checkpoint = utils.load_checkpoint(checkpoint_path) utils.unpack_checkpoint(checkpoint, model=model) runner: RunnerType = RunnerType(**runner_params) runner.model, runner.device = model, device if loader is None: loader = 0 batch = experiment.get_native_batch(stage, loader) print("Tracing") traced = trace.trace_model( model=model, runner=runner, batch=batch, method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, device=device, ) print("Done") return traced