Example #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('rom', type=str)
    parser.add_argument('--outdir',
                        type=str,
                        default='results',
                        help='Directory path to save output files.'
                        ' If it does not exist, it will be created.')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help='Random seed [0, 2 ** 31)')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--demo', action='store_true', default=False)
    parser.add_argument('--load', type=str, default=None)
    parser.add_argument('--use-sdl', action='store_true', default=False)
    parser.add_argument('--final-exploration-frames', type=int, default=10**6)
    parser.add_argument('--final-epsilon', type=float, default=0.1)
    parser.add_argument('--eval-epsilon', type=float, default=0.05)
    parser.add_argument('--model', type=str, default='')
    parser.add_argument('--arch',
                        type=str,
                        default='nature',
                        choices=['nature', 'nips', 'dueling'])
    parser.add_argument('--steps', type=int, default=10**7)
    parser.add_argument('--replay-start-size', type=int, default=5 * 10**4)
    parser.add_argument('--target-update-interval', type=int, default=10**4)
    parser.add_argument('--eval-interval', type=int, default=10**5)
    parser.add_argument('--update-interval', type=int, default=4)
    parser.add_argument('--activation', type=str, default='relu')
    parser.add_argument('--eval-n-runs', type=int, default=10)
    parser.add_argument('--no-clip-delta',
                        dest='clip_delta',
                        action='store_false')
    parser.set_defaults(clip_delta=True)
    parser.add_argument('--agent',
                        type=str,
                        default='DQN',
                        choices=['DQN', 'DoubleDQN', 'PAL'])
    parser.add_argument('--logging-level',
                        type=int,
                        default=20,
                        help='Logging level. 10:DEBUG, 20:INFO etc.')
    args = parser.parse_args()

    import logging
    logging.basicConfig(level=args.logging_level)

    # Set a random seed used in ChainerRL.
    misc.set_random_seed(args.seed, gpus=(args.gpu, ))

    # Set different random seeds for train and test envs.
    train_seed = args.seed
    test_seed = 2**31 - 1 - args.seed

    args.outdir = experiments.prepare_output_dir(args, args.outdir)
    print('Output files are saved in {}'.format(args.outdir))

    # In training, life loss is considered as terminal states
    env = ale.ALE(args.rom, use_sdl=args.use_sdl, seed=train_seed)
    misc.env_modifiers.make_reward_clipped(env, -1, 1)
    # In testing, an episode is terminated  when all lives are lost
    eval_env = ale.ALE(args.rom,
                       use_sdl=args.use_sdl,
                       treat_life_lost_as_terminal=False,
                       seed=test_seed)

    n_actions = env.number_of_actions
    activation = parse_activation(args.activation)
    q_func = parse_arch(args.arch, n_actions, activation)

    # Draw the computational graph and save it in the output directory.
    chainerrl.misc.draw_computational_graph(
        [q_func(np.zeros((4, 84, 84), dtype=np.float32)[None])],
        os.path.join(args.outdir, 'model'))

    # Use the same hyper parameters as the Nature paper's
    opt = optimizers.RMSpropGraves(lr=2.5e-4,
                                   alpha=0.95,
                                   momentum=0.0,
                                   eps=1e-2)

    opt.setup(q_func)

    rbuf = replay_buffer.ReplayBuffer(10**6)

    explorer = explorers.LinearDecayEpsilonGreedy(
        1.0, args.final_epsilon, args.final_exploration_frames,
        lambda: np.random.randint(n_actions))
    Agent = parse_agent(args.agent)
    agent = Agent(q_func,
                  opt,
                  rbuf,
                  gpu=args.gpu,
                  gamma=0.99,
                  explorer=explorer,
                  replay_start_size=args.replay_start_size,
                  target_update_interval=args.target_update_interval,
                  clip_delta=args.clip_delta,
                  update_interval=args.update_interval,
                  batch_accumulator='sum',
                  phi=dqn_phi)

    if args.load:
        agent.load(args.load)

    if args.demo:
        eval_stats = experiments.eval_performance(env=eval_env,
                                                  agent=agent,
                                                  n_runs=args.eval_n_runs)
        print('n_runs: {} mean: {} median: {} stdev {}'.format(
            args.eval_n_runs, eval_stats['mean'], eval_stats['median'],
            eval_stats['stdev']))
    else:
        # In testing DQN, randomly select 5% of actions
        eval_explorer = explorers.ConstantEpsilonGreedy(
            args.eval_epsilon, lambda: np.random.randint(n_actions))
        experiments.train_agent_with_evaluation(
            agent=agent,
            env=env,
            steps=args.steps,
            eval_n_runs=args.eval_n_runs,
            eval_interval=args.eval_interval,
            outdir=args.outdir,
            eval_explorer=eval_explorer,
            save_best_so_far_agent=False,
            eval_env=eval_env)
Example #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--outdir',
                        type=str,
                        default='results',
                        help='Directory path to save output files.'
                        ' If it does not exist, it will be created.')
    parser.add_argument('--env',
                        type=str,
                        default='Hopper-v2',
                        required=True,
                        help='OpenAI Gym MuJoCo env to perform algorithm on.')
    parser.add_argument('--num-envs',
                        type=int,
                        default=1,
                        help='Number of envs run in parallel.')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help='Random seed [0, 2 ** 32)')
    parser.add_argument('--gpu',
                        type=int,
                        default=0,
                        help='GPU to use, set to -1 if no GPU.')
    parser.add_argument('--load',
                        type=str,
                        default='',
                        help='Directory to load agent from.')
    parser.add_argument('--steps',
                        type=int,
                        default=10**6,
                        help='Total number of timesteps to train the agent.')
    parser.add_argument('--eval-n-runs',
                        type=int,
                        default=20,
                        help='Number of episodes run for each evaluation.')
    parser.add_argument('--eval-interval',
                        type=int,
                        default=10000,
                        help='Interval in timesteps between evaluations.')
    parser.add_argument('--replay-start-size',
                        type=int,
                        default=10000,
                        help='Minimum replay buffer size before ' +
                        'performing gradient updates.')
    parser.add_argument('--batch-size',
                        type=int,
                        default=256,
                        help='Minibatch size')
    parser.add_argument('--render',
                        action='store_true',
                        help='Render env states in a GUI window.')
    parser.add_argument('--demo',
                        action='store_true',
                        help='Just run evaluation, not training.')
    parser.add_argument('--monitor',
                        action='store_true',
                        help='Wrap env with gym.wrappers.Monitor.')
    parser.add_argument('--log-interval',
                        type=int,
                        default=1000,
                        help='Interval in timesteps between outputting log'
                        ' messages during training')
    parser.add_argument('--logger-level',
                        type=int,
                        default=logging.INFO,
                        help='Level of the root logger.')
    parser.add_argument('--policy-output-scale',
                        type=float,
                        default=1.,
                        help='Weight initialization scale of polity output.')
    parser.add_argument('--debug', action='store_true', help='Debug mode.')

    # Environment related
    parser.add_argument('--sparse-level', type=int, default=0)
    parser.add_argument('--noise-scale', type=float, default=0.0)

    parser.add_argument('--all-h', type=str, default='256,256,256')
    parser.add_argument('--all-d', type=str, default='2,2,2')
    parser.add_argument('--all-lr', type=str, default='3e-4,3e-4,3e-4')

    parser.add_argument('--cpo-temp', type=float, default=1.0)
    parser.add_argument('--cpo-train-sample',
                        type=str,
                        default='all',
                        choices=['sample', 'all'])
    parser.add_argument('--cpo-num-train-batch', type=int, default=20)
    parser.add_argument('--cpo-select-algo',
                        type=str,
                        default='uniform',
                        choices=['uniform', 'softmax', 'eps-greedy'])
    parser.add_argument('--cpo-select-prob-update',
                        type=str,
                        default='episode',
                        choices=['interval', 'episode'])
    parser.add_argument('--cpo-eps-schedule',
                        type=str,
                        default='linear:500000,1.0,0.1',
                        help='linear:steps,init,final or const:val')
    parser.add_argument('--cpo-recent-returns-interval', type=int, default=5)
    parser.add_argument('--cpo-distill-interval', type=int, default=5000)
    parser.add_argument('--cpo-distill-batch-size', type=int, default=256)
    parser.add_argument('--cpo-distill-epochs', type=int, default=5)
    parser.add_argument('--cpo-distill-lr', type=float, default=1e-3)
    parser.add_argument('--cpo-distill-bc-coef', type=float, default=0.8)
    parser.add_argument('--cpo-distill-type',
                        type=str,
                        default='bc',
                        choices=['bc', 'rlbc'])
    parser.add_argument('--cpo-distill-pi-only',
                        action='store_true',
                        default=False)
    parser.add_argument('--cpo-distill-q-only',
                        action='store_true',
                        default=False)
    parser.add_argument('--cpo-random-distill',
                        action='store_true',
                        default=False)
    parser.add_argument('--cpo-bad-distill',
                        action='store_true',
                        default=False)
    parser.add_argument('--cpo-distill-reset-returns',
                        action='store_true',
                        default=False)
    parser.add_argument('--cpo-distill-perf-metric',
                        type=str,
                        choices=['train', 'eval'],
                        default='train')
    parser.add_argument('--cpo-distill-reset-model',
                        action='store_true',
                        default=False)
    parser.add_argument('--cpo-distill-schedule',
                        type=str,
                        choices=['fix', 'ada'],
                        default='fix')
    parser.add_argument('--cpo-distill-ada-threshold', type=float, default=0.8)
    parser.add_argument('--cpo-test-extra-update',
                        action='store_true',
                        default=False)
    parser.add_argument('--cpo-eval-before-distill',
                        action='store_true',
                        default=False)
    parser.add_argument('--cpo-use-hardcopy',
                        action='store_true',
                        default=False)
    parser.add_argument('--cpo-mutual-learning',
                        action='store_true',
                        default=False)
    args = parser.parse_args()

    logger = logging.getLogger(__name__)

    if args.debug:
        chainer.set_debug(True)

    args.outdir = experiments.prepare_output_dir(args,
                                                 args.outdir,
                                                 argv=sys.argv)

    consoleHandler = logging.StreamHandler()
    fileHandler = logging.FileHandler("{}/log.txt".format(args.outdir))
    logging.basicConfig(level=args.logger_level,
                        handlers=[consoleHandler, fileHandler])
    print('Output files are saved in {}'.format(args.outdir))

    #fileHandler.setFormatter(logFormatter)
    #logger.addHandler(fileHandler)

    #consoleHandler.setFormatter(logFormatter)
    #logger.addHandler(consoleHandler)

    misc.set_random_seed(args.seed, gpus=(args.gpu, ))
    process_seeds = np.arange(args.num_envs) + args.seed * args.num_envs
    assert process_seeds.max() < 2**32

    def make_env(process_idx, test):
        if args.env.startswith('DM'):
            import dm_wrapper
            domain, task = args.env.split('-')[1:]
            env = dm_wrapper.make(domain_name=domain, task_name=task)
            timestep_limit = env.dmcenv._step_limit
        else:
            env = gym.make(args.env)

            # Unwrap TimiLimit wrapper
            assert isinstance(env, gym.wrappers.TimeLimit)
            env = env.env
            timestep_limit = env.spec.tags.get(
                'wrapper_config.TimeLimit.max_episode_steps')

            # Wrapped with FlattenDict
            if isinstance(env, gym.GoalEnv):
                keys = env.observation_space.spaces.keys()
                logger.info('GoalEnv: {}'.format(keys))
                env = gym.wrappers.FlattenDictWrapper(env,
                                                      dict_keys=list(keys))

        # Use different random seeds for train and test envs
        process_seed = int(process_seeds[process_idx])
        env_seed = 2**32 - 1 - process_seed if test else process_seed
        env.seed(env_seed)
        # Cast observations to float32 because our model uses float32
        env = chainerrl.wrappers.CastObservationToFloat32(env)
        # Normalize action space to [-1, 1]^n
        env = chainerrl.wrappers.NormalizeActionSpace(env)
        # Sparsify the reward signal if needed
        if args.sparse_level > 0:
            env = SparseRewardWrapper(env, args.sparse_level, timestep_limit)
        if args.noise_scale > 0:
            from noise_wrapper import NoiseWrapper
            env = NoiseWrapper(env, scale=args.noise_scale)
        if args.monitor:
            env = gym.wrappers.Monitor(env, args.outdir)
        if args.render:
            env = chainerrl.wrappers.Render(env)
        return env

    def make_batch_env(test):
        return chainerrl.envs.MultiprocessVectorEnv([
            functools.partial(make_env, idx, test)
            for idx, env in enumerate(range(args.num_envs))
        ])

    sample_env = make_env(process_idx=0, test=False)
    timestep_limit = sample_env.spec.tags.get(
        'wrapper_config.TimeLimit.max_episode_steps')
    obs_space = sample_env.observation_space
    obs_size = np.asarray(obs_space.shape).prod()
    action_space = sample_env.action_space
    print('Observation space:', obs_space)
    print('Action space:', action_space)

    action_size = action_space.low.size

    all_h = [int(h) for h in args.all_h.split(',')]
    all_d = [int(d) for d in args.all_d.split(',')]
    all_lr = [float(lr) for lr in args.all_lr.split(',')]
    assert len(all_h) == len(all_d) and len(all_d) == len(all_lr)

    rbuf = replay_buffer.ReplayBuffer(10**6)

    def make_agent(h, d, lr):
        funcs, optimizers = make_model(obs_size, action_size, d, h,
                                       args.policy_output_scale, lr)
        policy, policy_optimizer = funcs['pi'], optimizers['pi']
        q_func1, q_func1_optimizer = funcs['q1'], optimizers['q1']
        q_func2, q_func2_optimizer = funcs['q2'], optimizers['q2']

        # Draw the computational graph and save it in the output directory.
        fake_obs = chainer.Variable(policy.xp.zeros_like(
            obs_space.low, dtype=np.float32)[None],
                                    name='observation')
        fake_action = chainer.Variable(policy.xp.zeros_like(
            action_space.low, dtype=np.float32)[None],
                                       name='action')
        chainerrl.misc.draw_computational_graph([policy(fake_obs)],
                                                os.path.join(
                                                    args.outdir, 'policy'))
        chainerrl.misc.draw_computational_graph(
            [q_func1(fake_obs, fake_action)],
            os.path.join(args.outdir, 'q_func1'))
        chainerrl.misc.draw_computational_graph(
            [q_func2(fake_obs, fake_action)],
            os.path.join(args.outdir, 'q_func2'))

        def burnin_action_func():
            """Select random actions until model is updated one or more times."""
            return np.random.uniform(action_space.low,
                                     action_space.high).astype(np.float32)

        # Hyperparameters in http://arxiv.org/abs/1802.09477
        agent = chainerrl.agents.SoftActorCritic(
            policy,
            q_func1,
            q_func2,
            policy_optimizer,
            q_func1_optimizer,
            q_func2_optimizer,
            rbuf,
            gamma=0.99,
            replay_start_size=args.replay_start_size,
            gpu=args.gpu,
            minibatch_size=args.batch_size,
            burnin_action_func=burnin_action_func,
            entropy_target=-action_size,
            temperature_optimizer=chainer.optimizers.Adam(3e-4),
            use_mutual_learning=args.cpo_mutual_learning)

        return agent

    env = make_batch_env(test=False)
    eval_env = make_batch_env(test=True)
    all_agents = [
        make_agent(h, d, lr) for h, d, lr in zip(all_h, all_d, all_lr)
    ]
    if args.cpo_mutual_learning:
        for i in range(len(all_agents)):
            all_agents[i].set_mutual_learning(all_agents, i)

    def distill_to_agent(teacher_agent, student_agent, history_obses,
                         history_obses_acs, fix_batch_num, num_train_batch):
        if not args.cpo_distill_pi_only:
            qf1_distill_losses = train_bc_batch(
                target_model=teacher_agent.q_func1,
                model=student_agent.q_func1,
                loss_fn=F.mean_squared_error,
                train_inputs=history_obses_acs,
                batch_size=args.cpo_distill_batch_size,
                lr=args.cpo_distill_lr,
                n_epochs=args.cpo_distill_epochs,
                predict_fn=lambda m, x: m(x[:, :obs_size], x[:, obs_size:]),
                fix_batch_num=fix_batch_num,
                num_batch=num_train_batch)
            logger.info('Qf1 distill min loss: {}'.format(
                np.min(qf1_distill_losses)))

            qf2_distill_losses = train_bc_batch(
                target_model=teacher_agent.q_func2,
                model=student_agent.q_func2,
                loss_fn=F.mean_squared_error,
                train_inputs=history_obses_acs,
                batch_size=args.cpo_distill_batch_size,
                lr=args.cpo_distill_lr,
                n_epochs=args.cpo_distill_epochs,
                predict_fn=lambda m, x: m(x[:, :obs_size], x[:, obs_size:]),
                fix_batch_num=fix_batch_num,
                num_batch=num_train_batch)
            logger.info('Qf2 distill min losses: {}'.format(
                np.min(qf2_distill_losses)))
        else:
            qf1_distill_losses = qf2_distill_losses = None

        if not args.cpo_distill_q_only:

            def rlbc_loss(inputs, pred, targ):
                bc_loss = F.mean(targ.kl(pred))

                batch_state = inputs

                action_distrib = pred
                actions, log_prob = pred.sample_with_log_prob()
                q1 = teacher_agent.q_func1(batch_state, actions)
                q2 = teacher_agent.q_func2(batch_state, actions)
                q = F.minimum(q1, q2)

                entropy_term = student_agent.temperature * log_prob[..., None]
                assert q.shape == entropy_term.shape
                rl_loss = F.mean(entropy_term - q)

                return (args.cpo_distill_bc_coef) * bc_loss + (
                    1.0 - args.cpo_distill_bc_coef) * rl_loss

            if args.cpo_distill_type == 'rlbc':
                logger.info('Use RL+BC')
                pi_distill_losses = train_bc_batch(
                    target_model=teacher_agent.policy,
                    model=student_agent.policy,
                    loss_fn=rlbc_loss,
                    train_inputs=history_obses,
                    batch_size=args.cpo_distill_batch_size,
                    lr=args.cpo_distill_lr,
                    n_epochs=args.cpo_distill_epochs,
                    with_inputs=True,
                    fix_batch_num=fix_batch_num,
                    num_batch=num_train_batch)
            elif args.cpo_distill_type == 'bc':
                logger.info('Use BC')
                pi_distill_losses = train_bc_batch(
                    target_model=teacher_agent.policy,
                    model=student_agent.policy,
                    loss_fn=lambda pred, targ: F.mean(targ.kl(pred)),
                    train_inputs=history_obses,
                    batch_size=args.cpo_distill_batch_size,
                    lr=args.cpo_distill_lr,
                    n_epochs=args.cpo_distill_epochs,
                    fix_batch_num=fix_batch_num,
                    num_batch=num_train_batch)
            logger.info('Pi distill min losses: {}'.format(
                np.min(pi_distill_losses)))
        else:
            pi_distill_losses = None

        return qf1_distill_losses, qf2_distill_losses, pi_distill_losses

    def extra_update(agent, all_experiences):
        for epoch in range(args.cpo_distill_epochs):
            n_samples = len(all_experiences)
            indices = np.asarray(range(n_samples))
            np.random.shuffle(indices)
            for start_idx in (range(0, n_samples,
                                    args.cpo_distill_batch_size)):
                batch_idx = indices[start_idx:start_idx +
                                    args.cpo_distill_batch_size].astype(
                                        np.int32)
                experiences = [all_experiences[idx] for idx in batch_idx]
                batch = batch_experiences(experiences, agent.xp, agent.phi,
                                          agent.gamma)
                agent.update_policy_and_temperature(batch)

    def distill_callback(t, all_recent_returns, all_eval_returns, end):
        if t > args.replay_start_size:
            if args.cpo_distill_perf_metric == 'eval':
                all_mean_returns = np.asarray(all_eval_returns)
            else:
                all_mean_returns = np.asarray([
                    np.mean(recent_returns)
                    for recent_returns in all_recent_returns
                ])
            temp = args.cpo_temp
            all_weights = softmax(all_mean_returns / temp)

            if args.cpo_distill_schedule == 'fix':
                require_distill = t % args.cpo_distill_interval == 0
            else:
                if end:
                    logger.info('Prob.: {}'.format(all_weights))
                require_distill = np.max(
                    all_weights) >= args.cpo_distill_ada_threshold

            if require_distill:
                if args.cpo_test_extra_update:
                    #assert len(all_agents) == 1
                    all_experiences = [e for e in rbuf.memory]
                    #agent = all_agents[0]
                    for agent in all_agents:
                        extra_update(agent, all_experiences)
                        logger.info('Did extra update')
                else:
                    history_obses = np.asarray(
                        [e[0]['state'] for e in rbuf.memory])
                    history_acs = np.asarray(
                        [e[0]['action'] for e in rbuf.memory])
                    history_obses_acs = np.concatenate(
                        [history_obses, history_acs], axis=1)

                    if args.cpo_random_distill:
                        best_agent_idx = np.random.randint(
                            len(all_mean_returns))
                        logger.info('Random distill')
                    elif args.cpo_bad_distill:
                        best_agent_idx = np.argmin(all_mean_returns)
                        logger.info('Bad distill')
                    else:
                        logger.info('Best distill')
                        best_agent_idx = np.argmax(all_mean_returns)
                    best_agent = all_agents[best_agent_idx]

                    for idx, other_agent in enumerate(all_agents):
                        if idx != best_agent_idx:
                            if args.cpo_use_hardcopy:
                                copy_model_params(best_agent.q_func1,
                                                  other_agent.q_func1)
                                copy_model_params(best_agent.q_func2,
                                                  other_agent.q_func2)
                                copy_model_params(best_agent.policy,
                                                  other_agent.policy)
                                logger.info('Copy done')
                            else:
                                if args.cpo_distill_reset_model:
                                    logger.info('Reset model')
                                    reset_model_params(other_agent.q_func1)
                                    reset_model_params(other_agent.q_func2)
                                    reset_model_params(other_agent.policy)
                                logger.info('Distill to {} from {}'.format(
                                    idx, best_agent_idx))
                                qf1_losses, qf2_losses, pi_losses = distill_to_agent(
                                    best_agent,
                                    other_agent,
                                    history_obses,
                                    history_obses_acs,
                                    fix_batch_num=True if args.cpo_train_sample
                                    == 'sample' else False,
                                    num_train_batch=args.cpo_num_train_batch)
                                other_agent.sync_target_network()

                            if args.cpo_distill_reset_returns:
                                logger.info('Reset returns')
                                all_recent_returns[idx] = copy.copy(
                                    all_recent_returns[best_agent_idx])
                return all_weights
        return None

    custom_train_agent_batch.parallel_train_agent_batch_with_evaluation(
        start_weighted_size=args.replay_start_size,
        all_agents=all_agents,
        env=env,
        eval_env=eval_env,
        outdir=args.outdir,
        steps=args.steps,
        eval_n_steps=None,
        eval_n_episodes=args.eval_n_runs,
        eval_interval=args.eval_interval,
        log_interval=args.log_interval,
        max_episode_len=timestep_limit,
        save_best_so_far_agent=True,
        schedule_args={
            'select_prob_temp': args.cpo_temp,
            'select_prob_update': args.cpo_select_prob_update,
            'select_algo': args.cpo_select_algo,
            'eps_schedule': parse_eps_schedule(args.cpo_eps_schedule),
            'perf_metric': args.cpo_distill_perf_metric
        },
        step_callback=distill_callback,
        return_window_size=args.cpo_recent_returns_interval,
        eval_before_distill=args.cpo_eval_before_distill)

    env.close()
    eval_env.close()
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--outdir',
                        type=str,
                        default='results',
                        help='Directory path to save output files.'
                        ' If it does not exist, it will be created.')
    parser.add_argument('--env',
                        type=str,
                        default='RoboschoolAtlasForwardWalk-v1',
                        help='OpenAI Gym env to perform algorithm on.')
    parser.add_argument('--num-envs',
                        type=int,
                        default=4,
                        help='Number of envs run in parallel.')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help='Random seed [0, 2 ** 32)')
    parser.add_argument('--gpu',
                        type=int,
                        default=0,
                        help='GPU to use, set to -1 if no GPU.')
    parser.add_argument('--load',
                        type=str,
                        default='',
                        help='Directory to load agent from.')
    parser.add_argument('--steps',
                        type=int,
                        default=10**7,
                        help='Total number of timesteps to train the agent.')
    parser.add_argument('--eval-n-runs',
                        type=int,
                        default=20,
                        help='Number of episodes run for each evaluation.')
    parser.add_argument('--eval-interval',
                        type=int,
                        default=100000,
                        help='Interval in timesteps between evaluations.')
    parser.add_argument('--replay-start-size',
                        type=int,
                        default=10000,
                        help='Minimum replay buffer size before ' +
                        'performing gradient updates.')
    parser.add_argument('--update-interval',
                        type=int,
                        default=1,
                        help='Interval in timesteps between model updates.')
    parser.add_argument('--batch-size',
                        type=int,
                        default=256,
                        help='Minibatch size')
    parser.add_argument('--render',
                        action='store_true',
                        help='Render env states in a GUI window.')
    parser.add_argument('--demo',
                        action='store_true',
                        help='Just run evaluation, not training.')
    parser.add_argument('--monitor',
                        action='store_true',
                        help='Wrap env with Monitor to write videos.')
    parser.add_argument('--log-interval',
                        type=int,
                        default=1000,
                        help='Interval in timesteps between outputting log'
                        ' messages during training')
    parser.add_argument('--logger-level',
                        type=int,
                        default=logging.INFO,
                        help='Level of the root logger.')
    parser.add_argument('--n-hidden-channels',
                        type=int,
                        default=1024,
                        help='Number of hidden channels of NN models.')
    parser.add_argument('--discount',
                        type=float,
                        default=0.98,
                        help='Discount factor.')
    parser.add_argument('--n-step-return',
                        type=int,
                        default=3,
                        help='N-step return.')
    parser.add_argument('--lr',
                        type=float,
                        default=3e-4,
                        help='Learning rate.')
    parser.add_argument('--adam-eps',
                        type=float,
                        default=1e-1,
                        help='Adam eps.')
    args = parser.parse_args()

    logging.basicConfig(level=args.logger_level)

    args.outdir = experiments.prepare_output_dir(args,
                                                 args.outdir,
                                                 argv=sys.argv)
    print('Output files are saved in {}'.format(args.outdir))

    # Set a random seed used in ChainerRL
    misc.set_random_seed(args.seed, gpus=(args.gpu, ))

    # Set different random seeds for different subprocesses.
    # If seed=0 and processes=4, subprocess seeds are [0, 1, 2, 3].
    # If seed=1 and processes=4, subprocess seeds are [4, 5, 6, 7].
    process_seeds = np.arange(args.num_envs) + args.seed * args.num_envs
    assert process_seeds.max() < 2**32

    def make_batch_env(test):
        return chainerrl.envs.MultiprocessVectorEnv([
            functools.partial(make_env, args, process_seeds[idx], test)
            for idx, env in enumerate(range(args.num_envs))
        ])

    sample_env = make_env(args, process_seeds[0], test=False)
    timestep_limit = sample_env.spec.max_episode_steps
    obs_space = sample_env.observation_space
    action_space = sample_env.action_space
    print('Observation space:', obs_space)
    print('Action space:', action_space)
    del sample_env

    action_size = action_space.low.size

    winit = chainer.initializers.GlorotUniform()
    winit_policy_output = chainer.initializers.GlorotUniform()

    def squashed_diagonal_gaussian_head(x):
        assert x.shape[-1] == action_size * 2
        mean, log_scale = F.split_axis(x, 2, axis=1)
        log_scale = F.clip(log_scale, -20., 2.)
        var = F.exp(log_scale * 2)
        return chainerrl.distribution.SquashedGaussianDistribution(mean,
                                                                   var=var)

    policy = chainer.Sequential(
        L.Linear(None, args.n_hidden_channels, initialW=winit),
        F.relu,
        L.Linear(None, args.n_hidden_channels, initialW=winit),
        F.relu,
        L.Linear(None, action_size * 2, initialW=winit_policy_output),
        squashed_diagonal_gaussian_head,
    )
    policy_optimizer = optimizers.Adam(args.lr,
                                       eps=args.adam_eps).setup(policy)

    def make_q_func_with_optimizer():
        q_func = chainer.Sequential(
            concat_obs_and_action,
            L.Linear(None, args.n_hidden_channels, initialW=winit),
            F.relu,
            L.Linear(None, args.n_hidden_channels, initialW=winit),
            F.relu,
            L.Linear(None, 1, initialW=winit),
        )
        q_func_optimizer = optimizers.Adam(args.lr,
                                           eps=args.adam_eps).setup(q_func)
        return q_func, q_func_optimizer

    q_func1, q_func1_optimizer = make_q_func_with_optimizer()
    q_func2, q_func2_optimizer = make_q_func_with_optimizer()

    # Draw the computational graph and save it in the output directory.
    fake_obs = chainer.Variable(policy.xp.zeros_like(obs_space.low,
                                                     dtype=np.float32)[None],
                                name='observation')
    fake_action = chainer.Variable(policy.xp.zeros_like(
        action_space.low, dtype=np.float32)[None],
                                   name='action')
    chainerrl.misc.draw_computational_graph([policy(fake_obs)],
                                            os.path.join(
                                                args.outdir, 'policy'))
    chainerrl.misc.draw_computational_graph([q_func1(fake_obs, fake_action)],
                                            os.path.join(
                                                args.outdir, 'q_func1'))
    chainerrl.misc.draw_computational_graph([q_func2(fake_obs, fake_action)],
                                            os.path.join(
                                                args.outdir, 'q_func2'))

    rbuf = replay_buffer.ReplayBuffer(10**6, num_steps=args.n_step_return)

    def burnin_action_func():
        """Select random actions until model is updated one or more times."""
        return np.random.uniform(action_space.low,
                                 action_space.high).astype(np.float32)

    # Hyperparameters in http://arxiv.org/abs/1802.09477
    agent = chainerrl.agents.SoftActorCritic(
        policy,
        q_func1,
        q_func2,
        policy_optimizer,
        q_func1_optimizer,
        q_func2_optimizer,
        rbuf,
        gamma=args.discount,
        update_interval=args.update_interval,
        replay_start_size=args.replay_start_size,
        gpu=args.gpu,
        minibatch_size=args.batch_size,
        burnin_action_func=burnin_action_func,
        entropy_target=-action_size,
        temperature_optimizer=chainer.optimizers.Adam(args.lr,
                                                      eps=args.adam_eps),
    )

    if len(args.load) > 0:
        agent.load(args.load)

    if args.demo:
        eval_env = make_env(args, seed=0, test=True)
        eval_stats = experiments.eval_performance(
            env=eval_env,
            agent=agent,
            n_steps=None,
            n_episodes=args.eval_n_runs,
            max_episode_len=timestep_limit,
        )
        print('n_runs: {} mean: {} median: {} stdev {}'.format(
            args.eval_n_runs, eval_stats['mean'], eval_stats['median'],
            eval_stats['stdev']))
    else:
        experiments.train_agent_batch_with_evaluation(
            agent=agent,
            env=make_batch_env(test=False),
            eval_env=make_batch_env(test=True),
            outdir=args.outdir,
            steps=args.steps,
            eval_n_steps=None,
            eval_n_episodes=args.eval_n_runs,
            eval_interval=args.eval_interval,
            log_interval=args.log_interval,
            max_episode_len=timestep_limit,
        )
Example #4
0
def main():
    import logging
    logging.basicConfig(level=logging.DEBUG)

    parser = argparse.ArgumentParser()
    parser.add_argument('--outdir',
                        type=str,
                        default='results',
                        help='Directory path to save output files.'
                        ' If it does not exist, it will be created.')
    parser.add_argument('--env', type=str, default='Humanoid-v1')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help='Random seed [0, 2 ** 32)')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--final-exploration-steps', type=int, default=10**6)
    parser.add_argument('--actor-lr', type=float, default=1e-4)
    parser.add_argument('--critic-lr', type=float, default=1e-3)
    parser.add_argument('--load', type=str, default='')
    parser.add_argument('--steps', type=int, default=10**7)
    parser.add_argument('--n-hidden-channels', type=int, default=300)
    parser.add_argument('--n-hidden-layers', type=int, default=3)
    parser.add_argument('--replay-start-size', type=int, default=5000)
    parser.add_argument('--n-update-times', type=int, default=1)
    parser.add_argument('--target-update-interval', type=int, default=1)
    parser.add_argument('--target-update-method',
                        type=str,
                        default='soft',
                        choices=['hard', 'soft'])
    parser.add_argument('--soft-update-tau', type=float, default=1e-2)
    parser.add_argument('--update-interval', type=int, default=4)
    parser.add_argument('--eval-n-runs', type=int, default=100)
    parser.add_argument('--eval-interval', type=int, default=10**5)
    parser.add_argument('--gamma', type=float, default=0.995)
    parser.add_argument('--minibatch-size', type=int, default=200)
    parser.add_argument('--render', action='store_true')
    parser.add_argument('--demo', action='store_true')
    parser.add_argument('--use-bn', action='store_true', default=False)
    parser.add_argument('--monitor', action='store_true')
    parser.add_argument('--reward-scale-factor', type=float, default=1e-2)
    args = parser.parse_args()

    args.outdir = experiments.prepare_output_dir(args,
                                                 args.outdir,
                                                 argv=sys.argv)
    print('Output files are saved in {}'.format(args.outdir))

    # Set a random seed used in ChainerRL
    misc.set_random_seed(args.seed, gpus=(args.gpu, ))

    def clip_action_filter(a):
        return np.clip(a, action_space.low, action_space.high)

    def reward_filter(r):
        return r * args.reward_scale_factor

    def make_env(test):
        env = gym.make(args.env)
        # Use different random seeds for train and test envs
        env_seed = 2**32 - 1 - args.seed if test else args.seed
        env.seed(env_seed)
        # Cast observations to float32 because our model uses float32
        env = chainerrl.wrappers.CastObservationToFloat32(env)
        if args.monitor:
            env = gym.wrappers.Monitor(env, args.outdir)
        if isinstance(env.action_space, spaces.Box):
            misc.env_modifiers.make_action_filtered(env, clip_action_filter)
        if not test:
            # Scale rewards (and thus returns) to a reasonable range so that
            # training is easier
            env = chainerrl.wrappers.ScaleReward(env, args.reward_scale_factor)
        if args.render and not test:
            misc.env_modifiers.make_rendered(env)
        return env

    env = make_env(test=False)
    timestep_limit = env.spec.tags.get(
        'wrapper_config.TimeLimit.max_episode_steps')
    obs_size = np.asarray(env.observation_space.shape).prod()
    action_space = env.action_space

    action_size = np.asarray(action_space.shape).prod()
    if args.use_bn:
        q_func = q_functions.FCBNLateActionSAQFunction(
            obs_size,
            action_size,
            n_hidden_channels=args.n_hidden_channels,
            n_hidden_layers=args.n_hidden_layers,
            normalize_input=True)
        pi = policy.FCBNDeterministicPolicy(
            obs_size,
            action_size=action_size,
            n_hidden_channels=args.n_hidden_channels,
            n_hidden_layers=args.n_hidden_layers,
            min_action=action_space.low,
            max_action=action_space.high,
            bound_action=True,
            normalize_input=True)
    else:
        q_func = q_functions.FCSAQFunction(
            obs_size,
            action_size,
            n_hidden_channels=args.n_hidden_channels,
            n_hidden_layers=args.n_hidden_layers)
        pi = policy.FCDeterministicPolicy(
            obs_size,
            action_size=action_size,
            n_hidden_channels=args.n_hidden_channels,
            n_hidden_layers=args.n_hidden_layers,
            min_action=action_space.low,
            max_action=action_space.high,
            bound_action=True)
    model = DDPGModel(q_func=q_func, policy=pi)
    opt_a = optimizers.Adam(alpha=args.actor_lr)
    opt_c = optimizers.Adam(alpha=args.critic_lr)
    opt_a.setup(model['policy'])
    opt_c.setup(model['q_function'])
    opt_a.add_hook(chainer.optimizer.GradientClipping(1.0), 'hook_a')
    opt_c.add_hook(chainer.optimizer.GradientClipping(1.0), 'hook_c')

    rbuf = replay_buffer.ReplayBuffer(5 * 10**5)

    def random_action():
        a = action_space.sample()
        if isinstance(a, np.ndarray):
            a = a.astype(np.float32)
        return a

    ou_sigma = (action_space.high - action_space.low) * 0.2
    explorer = explorers.AdditiveOU(sigma=ou_sigma)
    agent = DDPG(model,
                 opt_a,
                 opt_c,
                 rbuf,
                 gamma=args.gamma,
                 explorer=explorer,
                 replay_start_size=args.replay_start_size,
                 target_update_method=args.target_update_method,
                 target_update_interval=args.target_update_interval,
                 update_interval=args.update_interval,
                 soft_update_tau=args.soft_update_tau,
                 n_times_update=args.n_update_times,
                 gpu=args.gpu,
                 minibatch_size=args.minibatch_size)

    if len(args.load) > 0:
        agent.load(args.load)

    eval_env = make_env(test=True)
    if args.demo:
        eval_stats = experiments.eval_performance(
            env=eval_env,
            agent=agent,
            n_runs=args.eval_n_runs,
            max_episode_len=timestep_limit)
        print('n_runs: {} mean: {} median: {} stdev {}'.format(
            args.eval_n_runs, eval_stats['mean'], eval_stats['median'],
            eval_stats['stdev']))
    else:
        experiments.train_agent_with_evaluation(
            agent=agent,
            env=env,
            steps=args.steps,
            eval_env=eval_env,
            eval_n_runs=args.eval_n_runs,
            eval_interval=args.eval_interval,
            outdir=args.outdir,
            max_episode_len=timestep_limit)
Example #5
0
    def _test_load_td3(self, gpu):
        def concat_obs_and_action(obs, action):
            """Concat observation and action to feed the critic."""
            return F.concat((obs, action), axis=-1)

        def make_q_func_with_optimizer():
            q_func = chainer.Sequential(
                concat_obs_and_action,
                L.Linear(None, 400, initialW=winit),
                F.relu,
                L.Linear(None, 300, initialW=winit),
                F.relu,
                L.Linear(None, 1, initialW=winit),
            )
            q_func_optimizer = optimizers.Adam().setup(q_func)
            return q_func, q_func_optimizer

        winit = chainer.initializers.LeCunUniform(3**-0.5)

        q_func1, q_func1_optimizer = make_q_func_with_optimizer()
        q_func2, q_func2_optimizer = make_q_func_with_optimizer()

        action_size = 3
        policy = chainer.Sequential(
            L.Linear(None, 400, initialW=winit),
            F.relu,
            L.Linear(None, 300, initialW=winit),
            F.relu,
            L.Linear(None, action_size, initialW=winit),
            F.tanh,
            chainerrl.distribution.ContinuousDeterministicDistribution,
        )

        policy_optimizer = optimizers.Adam().setup(policy)

        rbuf = replay_buffer.ReplayBuffer(100)
        explorer = explorers.AdditiveGaussian(scale=0.1,
                                              low=[-1., -1., -1.],
                                              high=[1., 1., 1.])

        agent = agents.TD3(policy,
                           q_func1,
                           q_func2,
                           policy_optimizer,
                           q_func1_optimizer,
                           q_func2_optimizer,
                           rbuf,
                           gamma=0.99,
                           soft_update_tau=5e-3,
                           explorer=explorer,
                           replay_start_size=10000,
                           gpu=gpu,
                           minibatch_size=100,
                           burnin_action_func=None)

        model, exists = download_model("TD3",
                                       "Hopper-v2",
                                       model_type=self.pretrained_type)
        agent.load(model)
        if os.environ.get('CHAINERRL_ASSERT_DOWNLOADED_MODEL_IS_CACHED'):
            assert exists
Example #6
0
    def _test_load_sac(self, gpu):

        winit = chainer.initializers.GlorotUniform()
        winit_policy_output = chainer.initializers.GlorotUniform(1.0)

        def concat_obs_and_action(obs, action):
            """Concat observation and action to feed the critic."""
            return F.concat((obs, action), axis=-1)

        def squashed_diagonal_gaussian_head(x):
            assert x.shape[-1] == 3 * 2
            mean, log_scale = F.split_axis(x, 2, axis=1)
            log_scale = F.clip(log_scale, -20., 2.)
            var = F.exp(log_scale * 2)
            return chainerrl.distribution.SquashedGaussianDistribution(mean,
                                                                       var=var)

        policy = chainer.Sequential(
            L.Linear(None, 256, initialW=winit),
            F.relu,
            L.Linear(None, 256, initialW=winit),
            F.relu,
            L.Linear(None, 3 * 2, initialW=winit_policy_output),
            squashed_diagonal_gaussian_head,
        )
        policy_optimizer = optimizers.Adam(3e-4).setup(policy)

        def make_q_func_with_optimizer():
            q_func = chainer.Sequential(
                concat_obs_and_action,
                L.Linear(None, 256, initialW=winit),
                F.relu,
                L.Linear(None, 256, initialW=winit),
                F.relu,
                L.Linear(None, 1, initialW=winit),
            )
            q_func_optimizer = optimizers.Adam(3e-4).setup(q_func)
            return q_func, q_func_optimizer

        q_func1, q_func1_optimizer = make_q_func_with_optimizer()
        q_func2, q_func2_optimizer = make_q_func_with_optimizer()

        agent = agents.SoftActorCritic(
            policy,
            q_func1,
            q_func2,
            policy_optimizer,
            q_func1_optimizer,
            q_func2_optimizer,
            replay_buffer.ReplayBuffer(100),
            gamma=0.99,
            replay_start_size=1000,
            gpu=gpu,
            minibatch_size=256,
            burnin_action_func=None,
            entropy_target=-3,
            temperature_optimizer=optimizers.Adam(3e-4),
        )

        model, exists = download_model("SAC",
                                       "Hopper-v2",
                                       model_type=self.pretrained_type)
        agent.load(model)
        if os.environ.get('CHAINERRL_ASSERT_DOWNLOADED_MODEL_IS_CACHED'):
            assert exists
Example #7
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--env',
                        type=str,
                        default='Inv',
                        help='OpenAI Gym MuJoCo env to perform algorithm on.')
    parser.add_argument(
        '--outdir',
        type=str,
        default='results',
        help=
        'Directory path to save output files. it will be created if not existent.'
    )
    parser.add_argument('--seed',
                        type=int,
                        default=420,
                        help='Random seed [0, 2 ** 32)')
    parser.add_argument('--gpu',
                        type=int,
                        default=-1,
                        help='GPU to use, set to -1 if no GPU.')
    parser.add_argument('--load',
                        type=str,
                        default='',
                        help='Directory to load agent from.')
    parser.add_argument('--steps',
                        type=int,
                        default=10**5,
                        help='Total number of timesteps to train the agent.')
    parser.add_argument('--eval-n-runs',
                        type=int,
                        default=10,
                        help='Number of episodes run for each evaluation.')
    parser.add_argument('--eval-interval',
                        type=int,
                        default=100,
                        help='Interval in timesteps between evaluations.')
    parser.add_argument(
        '--replay-start-size',
        type=int,
        default=1000,
        help='Minimum replay buffer size before performing gradient updates.')
    parser.add_argument('--batch-size',
                        type=int,
                        default=4,
                        help='Minibatch size')
    parser.add_argument('--logger-level',
                        type=int,
                        default=logging.INFO,
                        help='Level of the root logger.')
    parser.add_argument('--render',
                        action='store_true',
                        help='Render env states in a GUI window.')
    parser.add_argument('--demo',
                        action='store_true',
                        help='Just run evaluation, not training.')
    parser.add_argument('--monitor',
                        action='store_true',
                        help='Wrap env with gym.wrappers.Monitor.')
    args = parser.parse_args()

    logging.basicConfig(level=args.logger_level)

    args.outdir = experiments.prepare_output_dir(args,
                                                 args.outdir,
                                                 argv=sys.argv)
    print('Output files are saved in {}'.format(args.outdir))

    # Set a random seed used in ChainerRL
    misc.set_random_seed(args.seed, gpus=(args.gpu, ))

    def make_env(test):

        env = gym.make(
            "DaktyPushingSimulationEnv-v0",
            level=5,
            simulation_backend="mujoco",
            control_frequency_in_hertz=100,
            state_space_components_to_be_used=None,
            alternate_env_object=None,
            discretization_factor_torque_control_space=None,
            model_as_function_for_pixel_to_latent_space_parsing=(None, None))

        print('\n############\n', env, '\n############\n')

        env.unwrapped.finger.set_resolution_quality('low')

        print('\n############\n', env, '\n############\n')

        env = gym.wrappers.TimeLimit(env)

        print('\n############\n', env, '\n############\n')

        # Unwrap TimeLimit wrapper
        assert isinstance(env, gym.wrappers.TimeLimit)
        env = env.env

        # Use different random seeds for train and test envs
        env_seed = 2**32 - 1 - args.seed if test else args.seed
        env.seed(env_seed)
        # Cast observations to float32 because our model uses float32
        env = chainerrl.wrappers.CastObservationToFloat32(env)
        if args.monitor:
            env = chainerrl.wrappers.Monitor(env, args.outdir)
        if args.render and not test:
            env = chainerrl.wrappers.Render(env)
        return env

    env = make_env(test=False)
    timestep_limit = env.spec.tags.get(
        'wrapper_config.TimeLimit.max_episode_steps')
    obs_space = env.observation_space
    action_space = env.action_space
    print('Observation space:', obs_space)
    print('Action space:', action_space)

    action_size = action_space.low.size

    winit = chainer.initializers.LeCunUniform(3**-0.5)
    '''
    define policy and optimiser
    output_dim = action_size
    '''
    policy = chainer.Sequential(
        L.Linear(None, 128, initialW=winit),
        F.relu,
        L.Linear(None, 64, initialW=winit),
        F.relu,
        L.Linear(None, action_size, initialW=winit),
        F.tanh,
        chainerrl.distribution.ContinuousDeterministicDistribution,
    )
    policy_optimizer = optimizers.Adam(3e-4).setup(policy)

    # policy.to_gpu(0)
    '''
    define q-function and optimiser
    output_dim = 1
    defines 2 identical q_functions with resp. optimisers
    '''

    def make_q_func_with_optimizer():
        q_func = chainer.Sequential(
            concat_obs_and_action,
            L.Linear(None, 128, initialW=winit),
            F.relu,
            L.Linear(None, 64, initialW=winit),
            F.relu,
            L.Linear(None, 1, initialW=winit),
        )
        q_func_optimizer = optimizers.Adam().setup(q_func)
        return q_func, q_func_optimizer

    q_func1, q_func1_optimizer = make_q_func_with_optimizer()
    q_func2, q_func2_optimizer = make_q_func_with_optimizer()

    # q_func1.to_gpu(0)
    # q_func2.to_gpu(0)

    print('\n\n-------------------\n', obs_space.low.shape,
          '\n-------------------\n')

    # Draw the computational graph and save it in the output directory.
    fake_obs = chainer.Variable(policy.xp.zeros_like(obs_space.low,
                                                     dtype=np.float32)[None],
                                name='observation')
    fake_action = chainer.Variable(policy.xp.zeros_like(
        action_space.low, dtype=np.float32)[None],
                                   name='action')
    chainerrl.misc.draw_computational_graph([policy(fake_obs)],
                                            os.path.join(
                                                args.outdir, 'policy'))
    chainerrl.misc.draw_computational_graph([q_func1(fake_obs, fake_action)],
                                            os.path.join(
                                                args.outdir, 'q_func1'))
    chainerrl.misc.draw_computational_graph([q_func2(fake_obs, fake_action)],
                                            os.path.join(
                                                args.outdir, 'q_func2'))

    rbuf = replay_buffer.ReplayBuffer(10**5)

    explorer = explorers.AdditiveGaussian(scale=0.1,
                                          low=action_space.low,
                                          high=action_space.high)

    def burnin_action_func():
        """Select random actions until model is updated one or more times."""
        return np.random.uniform(action_space.low,
                                 action_space.high).astype(np.float32)

    # Hyperparameters in http://arxiv.org/abs/1802.09477
    agent = chainerrl.agents.TD3(
        policy,
        q_func1,
        q_func2,
        policy_optimizer,
        q_func1_optimizer,
        q_func2_optimizer,
        rbuf,
        gamma=0.99,
        soft_update_tau=5e-3,
        explorer=explorer,
        replay_start_size=args.replay_start_size,
        gpu=args.gpu,
        minibatch_size=args.batch_size,
        burnin_action_func=burnin_action_func,
    )

    # agent.to_gpu(0)

    if len(args.load) > 0:
        agent.load(args.load)

    sys.stdout.flush()

    print('\nbeginning training\n')

    n_episodes = 10000

    # pbar = tqdm(total=n_episodes)

    max_episode_len = 5000
    for i in range(1, n_episodes + 1):

        # pbar.update(1)

        obs = env.reset()
        # print('obs inital..............', obs.shape)
        reward = 0
        done = False
        R = 0  # return (sum of rewards)
        t = 0  # time step

        pbar = tqdm(total=max_episode_len)

        while not done and t < max_episode_len:

            pbar.update(1)

            # Uncomment to watch the behaviour
            # env.render()
            action = agent.act_and_train(obs, reward)
            # print('action..................', action)

            obs, reward, done, _ = env.step(action)
            # print('obs.....................', obs)
            # print('reward..................', reward)

            R += reward
            t += 1

        if i % 1 == 0:
            print('episode:', i, 'R:', R, 'statistics:',
                  agent.get_statistics())
        agent.stop_episode_and_train(obs, reward, done)
    print('Finished.')
Example #8
0
    def _test_load_ddpg(self, gpu):
        def concat_obs_and_action(obs, action):
            return F.concat((obs, action), axis=-1)

        action_size = 3
        winit = chainer.initializers.LeCunUniform(3**-0.5)
        q_func = chainer.Sequential(
            concat_obs_and_action,
            L.Linear(None, 400, initialW=winit),
            F.relu,
            L.Linear(None, 300, initialW=winit),
            F.relu,
            L.Linear(None, 1, initialW=winit),
        )
        policy = chainer.Sequential(
            L.Linear(None, 400, initialW=winit),
            F.relu,
            L.Linear(None, 300, initialW=winit),
            F.relu,
            L.Linear(None, action_size, initialW=winit),
            F.tanh,
            chainerrl.distribution.ContinuousDeterministicDistribution,
        )
        from chainerrl.agents.ddpg import DDPGModel
        model = DDPGModel(q_func=q_func, policy=policy)

        obs_low = [-np.inf] * 11
        fake_obs = chainer.Variable(model.xp.zeros_like(
            obs_low, dtype=np.float32)[None],
                                    name='observation')
        fake_action = chainer.Variable(model.xp.zeros_like(
            [-1., -1., -1.], dtype=np.float32)[None],
                                       name='action')
        policy(fake_obs)
        q_func(fake_obs, fake_action)

        opt_a = optimizers.Adam()
        opt_c = optimizers.Adam()
        opt_a.setup(model['policy'])
        opt_c.setup(model['q_function'])

        explorer = explorers.AdditiveGaussian(scale=0.1,
                                              low=[-1., -1., -1.],
                                              high=[1., 1., 1.])

        agent = agents.DDPG(model,
                            opt_a,
                            opt_c,
                            replay_buffer.ReplayBuffer(100),
                            gamma=0.99,
                            explorer=explorer,
                            replay_start_size=1000,
                            target_update_method='soft',
                            target_update_interval=1,
                            update_interval=1,
                            soft_update_tau=5e-3,
                            n_times_update=1,
                            gpu=gpu,
                            minibatch_size=100,
                            burnin_action_func=None)

        model, exists = download_model("DDPG",
                                       "Hopper-v2",
                                       model_type=self.pretrained_type)
        agent.load(model)
        if os.environ.get('CHAINERRL_ASSERT_DOWNLOADED_MODEL_IS_CACHED'):
            assert exists
Example #9
0
 def make_replay_buffer(self, env):
     return replay_buffer.ReplayBuffer(10 ** 5)
def main():
    import logging
    logging.basicConfig(level=logging.DEBUG)

    parser = argparse.ArgumentParser()
    parser.add_argument('--outdir',
                        type=str,
                        default='results',
                        help='Directory path to save output files.'
                        ' If it does not exist, it will be created.')
    parser.add_argument('--env', type=str, default='CartPole-v1')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--final-exploration-steps', type=int, default=1000)
    parser.add_argument('--start-epsilon', type=float, default=1.0)
    parser.add_argument('--end-epsilon', type=float, default=0.1)
    parser.add_argument('--demo', action='store_true', default=False)
    parser.add_argument('--load', type=str, default=None)
    parser.add_argument('--steps', type=int, default=10**8)
    parser.add_argument('--prioritized-replay', action='store_true')
    parser.add_argument('--episodic-replay', action='store_true')
    parser.add_argument('--replay-start-size', type=int, default=50)
    parser.add_argument('--target-update-interval', type=int, default=100)
    parser.add_argument('--target-update-method', type=str, default='hard')
    parser.add_argument('--soft-update-tau', type=float, default=1e-2)
    parser.add_argument('--update-interval', type=int, default=1)
    parser.add_argument('--eval-n-runs', type=int, default=100)
    parser.add_argument('--eval-interval', type=int, default=1000)
    parser.add_argument('--n-hidden-channels', type=int, default=12)
    parser.add_argument('--n-hidden-layers', type=int, default=3)
    parser.add_argument('--gamma', type=float, default=0.95)
    parser.add_argument('--minibatch-size', type=int, default=None)
    parser.add_argument('--render-train', action='store_true')
    parser.add_argument('--render-eval', action='store_true')
    parser.add_argument('--monitor', action='store_true')
    parser.add_argument('--reward-scale-factor', type=float, default=1.0)
    args = parser.parse_args()

    # Set a random seed used in ChainerRL
    misc.set_random_seed(args.seed, gpus=(args.gpu, ))

    args.outdir = experiments.prepare_output_dir(args,
                                                 args.outdir,
                                                 argv=sys.argv)
    print('Output files are saved in {}'.format(args.outdir))

    def make_env(test):
        env = gym.make(args.env)
        env_seed = 2**32 - 1 - args.seed if test else args.seed
        env.seed(env_seed)
        # Cast observations to float32 because our model uses float32
        env = chainerrl.wrappers.CastObservationToFloat32(env)
        if args.monitor:
            env = chainerrl.wrappers.Monitor(env, args.outdir)
        if not test:
            # Scale rewards (and thus returns) to a reasonable range so that
            # training is easier
            env = chainerrl.wrappers.ScaleReward(env, args.reward_scale_factor)
        if ((args.render_eval and test) or (args.render_train and not test)):
            env = chainerrl.wrappers.Render(env)
        return env

    env = make_env(test=False)
    timestep_limit = env.spec.tags.get(
        'wrapper_config.TimeLimit.max_episode_steps')
    obs_size = env.observation_space.low.size
    action_space = env.action_space

    n_atoms = 51
    v_max = 500
    v_min = 0

    n_actions = action_space.n
    q_func = q_functions.DistributionalFCStateQFunctionWithDiscreteAction(
        obs_size,
        n_actions,
        n_atoms,
        v_min,
        v_max,
        n_hidden_channels=args.n_hidden_channels,
        n_hidden_layers=args.n_hidden_layers)
    # Use epsilon-greedy for exploration
    explorer = explorers.LinearDecayEpsilonGreedy(args.start_epsilon,
                                                  args.end_epsilon,
                                                  args.final_exploration_steps,
                                                  action_space.sample)

    opt = optimizers.Adam(1e-3)
    opt.setup(q_func)

    rbuf_capacity = 50000  # 5 * 10 ** 5
    if args.episodic_replay:
        if args.minibatch_size is None:
            args.minibatch_size = 4
        if args.prioritized_replay:
            betasteps = (args.steps - args.replay_start_size) \
                // args.update_interval
            rbuf = replay_buffer.PrioritizedEpisodicReplayBuffer(
                rbuf_capacity, betasteps=betasteps)
        else:
            rbuf = replay_buffer.EpisodicReplayBuffer(rbuf_capacity)
    else:
        if args.minibatch_size is None:
            args.minibatch_size = 32
        if args.prioritized_replay:
            betasteps = (args.steps - args.replay_start_size) \
                // args.update_interval
            rbuf = replay_buffer.PrioritizedReplayBuffer(rbuf_capacity,
                                                         betasteps=betasteps)
        else:
            rbuf = replay_buffer.ReplayBuffer(rbuf_capacity)

    agent = chainerrl.agents.CategoricalDQN(
        q_func,
        opt,
        rbuf,
        gpu=args.gpu,
        gamma=args.gamma,
        explorer=explorer,
        replay_start_size=args.replay_start_size,
        target_update_interval=args.target_update_interval,
        update_interval=args.update_interval,
        minibatch_size=args.minibatch_size,
        target_update_method=args.target_update_method,
        soft_update_tau=args.soft_update_tau,
        episodic_update=args.episodic_replay,
        episodic_update_len=16)

    if args.load:
        agent.load(args.load)

    eval_env = make_env(test=True)

    if args.demo:
        eval_stats = experiments.eval_performance(
            env=eval_env,
            agent=agent,
            n_steps=None,
            n_episodes=args.eval_n_runs,
            max_episode_len=timestep_limit)
        print('n_runs: {} mean: {} median: {} stdev {}'.format(
            args.eval_n_runs, eval_stats['mean'], eval_stats['median'],
            eval_stats['stdev']))
    else:
        experiments.train_agent_with_evaluation(
            agent=agent,
            env=env,
            steps=args.steps,
            eval_n_steps=None,
            eval_n_episodes=args.eval_n_runs,
            eval_interval=args.eval_interval,
            outdir=args.outdir,
            eval_env=eval_env,
            train_max_episode_len=timestep_limit)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='BreakoutNoFrameskip-v4')
    parser.add_argument('--outdir',
                        type=str,
                        default='results',
                        help='Directory path to save output files.'
                        ' If it does not exist, it will be created.')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help='Random seed [0, 2 ** 31)')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--demo', action='store_true', default=False)
    parser.add_argument('--load', type=str, default=None)
    parser.add_argument('--use-sdl', action='store_true', default=False)
    parser.add_argument('--final-exploration-frames', type=int, default=10**6)
    parser.add_argument('--final-epsilon', type=float, default=0.1)
    parser.add_argument('--eval-epsilon', type=float, default=0.05)
    parser.add_argument('--steps', type=int, default=10**7)
    parser.add_argument(
        '--max-episode-len',
        type=int,
        default=5 * 60 * 60 // 4,  # 5 minutes with 60/4 fps
        help='Maximum number of steps for each episode.')
    parser.add_argument('--replay-start-size', type=int, default=5 * 10**4)
    parser.add_argument('--target-update-interval', type=int, default=10**4)
    parser.add_argument('--eval-interval', type=int, default=10**5)
    parser.add_argument('--update-interval', type=int, default=4)
    parser.add_argument('--eval-n-runs', type=int, default=10)
    parser.add_argument('--batch-size', type=int, default=32)
    parser.add_argument('--logging-level',
                        type=int,
                        default=20,
                        help='Logging level. 10:DEBUG, 20:INFO etc.')
    parser.add_argument('--render',
                        action='store_true',
                        default=False,
                        help='Render env states in a GUI window.')
    parser.add_argument('--monitor',
                        action='store_true',
                        default=False,
                        help='Monitor env. Videos and additional information'
                        ' are saved as output files.')
    args = parser.parse_args()

    import logging
    logging.basicConfig(level=args.logging_level)

    # Set a random seed used in ChainerRL.
    misc.set_random_seed(args.seed, gpus=(args.gpu, ))

    # Set different random seeds for train and test envs.
    train_seed = args.seed
    test_seed = 2**31 - 1 - args.seed

    args.outdir = experiments.prepare_output_dir(args, args.outdir)
    print('Output files are saved in {}'.format(args.outdir))

    def make_env(test):
        # Use different random seeds for train and test envs
        env_seed = test_seed if test else train_seed
        env = atari_wrappers.wrap_deepmind(atari_wrappers.make_atari(args.env),
                                           episode_life=not test,
                                           clip_rewards=not test)
        env.seed(int(env_seed))
        if test:
            # Randomize actions like epsilon-greedy in evaluation as well
            env = chainerrl.wrappers.RandomizeAction(env, args.eval_epsilon)
        if args.monitor:
            env = gym.wrappers.Monitor(
                env, args.outdir, mode='evaluation' if test else 'training')
        if args.render:
            env = chainerrl.wrappers.Render(env)
        return env

    env = make_env(test=False)
    eval_env = make_env(test=True)

    n_actions = env.action_space.n

    n_atoms = 51
    v_max = 10
    v_min = -10
    q_func = chainerrl.links.Sequence(
        chainerrl.links.NatureDQNHead(),
        chainerrl.q_functions.DistributionalFCStateQFunctionWithDiscreteAction(
            None,
            n_actions,
            n_atoms,
            v_min,
            v_max,
            n_hidden_channels=0,
            n_hidden_layers=0),
    )

    # Draw the computational graph and save it in the output directory.
    chainerrl.misc.draw_computational_graph(
        [q_func(np.zeros((4, 84, 84), dtype=np.float32)[None])],
        os.path.join(args.outdir, 'model'))

    # Use the same hyper parameters as https://arxiv.org/abs/1707.06887
    opt = chainer.optimizers.Adam(2.5e-4, eps=1e-2 / args.batch_size)
    opt.setup(q_func)

    rbuf = replay_buffer.ReplayBuffer(10**6)

    explorer = explorers.LinearDecayEpsilonGreedy(
        1.0, args.final_epsilon, args.final_exploration_frames,
        lambda: np.random.randint(n_actions))

    def phi(x):
        # Feature extractor
        return np.asarray(x, dtype=np.float32) / 255

    agent = chainerrl.agents.CategoricalDQN(
        q_func,
        opt,
        rbuf,
        gpu=args.gpu,
        gamma=0.99,
        explorer=explorer,
        replay_start_size=args.replay_start_size,
        target_update_interval=args.target_update_interval,
        update_interval=args.update_interval,
        batch_accumulator='mean',
        phi=phi,
    )

    if args.load:
        agent.load(args.load)

    if args.demo:
        eval_stats = experiments.eval_performance(env=eval_env,
                                                  agent=agent,
                                                  n_runs=args.eval_n_runs)
        print('n_runs: {} mean: {} median: {} stdev {}'.format(
            args.eval_n_runs, eval_stats['mean'], eval_stats['median'],
            eval_stats['stdev']))
    else:
        experiments.train_agent_with_evaluation(
            agent=agent,
            env=env,
            steps=args.steps,
            eval_n_runs=args.eval_n_runs,
            eval_interval=args.eval_interval,
            outdir=args.outdir,
            save_best_so_far_agent=False,
            max_episode_len=args.max_episode_len,
            eval_env=eval_env,
        )
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env',
                        type=str,
                        default='BreakoutNoFrameskip-v4',
                        help='OpenAI Atari domain to perform algorithm on.')
    parser.add_argument('--outdir',
                        type=str,
                        default='results',
                        help='Directory path to save output files.'
                        ' If it does not exist, it will be created.')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help='Random seed [0, 2 ** 31)')
    parser.add_argument('--gpu',
                        type=int,
                        default=0,
                        help='GPU to use, set to -1 if no GPU.')
    parser.add_argument('--load', type=str, default=None, required=True)
    parser.add_argument('--logging-level',
                        type=int,
                        default=20,
                        help='Logging level. 10:DEBUG, 20:INFO etc.')
    parser.add_argument('--render',
                        action='store_true',
                        default=False,
                        help='Render env states in a GUI window.')
    parser.add_argument('--monitor',
                        action='store_true',
                        default=False,
                        help='Monitor env. Videos and additional information'
                        ' are saved as output files.')
    parser.add_argument('--steps',
                        type=int,
                        default=5 * 10**7,
                        help='Total number of demo timesteps to collect')
    args = parser.parse_args()

    import logging
    logging.basicConfig(level=args.logging_level)

    # Set a random seed used in ChainerRL.
    misc.set_random_seed(args.seed, gpus=(args.gpu, ))

    args.outdir = experiments.prepare_output_dir(args, args.outdir)
    print('Output files are saved in {}'.format(args.outdir))

    def make_env():
        env = atari_wrappers.wrap_deepmind(atari_wrappers.make_atari(
            args.env, max_frames=None),
                                           episode_life=False,
                                           clip_rewards=False)
        env.seed(int(args.seed))
        # Randomize actions like epsilon-greedy
        env = chainerrl.wrappers.RandomizeAction(env, 0.01)
        if args.monitor:
            env = chainerrl.wrappers.Monitor(env,
                                             args.outdir,
                                             mode='evaluation')
        if args.render:
            env = chainerrl.wrappers.Render(env)
        return env

    env = make_env()

    n_actions = env.action_space.n
    q_func = links.Sequence(links.NatureDQNHead(), L.Linear(512, n_actions),
                            DiscreteActionValue)

    # Draw the computational graph and save it in the output directory.
    chainerrl.misc.draw_computational_graph(
        [q_func(np.zeros((4, 84, 84), dtype=np.float32)[None])],
        os.path.join(args.outdir, 'model'))

    # The optimizer and replay buffer are dummy variables required by agent
    opt = optimizers.RMSpropGraves()
    opt.setup(q_func)
    rbuf = replay_buffer.ReplayBuffer(1)

    def phi(x):
        # Feature extractor
        return np.asarray(x, dtype=np.float32) / 255

    Agent = agents.DQN
    agent = Agent(q_func,
                  opt,
                  rbuf,
                  gpu=args.gpu,
                  gamma=0.99,
                  explorer=None,
                  replay_start_size=1,
                  minibatch_size=1,
                  target_update_interval=None,
                  clip_delta=True,
                  update_interval=4,
                  phi=phi)

    agent.load(args.load)

    # saves demos to outdir/demos.pickle
    experiments.collect_demonstrations(agent=agent,
                                       env=env,
                                       steps=args.steps,
                                       episodes=None,
                                       outdir=args.outdir,
                                       max_episode_len=None)
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--outdir',
                        type=str,
                        default='results',
                        help='Directory path to save output files.'
                        ' If it does not exist, it will be created.')
    parser.add_argument('--env',
                        type=str,
                        default='Hopper-v2',
                        help='OpenAI Gym MuJoCo env to perform algorithm on.')
    parser.add_argument('--num-envs',
                        type=int,
                        default=1,
                        help='Number of envs run in parallel.')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help='Random seed [0, 2 ** 32)')
    parser.add_argument('--gpu',
                        type=int,
                        default=0,
                        help='GPU to use, set to -1 if no GPU.')
    parser.add_argument('--load',
                        type=str,
                        default='',
                        help='Directory to load agent from.')
    parser.add_argument('--steps',
                        type=int,
                        default=10**6,
                        help='Total number of timesteps to train the agent.')
    parser.add_argument('--eval-n-runs',
                        type=int,
                        default=10,
                        help='Number of episodes run for each evaluation.')
    parser.add_argument('--eval-interval',
                        type=int,
                        default=5000,
                        help='Interval in timesteps between evaluations.')
    parser.add_argument('--replay-start-size',
                        type=int,
                        default=10000,
                        help='Minimum replay buffer size before ' +
                        'performing gradient updates.')
    parser.add_argument('--batch-size',
                        type=int,
                        default=256,
                        help='Minibatch size')
    parser.add_argument('--render',
                        action='store_true',
                        help='Render env states in a GUI window.')
    parser.add_argument('--demo',
                        action='store_true',
                        help='Just run evaluation, not training.')
    parser.add_argument('--load-pretrained',
                        action='store_true',
                        default=False)
    parser.add_argument('--pretrained-type',
                        type=str,
                        default="best",
                        choices=['best', 'final'])
    parser.add_argument('--monitor',
                        action='store_true',
                        help='Wrap env with gym.wrappers.Monitor.')
    parser.add_argument('--log-interval',
                        type=int,
                        default=1000,
                        help='Interval in timesteps between outputting log'
                        ' messages during training')
    parser.add_argument('--logger-level',
                        type=int,
                        default=logging.INFO,
                        help='Level of the root logger.')
    parser.add_argument('--policy-output-scale',
                        type=float,
                        default=1.,
                        help='Weight initialization scale of policy output.')
    parser.add_argument('--debug', action='store_true', help='Debug mode.')
    args = parser.parse_args()

    logging.basicConfig(level=args.logger_level)

    if args.debug:
        chainer.set_debug(True)

    args.outdir = experiments.prepare_output_dir(args,
                                                 args.outdir,
                                                 argv=sys.argv)
    print('Output files are saved in {}'.format(args.outdir))

    # Set a random seed used in ChainerRL
    misc.set_random_seed(args.seed, gpus=(args.gpu, ))

    # Set different random seeds for different subprocesses.
    # If seed=0 and processes=4, subprocess seeds are [0, 1, 2, 3].
    # If seed=1 and processes=4, subprocess seeds are [4, 5, 6, 7].
    process_seeds = np.arange(args.num_envs) + args.seed * args.num_envs
    assert process_seeds.max() < 2**32

    def make_env(process_idx, test):
        env = gym.make(args.env)
        # Unwrap TimiLimit wrapper
        assert isinstance(env, gym.wrappers.TimeLimit)
        env = env.env
        # Use different random seeds for train and test envs
        process_seed = int(process_seeds[process_idx])
        env_seed = 2**32 - 1 - process_seed if test else process_seed
        env.seed(env_seed)
        # Cast observations to float32 because our model uses float32
        env = chainerrl.wrappers.CastObservationToFloat32(env)
        # Normalize action space to [-1, 1]^n
        env = chainerrl.wrappers.NormalizeActionSpace(env)
        if args.monitor:
            env = gym.wrappers.Monitor(env, args.outdir)
        if args.render:
            env = chainerrl.wrappers.Render(env)
        return env

    def make_batch_env(test):
        return chainerrl.envs.MultiprocessVectorEnv([
            functools.partial(make_env, idx, test)
            for idx, env in enumerate(range(args.num_envs))
        ])

    sample_env = make_env(process_idx=0, test=False)
    timestep_limit = sample_env.spec.max_episode_steps
    obs_space = sample_env.observation_space
    action_space = sample_env.action_space
    print('Observation space:', obs_space)
    print('Action space:', action_space)

    action_size = action_space.low.size

    winit = chainer.initializers.GlorotUniform()
    winit_policy_output = chainer.initializers.GlorotUniform(
        args.policy_output_scale)

    def squashed_diagonal_gaussian_head(x):
        assert x.shape[-1] == action_size * 2
        mean, log_scale = F.split_axis(x, 2, axis=1)
        log_scale = F.clip(log_scale, -20., 2.)
        var = F.exp(log_scale * 2)
        return chainerrl.distribution.SquashedGaussianDistribution(mean,
                                                                   var=var)

    policy = chainer.Sequential(
        L.Linear(None, 256, initialW=winit),
        F.relu,
        L.Linear(None, 256, initialW=winit),
        F.relu,
        L.Linear(None, action_size * 2, initialW=winit_policy_output),
        squashed_diagonal_gaussian_head,
    )
    policy_optimizer = optimizers.Adam(3e-4).setup(policy)

    def make_q_func_with_optimizer():
        q_func = chainer.Sequential(
            concat_obs_and_action,
            L.Linear(None, 256, initialW=winit),
            F.relu,
            L.Linear(None, 256, initialW=winit),
            F.relu,
            L.Linear(None, 1, initialW=winit),
        )
        q_func_optimizer = optimizers.Adam(3e-4).setup(q_func)
        return q_func, q_func_optimizer

    q_func1, q_func1_optimizer = make_q_func_with_optimizer()
    q_func2, q_func2_optimizer = make_q_func_with_optimizer()

    # Draw the computational graph and save it in the output directory.
    fake_obs = chainer.Variable(policy.xp.zeros_like(obs_space.low,
                                                     dtype=np.float32)[None],
                                name='observation')
    fake_action = chainer.Variable(policy.xp.zeros_like(
        action_space.low, dtype=np.float32)[None],
                                   name='action')
    chainerrl.misc.draw_computational_graph([policy(fake_obs)],
                                            os.path.join(
                                                args.outdir, 'policy'))
    chainerrl.misc.draw_computational_graph([q_func1(fake_obs, fake_action)],
                                            os.path.join(
                                                args.outdir, 'q_func1'))
    chainerrl.misc.draw_computational_graph([q_func2(fake_obs, fake_action)],
                                            os.path.join(
                                                args.outdir, 'q_func2'))

    rbuf = replay_buffer.ReplayBuffer(10**6)

    def burnin_action_func():
        """Select random actions until model is updated one or more times."""
        return np.random.uniform(action_space.low,
                                 action_space.high).astype(np.float32)

    # Hyperparameters in http://arxiv.org/abs/1802.09477
    agent = chainerrl.agents.SoftActorCritic(
        policy,
        q_func1,
        q_func2,
        policy_optimizer,
        q_func1_optimizer,
        q_func2_optimizer,
        rbuf,
        gamma=0.99,
        replay_start_size=args.replay_start_size,
        gpu=args.gpu,
        minibatch_size=args.batch_size,
        burnin_action_func=burnin_action_func,
        entropy_target=-action_size,
        temperature_optimizer=chainer.optimizers.Adam(3e-4),
    )

    if len(args.load) > 0 or args.load_pretrained:
        # either load or load_pretrained must be false
        assert not len(args.load) > 0 or not args.load_pretrained
        if len(args.load) > 0:
            agent.load(args.load)
        else:
            agent.load(
                misc.download_model("SAC",
                                    args.env,
                                    model_type=args.pretrained_type)[0])

    if args.demo:
        eval_stats = experiments.eval_performance(
            env=make_batch_env(test=True),
            agent=agent,
            n_steps=None,
            n_episodes=args.eval_n_runs,
            max_episode_len=timestep_limit,
        )
        print('n_runs: {} mean: {} median: {} stdev {}'.format(
            args.eval_n_runs, eval_stats['mean'], eval_stats['median'],
            eval_stats['stdev']))
    else:
        experiments.train_agent_batch_with_evaluation(
            agent=agent,
            env=make_batch_env(test=False),
            eval_env=make_batch_env(test=True),
            outdir=args.outdir,
            steps=args.steps,
            eval_n_steps=None,
            eval_n_episodes=args.eval_n_runs,
            eval_interval=args.eval_interval,
            log_interval=args.log_interval,
            max_episode_len=timestep_limit,
        )