def main(): parser = base_config.get_base_config() params = base_config.make_parser(parser) dir = osp.join('../log/baseline_' + params.task, params.output_dir) dir = get_dir(dir) if not osp.exists(dir): os.makedirs(dir) if params.write_log: logger.set_file_handler(dir, time_str=params.exp_id) argparse_dict = vars(params) import json with open(osp.join(dir, 'args.json'), 'w') as f: json.dump(argparse_dict, f) print('Training starts at {}'.format(init_path.get_abs_base_dir())) if params.separate_train: train(trainer.Trainer, ppo_runner, base_worker, sparse_ppo_policy.SparsePPOPolicy, ppo_policy.PPOPolicy, params) else: train(trainer.Trainer, ppo_runner, base_worker, consolidated_ppo_policy.ConsolidatedPPOPolicy, ppo_policy.PPOPolicy, params)
def main(): parser = base_config.get_base_config() parser = ecco_config.get_ecco_config(parser) parser = dqn_transfer_config.get_dqn_transfer_config(parser) args = base_config.make_parser(parser) if args.write_log: logger.set_file_handler(path=args.output_dir, prefix='ecco_ecco' + args.task, time_str=args.exp_id) print('DQN_TRANSFER_MAIN.PY is Deprecated, do not use') print('Training starts at {}'.format(init_path.get_abs_base_dir())) from trainer import dqn_transfer_trainer from runners import dqn_transfer_task_sampler from runners.workers import dqn_transfer_worker from policy import ecco_pretrain from policy import dqn_base, a2c_base from policy import ecco_transfer base_model = {'dqn': dqn_base, 'a2c': a2c_base}[args.base_policy] models = { 'final': ecco_pretrain.model, 'transfer': ecco_transfer.model, 'base': base_model.model } pretrain_weights = None train(dqn_transfer_trainer.trainer, dqn_transfer_task_sampler, dqn_transfer_worker, models, args, pretrain_weights)
def main(): parser = base_config.get_base_config() parser = ecco_config.get_ecco_config(parser) args = base_config.make_parser(parser) if args.write_log: logger.set_file_handler(path=args.output_dir, prefix='ecco_ecco' + args.task, time_str=args.exp_id) print('Training starts at {}'.format(init_path.get_abs_base_dir())) from trainer import ecco_trainer from runners import task_sampler from runners.workers import base_worker from policy import ecco_pretrain train(ecco_trainer.trainer, task_sampler, base_worker, ecco_pretrain.model, args)
def main(): parser = base_config.get_base_config() parser = ecco_config.get_ecco_config(parser) parser = dqn_transfer_config.get_dqn_transfer_config(parser) args = base_config.make_parser(parser) if args.write_log: logger.set_file_handler(path=args.output_dir, prefix='ecco_ecco' + args.task, time_str=args.exp_id) from trainer import dqn_transfer_trainer, dqn_transfer_jwt from runners import dqn_transfer_task_sampler from runners.workers import dqn_transfer_worker from policy import ecco_pretrain from policy import dqn_base, a2c_base from policy import ecco_transfer base_model = { 'dqn': dqn_base, 'a2c':a2c_base }[args.base_policy] models = {'final': ecco_pretrain.model, 'transfer': ecco_transfer.model, 'base': base_model.model} from env.env_utils import load_environments if args.load_environments is not None: environments_cache = load_environments( args.load_environments, args.num_cache, args.task, args.episode_length, args.seed ) else: environments_cache = None train(dqn_transfer_trainer.trainer, dqn_transfer_task_sampler, dqn_transfer_worker, models, args, {'pretrain_fnc':pretrain, 'pretrain_thread': dqn_transfer_jwt}, environments_cache)