示例#1
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seed(args.seed)
    prepare_cudnn(args.deterministic, args.benchmark)

    if args.logdir is not None:
        os.makedirs(args.logdir, exist_ok=True)
        dump_environment(config, args.logdir, args.configs)

    if args.expdir is not None:
        module = import_module(expdir=args.expdir)  # noqa: F841
        if args.logdir is not None:
            dump_code(args.expdir, args.logdir)

    env = ENVIRONMENTS.get_from_params(**config["environment"])

    algorithm_name = config["algorithm"].pop("algorithm")
    if algorithm_name in OFFPOLICY_ALGORITHMS_NAMES:
        ALGORITHMS = OFFPOLICY_ALGORITHMS
        trainer_fn = OffpolicyTrainer
        sync_epoch = False
    elif algorithm_name in ONPOLICY_ALGORITHMS_NAMES:
        ALGORITHMS = ONPOLICY_ALGORITHMS
        trainer_fn = OnpolicyTrainer
        sync_epoch = True
    else:
        # @TODO: add registry for algorithms, trainers, samplers
        raise NotImplementedError()

    db_server = DATABASES.get_from_params(
        **config.get("db", {}), sync_epoch=sync_epoch
    )

    algorithm_fn = ALGORITHMS.get(algorithm_name)
    algorithm = algorithm_fn.prepare_for_trainer(env_spec=env, config=config)

    if args.resume is not None:
        checkpoint = utils.load_checkpoint(filepath=args.resume)
        checkpoint = utils.any2device(checkpoint, utils.get_device())
        algorithm.unpack_checkpoint(
            checkpoint=checkpoint,
            with_optimizer=False
        )

    monitoring_params = config.get("monitoring_params", None)

    trainer = trainer_fn(
        algorithm=algorithm,
        env_spec=env,
        db_server=db_server,
        logdir=args.logdir,
        monitoring_params=monitoring_params,
        **config["trainer"],
    )

    trainer.run()
示例#2
0
    def work(self):
        args, config = self.parse_args_uargs()
        set_global_seed(args.seed)

        Experiment, R = import_experiment_and_runner(Path(args.expdir))

        runner_params = config.pop('runner_params', {})

        experiment = Experiment(config)
        runner: Runner = R(**runner_params)

        self.experiment = experiment
        self.runner = runner

        stages = experiment.stages[:]

        if self.task.parent:
            self.parent = self.task_provider.by_id(self.task.parent)

        if self.master:
            task = self.get_parent_task()
            task.steps = len(stages)
            self.task_provider.commit()

        self._checkpoint_fix_config(experiment)
        self._fix_memory(experiment)

        _get_callbacks = experiment.get_callbacks

        def get_callbacks(stage):
            res = self.callbacks()
            for k, v in _get_callbacks(stage).items():
                res[k] = v

            self._checkpoint_fix_callback(res)
            return res

        experiment.get_callbacks = get_callbacks

        if experiment.logdir is not None:
            dump_environment(config, experiment.logdir, args.configs)

        if self.distr_info:
            info = yaml_load(self.task.additional_info)
            info['resume'] = {
                'master_computer': self.distr_info['master_computer'],
                'master_task_id': self.task.id - self.distr_info['rank'],
                'load_best': True
            }
            self.task.additional_info = yaml_dump(info)
            self.task_provider.commit()

            experiment.stages_config = {
                k: v
                for k, v in experiment.stages_config.items()
                if k == experiment.stages[0]
            }

        runner.run_experiment(experiment)
        if runner.state.exception:
            raise runner.state.exception

        if self.master and self.trace:
            traced = trace_model_from_checkpoint(self.experiment.logdir, self)
            torch.jit.save(traced, self.trace)
        return {'stage': experiment.stages[-1], 'stages': stages}