def train_smac(ctxt=None, args_dict=vars(args)): args = SimpleNamespace(**args_dict) env = SMACWrapper( centralized=True, # important, using centralized sampler map_name=args.map, difficulty=args.difficulty, # seed=args.seed ) env = GarageEnv(env) runner = LocalRunnerWrapper(ctxt, eval=args.eval_during_training, n_eval_episodes=args.n_eval_episodes, eval_greedy=args.eval_greedy, eval_epoch_freq=args.eval_epoch_freq, save_env=env.pickleable) hidden_nonlinearity = F.relu if args.hidden_nonlinearity == 'relu' \ else torch.tanh policy = DecCategoricalLSTMPolicy( env.spec, n_agents=env.n_agents, encoder_hidden_sizes=args.encoder_hidden_sizes, embedding_dim=args.embedding_dim, # encoder output size lstm_hidden_size=args.lstm_hidden_size, state_include_actions=args.state_include_actions, name='dec_categorical_lstm_policy') baseline = GaussianMLPBaseline(env_spec=env.spec, hidden_sizes=(64, 64, 64)) # Set max_path_length <= max_steps # If max_path_length > max_steps, algo will pad obs # obs.shape = torch.Size([n_paths, algo.max_path_length, feat_dim]) algo = CentralizedMAPPO( env_spec=env.spec, policy=policy, baseline=baseline, max_path_length=env.episode_limit, # Notice discount=args.discount, center_adv=bool(args.center_adv), positive_adv=bool(args.positive_adv), gae_lambda=args.gae_lambda, policy_ent_coeff=args.ent, entropy_method=args.entropy_method, stop_entropy_gradient=True \ if args.entropy_method == 'max' else False, clip_grad_norm=args.clip_grad_norm, optimization_n_minibatches=args.opt_n_minibatches, optimization_mini_epochs=args.opt_mini_epochs, ) runner.setup(algo, env, sampler_cls=CentralizedMAOnPolicyVectorizedSampler, sampler_args={'n_envs': args.n_envs}) runner.train(n_epochs=args.n_epochs, batch_size=args.bs)
def train_predatorprey(ctxt=None, args_dict=vars(args)): args = SimpleNamespace(**args_dict) set_seed(args.seed) env = PredatorPreyWrapper( centralized=True, # centralized training grid_shape=(args.grid_size, args.grid_size), n_agents=args.n_agents, n_preys=args.n_preys, max_steps=args.max_env_steps, step_cost=args.step_cost, prey_capture_reward=args.capture_reward, penalty=args.penalty, other_agent_visible=args.agent_visible) env = GarageEnv(env) runner = LocalRunnerWrapper(ctxt, eval=args.eval_during_training, n_eval_episodes=args.n_eval_episodes, eval_greedy=args.eval_greedy, eval_epoch_freq=args.eval_epoch_freq, save_env=env.pickleable) hidden_nonlinearity = F.relu if args.hidden_nonlinearity == 'relu' \ else torch.tanh policy = DecCategoricalMLPPolicy( env.spec, env.n_agents, hidden_nonlinearity=hidden_nonlinearity, hidden_sizes=args.hidden_sizes, name='dec_categorical_mlp_policy') baseline = GaussianMLPBaseline(env_spec=env.spec, hidden_sizes=(64, 64, 64)) # Set max_path_length <= max_steps # If max_path_length > max_steps, algo will pad obs # obs.shape = torch.Size([n_paths, algo.max_path_length, feat_dim]) algo = CentralizedMAPPO( env_spec=env.spec, policy=policy, baseline=baseline, max_path_length=args.max_env_steps, # Notice discount=args.discount, center_adv=bool(args.center_adv), positive_adv=bool(args.positive_adv), gae_lambda=args.gae_lambda, policy_ent_coeff=args.ent, entropy_method=args.entropy_method, stop_entropy_gradient = True \ if args.entropy_method == 'max' else False, optimization_n_minibatches=args.opt_n_minibatches, optimization_mini_epochs=args.opt_mini_epochs, ) runner.setup(algo, env, sampler_cls=CentralizedMAOnPolicyVectorizedSampler, sampler_args={'n_envs': args.n_envs}) runner.train(n_epochs=args.n_epochs, batch_size=args.bs)
def restore_training(log_dir, exp_name, args, env_saved=True, env=None): tabular_log_file = os.path.join( log_dir, 'progress_restored.{}.{}.csv'.format( str(time.time())[:10], socket.gethostname())) text_log_file = os.path.join( log_dir, 'debug_restored.{}.{}.log'.format( str(time.time())[:10], socket.gethostname())) logger.add_output(dowel.TextOutput(text_log_file)) logger.add_output(dowel.CsvOutput(tabular_log_file)) logger.add_output(dowel.TensorBoardOutput(log_dir)) logger.add_output(dowel.StdOutput()) logger.push_prefix('[%s] ' % exp_name) ctxt = ExperimentContext(snapshot_dir=log_dir, snapshot_mode='last', snapshot_gap=1) runner = LocalRunnerWrapper(ctxt, eval=args.eval_during_training, n_eval_episodes=args.n_eval_episodes, eval_greedy=args.eval_greedy, eval_epoch_freq=args.eval_epoch_freq, save_env=env_saved) saved = runner._snapshotter.load(log_dir, 'last') runner._setup_args = saved['setup_args'] runner._train_args = saved['train_args'] runner._stats = saved['stats'] set_seed(runner._setup_args.seed) algo = saved['algo'] # Compatibility patch if not hasattr(algo, '_clip_grad_norm'): setattr(algo, '_clip_grad_norm', args.clip_grad_norm) if env_saved: env = saved['env'] runner.setup(env=env, algo=algo, sampler_cls=runner._setup_args.sampler_cls, sampler_args=runner._setup_args.sampler_args) runner._train_args.start_epoch = runner._stats.total_epoch + 1 runner._train_args.n_epochs = runner._train_args.start_epoch + args.n_epochs print('\nRestored checkpoint from epoch #{}...'.format( runner._train_args.start_epoch)) print('To be trained for additional {} epochs...'.format(args.n_epochs)) print('Will be finished at epoch #{}...\n'.format( runner._train_args.n_epochs)) return runner._algo.train(runner)
def train_predatorprey(ctxt=None, args_dict=vars(args)): args = SimpleNamespace(**args_dict) set_seed(args.seed) if args.curriculum: curr_start = int(0.125 * args.n_epochs) curr_end = int(0.625 * args.n_epochs) else: curr_start = 0 curr_end = 0 args.add_rate_min = args.add_rate_max env = TrafficJunctionWrapper(centralized=True, dim=args.dim, vision=1, add_rate_min=args.add_rate_min, add_rate_max=args.add_rate_max, curr_start=curr_start, curr_end=curr_end, difficulty=args.difficulty, n_agents=args.n_agents, max_steps=args.max_env_steps) env = GarageEnv(env) runner = LocalRunnerWrapper(ctxt, eval=args.eval_during_training, n_eval_episodes=args.n_eval_episodes, eval_greedy=args.eval_greedy, eval_epoch_freq=args.eval_epoch_freq, save_env=env.pickleable) hidden_nonlinearity = F.relu if args.hidden_nonlinearity == 'relu' \ else torch.tanh policy = DecCategoricalMLPPolicy( env.spec, env.n_agents, hidden_nonlinearity=hidden_nonlinearity, hidden_sizes=args.policy_hidden_sizes, name='dec_categorical_mlp_policy') baseline = DICGCritic( env.spec, env.n_agents, encoder_hidden_sizes=args.encoder_hidden_sizes, embedding_dim=args.embedding_dim, attention_type=args.attention_type, n_gcn_layers=args.n_gcn_layers, residual=args.residual, gcn_bias=args.gcn_bias, name='dicg_critic') # Set max_path_length <= max_steps # If max_path_length > max_steps, algo will pad obs # obs.shape = torch.Size([n_paths, algo.max_path_length, feat_dim]) algo = CentralizedMAPPO( env_spec=env.spec, policy=policy, baseline=baseline, max_path_length=args.max_env_steps, # Notice discount=args.discount, center_adv=bool(args.center_adv), positive_adv=bool(args.positive_adv), gae_lambda=args.gae_lambda, policy_ent_coeff=args.ent, entropy_method=args.entropy_method, stop_entropy_gradient=True \ if args.entropy_method == 'max' else False, clip_grad_norm=args.clip_grad_norm, optimization_n_minibatches=args.opt_n_minibatches, optimization_mini_epochs=args.opt_mini_epochs, ) runner.setup(algo, env, sampler_cls=CentralizedMAOnPolicyVectorizedSampler, sampler_args={'n_envs': args.n_envs}) runner.train(n_epochs=args.n_epochs, batch_size=args.bs)