示例#1
0
def build_and_train(run_id=0, greedy_eval=False, test=True, test_date=None):
    sampler = BatchedEpisodicSampler(
        EnvCls=MyEnv,
        env_kwargs=dict(),
        batch_T=500,
        batch_B=64,
    )
    log_dir = "data/rl_example_3/"
    init_agent = None
    if test:
        data = load_params(log_dir, run_id, test_date)
        init_agent = data['agent_state_dict']
    runner = MinibatchRl(
        algo=PPO(entropy_loss_coeff=0., learning_rate=3e-4),
        agent=AgentPgDiscrete(
            greedy_eval,
            model_kwargs={
                'policy_hidden_sizes': [64, 64],
                'value_hidden_sizes': [64, 64],
            },
            initial_model_state_dict=init_agent,
        ),
        sampler=sampler,
        n_steps=int(400 * sampler.batch_size),
        log_interval_steps=int(10 * sampler.batch_size),
    )
    if test:
        runner.startup()
        sampler.obtain_samples(0, 'eval')
        obs = sampler.samples_np.env.observation
        plot_obs(obs)
    else:
        with logger_context("{}{}".format(
                log_dir,
                datetime.datetime.today().strftime("%Y%m%d_%H%M")),
                            run_id,
                            'Reacher2D',
                            snapshot_mode="last",
                            use_summary_writer=True,
                            override_prefix=True):
            runner.train()
示例#2
0
        action_start_time=3.,
        action_end_time=7.,
        open_gripper_on_leave=action_class != ActionClasses.PICK_UP,
        close_gripper_on_leave=action_class == ActionClasses.PICK_UP,
    ),
    batch_T=horizon, batch_B=1, max_decorrelation_steps=0
)

algo = args.get_ppo_from_options(options)
agent = AgentPgContinuous(
    options.greedy_eval,
    ModelCls=ModelPgNNContinuousSelective,
    initial_model_state_dict=args.load_initial_model_state(options),
    model_kwargs=dict(
        policy_hidden_sizes=[128, 128, 128], policy_hidden_nonlinearity=torch.nn.Tanh,
        value_hidden_sizes=[128, 128, 128], value_hidden_nonlinearity=torch.nn.Tanh,
        policy_inputs_indices=list(range(8)) if options.without_object_obs else None,
    )
)

runner = MinibatchRl(
    algo=algo, agent=agent, sampler=sampler, log_traj_window=1, seed=options.seed, n_steps=1,
    log_interval_steps=int(1 * horizon), affinity=args.get_affinity(options)
)
runner.startup()
for i in tqdm(range(bench_data.shape[0])):
    benchmark_sample.bid = i + 1
    sampler.obtain_samples(i)
    GripperCylinderEnv.df_from_observations(sampler.samples_np.env.observation[:, 0, :]).to_csv(
        '{}/trajectory_{}.csv'.format(output_dir, benchmark_sample.bid - 1))