예제 #1
0
 def multi_fps(n_workers):
     vec_env = SubprocVecEnv([make_env_fn for _ in range(n_workers)])
     start_time = time.time()
     steps = 0
     for episode in range(500):
         vec_env.reset()
         for idx, act in enumerate(ref_actions[reset_step:]):
             acts = np.tile(act, (n_workers, 1))
             obs, rew, done, info = vec_env.step(acts)
             steps += 1
     elapsed = time.time() - start_time
     fps = steps / elapsed
     print(f'{n_workers}-worker FPS: {fps} EffectiveFPS: {fps*n_workers}')
     vec_env.close()
예제 #2
0
def main(argv):
    """ Trains a model through backward RL. """
    ref_actions = np.load(os.path.join(DATA_DIR, FLAGS.ref_actions_path))
    clip_name, start_step = parse_path(FLAGS.ref_actions_path)

    make_env_fn = lambda: RefTrackingEnv(
        clip_name, ref_actions, start_step, reset_step=0)
    vec_env = SubprocVecEnv([make_env_fn for _ in range(FLAGS.num_workers)])
    eval_env = make_env_fn()

    config_class = SACTrainerConfig
    train_class = SACTrainer

    if FLAGS.visualize:
        tconf = config_class.from_json(FLAGS.config_path)
        trainer = train_class(vec_env, env, tconf, OUTPUT_DIR)
        trainer.load_checkpoint(os.path.join(OUTPUT_DIR,
                                             FLAGS.checkpoint_path))
        env.visualize(trainer.policy, device='cpu')
    else:
        tconf = config_class.from_flags(FLAGS)
        tconf.to_json(os.path.join(OUTPUT_DIR, FLAGS.config_path))
        trainer = train_class(vec_env, eval_env, tconf, OUTPUT_DIR)

        # Generate the curriculum
        for idx in range(len(ref_actions)):
            reset_step = len(ref_actions) - (idx + 1)

            # Modify the environments to reflect the new reset_step
            vec_env.set_attr('reset_step', reset_step)
            vec_env.reset()
            eval_env.reset_step = reset_step
            eval_env.reset()

            target_return = eval_env.ref_returns[reset_step]
            print(
                f'Curriculum Task {idx}: reset_step {reset_step} target_return {target_return:.3f}'
            )

            trainer.train(target_return)
            trainer.save_checkpoint(
                os.path.join(OUTPUT_DIR, FLAGS.checkpoint_path))

    vec_env.close()
    eval_env.close()
예제 #3
0
파일: main.py 프로젝트: doerlbh/hrl-ep3
def main():
    global args
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.vis = not args.no_vis

    # Set options
    if args.path_opt is not None:
        with open(args.path_opt, 'r') as handle:
            options = yaml.load(handle)
    if args.vis_path_opt is not None:
        with open(args.vis_path_opt, 'r') as handle:
            vis_options = yaml.load(handle)
    print('## args')
    pprint(vars(args))
    print('## options')
    pprint(options)

    # Put alg_%s and optim_%s to alg and optim depending on commandline
    options['use_cuda'] = args.cuda
    options['trial'] = args.trial
    options['alg'] = options['alg_%s' % args.algo]
    options['optim'] = options['optim_%s' % args.algo]
    alg_opt = options['alg']
    alg_opt['algo'] = args.algo
    model_opt = options['model']
    env_opt = options['env']
    env_opt['env-name'] = args.env_name
    log_opt = options['logs']
    optim_opt = options['optim']
    model_opt['time_scale'] = env_opt['time_scale']
    if model_opt['mode'] in ['baselinewtheta', 'phasewtheta']:
        model_opt['theta_space_mode'] = env_opt['theta_space_mode']
        model_opt['theta_sz'] = env_opt['theta_sz']
    elif model_opt['mode'] in ['baseline_lowlevel', 'phase_lowlevel']:
        model_opt['theta_space_mode'] = env_opt['theta_space_mode']

    # Check asserts
    assert (model_opt['mode'] in [
        'baseline', 'baseline_reverse', 'phasesimple', 'phasewstate',
        'baselinewtheta', 'phasewtheta', 'baseline_lowlevel', 'phase_lowlevel',
        'interpolate', 'cyclic', 'maze_baseline', 'maze_baseline_wphase'
    ])
    assert (args.algo in ['a2c', 'ppo', 'acktr'])
    if model_opt['recurrent_policy']:
        assert args.algo in ['a2c', 'ppo'
                             ], 'Recurrent policy is not implemented for ACKTR'

    # Set seed - just make the seed the trial number
    seed = args.trial
    torch.manual_seed(seed)
    if args.cuda:
        torch.cuda.manual_seed(seed)

    # Initialization
    num_updates = int(optim_opt['num_frames']
                      ) // alg_opt['num_steps'] // alg_opt['num_processes']
    torch.set_num_threads(1)

    # Print warning
    print("#######")
    print(
        "WARNING: All rewards are clipped or normalized so you need to use a monitor (see envs.py) or visdom plot to get true rewards"
    )
    print("#######")

    # Set logging / load previous checkpoint
    logpath = os.path.join(log_opt['log_base'], model_opt['mode'],
                           log_opt['exp_name'], args.algo, args.env_name,
                           'trial%d' % args.trial)
    if len(args.resume) > 0:
        assert (os.path.isfile(os.path.join(logpath, args.resume)))
        ckpt = torch.load(os.path.join(logpath, 'ckpt.pth.tar'))
        start_update = ckpt['update_count']
    else:
        # Make directory, check before overwriting
        if os.path.isdir(logpath):
            if click.confirm(
                    'Logs directory already exists in {}. Erase?'.format(
                        logpath, default=False)):
                os.system('rm -rf ' + logpath)
            else:
                return
        os.system('mkdir -p ' + logpath)
        start_update = 0

        # Save options and args
        with open(os.path.join(logpath, os.path.basename(args.path_opt)),
                  'w') as f:
            yaml.dump(options, f, default_flow_style=False)
        with open(os.path.join(logpath, 'args.yaml'), 'w') as f:
            yaml.dump(vars(args), f, default_flow_style=False)

        # Save git info as well
        os.system('git status > %s' % os.path.join(logpath, 'git_status.txt'))
        os.system('git diff > %s' % os.path.join(logpath, 'git_diff.txt'))
        os.system('git show > %s' % os.path.join(logpath, 'git_show.txt'))

    # Set up plotting dashboard
    dashboard = Dashboard(options,
                          vis_options,
                          logpath,
                          vis=args.vis,
                          port=args.port)

    # If interpolate mode, choose states
    if options['model']['mode'] == 'phase_lowlevel' and options['env'][
            'theta_space_mode'] == 'pretrain_interp':
        all_states = torch.load(env_opt['saved_state_file'])
        s1 = random.choice(all_states)
        s2 = random.choice(all_states)
        fixed_states = [s1, s2]
    elif model_opt['mode'] == 'interpolate':
        all_states = torch.load(env_opt['saved_state_file'])
        s1 = all_states[env_opt['s1_ind']]
        s2 = all_states[env_opt['s2_ind']]
        fixed_states = [s1, s2]
    else:
        fixed_states = None

    # Create environments
    dummy_env = make_env(args.env_name, seed, 0, logpath, options,
                         args.verbose)
    dummy_env = dummy_env()
    envs = [
        make_env(args.env_name, seed, i, logpath, options, args.verbose,
                 fixed_states) for i in range(alg_opt['num_processes'])
    ]
    if alg_opt['num_processes'] > 1:
        envs = SubprocVecEnv(envs)
    else:
        envs = DummyVecEnv(envs)

    # Get theta_sz for models (if applicable)
    dummy_env.reset()
    if model_opt['mode'] == 'baseline_lowlevel':
        model_opt['theta_sz'] = dummy_env.env.theta_sz
    elif model_opt['mode'] == 'phase_lowlevel':
        model_opt['theta_sz'] = dummy_env.env.env.theta_sz
    if 'theta_sz' in model_opt:
        env_opt['theta_sz'] = model_opt['theta_sz']

    # Get observation shape
    obs_shape = envs.observation_space.shape
    obs_shape = (obs_shape[0] * env_opt['num_stack'], *obs_shape[1:])

    # Do vec normalize, but mask out what we don't want altered
    if len(envs.observation_space.shape) == 1:
        ignore_mask = np.zeros(envs.observation_space.shape)
        if env_opt['add_timestep']:
            ignore_mask[-1] = 1
        if model_opt['mode'] in [
                'baselinewtheta', 'phasewtheta', 'baseline_lowlevel',
                'phase_lowlevel'
        ]:
            theta_sz = env_opt['theta_sz']
            if env_opt['add_timestep']:
                ignore_mask[-(theta_sz + 1):] = 1
            else:
                ignore_mask[-theta_sz:] = 1
        if args.finetune_baseline:
            ignore_mask = dummy_env.unwrapped._get_obs_mask()
            freeze_mask, _ = dummy_env.unwrapped._get_pro_ext_mask()
            if env_opt['add_timestep']:
                ignore_mask = np.concatenate([ignore_mask, [1]])
                freeze_mask = np.concatenate([freeze_mask, [0]])
            ignore_mask = (ignore_mask + freeze_mask > 0).astype(float)
            envs = ObservationFilter(envs,
                                     ret=alg_opt['norm_ret'],
                                     has_timestep=True,
                                     noclip=env_opt['step_plus_noclip'],
                                     ignore_mask=ignore_mask,
                                     freeze_mask=freeze_mask,
                                     time_scale=env_opt['time_scale'],
                                     gamma=env_opt['gamma'])
        else:
            envs = ObservationFilter(envs,
                                     ret=alg_opt['norm_ret'],
                                     has_timestep=env_opt['add_timestep'],
                                     noclip=env_opt['step_plus_noclip'],
                                     ignore_mask=ignore_mask,
                                     time_scale=env_opt['time_scale'],
                                     gamma=env_opt['gamma'])

    # Set up algo monitoring
    alg_filename = os.path.join(logpath, 'Alg.Monitor.csv')
    alg_f = open(alg_filename, "wt")
    alg_f.write('# Alg Logging %s\n' %
                json.dumps({
                    "t_start": time.time(),
                    'env_id': dummy_env.spec and dummy_env.spec.id,
                    'mode': options['model']['mode'],
                    'name': options['logs']['exp_name']
                }))
    alg_fields = ['value_loss', 'action_loss', 'dist_entropy']
    alg_logger = csv.DictWriter(alg_f, fieldnames=alg_fields)
    alg_logger.writeheader()
    alg_f.flush()

    # Create the policy network
    actor_critic = Policy(obs_shape, envs.action_space, model_opt)
    if args.cuda:
        actor_critic.cuda()

    # Create the agent
    if envs.action_space.__class__.__name__ == "Discrete":
        action_shape = 1
    else:
        action_shape = envs.action_space.shape[0]

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               alg_opt['value_loss_coef'],
                               alg_opt['entropy_coef'],
                               lr=optim_opt['lr'],
                               eps=optim_opt['eps'],
                               alpha=optim_opt['alpha'],
                               max_grad_norm=optim_opt['max_grad_norm'])
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         alg_opt['clip_param'],
                         alg_opt['ppo_epoch'],
                         alg_opt['num_mini_batch'],
                         alg_opt['value_loss_coef'],
                         alg_opt['entropy_coef'],
                         lr=optim_opt['lr'],
                         eps=optim_opt['eps'],
                         max_grad_norm=optim_opt['max_grad_norm'])
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               alg_opt['value_loss_coef'],
                               alg_opt['entropy_coef'],
                               acktr=True)
    rollouts = RolloutStorage(alg_opt['num_steps'], alg_opt['num_processes'],
                              obs_shape, envs.action_space,
                              actor_critic.state_size)
    current_obs = torch.zeros(alg_opt['num_processes'], *obs_shape)

    # Update agent with loaded checkpoint
    if len(args.resume) > 0:
        # This should update both the policy network and the optimizer
        agent.load_state_dict(ckpt['agent'])

        # Set ob_rms
        envs.ob_rms = ckpt['ob_rms']
    elif len(args.other_resume) > 0:
        ckpt = torch.load(args.other_resume)

        # This should update both the policy network
        agent.actor_critic.load_state_dict(ckpt['agent']['model'])

        # Set ob_rms
        envs.ob_rms = ckpt['ob_rms']
    elif args.finetune_baseline:
        # Load the model based on the trial number
        ckpt_base = options['lowlevel']['ckpt']
        ckpt_file = ckpt_base + '/trial%d/ckpt.pth.tar' % args.trial
        ckpt = torch.load(ckpt_file)

        # Make "input mask" that tells the model which inputs were the same from before and should be copied
        oldinput_mask, _ = dummy_env.unwrapped._get_pro_ext_mask()

        # This should update both the policy network
        agent.actor_critic.load_state_dict_special(ckpt['agent']['model'],
                                                   oldinput_mask)

        # Set ob_rms
        old_rms = ckpt['ob_rms']
        old_size = old_rms.mean.size
        if env_opt['add_timestep']:
            old_size -= 1

        # Only copy the pro state part of it
        envs.ob_rms.mean[:old_size] = old_rms.mean[:old_size]
        envs.ob_rms.var[:old_size] = old_rms.var[:old_size]

    # Inline define our helper function for updating obs
    def update_current_obs(obs):
        shape_dim0 = envs.observation_space.shape[0]
        obs = torch.from_numpy(obs).float()
        if env_opt['num_stack'] > 1:
            current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
        current_obs[:, -shape_dim0:] = obs

    # Reset our env and rollouts
    obs = envs.reset()
    update_current_obs(obs)
    rollouts.observations[0].copy_(current_obs)
    if args.cuda:
        current_obs = current_obs.cuda()
        rollouts.cuda()

    # These variables are used to compute average rewards for all processes.
    episode_rewards = torch.zeros([alg_opt['num_processes'], 1])
    final_rewards = torch.zeros([alg_opt['num_processes'], 1])

    # Update loop
    start = time.time()
    for j in range(start_update, num_updates):
        for step in range(alg_opt['num_steps']):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, states = actor_critic.act(
                    rollouts.observations[step], rollouts.states[step],
                    rollouts.masks[step])
            cpu_actions = action.squeeze(1).cpu().numpy()

            # Observe reward and next obs
            obs, reward, done, info = envs.step(cpu_actions)
            #pdb.set_trace()
            reward = torch.from_numpy(np.expand_dims(np.stack(reward),
                                                     1)).float()
            episode_rewards += reward

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            final_rewards *= masks
            final_rewards += (1 - masks) * episode_rewards
            episode_rewards *= masks

            if args.cuda:
                masks = masks.cuda()

            if current_obs.dim() == 4:
                current_obs *= masks.unsqueeze(2).unsqueeze(2)
            else:
                current_obs *= masks

            update_current_obs(obs)
            rollouts.insert(current_obs, states, action, action_log_prob,
                            value, reward, masks)

        # Update model and rollouts
        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.observations[-1],
                                                rollouts.states[-1],
                                                rollouts.masks[-1]).detach()
        rollouts.compute_returns(next_value, alg_opt['use_gae'],
                                 env_opt['gamma'], alg_opt['gae_tau'])
        value_loss, action_loss, dist_entropy = agent.update(rollouts)
        rollouts.after_update()

        # Add algo updates here
        alg_info = {}
        alg_info['value_loss'] = value_loss
        alg_info['action_loss'] = action_loss
        alg_info['dist_entropy'] = dist_entropy
        alg_logger.writerow(alg_info)
        alg_f.flush()

        # Save checkpoints
        total_num_steps = (j +
                           1) * alg_opt['num_processes'] * alg_opt['num_steps']
        #save_interval = log_opt['save_interval'] * alg_opt['log_mult']
        save_interval = 100
        if j % save_interval == 0:
            # Save all of our important information
            save_checkpoint(logpath,
                            agent,
                            envs,
                            j,
                            total_num_steps,
                            args.save_every,
                            final=False)

        # Print log
        log_interval = log_opt['log_interval'] * alg_opt['log_mult']
        if j % log_interval == 0:
            end = time.time()
            print(
                "{}: Updates {}, num timesteps {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}"
                .format(options['logs']['exp_name'], j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        final_rewards.mean(), final_rewards.median(),
                        final_rewards.min(), final_rewards.max(), dist_entropy,
                        value_loss, action_loss))

        # Do dashboard logging
        vis_interval = log_opt['vis_interval'] * alg_opt['log_mult']
        if args.vis and j % vis_interval == 0:
            try:
                # Sometimes monitor doesn't properly flush the outputs
                dashboard.visdom_plot()
            except IOError:
                pass

    # Save final checkpoint
    save_checkpoint(logpath,
                    agent,
                    envs,
                    j,
                    total_num_steps,
                    args.save_every,
                    final=False)

    # Close logging file
    alg_f.close()
예제 #4
0
def main():
    global args
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.vis = not args.no_vis

    # Set options
    if args.path_opt is not None:
        with open(args.path_opt, 'r') as handle:
            options = yaml.load(handle)
    if args.vis_path_opt is not None:
        with open(args.vis_path_opt, 'r') as handle:
            vis_options = yaml.load(handle)
    print('## args')
    pprint(vars(args))
    print('## options')
    pprint(options)

    # Load the lowlevel opt and
    lowlevel_optfile = options['lowlevel']['optfile']
    with open(lowlevel_optfile, 'r') as handle:
        ll_opt = yaml.load(handle)

    # Whether we should set ll policy to be deterministic or not
    ll_deterministic = options['lowlevel']['deterministic']

    # Put alg_%s and optim_%s to alg and optim depending on commandline
    options['use_cuda'] = args.cuda
    options['trial'] = args.trial
    options['alg'] = options['alg_%s' % args.algo]
    options['optim'] = options['optim_%s' % args.algo]
    alg_opt = options['alg']
    alg_opt['algo'] = args.algo
    model_opt = options['model']
    env_opt = options['env']
    env_opt['env-name'] = args.env_name
    log_opt = options['logs']
    optim_opt = options['optim']
    options[
        'lowlevel_opt'] = ll_opt  # Save low level options in option file (for logging purposes)

    # Pass necessary values in ll_opt
    assert (ll_opt['model']['mode'] in ['baseline_lowlevel', 'phase_lowlevel'])
    ll_opt['model']['theta_space_mode'] = ll_opt['env']['theta_space_mode']
    ll_opt['model']['time_scale'] = ll_opt['env']['time_scale']

    # If in many module mode, load the lowlevel policies we want
    if model_opt['mode'] == 'hierarchical_many':
        # Check asserts
        theta_obs_mode = ll_opt['env']['theta_obs_mode']
        theta_space_mode = ll_opt['env']['theta_space_mode']
        assert (theta_space_mode in [
            'pretrain_interp', 'pretrain_any', 'pretrain_any_far',
            'pretrain_any_fromstart'
        ])
        assert (theta_obs_mode == 'pretrain')

        # Get the theta size
        theta_sz = options['lowlevel']['num_load']
        ckpt_base = options['lowlevel']['ckpt']

        # Load checkpoints
        lowlevel_ckpts = []
        for ll_ind in range(theta_sz):
            if args.change_ll_offset:
                ll_offset = theta_sz * args.trial
            else:
                ll_offset = 0
            lowlevel_ckpt_file = ckpt_base + '/trial%d/ckpt.pth.tar' % (
                ll_ind + ll_offset)
            assert (os.path.isfile(lowlevel_ckpt_file))
            lowlevel_ckpts.append(torch.load(lowlevel_ckpt_file))

    # Otherwise it's one ll polciy to load
    else:
        # Get theta_sz for low level model
        theta_obs_mode = ll_opt['env']['theta_obs_mode']
        theta_space_mode = ll_opt['env']['theta_space_mode']
        assert (theta_obs_mode in ['ind', 'vector'])
        if theta_obs_mode == 'ind':
            if theta_space_mode == 'forward':
                theta_sz = 1
            elif theta_space_mode == 'simple_four':
                theta_sz = 4
            elif theta_space_mode == 'simple_eight':
                theta_sz = 8
            elif theta_space_mode == 'k_theta':
                theta_sz = ll_opt['env']['num_theta']
            elif theta_obs_mode == 'vector':
                theta_sz = 2
            else:
                raise NotImplementedError
        else:
            raise NotImplementedError
        ll_opt['model']['theta_sz'] = theta_sz
        ll_opt['env']['theta_sz'] = theta_sz

        # Load the low level policy params
        lowlevel_ckpt = options['lowlevel']['ckpt']
        assert (os.path.isfile(lowlevel_ckpt))
        lowlevel_ckpt = torch.load(lowlevel_ckpt)
    hl_action_space = spaces.Discrete(theta_sz)

    # Check asserts
    assert (args.algo in ['a2c', 'ppo', 'acktr', 'dqn'])
    assert (optim_opt['hierarchical_mode']
            in ['train_highlevel', 'train_both'])
    if model_opt['recurrent_policy']:
        assert args.algo in ['a2c', 'ppo'
                             ], 'Recurrent policy is not implemented for ACKTR'
    assert (model_opt['mode'] in ['hierarchical', 'hierarchical_many'])

    # Set seed - just make the seed the trial number
    seed = args.trial + 1000  # Make it different than lowlevel seed
    torch.manual_seed(seed)
    if args.cuda:
        torch.cuda.manual_seed(seed)

    # Initialization
    num_updates = int(optim_opt['num_frames']) // alg_opt[
        'num_steps'] // alg_opt['num_processes'] // optim_opt['num_ll_steps']
    torch.set_num_threads(1)

    # Print warning
    print("#######")
    print(
        "WARNING: All rewards are clipped or normalized so you need to use a monitor (see envs.py) or visdom plot to get true rewards"
    )
    print("#######")

    # Set logging / load previous checkpoint
    logpath = os.path.join(log_opt['log_base'], model_opt['mode'],
                           log_opt['exp_name'], args.algo, args.env_name,
                           'trial%d' % args.trial)
    if len(args.resume) > 0:
        assert (os.path.isfile(os.path.join(logpath, args.resume)))
        ckpt = torch.load(os.path.join(logpath, 'ckpt.pth.tar'))
        start_update = ckpt['update_count']
    else:
        # Make directory, check before overwriting
        if os.path.isdir(logpath):
            if click.confirm(
                    'Logs directory already exists in {}. Erase?'.format(
                        logpath, default=False)):
                os.system('rm -rf ' + logpath)
            else:
                return
        os.system('mkdir -p ' + logpath)
        start_update = 0

        # Save options and args
        with open(os.path.join(logpath, os.path.basename(args.path_opt)),
                  'w') as f:
            yaml.dump(options, f, default_flow_style=False)
        with open(os.path.join(logpath, 'args.yaml'), 'w') as f:
            yaml.dump(vars(args), f, default_flow_style=False)

        # Save git info as well
        os.system('git status > %s' % os.path.join(logpath, 'git_status.txt'))
        os.system('git diff > %s' % os.path.join(logpath, 'git_diff.txt'))
        os.system('git show > %s' % os.path.join(logpath, 'git_show.txt'))

    # Set up plotting dashboard
    dashboard = Dashboard(options,
                          vis_options,
                          logpath,
                          vis=args.vis,
                          port=args.port)

    # Create environments
    envs = [
        make_env(args.env_name, seed, i, logpath, options, args.verbose)
        for i in range(alg_opt['num_processes'])
    ]
    if alg_opt['num_processes'] > 1:
        envs = SubprocVecEnv(envs)
    else:
        envs = DummyVecEnv(envs)

    # Check if we use timestep in low level
    if 'baseline' in ll_opt['model']['mode']:
        add_timestep = False
    elif 'phase' in ll_opt['model']['mode']:
        add_timestep = True
    else:
        raise NotImplementedError

    # Get shapes
    dummy_env = make_env(args.env_name, seed, 0, logpath, options,
                         args.verbose)
    dummy_env = dummy_env()
    s_pro_dummy = dummy_env.unwrapped._get_pro_obs()
    s_ext_dummy = dummy_env.unwrapped._get_ext_obs()
    if add_timestep:
        ll_obs_shape = (s_pro_dummy.shape[0] + theta_sz + 1, )
        ll_raw_obs_shape = (s_pro_dummy.shape[0] + 1, )
    else:
        ll_obs_shape = (s_pro_dummy.shape[0] + theta_sz, )
        ll_raw_obs_shape = (s_pro_dummy.shape[0], )
    ll_obs_shape = (ll_obs_shape[0] * env_opt['num_stack'], *ll_obs_shape[1:])
    hl_obs_shape = (s_ext_dummy.shape[0], )
    hl_obs_shape = (hl_obs_shape[0] * env_opt['num_stack'], *hl_obs_shape[1:])

    # Do vec normalize, but mask out what we don't want altered
    # Also freeze all of the low level obs
    ignore_mask = dummy_env.env._get_obs_mask()
    freeze_mask, _ = dummy_env.unwrapped._get_pro_ext_mask()
    freeze_mask = np.concatenate([freeze_mask, [0]])
    if ('normalize' in env_opt
            and not env_opt['normalize']) or args.algo == 'dqn':
        ignore_mask = 1 - freeze_mask
    if model_opt['mode'] == 'hierarchical_many':
        # Actually ignore both ignored values and the low level values
        # That filtering will happen later
        ignore_mask = (ignore_mask + freeze_mask > 0).astype(float)
        envs = ObservationFilter(envs,
                                 ret=alg_opt['norm_ret'],
                                 has_timestep=True,
                                 noclip=env_opt['step_plus_noclip'],
                                 ignore_mask=ignore_mask,
                                 freeze_mask=freeze_mask,
                                 time_scale=env_opt['time_scale'],
                                 gamma=env_opt['gamma'])
    else:
        envs = ObservationFilter(envs,
                                 ret=alg_opt['norm_ret'],
                                 has_timestep=True,
                                 noclip=env_opt['step_plus_noclip'],
                                 ignore_mask=ignore_mask,
                                 freeze_mask=freeze_mask,
                                 time_scale=env_opt['time_scale'],
                                 gamma=env_opt['gamma'])

    # Make our helper object for dealing with hierarchical observations
    hier_utils = HierarchyUtils(ll_obs_shape, hl_obs_shape, hl_action_space,
                                theta_sz, add_timestep)

    # Set up algo monitoring
    alg_filename = os.path.join(logpath, 'Alg.Monitor.csv')
    alg_f = open(alg_filename, "wt")
    alg_f.write('# Alg Logging %s\n' %
                json.dumps({
                    "t_start": time.time(),
                    'env_id': dummy_env.spec and dummy_env.spec.id,
                    'mode': options['model']['mode'],
                    'name': options['logs']['exp_name']
                }))
    alg_fields = ['value_loss', 'action_loss', 'dist_entropy']
    alg_logger = csv.DictWriter(alg_f, fieldnames=alg_fields)
    alg_logger.writeheader()
    alg_f.flush()
    ll_alg_filename = os.path.join(logpath, 'AlgLL.Monitor.csv')
    ll_alg_f = open(ll_alg_filename, "wt")
    ll_alg_f.write('# Alg Logging LL %s\n' %
                   json.dumps({
                       "t_start": time.time(),
                       'env_id': dummy_env.spec and dummy_env.spec.id,
                       'mode': options['model']['mode'],
                       'name': options['logs']['exp_name']
                   }))
    ll_alg_fields = ['value_loss', 'action_loss', 'dist_entropy']
    ll_alg_logger = csv.DictWriter(ll_alg_f, fieldnames=ll_alg_fields)
    ll_alg_logger.writeheader()
    ll_alg_f.flush()

    # Create the policy networks
    ll_action_space = envs.action_space
    if args.algo == 'dqn':
        model_opt['eps_start'] = optim_opt['eps_start']
        model_opt['eps_end'] = optim_opt['eps_end']
        model_opt['eps_decay'] = optim_opt['eps_decay']
        hl_policy = DQNPolicy(hl_obs_shape, hl_action_space, model_opt)
    else:
        hl_policy = Policy(hl_obs_shape, hl_action_space, model_opt)
    if model_opt['mode'] == 'hierarchical_many':
        ll_policy = ModularPolicy(ll_raw_obs_shape, ll_action_space, theta_sz,
                                  ll_opt)
    else:
        ll_policy = Policy(ll_obs_shape, ll_action_space, ll_opt['model'])
    # Load the previous ones here?
    if args.cuda:
        hl_policy.cuda()
        ll_policy.cuda()

    # Create the high level agent
    if args.algo == 'a2c':
        hl_agent = algo.A2C_ACKTR(hl_policy,
                                  alg_opt['value_loss_coef'],
                                  alg_opt['entropy_coef'],
                                  lr=optim_opt['lr'],
                                  eps=optim_opt['eps'],
                                  alpha=optim_opt['alpha'],
                                  max_grad_norm=optim_opt['max_grad_norm'])
    elif args.algo == 'ppo':
        hl_agent = algo.PPO(hl_policy,
                            alg_opt['clip_param'],
                            alg_opt['ppo_epoch'],
                            alg_opt['num_mini_batch'],
                            alg_opt['value_loss_coef'],
                            alg_opt['entropy_coef'],
                            lr=optim_opt['lr'],
                            eps=optim_opt['eps'],
                            max_grad_norm=optim_opt['max_grad_norm'])
    elif args.algo == 'acktr':
        hl_agent = algo.A2C_ACKTR(hl_policy,
                                  alg_opt['value_loss_coef'],
                                  alg_opt['entropy_coef'],
                                  acktr=True)
    elif args.algo == 'dqn':
        hl_agent = algo.DQN(hl_policy,
                            env_opt['gamma'],
                            batch_size=alg_opt['batch_size'],
                            target_update=alg_opt['target_update'],
                            mem_capacity=alg_opt['mem_capacity'],
                            lr=optim_opt['lr'],
                            eps=optim_opt['eps'],
                            max_grad_norm=optim_opt['max_grad_norm'])

    # Create the low level agent
    # If only training high level, make dummy agent (just does passthrough, doesn't change anything)
    if optim_opt['hierarchical_mode'] == 'train_highlevel':
        ll_agent = algo.Passthrough(ll_policy)
    elif optim_opt['hierarchical_mode'] == 'train_both':
        if args.algo == 'a2c':
            ll_agent = algo.A2C_ACKTR(ll_policy,
                                      alg_opt['value_loss_coef'],
                                      alg_opt['entropy_coef'],
                                      lr=optim_opt['ll_lr'],
                                      eps=optim_opt['eps'],
                                      alpha=optim_opt['alpha'],
                                      max_grad_norm=optim_opt['max_grad_norm'])
        elif args.algo == 'ppo':
            ll_agent = algo.PPO(ll_policy,
                                alg_opt['clip_param'],
                                alg_opt['ll_ppo_epoch'],
                                alg_opt['num_mini_batch'],
                                alg_opt['value_loss_coef'],
                                alg_opt['entropy_coef'],
                                lr=optim_opt['ll_lr'],
                                eps=optim_opt['eps'],
                                max_grad_norm=optim_opt['max_grad_norm'])
        elif args.algo == 'acktr':
            ll_agent = algo.A2C_ACKTR(ll_policy,
                                      alg_opt['value_loss_coef'],
                                      alg_opt['entropy_coef'],
                                      acktr=True)
    else:
        raise NotImplementedError

    # Make the rollout structures
    hl_rollouts = RolloutStorage(alg_opt['num_steps'],
                                 alg_opt['num_processes'], hl_obs_shape,
                                 hl_action_space, hl_policy.state_size)
    ll_rollouts = MaskingRolloutStorage(alg_opt['num_steps'],
                                        alg_opt['num_processes'], ll_obs_shape,
                                        ll_action_space, ll_policy.state_size)
    hl_current_obs = torch.zeros(alg_opt['num_processes'], *hl_obs_shape)
    ll_current_obs = torch.zeros(alg_opt['num_processes'], *ll_obs_shape)

    # Helper functions to update the current obs
    def update_hl_current_obs(obs):
        shape_dim0 = hl_obs_shape[0]
        obs = torch.from_numpy(obs).float()
        if env_opt['num_stack'] > 1:
            hl_current_obs[:, :-shape_dim0] = hl_current_obs[:, shape_dim0:]
        hl_current_obs[:, -shape_dim0:] = obs

    def update_ll_current_obs(obs):
        shape_dim0 = ll_obs_shape[0]
        obs = torch.from_numpy(obs).float()
        if env_opt['num_stack'] > 1:
            ll_current_obs[:, :-shape_dim0] = ll_current_obs[:, shape_dim0:]
        ll_current_obs[:, -shape_dim0:] = obs

    # Update agent with loaded checkpoint
    if len(args.resume) > 0:
        # This should update both the policy network and the optimizer
        ll_agent.load_state_dict(ckpt['ll_agent'])
        hl_agent.load_state_dict(ckpt['hl_agent'])

        # Set ob_rms
        envs.ob_rms = ckpt['ob_rms']
    else:
        if model_opt['mode'] == 'hierarchical_many':
            ll_agent.load_pretrained_policies(lowlevel_ckpts)
        else:
            # Load low level agent
            ll_agent.load_state_dict(lowlevel_ckpt['agent'])

            # Load ob_rms from low level (but need to reshape it)
            old_rms = lowlevel_ckpt['ob_rms']
            assert (old_rms.mean.shape[0] == ll_obs_shape[0])
            # Only copy the pro state part of it (not including thetas or count)
            envs.ob_rms.mean[:s_pro_dummy.
                             shape[0]] = old_rms.mean[:s_pro_dummy.shape[0]]
            envs.ob_rms.var[:s_pro_dummy.shape[0]] = old_rms.var[:s_pro_dummy.
                                                                 shape[0]]

    # Reset our env and rollouts
    raw_obs = envs.reset()
    hl_obs, raw_ll_obs, step_counts = hier_utils.seperate_obs(raw_obs)
    ll_obs = hier_utils.placeholder_theta(raw_ll_obs, step_counts)
    update_hl_current_obs(hl_obs)
    update_ll_current_obs(ll_obs)
    hl_rollouts.observations[0].copy_(hl_current_obs)
    ll_rollouts.observations[0].copy_(ll_current_obs)
    ll_rollouts.recent_obs.copy_(ll_current_obs)
    if args.cuda:
        hl_current_obs = hl_current_obs.cuda()
        ll_current_obs = ll_current_obs.cuda()
        hl_rollouts.cuda()
        ll_rollouts.cuda()

    # These variables are used to compute average rewards for all processes.
    episode_rewards = torch.zeros([alg_opt['num_processes'], 1])
    final_rewards = torch.zeros([alg_opt['num_processes'], 1])

    # Update loop
    start = time.time()
    for j in range(start_update, num_updates):
        for step in range(alg_opt['num_steps']):
            # Step through high level action
            start_time = time.time()
            with torch.no_grad():
                hl_value, hl_action, hl_action_log_prob, hl_states = hl_policy.act(
                    hl_rollouts.observations[step], hl_rollouts.states[step],
                    hl_rollouts.masks[step])
            hl_cpu_actions = hl_action.squeeze(1).cpu().numpy()
            if args.profile:
                print('hl act %f' % (time.time() - start_time))

            # Get values to use for Q learning
            hl_state_dqn = hl_rollouts.observations[step]
            hl_action_dqn = hl_action

            # Update last ll observation with new theta
            for proc in range(alg_opt['num_processes']):
                # Update last observations in memory
                last_obs = ll_rollouts.observations[ll_rollouts.steps[proc],
                                                    proc]
                if hier_utils.has_placeholder(last_obs):
                    new_last_obs = hier_utils.update_theta(
                        last_obs, hl_cpu_actions[proc])
                    ll_rollouts.observations[ll_rollouts.steps[proc],
                                             proc].copy_(new_last_obs)

                # Update most recent observations (not necessarily the same)
                assert (hier_utils.has_placeholder(
                    ll_rollouts.recent_obs[proc]))
                new_last_obs = hier_utils.update_theta(
                    ll_rollouts.recent_obs[proc], hl_cpu_actions[proc])
                ll_rollouts.recent_obs[proc].copy_(new_last_obs)
            assert (ll_rollouts.observations.max().item() < float('inf')
                    and ll_rollouts.recent_obs.max().item() < float('inf'))

            # Given high level action, step through the low level actions
            death_step_mask = np.ones([alg_opt['num_processes'],
                                       1])  # 1 means still alive, 0 means dead
            hl_reward = torch.zeros([alg_opt['num_processes'], 1])
            hl_obs = [None for i in range(alg_opt['num_processes'])]
            for ll_step in range(optim_opt['num_ll_steps']):
                # Sample actions
                start_time = time.time()
                with torch.no_grad():
                    ll_value, ll_action, ll_action_log_prob, ll_states = ll_policy.act(
                        ll_rollouts.recent_obs,
                        ll_rollouts.recent_s,
                        ll_rollouts.recent_masks,
                        deterministic=ll_deterministic)
                ll_cpu_actions = ll_action.squeeze(1).cpu().numpy()
                if args.profile:
                    print('ll act %f' % (time.time() - start_time))

                # Observe reward and next obs
                raw_obs, ll_reward, done, info = envs.step(
                    ll_cpu_actions, death_step_mask)
                raw_hl_obs, raw_ll_obs, step_counts = hier_utils.seperate_obs(
                    raw_obs)
                ll_obs = []
                for proc in range(alg_opt['num_processes']):
                    if (ll_step
                            == optim_opt['num_ll_steps'] - 1) or done[proc]:
                        ll_obs.append(
                            hier_utils.placeholder_theta(
                                np.array([raw_ll_obs[proc]]),
                                np.array([step_counts[proc]])))
                    else:
                        ll_obs.append(
                            hier_utils.append_theta(
                                np.array([raw_ll_obs[proc]]),
                                np.array([hl_cpu_actions[proc]]),
                                np.array([step_counts[proc]])))
                ll_obs = np.concatenate(ll_obs, 0)
                ll_reward = torch.from_numpy(
                    np.expand_dims(np.stack(ll_reward), 1)).float()
                episode_rewards += ll_reward
                hl_reward += ll_reward

                # Update values for Q learning and update replay memory
                time.time()
                hl_next_state_dqn = torch.from_numpy(raw_hl_obs)
                hl_reward_dqn = ll_reward
                hl_isdone_dqn = done
                if args.algo == 'dqn':
                    hl_agent.update_memory(hl_state_dqn, hl_action_dqn,
                                           hl_next_state_dqn, hl_reward_dqn,
                                           hl_isdone_dqn, death_step_mask)
                hl_state_dqn = hl_next_state_dqn
                if args.profile:
                    print('dqn memory %f' % (time.time() - start_time))

                # Update high level observations (only take most recent obs if we haven't see a done before now and thus the value is valid)
                for proc, raw_hl in enumerate(raw_hl_obs):
                    if death_step_mask[proc].item() > 0:
                        hl_obs[proc] = np.array([raw_hl])

                # If done then clean the history of observations
                masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                           for done_ in done])
                final_rewards *= masks
                final_rewards += (
                    1 - masks
                ) * episode_rewards  # TODO - actually not sure if I broke this logic, but this value is not used anywhere
                episode_rewards *= masks

                # TODO - I commented this out, which possibly breaks things if num_stack > 1. Fix later if necessary
                #if args.cuda:
                #    masks = masks.cuda()
                #if current_obs.dim() == 4:
                #    current_obs *= masks.unsqueeze(2).unsqueeze(2)
                #else:
                #    current_obs *= masks

                # Update low level observations
                update_ll_current_obs(ll_obs)

                # Update low level rollouts
                ll_rollouts.insert(ll_current_obs, ll_states, ll_action,
                                   ll_action_log_prob, ll_value, ll_reward,
                                   masks, death_step_mask)

                # Update which ones have stepped to the end and shouldn't be updated next time in the loop
                death_step_mask *= masks

            # Update high level rollouts
            hl_obs = np.concatenate(hl_obs, 0)
            update_hl_current_obs(hl_obs)
            hl_rollouts.insert(hl_current_obs, hl_states, hl_action,
                               hl_action_log_prob, hl_value, hl_reward, masks)

            # Check if we want to update lowlevel policy
            if ll_rollouts.isfull and all([
                    not hier_utils.has_placeholder(
                        ll_rollouts.observations[ll_rollouts.steps[proc],
                                                 proc])
                    for proc in range(alg_opt['num_processes'])
            ]):
                # Update low level policy
                assert (ll_rollouts.observations.max().item() < float('inf'))
                if optim_opt['hierarchical_mode'] == 'train_both':
                    with torch.no_grad():
                        ll_next_value = ll_policy.get_value(
                            ll_rollouts.observations[-1],
                            ll_rollouts.states[-1],
                            ll_rollouts.masks[-1]).detach()
                    ll_rollouts.compute_returns(ll_next_value,
                                                alg_opt['use_gae'],
                                                env_opt['gamma'],
                                                alg_opt['gae_tau'])
                    ll_value_loss, ll_action_loss, ll_dist_entropy = ll_agent.update(
                        ll_rollouts)
                else:
                    ll_value_loss = 0
                    ll_action_loss = 0
                    ll_dist_entropy = 0
                ll_rollouts.after_update()

                # Update logger
                alg_info = {}
                alg_info['value_loss'] = ll_value_loss
                alg_info['action_loss'] = ll_action_loss
                alg_info['dist_entropy'] = ll_dist_entropy
                ll_alg_logger.writerow(alg_info)
                ll_alg_f.flush()

        # Update high level policy
        start_time = time.time()
        assert (hl_rollouts.observations.max().item() < float('inf'))
        if args.algo == 'dqn':
            hl_value_loss, hl_action_loss, hl_dist_entropy = hl_agent.update(
                alg_opt['updates_per_step']
            )  # TODO - maybe log this loss properly
        else:
            with torch.no_grad():
                hl_next_value = hl_policy.get_value(
                    hl_rollouts.observations[-1], hl_rollouts.states[-1],
                    hl_rollouts.masks[-1]).detach()
            hl_rollouts.compute_returns(hl_next_value, alg_opt['use_gae'],
                                        env_opt['gamma'], alg_opt['gae_tau'])
            hl_value_loss, hl_action_loss, hl_dist_entropy = hl_agent.update(
                hl_rollouts)
        hl_rollouts.after_update()
        if args.profile:
            print('hl update %f' % (time.time() - start_time))

        # Update alg monitor for high level
        alg_info = {}
        alg_info['value_loss'] = hl_value_loss
        alg_info['action_loss'] = hl_action_loss
        alg_info['dist_entropy'] = hl_dist_entropy
        alg_logger.writerow(alg_info)
        alg_f.flush()

        # Save checkpoints
        total_num_steps = (j + 1) * alg_opt['num_processes'] * alg_opt[
            'num_steps'] * optim_opt['num_ll_steps']
        if 'save_interval' in alg_opt:
            save_interval = alg_opt['save_interval']
        else:
            save_interval = 100
        if j % save_interval == 0:
            # Save all of our important information
            start_time = time.time()
            save_checkpoint(logpath, ll_agent, hl_agent, envs, j,
                            total_num_steps)
            if args.profile:
                print('save checkpoint %f' % (time.time() - start_time))

        # Print log
        log_interval = log_opt['log_interval'] * alg_opt['log_mult']
        if j % log_interval == 0:
            end = time.time()
            print(
                "{}: Updates {}, num timesteps {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}"
                .format(options['logs']['exp_name'], j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        final_rewards.mean(), final_rewards.median(),
                        final_rewards.min(), final_rewards.max(),
                        hl_dist_entropy, hl_value_loss, hl_action_loss))

        # Do dashboard logging
        vis_interval = log_opt['vis_interval'] * alg_opt['log_mult']
        if args.vis and j % vis_interval == 0:
            try:
                # Sometimes monitor doesn't properly flush the outputs
                dashboard.visdom_plot()
            except IOError:
                pass

    # Save final checkpoint
    save_checkpoint(logpath, ll_agent, hl_agent, envs, j, total_num_steps)

    # Close logging file
    alg_f.close()
    ll_alg_f.close()
예제 #5
0
                     'eval/CMU_069_02_start_step_0.npy'))
    reset_step = len(ref_actions) - 16
    env = RefTrackingEnv(
        clip_name='CMU_069_02',
        ref_actions=ref_actions,
        start_step=0,
        reset_step=reset_step,
    )

    # Check set_attr for reset_step is correct
    make_env_fn = lambda: RefTrackingEnv(
        'CMU_069_02', ref_actions, 0, reset_step=len(ref_actions) - 2)
    vec_env = SubprocVecEnv([make_env_fn for _ in range(2)])
    # Change the reset step and make sure we don't get an epsiode termination until expected
    vec_env.set_attr('reset_step', reset_step)
    vec_env.reset()
    for idx, act in enumerate(ref_actions[reset_step:]):
        acts = np.tile(act, (2, 1))
        obs, rew, done, info = vec_env.step(acts)
        # Ensure we're not done until the last step
        if idx < len(ref_actions[reset_step:]) - 1:
            assert not np.any(done)
    # Make sure we're done on the last step
    assert np.all(done)
    vec_env.close()

    # Check for correctness
    for _ in range(2):
        obs = env.reset()
        # Check that the reset is working properly and we get the expected sequences
        # of observations and rewards from executing the reference actions.
예제 #6
0
def main():
    global args
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.vis = not args.no_vis

    # Set options
    if args.path_opt is not None:
        with open(args.path_opt, 'r') as handle:
            options = yaml.load(handle)
    if args.vis_path_opt is not None:
        with open(args.vis_path_opt, 'r') as handle:
            vis_options = yaml.load(handle)
    print('## args'); pprint(vars(args))
    print('## options'); pprint(options)

    # Load the lowlevel opt and 
    lowlevel_optfile = options['lowlevel']['optfile']
    with open(lowlevel_optfile, 'r') as handle:
        ll_opt = yaml.load(handle)

    # Whether we should set ll policy to be deterministic or not
    ll_deterministic = options['lowlevel']['deterministic']

    # Put alg_%s and optim_%s to alg and optim depending on commandline
    options['use_cuda'] = args.cuda
    options['trial'] = 0
    options['alg'] = options['alg_%s' % args.algo]
    options['optim'] = options['optim_%s' % args.algo]
    alg_opt = options['alg']
    alg_opt['algo'] = args.algo
    model_opt = options['model']
    env_opt = options['env']
    env_opt['env-name'] = args.env_name
    log_opt = options['logs']
    optim_opt = options['optim'] 
    options['lowlevel_opt'] = ll_opt    # Save low level options in option file (for logging purposes)

    # Pass necessary values in ll_opt
    assert(ll_opt['model']['mode'] in ['baseline_lowlevel', 'phase_lowlevel'])
    ll_opt['model']['theta_space_mode'] = ll_opt['env']['theta_space_mode']
    ll_opt['model']['time_scale'] = ll_opt['env']['time_scale']

    # If in many module mode, load the lowlevel policies we want
    if model_opt['mode'] == 'hierarchical_many':
        # Check asserts
        theta_obs_mode = ll_opt['env']['theta_obs_mode']
        theta_space_mode = ll_opt['env']['theta_space_mode']
        assert(theta_space_mode in ['pretrain_interp', 'pretrain_any', 'pretrain_any_far', 'pretrain_any_fromstart'])
        assert(theta_obs_mode == 'pretrain')

        # Get the theta size
        theta_sz = options['lowlevel']['num_load']
        ckpt_base = options['lowlevel']['ckpt']

        # Load checkpoints
        #lowlevel_ckpts = []
        #for ll_ind in range(theta_sz):
        #    lowlevel_ckpt_file = ckpt_base + '/trial%d/ckpt.pth.tar' % ll_ind
        #    assert(os.path.isfile(lowlevel_ckpt_file))
        #    lowlevel_ckpts.append(torch.load(lowlevel_ckpt_file))

    # Otherwise it's one ll polciy to load
    else:
        # Get theta_sz for low level model
        theta_obs_mode = ll_opt['env']['theta_obs_mode']
        theta_space_mode = ll_opt['env']['theta_space_mode']
        assert(theta_obs_mode in ['ind', 'vector'])
        if theta_obs_mode == 'ind': 
            if theta_space_mode == 'forward':
                theta_sz = 1
            elif theta_space_mode == 'simple_four':
                theta_sz = 4
            elif theta_space_mode == 'simple_eight':
                theta_sz = 8
            elif theta_space_mode == 'k_theta':
                theta_sz = ll_opt['env']['num_theta']
            elif theta_obs_mode == 'vector':
                theta_sz = 2
            else:
                raise NotImplementedError            
        else:
            raise NotImplementedError
        ll_opt['model']['theta_sz'] = theta_sz
        ll_opt['env']['theta_sz'] = theta_sz

        # Load the low level policy params
        #lowlevel_ckpt = options['lowlevel']['ckpt']            
        #assert(os.path.isfile(lowlevel_ckpt))
        #lowlevel_ckpt = torch.load(lowlevel_ckpt)
    hl_action_space = spaces.Discrete(theta_sz)

    # Check asserts
    assert(args.algo in ['a2c', 'ppo', 'acktr', 'dqn'])
    assert(optim_opt['hierarchical_mode'] in ['train_highlevel', 'train_both'])
    if model_opt['recurrent_policy']:
        assert args.algo in ['a2c', 'ppo'], 'Recurrent policy is not implemented for ACKTR'
    assert(model_opt['mode'] in ['hierarchical', 'hierarchical_many'])

    # Set seed - just make the seed the trial number
    seed = args.seed + 1000    # Make it different than lowlevel seed
    torch.manual_seed(seed)
    if args.cuda:
        torch.cuda.manual_seed(seed)

    # Initialization
    torch.set_num_threads(1)

    # Print warning
    print("#######")
    print("WARNING: All rewards are clipped or normalized so you need to use a monitor (see envs.py) or visdom plot to get true rewards")
    print("#######")

    # Set logging / load previous checkpoint
    logpath = args.logdir

    # Make directory, check before overwriting
    assert not os.path.isdir(logpath), "Give a new directory to save so we don't overwrite anything"
    os.system('mkdir -p ' + logpath)

    # Load checkpoint
    assert(os.path.isfile(args.ckpt))
    if args.cuda:
        ckpt = torch.load(args.ckpt)
    else:
        ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage)

    # Save options and args
    with open(os.path.join(logpath, os.path.basename(args.path_opt)), 'w') as f:
        yaml.dump(options, f, default_flow_style=False)
    with open(os.path.join(logpath, 'args.yaml'), 'w') as f:
        yaml.dump(vars(args), f, default_flow_style=False)

    # Save git info as well
    os.system('git status > %s' % os.path.join(logpath, 'git_status.txt'))
    os.system('git diff > %s' % os.path.join(logpath, 'git_diff.txt'))
    os.system('git show > %s' % os.path.join(logpath, 'git_show.txt'))
    
    # Set up plotting dashboard
    dashboard = Dashboard(options, vis_options, logpath, vis=args.vis, port=args.port)

    # Create environments
    envs = [make_env(args.env_name, seed, i, logpath, options, not args.no_verbose) for i in range(1)]
    if alg_opt['num_processes'] > 1:
        envs = SubprocVecEnv(envs)
    else:
        envs = DummyVecEnv(envs)

    # Check if we use timestep in low level
    if 'baseline' in ll_opt['model']['mode']:
        add_timestep = False
    elif 'phase' in ll_opt['model']['mode']:
        add_timestep = True
    else:
        raise NotImplementedError

    # Get shapes
    dummy_env = make_env(args.env_name, seed, 0, logpath, options, not args.no_verbose) 
    dummy_env = dummy_env()
    s_pro_dummy = dummy_env.unwrapped._get_pro_obs()
    s_ext_dummy = dummy_env.unwrapped._get_ext_obs()
    if add_timestep:
        ll_obs_shape = (s_pro_dummy.shape[0] + theta_sz + 1,)
        ll_raw_obs_shape =(s_pro_dummy.shape[0] + 1,)
    else:
        ll_obs_shape = (s_pro_dummy.shape[0] + theta_sz,)
        ll_raw_obs_shape = (s_pro_dummy.shape[0],)
    ll_obs_shape = (ll_obs_shape[0] * env_opt['num_stack'], *ll_obs_shape[1:])
    hl_obs_shape = (s_ext_dummy.shape[0],)
    hl_obs_shape = (hl_obs_shape[0] * env_opt['num_stack'], *hl_obs_shape[1:])
   
    # Do vec normalize, but mask out what we don't want altered
    # Also freeze all of the low level obs
    ignore_mask = dummy_env.env._get_obs_mask()
    freeze_mask, _ = dummy_env.unwrapped._get_pro_ext_mask()
    freeze_mask = np.concatenate([freeze_mask, [0]])
    if ('normalize' in env_opt and not env_opt['normalize']) or args.algo == 'dqn':
        ignore_mask = 1 - freeze_mask
    if model_opt['mode'] == 'hierarchical_many':
        # Actually ignore both ignored values and the low level values
        # That filtering will happen later
        ignore_mask = (ignore_mask + freeze_mask > 0).astype(float)
        envs = ObservationFilter(envs, ret=alg_opt['norm_ret'], has_timestep=True, noclip=env_opt['step_plus_noclip'], ignore_mask=ignore_mask, freeze_mask=freeze_mask, time_scale=env_opt['time_scale'], gamma=env_opt['gamma'], train=False)
    else:
        envs = ObservationFilter(envs, ret=alg_opt['norm_ret'], has_timestep=True, noclip=env_opt['step_plus_noclip'], ignore_mask=ignore_mask, freeze_mask=freeze_mask, time_scale=env_opt['time_scale'], gamma=env_opt['gamma'], train=False)
    raw_env = envs.venv.envs[0]

    # Make our helper object for dealing with hierarchical observations
    hier_utils = HierarchyUtils(ll_obs_shape, hl_obs_shape, hl_action_space, theta_sz, add_timestep)

    # Set up algo monitoring
    alg_filename = os.path.join(logpath, 'Alg.Monitor.csv')
    alg_f = open(alg_filename, "wt")
    alg_f.write('# Alg Logging %s\n'%json.dumps({"t_start": time.time(), 'env_id' : dummy_env.spec and dummy_env.spec.id, 'mode': options['model']['mode'], 'name': options['logs']['exp_name']}))
    alg_fields = ['value_loss', 'action_loss', 'dist_entropy']
    alg_logger = csv.DictWriter(alg_f, fieldnames=alg_fields)
    alg_logger.writeheader()
    alg_f.flush()
    ll_alg_filename = os.path.join(logpath, 'AlgLL.Monitor.csv')
    ll_alg_f = open(ll_alg_filename, "wt")
    ll_alg_f.write('# Alg Logging LL %s\n'%json.dumps({"t_start": time.time(), 'env_id' : dummy_env.spec and dummy_env.spec.id, 'mode': options['model']['mode'], 'name': options['logs']['exp_name']}))
    ll_alg_fields = ['value_loss', 'action_loss', 'dist_entropy']
    ll_alg_logger = csv.DictWriter(ll_alg_f, fieldnames=ll_alg_fields)
    ll_alg_logger.writeheader()
    ll_alg_f.flush()

    # Create the policy networks
    ll_action_space = envs.action_space 
    if args.algo == 'dqn':
        model_opt['eps_start'] = optim_opt['eps_start']
        model_opt['eps_end'] = optim_opt['eps_end']
        model_opt['eps_decay'] = optim_opt['eps_decay']
        hl_policy = DQNPolicy(hl_obs_shape, hl_action_space, model_opt)
    else:
        hl_policy = Policy(hl_obs_shape, hl_action_space, model_opt)
    if model_opt['mode'] == 'hierarchical_many':
        ll_policy = ModularPolicy(ll_raw_obs_shape, ll_action_space, theta_sz, ll_opt)
    else:
        ll_policy = Policy(ll_obs_shape, ll_action_space, ll_opt['model'])
    # Load the previous ones here?
    if args.cuda:
        hl_policy.cuda()
        ll_policy.cuda()

    # Create the high level agent
    if args.algo == 'a2c':
        hl_agent = algo.A2C_ACKTR(hl_policy, alg_opt['value_loss_coef'],
                               alg_opt['entropy_coef'], lr=optim_opt['lr'],
                               eps=optim_opt['eps'], alpha=optim_opt['alpha'],
                               max_grad_norm=optim_opt['max_grad_norm'])
    elif args.algo == 'ppo':
        hl_agent = algo.PPO(hl_policy, alg_opt['clip_param'], alg_opt['ppo_epoch'], alg_opt['num_mini_batch'],
                         alg_opt['value_loss_coef'], alg_opt['entropy_coef'], lr=optim_opt['lr'],
                         eps=optim_opt['eps'], max_grad_norm=optim_opt['max_grad_norm'])
    elif args.algo == 'acktr':
        hl_agent = algo.A2C_ACKTR(hl_policy, alg_opt['value_loss_coef'],
                               alg_opt['entropy_coef'], acktr=True)
    elif args.algo == 'dqn':
        hl_agent = algo.DQN(hl_policy, env_opt['gamma'], batch_size=alg_opt['batch_size'], target_update=alg_opt['target_update'], 
                        mem_capacity=alg_opt['mem_capacity'], lr=optim_opt['lr'], eps=optim_opt['eps'], max_grad_norm=optim_opt['max_grad_norm'])

    # Create the low level agent
    # If only training high level, make dummy agent (just does passthrough, doesn't change anything)
    if optim_opt['hierarchical_mode'] == 'train_highlevel':
        ll_agent = algo.Passthrough(ll_policy)
    elif optim_opt['hierarchical_mode'] == 'train_both':
        if args.algo == 'a2c':
            ll_agent = algo.A2C_ACKTR(ll_policy, alg_opt['value_loss_coef'],
                                      alg_opt['entropy_coef'], lr=optim_opt['ll_lr'],
                                      eps=optim_opt['eps'], alpha=optim_opt['alpha'],
                                      max_grad_norm=optim_opt['max_grad_norm'])
        elif args.algo == 'ppo':
            ll_agent = algo.PPO(ll_policy, alg_opt['clip_param'], alg_opt['ppo_epoch'], alg_opt['num_mini_batch'],
                                alg_opt['value_loss_coef'], alg_opt['entropy_coef'], lr=optim_opt['ll_lr'],
                                eps=optim_opt['eps'], max_grad_norm=optim_opt['max_grad_norm'])
        elif args.algo == 'acktr':
            ll_agent = algo.A2C_ACKTR(ll_policy, alg_opt['value_loss_coef'],
                                      alg_opt['entropy_coef'], acktr=True)
    else:
        raise NotImplementedError

    # Make the rollout structures 
    # Kind of dumb hack to avoid having to deal with rollouts
    hl_rollouts = RolloutStorage(10000*args.num_ep, 1, hl_obs_shape, hl_action_space, hl_policy.state_size)
    ll_rollouts = MaskingRolloutStorage(alg_opt['num_steps'], 1, ll_obs_shape, ll_action_space, ll_policy.state_size)
    hl_current_obs = torch.zeros(1, *hl_obs_shape)
    ll_current_obs = torch.zeros(1, *ll_obs_shape)

    # Helper functions to update the current obs
    def update_hl_current_obs(obs):
        shape_dim0 = hl_obs_shape[0]
        obs = torch.from_numpy(obs).float()
        if env_opt['num_stack'] > 1:
            hl_current_obs[:, :-shape_dim0] = hl_current_obs[:, shape_dim0:]
        hl_current_obs[:, -shape_dim0:] = obs
    def update_ll_current_obs(obs):
        shape_dim0 = ll_obs_shape[0]
        obs = torch.from_numpy(obs).float()
        if env_opt['num_stack'] > 1:
            ll_current_obs[:, :-shape_dim0] = ll_current_obs[:, shape_dim0:]
        ll_current_obs[:, -shape_dim0:] = obs

    # Update agent with loaded checkpoint
    # This should update both the policy network and the optimizer
    ll_agent.load_state_dict(ckpt['ll_agent'])
    hl_agent.load_state_dict(ckpt['hl_agent'])

    # Set ob_rms
    envs.ob_rms = ckpt['ob_rms']
    
    # Reset our env and rollouts
    raw_obs = envs.reset()
    hl_obs, raw_ll_obs, step_counts = hier_utils.seperate_obs(raw_obs)
    ll_obs = hier_utils.placeholder_theta(raw_ll_obs, step_counts)
    update_hl_current_obs(hl_obs)
    update_ll_current_obs(ll_obs)
    hl_rollouts.observations[0].copy_(hl_current_obs)
    ll_rollouts.observations[0].copy_(ll_current_obs)
    ll_rollouts.recent_obs.copy_(ll_current_obs)
    if args.cuda:
        hl_current_obs = hl_current_obs.cuda()
        ll_current_obs = ll_current_obs.cuda()
        hl_rollouts.cuda()
        ll_rollouts.cuda()

    # These variables are used to compute average rewards for all processes.
    episode_rewards = []
    tabbed = False
    raw_data = []

    # Loop through episodes
    step = 0
    for ep in range(args.num_ep):
        if ep < args.num_vid:
            record = True
        else:
            record = False

        # Complete episode
        done = False
        frames = []
        ep_total_reward = 0
        num_steps = 0
        while not done:
            # Step through high level action
            start_time = time.time()
            with torch.no_grad():
                hl_value, hl_action, hl_action_log_prob, hl_states = hl_policy.act(hl_rollouts.observations[step], hl_rollouts.states[step], hl_rollouts.masks[step], deterministic=True)
            step += 1
            hl_cpu_actions = hl_action.squeeze(1).cpu().numpy()
 
            # Get values to use for Q learning
            hl_state_dqn = hl_rollouts.observations[step]
            hl_action_dqn = hl_action

            # Update last ll observation with new theta
            for proc in range(1):
                # Update last observations in memory
                last_obs = ll_rollouts.observations[ll_rollouts.steps[proc], proc]
                if hier_utils.has_placeholder(last_obs):         
                    new_last_obs = hier_utils.update_theta(last_obs, hl_cpu_actions[proc])
                    ll_rollouts.observations[ll_rollouts.steps[proc], proc].copy_(new_last_obs)

                # Update most recent observations (not necessarily the same)
                assert(hier_utils.has_placeholder(ll_rollouts.recent_obs[proc]))
                new_last_obs = hier_utils.update_theta(ll_rollouts.recent_obs[proc], hl_cpu_actions[proc])
                ll_rollouts.recent_obs[proc].copy_(new_last_obs)
            assert(ll_rollouts.observations.max().item() < float('inf') and ll_rollouts.recent_obs.max().item() < float('inf'))

            # Given high level action, step through the low level actions
            death_step_mask = np.ones([1, 1])    # 1 means still alive, 0 means dead  
            hl_reward = torch.zeros([1, 1])
            hl_obs = [None for i in range(1)]
            for ll_step in range(optim_opt['num_ll_steps']):
                num_steps += 1
                # Capture screenshot
                if record:
                    raw_env.render()
                    if not tabbed:
                        # GLFW TAB and RELEASE are hardcoded here
                        raw_env.unwrapped.viewer.cam.distance += 5
                        raw_env.unwrapped.viewer.cam.lookat[0] += 2.5
                        #raw_env.unwrapped.viewer.cam.lookat[1] += 2.5
                        raw_env.render()
                        tabbed = True
                    frames.append(raw_env.unwrapped.viewer._read_pixels_as_in_window())

                # Sample actions
                with torch.no_grad():
                    ll_value, ll_action, ll_action_log_prob, ll_states = ll_policy.act(ll_rollouts.recent_obs, ll_rollouts.recent_s, ll_rollouts.recent_masks, deterministic=True)
                ll_cpu_actions = ll_action.squeeze(1).cpu().numpy()

                # Observe reward and next obs
                raw_obs, ll_reward, done, info = envs.step(ll_cpu_actions, death_step_mask)
                raw_hl_obs, raw_ll_obs, step_counts = hier_utils.seperate_obs(raw_obs)
                ll_obs = []
                for proc in range(alg_opt['num_processes']):
                    if (ll_step == optim_opt['num_ll_steps'] - 1) or done[proc]:
                        ll_obs.append(hier_utils.placeholder_theta(np.array([raw_ll_obs[proc]]), np.array([step_counts[proc]])))
                    else:
                        ll_obs.append(hier_utils.append_theta(np.array([raw_ll_obs[proc]]), np.array([hl_cpu_actions[proc]]), np.array([step_counts[proc]])))
                ll_obs = np.concatenate(ll_obs, 0) 
                ll_reward = torch.from_numpy(np.expand_dims(np.stack(ll_reward), 1)).float()
                hl_reward += ll_reward
                ep_total_reward += ll_reward.item()


                # Update high level observations (only take most recent obs if we haven't see a done before now and thus the value is valid)
                for proc, raw_hl in enumerate(raw_hl_obs):
                    if death_step_mask[proc].item() > 0:
                        hl_obs[proc] = np.array([raw_hl])

                # If done then clean the history of observations
                masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])
                #final_rewards *= masks
                #final_rewards += (1 - masks) * episode_rewards  # TODO - actually not sure if I broke this logic, but this value is not used anywhere
                #episode_rewards *= masks

                # Done is actually a bool for eval since it's just one process
                done = done.item()
                if not done:
                    last_hl_obs = np.array(hl_obs)

                # TODO - I commented this out, which possibly breaks things if num_stack > 1. Fix later if necessary
                #if args.cuda:
                #    masks = masks.cuda()
                #if current_obs.dim() == 4:
                #    current_obs *= masks.unsqueeze(2).unsqueeze(2)
                #else:
                #    current_obs *= masks

                # Update low level observations
                update_ll_current_obs(ll_obs)

                # Update low level rollouts
                ll_rollouts.insert(ll_current_obs, ll_states, ll_action, ll_action_log_prob, ll_value, ll_reward, masks, death_step_mask) 

                # Update which ones have stepped to the end and shouldn't be updated next time in the loop
                death_step_mask *= masks    

            # Update high level rollouts
            hl_obs = np.concatenate(hl_obs, 0) 
            update_hl_current_obs(hl_obs) 
            hl_rollouts.insert(hl_current_obs, hl_states, hl_action, hl_action_log_prob, hl_value, hl_reward, masks)

            # Check if we want to update lowlevel policy
            if ll_rollouts.isfull and all([not hier_utils.has_placeholder(ll_rollouts.observations[ll_rollouts.steps[proc], proc]) for proc in range(alg_opt['num_processes'])]): 
                # Update low level policy
                assert(ll_rollouts.observations.max().item() < float('inf')) 
                ll_value_loss = 0
                ll_action_loss = 0
                ll_dist_entropy = 0
                ll_rollouts.after_update()

                # Update logger
                #alg_info = {}
                #alg_info['value_loss'] = ll_value_loss
                #alg_info['action_loss'] = ll_action_loss
                #alg_info['dist_entropy'] = ll_dist_entropy
                #ll_alg_logger.writerow(alg_info)
                #ll_alg_f.flush() 

        # Update alg monitor for high level
        #alg_info = {}
        #alg_info['value_loss'] = hl_value_loss
        #alg_info['action_loss'] = hl_action_loss
        #alg_info['dist_entropy'] = hl_dist_entropy
        #alg_logger.writerow(alg_info)
        #alg_f.flush() 

        # Save video
        if record:
            for fr_ind, fr in enumerate(frames):
                scipy.misc.imsave(os.path.join(logpath, 'tmp_fr_%d.jpg' % fr_ind), fr)
            os.system("ffmpeg -r 20 -i %s/" % logpath + "tmp_fr_%01d.jpg -y " + "%s/results_ep%d.mp4" % (logpath, ep))
            os.system("rm %s/tmp_fr*.jpg" % logpath)
            
        # Do dashboard logging for each epsiode
        try:
            dashboard.visdom_plot()
        except IOError:
            pass

        # Print / dump reward for episode
        # DEBUG for thetas
        #print("Theta %d" % env.venv.envs[0].env.env.theta)
        print("Total reward for episode %d: %f" % (ep, ep_total_reward))
        print("Episode length: %d" % num_steps)
        print(last_hl_obs)
        print("----------")
        episode_rewards.append(ep_total_reward)

    # Close logging file
    alg_f.close()
    ll_alg_f.close()