def build_and_train(run_id=0, greedy_eval=False, test=True, test_date=None): sampler = BatchedEpisodicSampler( EnvCls=MyEnv, env_kwargs=dict(), batch_T=500, batch_B=64, ) log_dir = "data/rl_example_3/" init_agent = None if test: data = load_params(log_dir, run_id, test_date) init_agent = data['agent_state_dict'] runner = MinibatchRl( algo=PPO(entropy_loss_coeff=0., learning_rate=3e-4), agent=AgentPgDiscrete( greedy_eval, model_kwargs={ 'policy_hidden_sizes': [64, 64], 'value_hidden_sizes': [64, 64], }, initial_model_state_dict=init_agent, ), sampler=sampler, n_steps=int(400 * sampler.batch_size), log_interval_steps=int(10 * sampler.batch_size), ) if test: runner.startup() sampler.obtain_samples(0, 'eval') obs = sampler.samples_np.env.observation plot_obs(obs) else: with logger_context("{}{}".format( log_dir, datetime.datetime.today().strftime("%Y%m%d_%H%M")), run_id, 'Reacher2D', snapshot_mode="last", use_summary_writer=True, override_prefix=True): runner.train()
action_start_time=3., action_end_time=7., open_gripper_on_leave=action_class != ActionClasses.PICK_UP, close_gripper_on_leave=action_class == ActionClasses.PICK_UP, ), batch_T=horizon, batch_B=1, max_decorrelation_steps=0 ) algo = args.get_ppo_from_options(options) agent = AgentPgContinuous( options.greedy_eval, ModelCls=ModelPgNNContinuousSelective, initial_model_state_dict=args.load_initial_model_state(options), model_kwargs=dict( policy_hidden_sizes=[128, 128, 128], policy_hidden_nonlinearity=torch.nn.Tanh, value_hidden_sizes=[128, 128, 128], value_hidden_nonlinearity=torch.nn.Tanh, policy_inputs_indices=list(range(8)) if options.without_object_obs else None, ) ) runner = MinibatchRl( algo=algo, agent=agent, sampler=sampler, log_traj_window=1, seed=options.seed, n_steps=1, log_interval_steps=int(1 * horizon), affinity=args.get_affinity(options) ) runner.startup() for i in tqdm(range(bench_data.shape[0])): benchmark_sample.bid = i + 1 sampler.obtain_samples(i) GripperCylinderEnv.df_from_observations(sampler.samples_np.env.observation[:, 0, :]).to_csv( '{}/trajectory_{}.csv'.format(output_dir, benchmark_sample.bid - 1))