Exemple #1
0
def init_agent(actor_critic, args):
    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR_NOREWARD(actor_critic,
                                        args.value_loss_coef,
                                        args.entropy_coef,
                                        lr=args.lr,
                                        eps=args.eps,
                                        alpha=args.alpha,
                                        max_grad_norm=args.max_grad_norm,
                                        curiosity=args.curiosity,
                                        args=args)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR_NOREWARD(actor_critic,
                                        args.value_loss_coef,
                                        args.entropy_coef,
                                        lr=args.lr,
                                        eps=args.eps,
                                        alpha=args.alpha,
                                        max_grad_norm=args.max_grad_norm,
                                        acktr=True,
                                        curiosity=args.curiosity,
                                        args=args)
    return agent
Exemple #2
0
def main():
    torch.set_num_threads(1)
    device = torch.device("cuda:1" if args.cuda else "cpu")

    ##
    UID = 'exp_{}'.format(
        datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    step_log = []
    reward_log = []

    ## To be used to selec environment
    mode = 'normal'

    # encoder type
    encoder = 'sym_VAE'
    if encoder == 'symbolic':
        embedding_size = (18, )
    elif encoder == 'AE':
        embedding_size = (200, )
    elif encoder == 'VAE':
        embedding_size = (100, )
    elif encoder == 'sym_VAE':
        embedding_size = (118, )
    else:
        raise NotImplementedError('fff')

    # load pre-trained AE
    #AE = VAEU([128,128])
    #model_path = '/hdd_c/data/miniWorld/trained_models/VAE/dataset_4/VAEU.pth'
    #AE = torch.load(model_path)
    #AE.eval()

    # load pre-trained VAE
    VAE = VAER([128, 128])
    model_path = '/hdd_c/data/miniWorld/trained_models/VAE/dataset_5/VAER.pth'
    VAE = torch.load(model_path).to(device)
    VAE.eval()

    # load pre-trained detector
    Detector_model = Detector
    model_path = '/hdd_c/data/miniWorld/trained_models/Detector/dataset_5/Detector_resnet18_e14.pth'
    Detector_model = torch.load(model_path).to(device)

    # load pre-trained RNN
    RNN_model = RNN(200, 128)
    model_path = '/hdd_c/data/miniWorld/trained_models/RNN/RNN1.pth'
    RNN_model = torch.load(model_path).to(device)
    RNN_model.eval()
    """
    if args.vis:
        from visdom import Visdom
        viz = Visdom(port=args.port)
        win = None
    """

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, args.add_timestep, device,
                         False)

    print(envs.observation_space.shape)

    #actor_critic = Policy(envs.observation_space.shape, envs.action_space,
    #    base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic = Policy(embedding_size,
                          envs.action_space,
                          base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    #rollouts = RolloutStorage(args.num_steps, args.num_processes,
    #                    envs.observation_space.shape, envs.action_space,
    #                    actor_critic.recurrent_hidden_state_size)
    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              embedding_size, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    #print(obs.size())
    #obs = make_var(obs)
    print(obs.size())
    with torch.no_grad():
        if encoder == 'symbolic':

            z = Detector_model(obs)
            print(z.size())
            z = Detector_to_symbolic(z)
            rollouts.obs[0].copy_(z)
        elif encoder == 'AE':
            z = AE.encode(obs)
            rollouts.obs[0].copy_(z)
        elif encoder == 'VAE':
            z = VAE.encode(obs)[0]
            rollouts.obs[0].copy_(z)
        elif encoder == 'sym_VAE':
            z_vae = VAE.encode(obs)[0]
            z_sym = Detector_model(obs)
            z_sym = Detector_to_symbolic(z_sym)
            z = torch.cat((z_vae, z_sym), dim=1)
            rollouts.obs[0].copy_(z)
        else:
            raise NotImplementedError('fff')

    #rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=100)

    start = time.time()
    for j in range(num_updates):
        #print(j)
        for step in range(args.num_steps):
            # Sample actions
            #print(step)
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Obser reward and next obs
            #print(action)
            with torch.no_grad():
                obs, reward, done, infos = envs.step(action)
                if encoder == 'symbolic':
                    #print(obs.size())
                    np.save(
                        '/hdd_c/data/miniWorld/training_obs_{}.npy'.format(
                            step),
                        obs.detach().cpu().numpy())
                    z = Detector_model(obs / 255.0)
                    z = Detector_to_symbolic(z)
                    #print(z)
                    np.save(
                        '/hdd_c/data/miniWorld/training_z_{}.npy'.format(step),
                        z.detach().cpu().numpy())
                elif encoder == 'AE':
                    z = AE.encode(obs)
                elif encoder == 'VAE':
                    z = VAE.encode(obs)[0]
                elif encoder == 'sym_VAE':
                    z_vae = VAE.encode(obs)[0]
                    z_sym = Detector_model(obs)
                    z_sym = Detector_to_symbolic(z_sym)
                    z = torch.cat((z_vae, z_sym), dim=1)
                else:
                    raise NotImplementedError('fff')
                #obs = make_var(obs)
            """
            for info in infos:
                if 'episode' in info.keys():
                    print(reward)
                    episode_rewards.append(info['episode']['r'])
            """

            #             # FIXME: works only for environments with sparse rewards
            #             for idx, eps_done in enumerate(done):
            #                 if eps_done:
            #                     episode_rewards.append(reward[idx])

            # FIXME: works only for environments with sparse rewards
            for idx, eps_done in enumerate(done):
                if eps_done:
                    #print('done')
                    episode_rewards.append(infos[idx]['accumulated_reward'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            #rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob, value, reward, masks)
            rollouts.insert(z, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.tau)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        if j % args.save_interval == 0 and args.save_dir != "":
            print('Saving model')
            print()

            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            # A really ugly way to save a model to CPU
            save_model = actor_critic
            if args.cuda:
                save_model = copy.deepcopy(actor_critic).cpu()

            save_model = [
                save_model,
                hasattr(envs.venv, 'ob_rms') and envs.venv.ob_rms or None
            ]

            torch.save(save_model,
                       os.path.join(save_path, args.env_name + ".pt"))

        total_num_steps = (j + 1) * args.num_processes * args.num_steps
        #print(len(episode_rewards))

        step_log.append(total_num_steps)
        reward_log.append(np.mean(episode_rewards))
        step_log_np = np.asarray(step_log)
        reward_log_np = np.asarray(reward_log)
        np.savez_compressed('/hdd_c/data/miniWorld/log/{}.npz'.format(UID),
                            step=step_log_np,
                            reward=reward_log_np)

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.2f}/{:.2f}, min/max reward {:.2f}/{:.2f}, success rate {:.2f}\n"
                .format(
                    j, total_num_steps, int(total_num_steps / (end - start)),
                    len(episode_rewards), np.mean(episode_rewards),
                    np.median(episode_rewards), np.min(episode_rewards),
                    np.max(episode_rewards),
                    np.count_nonzero(np.greater(episode_rewards, 0)) /
                    len(episode_rewards)))

        if args.eval_interval is not None and len(
                episode_rewards) > 1 and j % args.eval_interval == 0:
            eval_envs = make_vec_envs(args.env_name,
                                      args.seed + args.num_processes,
                                      args.num_processes, args.gamma,
                                      eval_log_dir, args.add_timestep, device,
                                      True)

            if eval_envs.venv.__class__.__name__ == "VecNormalize":
                eval_envs.venv.ob_rms = envs.venv.ob_rms

                # An ugly hack to remove updates
                def _obfilt(self, obs):
                    if self.ob_rms:
                        obs = np.clip((obs - self.ob_rms.mean) /
                                      np.sqrt(self.ob_rms.var + self.epsilon),
                                      -self.clipob, self.clipob)
                        return obs
                    else:
                        return obs

                eval_envs.venv._obfilt = types.MethodType(_obfilt, envs.venv)

            eval_episode_rewards = []

            obs = eval_envs.reset()
            eval_recurrent_hidden_states = torch.zeros(
                args.num_processes,
                actor_critic.recurrent_hidden_state_size,
                device=device)
            eval_masks = torch.zeros(args.num_processes, 1, device=device)

            while len(eval_episode_rewards) < 10:
                with torch.no_grad():
                    _, action, _, eval_recurrent_hidden_states = actor_critic.act(
                        obs,
                        eval_recurrent_hidden_states,
                        eval_masks,
                        deterministic=True)

                # Obser reward and next obs
                obs, reward, done, infos = eval_envs.step(action)
                eval_masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                                for done_ in done])
                for info in infos:
                    if 'episode' in info.keys():
                        eval_episode_rewards.append(info['episode']['r'])

            eval_envs.close()

            print(" Evaluation using {} episodes: mean reward {:.5f}\n".format(
                len(eval_episode_rewards), np.mean(eval_episode_rewards)))
        """
        if args.vis and j % args.vis_interval == 0:
            try:
                # Sometimes monitor doesn't properly flush the outputs
                win = visdom_plot(viz, win, args.log_dir, args.env_name,
                                  args.algo, args.num_frames)
            except IOError:
                pass
        """
    envs.close()
# algorithm selection
if args.algo == 'a2c':
    agent = algo.A2C_ACKTR(actor_critic,
                           args.value_loss_coef,
                           args.entropy_coef,
                           lr=args.lr,
                           eps=args.eps,
                           alpha=args.alpha,
                           max_grad_norm=args.max_grad_norm)
elif args.algo == 'ppo':
    agent = algo.PPO(actor_critic,
                     args.clip_param,
                     args.ppo_epoch,
                     args.num_mini_batch,
                     args.value_loss_coef,
                     args.entropy_coef,
                     lr=args.lr,
                     eps=args.eps,
                     max_grad_norm=args.max_grad_norm)
elif args.algo == 'acktr':
    agent = algo.A2C_ACKTR(actor_critic,
                           args.value_loss_coef,
                           args.entropy_coef,
                           acktr=True)

rollouts = RolloutStorage(args.num_steps, args.num_processes, obs_shape,
                          env.robot_dict[args.env_name].action_space,
                          actor_critic.state_size)
current_obs = torch.zeros(args.num_processes, *obs_shape)
Exemple #4
0
def main():

    print('Preparing parameters')

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    print('Creating envs: {}'.format(args.env_name))

    envs = test_mp_envs(args.env_name, args.num_processes)

    print('Creating network')
    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    print('Initializing PPO')
    agent = algo.PPO(actor_critic,
                     args.clip_param,
                     args.ppo_epoch,
                     args.num_mini_batch,
                     args.value_loss_coef,
                     args.entropy_coef,
                     lr=args.lr,
                     eps=args.eps,
                     max_grad_norm=args.max_grad_norm)

    print('Memory')
    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = []

    num_episodes = [0 for _ in range(args.num_processes)]

    last_index = 0

    print('Starting ! ')

    start = time.time()
    for j in tqdm(range(num_updates)):
        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob = actor_critic.act(
                    rollouts.obs[step], rollouts.masks[step])

            obs, reward, done, infos = envs.step(action)

            for info_num, info in enumerate(infos):
                if info_num == 0:
                    if 'episode' in info.keys():
                        episode_rewards.append(info['episode']['r'])
                        # end_episode_to_viz(writer, info, info_num, num_episodes[info_num])
                        num_episodes[info_num] += 1
                        plot_rewards(episode_rewards, args)

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])

            rollouts.insert(obs, action, action_log_prob, value, reward, masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.obs[-1],
                                                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.tau)
        losses = agent.update(rollouts)
        rollouts.after_update()
def main():
    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    if args.vis:
        from visdom import Visdom
        viz = Visdom(port=args.port)
        win = None

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, args.add_timestep, device,
                         False)

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    for j in range(num_updates):
        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.tau)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        if j % args.save_interval == 0 and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            # A really ugly way to save a model to CPU
            save_model = actor_critic
            if args.cuda:
                save_model = copy.deepcopy(actor_critic).cpu()

            save_model = [
                save_model,
                getattr(get_vec_normalize(envs), 'ob_rms', None)
            ]

            torch.save(save_model,
                       os.path.join(save_path, args.env_name + ".pt"))

        total_num_steps = (j + 1) * args.num_processes * args.num_steps

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            eval_envs = make_vec_envs(args.env_name,
                                      args.seed + args.num_processes,
                                      args.num_processes, args.gamma,
                                      eval_log_dir, args.add_timestep, device,
                                      True)

            vec_norm = get_vec_normalize(eval_envs)
            if vec_norm is not None:
                vec_norm.eval()
                vec_norm.ob_rms = get_vec_normalize(envs).ob_rms

            eval_episode_rewards = []

            obs = eval_envs.reset()
            eval_recurrent_hidden_states = torch.zeros(
                args.num_processes,
                actor_critic.recurrent_hidden_state_size,
                device=device)
            eval_masks = torch.zeros(args.num_processes, 1, device=device)

            while len(eval_episode_rewards) < 10:
                with torch.no_grad():
                    _, action, _, eval_recurrent_hidden_states = actor_critic.act(
                        obs,
                        eval_recurrent_hidden_states,
                        eval_masks,
                        deterministic=True)

                # Obser reward and next obs
                obs, reward, done, infos = eval_envs.step(action)

                eval_masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                                for done_ in done])
                for info in infos:
                    if 'episode' in info.keys():
                        eval_episode_rewards.append(info['episode']['r'])

            eval_envs.close()

            print(" Evaluation using {} episodes: mean reward {:.5f}\n".format(
                len(eval_episode_rewards), np.mean(eval_episode_rewards)))

        if args.vis and j % args.vis_interval == 0:
            try:
                # Sometimes monitor doesn't properly flush the outputs
                win = visdom_plot(viz, win, args.log_dir, args.env_name,
                                  args.algo, args.num_frames)
            except IOError:
                pass
Exemple #6
0
def main():
    global args
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.vis = not args.no_vis

    # Set options
    if args.path_opt is not None:
        with open(args.path_opt, 'r') as handle:
            options = yaml.load(handle)
    if args.vis_path_opt is not None:
        with open(args.vis_path_opt, 'r') as handle:
            vis_options = yaml.load(handle)
    print('## args')
    pprint(vars(args))
    print('## options')
    pprint(options)

    # Load the lowlevel opt and
    lowlevel_optfile = options['lowlevel']['optfile']
    with open(lowlevel_optfile, 'r') as handle:
        ll_opt = yaml.load(handle)

    # Whether we should set ll policy to be deterministic or not
    ll_deterministic = options['lowlevel']['deterministic']

    # Put alg_%s and optim_%s to alg and optim depending on commandline
    options['use_cuda'] = args.cuda
    options['trial'] = args.trial
    options['alg'] = options['alg_%s' % args.algo]
    options['optim'] = options['optim_%s' % args.algo]
    alg_opt = options['alg']
    alg_opt['algo'] = args.algo
    model_opt = options['model']
    env_opt = options['env']
    env_opt['env-name'] = args.env_name
    log_opt = options['logs']
    optim_opt = options['optim']
    options[
        'lowlevel_opt'] = ll_opt  # Save low level options in option file (for logging purposes)

    # Pass necessary values in ll_opt
    assert (ll_opt['model']['mode'] in ['baseline_lowlevel', 'phase_lowlevel'])
    ll_opt['model']['theta_space_mode'] = ll_opt['env']['theta_space_mode']
    ll_opt['model']['time_scale'] = ll_opt['env']['time_scale']

    # If in many module mode, load the lowlevel policies we want
    if model_opt['mode'] == 'hierarchical_many':
        # Check asserts
        theta_obs_mode = ll_opt['env']['theta_obs_mode']
        theta_space_mode = ll_opt['env']['theta_space_mode']
        assert (theta_space_mode in [
            'pretrain_interp', 'pretrain_any', 'pretrain_any_far',
            'pretrain_any_fromstart'
        ])
        assert (theta_obs_mode == 'pretrain')

        # Get the theta size
        theta_sz = options['lowlevel']['num_load']
        ckpt_base = options['lowlevel']['ckpt']

        # Load checkpoints
        lowlevel_ckpts = []
        for ll_ind in range(theta_sz):
            if args.change_ll_offset:
                ll_offset = theta_sz * args.trial
            else:
                ll_offset = 0
            lowlevel_ckpt_file = ckpt_base + '/trial%d/ckpt.pth.tar' % (
                ll_ind + ll_offset)
            assert (os.path.isfile(lowlevel_ckpt_file))
            lowlevel_ckpts.append(torch.load(lowlevel_ckpt_file))

    # Otherwise it's one ll polciy to load
    else:
        # Get theta_sz for low level model
        theta_obs_mode = ll_opt['env']['theta_obs_mode']
        theta_space_mode = ll_opt['env']['theta_space_mode']
        assert (theta_obs_mode in ['ind', 'vector'])
        if theta_obs_mode == 'ind':
            if theta_space_mode == 'forward':
                theta_sz = 1
            elif theta_space_mode == 'simple_four':
                theta_sz = 4
            elif theta_space_mode == 'simple_eight':
                theta_sz = 8
            elif theta_space_mode == 'k_theta':
                theta_sz = ll_opt['env']['num_theta']
            elif theta_obs_mode == 'vector':
                theta_sz = 2
            else:
                raise NotImplementedError
        else:
            raise NotImplementedError
        ll_opt['model']['theta_sz'] = theta_sz
        ll_opt['env']['theta_sz'] = theta_sz

        # Load the low level policy params
        lowlevel_ckpt = options['lowlevel']['ckpt']
        assert (os.path.isfile(lowlevel_ckpt))
        lowlevel_ckpt = torch.load(lowlevel_ckpt)
    hl_action_space = spaces.Discrete(theta_sz)

    # Check asserts
    assert (args.algo in ['a2c', 'ppo', 'acktr', 'dqn'])
    assert (optim_opt['hierarchical_mode']
            in ['train_highlevel', 'train_both'])
    if model_opt['recurrent_policy']:
        assert args.algo in ['a2c', 'ppo'
                             ], 'Recurrent policy is not implemented for ACKTR'
    assert (model_opt['mode'] in ['hierarchical', 'hierarchical_many'])

    # Set seed - just make the seed the trial number
    seed = args.trial + 1000  # Make it different than lowlevel seed
    torch.manual_seed(seed)
    if args.cuda:
        torch.cuda.manual_seed(seed)

    # Initialization
    num_updates = int(optim_opt['num_frames']) // alg_opt[
        'num_steps'] // alg_opt['num_processes'] // optim_opt['num_ll_steps']
    torch.set_num_threads(1)

    # Print warning
    print("#######")
    print(
        "WARNING: All rewards are clipped or normalized so you need to use a monitor (see envs.py) or visdom plot to get true rewards"
    )
    print("#######")

    # Set logging / load previous checkpoint
    logpath = os.path.join(log_opt['log_base'], model_opt['mode'],
                           log_opt['exp_name'], args.algo, args.env_name,
                           'trial%d' % args.trial)
    if len(args.resume) > 0:
        assert (os.path.isfile(os.path.join(logpath, args.resume)))
        ckpt = torch.load(os.path.join(logpath, 'ckpt.pth.tar'))
        start_update = ckpt['update_count']
    else:
        # Make directory, check before overwriting
        if os.path.isdir(logpath):
            if click.confirm(
                    'Logs directory already exists in {}. Erase?'.format(
                        logpath, default=False)):
                os.system('rm -rf ' + logpath)
            else:
                return
        os.system('mkdir -p ' + logpath)
        start_update = 0

        # Save options and args
        with open(os.path.join(logpath, os.path.basename(args.path_opt)),
                  'w') as f:
            yaml.dump(options, f, default_flow_style=False)
        with open(os.path.join(logpath, 'args.yaml'), 'w') as f:
            yaml.dump(vars(args), f, default_flow_style=False)

        # Save git info as well
        os.system('git status > %s' % os.path.join(logpath, 'git_status.txt'))
        os.system('git diff > %s' % os.path.join(logpath, 'git_diff.txt'))
        os.system('git show > %s' % os.path.join(logpath, 'git_show.txt'))

    # Set up plotting dashboard
    dashboard = Dashboard(options,
                          vis_options,
                          logpath,
                          vis=args.vis,
                          port=args.port)

    # Create environments
    envs = [
        make_env(args.env_name, seed, i, logpath, options, args.verbose)
        for i in range(alg_opt['num_processes'])
    ]
    if alg_opt['num_processes'] > 1:
        envs = SubprocVecEnv(envs)
    else:
        envs = DummyVecEnv(envs)

    # Check if we use timestep in low level
    if 'baseline' in ll_opt['model']['mode']:
        add_timestep = False
    elif 'phase' in ll_opt['model']['mode']:
        add_timestep = True
    else:
        raise NotImplementedError

    # Get shapes
    dummy_env = make_env(args.env_name, seed, 0, logpath, options,
                         args.verbose)
    dummy_env = dummy_env()
    s_pro_dummy = dummy_env.unwrapped._get_pro_obs()
    s_ext_dummy = dummy_env.unwrapped._get_ext_obs()
    if add_timestep:
        ll_obs_shape = (s_pro_dummy.shape[0] + theta_sz + 1, )
        ll_raw_obs_shape = (s_pro_dummy.shape[0] + 1, )
    else:
        ll_obs_shape = (s_pro_dummy.shape[0] + theta_sz, )
        ll_raw_obs_shape = (s_pro_dummy.shape[0], )
    ll_obs_shape = (ll_obs_shape[0] * env_opt['num_stack'], *ll_obs_shape[1:])
    hl_obs_shape = (s_ext_dummy.shape[0], )
    hl_obs_shape = (hl_obs_shape[0] * env_opt['num_stack'], *hl_obs_shape[1:])

    # Do vec normalize, but mask out what we don't want altered
    # Also freeze all of the low level obs
    ignore_mask = dummy_env.env._get_obs_mask()
    freeze_mask, _ = dummy_env.unwrapped._get_pro_ext_mask()
    freeze_mask = np.concatenate([freeze_mask, [0]])
    if ('normalize' in env_opt
            and not env_opt['normalize']) or args.algo == 'dqn':
        ignore_mask = 1 - freeze_mask
    if model_opt['mode'] == 'hierarchical_many':
        # Actually ignore both ignored values and the low level values
        # That filtering will happen later
        ignore_mask = (ignore_mask + freeze_mask > 0).astype(float)
        envs = ObservationFilter(envs,
                                 ret=alg_opt['norm_ret'],
                                 has_timestep=True,
                                 noclip=env_opt['step_plus_noclip'],
                                 ignore_mask=ignore_mask,
                                 freeze_mask=freeze_mask,
                                 time_scale=env_opt['time_scale'],
                                 gamma=env_opt['gamma'])
    else:
        envs = ObservationFilter(envs,
                                 ret=alg_opt['norm_ret'],
                                 has_timestep=True,
                                 noclip=env_opt['step_plus_noclip'],
                                 ignore_mask=ignore_mask,
                                 freeze_mask=freeze_mask,
                                 time_scale=env_opt['time_scale'],
                                 gamma=env_opt['gamma'])

    # Make our helper object for dealing with hierarchical observations
    hier_utils = HierarchyUtils(ll_obs_shape, hl_obs_shape, hl_action_space,
                                theta_sz, add_timestep)

    # Set up algo monitoring
    alg_filename = os.path.join(logpath, 'Alg.Monitor.csv')
    alg_f = open(alg_filename, "wt")
    alg_f.write('# Alg Logging %s\n' %
                json.dumps({
                    "t_start": time.time(),
                    'env_id': dummy_env.spec and dummy_env.spec.id,
                    'mode': options['model']['mode'],
                    'name': options['logs']['exp_name']
                }))
    alg_fields = ['value_loss', 'action_loss', 'dist_entropy']
    alg_logger = csv.DictWriter(alg_f, fieldnames=alg_fields)
    alg_logger.writeheader()
    alg_f.flush()
    ll_alg_filename = os.path.join(logpath, 'AlgLL.Monitor.csv')
    ll_alg_f = open(ll_alg_filename, "wt")
    ll_alg_f.write('# Alg Logging LL %s\n' %
                   json.dumps({
                       "t_start": time.time(),
                       'env_id': dummy_env.spec and dummy_env.spec.id,
                       'mode': options['model']['mode'],
                       'name': options['logs']['exp_name']
                   }))
    ll_alg_fields = ['value_loss', 'action_loss', 'dist_entropy']
    ll_alg_logger = csv.DictWriter(ll_alg_f, fieldnames=ll_alg_fields)
    ll_alg_logger.writeheader()
    ll_alg_f.flush()

    # Create the policy networks
    ll_action_space = envs.action_space
    if args.algo == 'dqn':
        model_opt['eps_start'] = optim_opt['eps_start']
        model_opt['eps_end'] = optim_opt['eps_end']
        model_opt['eps_decay'] = optim_opt['eps_decay']
        hl_policy = DQNPolicy(hl_obs_shape, hl_action_space, model_opt)
    else:
        hl_policy = Policy(hl_obs_shape, hl_action_space, model_opt)
    if model_opt['mode'] == 'hierarchical_many':
        ll_policy = ModularPolicy(ll_raw_obs_shape, ll_action_space, theta_sz,
                                  ll_opt)
    else:
        ll_policy = Policy(ll_obs_shape, ll_action_space, ll_opt['model'])
    # Load the previous ones here?
    if args.cuda:
        hl_policy.cuda()
        ll_policy.cuda()

    # Create the high level agent
    if args.algo == 'a2c':
        hl_agent = algo.A2C_ACKTR(hl_policy,
                                  alg_opt['value_loss_coef'],
                                  alg_opt['entropy_coef'],
                                  lr=optim_opt['lr'],
                                  eps=optim_opt['eps'],
                                  alpha=optim_opt['alpha'],
                                  max_grad_norm=optim_opt['max_grad_norm'])
    elif args.algo == 'ppo':
        hl_agent = algo.PPO(hl_policy,
                            alg_opt['clip_param'],
                            alg_opt['ppo_epoch'],
                            alg_opt['num_mini_batch'],
                            alg_opt['value_loss_coef'],
                            alg_opt['entropy_coef'],
                            lr=optim_opt['lr'],
                            eps=optim_opt['eps'],
                            max_grad_norm=optim_opt['max_grad_norm'])
    elif args.algo == 'acktr':
        hl_agent = algo.A2C_ACKTR(hl_policy,
                                  alg_opt['value_loss_coef'],
                                  alg_opt['entropy_coef'],
                                  acktr=True)
    elif args.algo == 'dqn':
        hl_agent = algo.DQN(hl_policy,
                            env_opt['gamma'],
                            batch_size=alg_opt['batch_size'],
                            target_update=alg_opt['target_update'],
                            mem_capacity=alg_opt['mem_capacity'],
                            lr=optim_opt['lr'],
                            eps=optim_opt['eps'],
                            max_grad_norm=optim_opt['max_grad_norm'])

    # Create the low level agent
    # If only training high level, make dummy agent (just does passthrough, doesn't change anything)
    if optim_opt['hierarchical_mode'] == 'train_highlevel':
        ll_agent = algo.Passthrough(ll_policy)
    elif optim_opt['hierarchical_mode'] == 'train_both':
        if args.algo == 'a2c':
            ll_agent = algo.A2C_ACKTR(ll_policy,
                                      alg_opt['value_loss_coef'],
                                      alg_opt['entropy_coef'],
                                      lr=optim_opt['ll_lr'],
                                      eps=optim_opt['eps'],
                                      alpha=optim_opt['alpha'],
                                      max_grad_norm=optim_opt['max_grad_norm'])
        elif args.algo == 'ppo':
            ll_agent = algo.PPO(ll_policy,
                                alg_opt['clip_param'],
                                alg_opt['ll_ppo_epoch'],
                                alg_opt['num_mini_batch'],
                                alg_opt['value_loss_coef'],
                                alg_opt['entropy_coef'],
                                lr=optim_opt['ll_lr'],
                                eps=optim_opt['eps'],
                                max_grad_norm=optim_opt['max_grad_norm'])
        elif args.algo == 'acktr':
            ll_agent = algo.A2C_ACKTR(ll_policy,
                                      alg_opt['value_loss_coef'],
                                      alg_opt['entropy_coef'],
                                      acktr=True)
    else:
        raise NotImplementedError

    # Make the rollout structures
    hl_rollouts = RolloutStorage(alg_opt['num_steps'],
                                 alg_opt['num_processes'], hl_obs_shape,
                                 hl_action_space, hl_policy.state_size)
    ll_rollouts = MaskingRolloutStorage(alg_opt['num_steps'],
                                        alg_opt['num_processes'], ll_obs_shape,
                                        ll_action_space, ll_policy.state_size)
    hl_current_obs = torch.zeros(alg_opt['num_processes'], *hl_obs_shape)
    ll_current_obs = torch.zeros(alg_opt['num_processes'], *ll_obs_shape)

    # Helper functions to update the current obs
    def update_hl_current_obs(obs):
        shape_dim0 = hl_obs_shape[0]
        obs = torch.from_numpy(obs).float()
        if env_opt['num_stack'] > 1:
            hl_current_obs[:, :-shape_dim0] = hl_current_obs[:, shape_dim0:]
        hl_current_obs[:, -shape_dim0:] = obs

    def update_ll_current_obs(obs):
        shape_dim0 = ll_obs_shape[0]
        obs = torch.from_numpy(obs).float()
        if env_opt['num_stack'] > 1:
            ll_current_obs[:, :-shape_dim0] = ll_current_obs[:, shape_dim0:]
        ll_current_obs[:, -shape_dim0:] = obs

    # Update agent with loaded checkpoint
    if len(args.resume) > 0:
        # This should update both the policy network and the optimizer
        ll_agent.load_state_dict(ckpt['ll_agent'])
        hl_agent.load_state_dict(ckpt['hl_agent'])

        # Set ob_rms
        envs.ob_rms = ckpt['ob_rms']
    else:
        if model_opt['mode'] == 'hierarchical_many':
            ll_agent.load_pretrained_policies(lowlevel_ckpts)
        else:
            # Load low level agent
            ll_agent.load_state_dict(lowlevel_ckpt['agent'])

            # Load ob_rms from low level (but need to reshape it)
            old_rms = lowlevel_ckpt['ob_rms']
            assert (old_rms.mean.shape[0] == ll_obs_shape[0])
            # Only copy the pro state part of it (not including thetas or count)
            envs.ob_rms.mean[:s_pro_dummy.
                             shape[0]] = old_rms.mean[:s_pro_dummy.shape[0]]
            envs.ob_rms.var[:s_pro_dummy.shape[0]] = old_rms.var[:s_pro_dummy.
                                                                 shape[0]]

    # Reset our env and rollouts
    raw_obs = envs.reset()
    hl_obs, raw_ll_obs, step_counts = hier_utils.seperate_obs(raw_obs)
    ll_obs = hier_utils.placeholder_theta(raw_ll_obs, step_counts)
    update_hl_current_obs(hl_obs)
    update_ll_current_obs(ll_obs)
    hl_rollouts.observations[0].copy_(hl_current_obs)
    ll_rollouts.observations[0].copy_(ll_current_obs)
    ll_rollouts.recent_obs.copy_(ll_current_obs)
    if args.cuda:
        hl_current_obs = hl_current_obs.cuda()
        ll_current_obs = ll_current_obs.cuda()
        hl_rollouts.cuda()
        ll_rollouts.cuda()

    # These variables are used to compute average rewards for all processes.
    episode_rewards = torch.zeros([alg_opt['num_processes'], 1])
    final_rewards = torch.zeros([alg_opt['num_processes'], 1])

    # Update loop
    start = time.time()
    for j in range(start_update, num_updates):
        for step in range(alg_opt['num_steps']):
            # Step through high level action
            start_time = time.time()
            with torch.no_grad():
                hl_value, hl_action, hl_action_log_prob, hl_states = hl_policy.act(
                    hl_rollouts.observations[step], hl_rollouts.states[step],
                    hl_rollouts.masks[step])
            hl_cpu_actions = hl_action.squeeze(1).cpu().numpy()
            if args.profile:
                print('hl act %f' % (time.time() - start_time))

            # Get values to use for Q learning
            hl_state_dqn = hl_rollouts.observations[step]
            hl_action_dqn = hl_action

            # Update last ll observation with new theta
            for proc in range(alg_opt['num_processes']):
                # Update last observations in memory
                last_obs = ll_rollouts.observations[ll_rollouts.steps[proc],
                                                    proc]
                if hier_utils.has_placeholder(last_obs):
                    new_last_obs = hier_utils.update_theta(
                        last_obs, hl_cpu_actions[proc])
                    ll_rollouts.observations[ll_rollouts.steps[proc],
                                             proc].copy_(new_last_obs)

                # Update most recent observations (not necessarily the same)
                assert (hier_utils.has_placeholder(
                    ll_rollouts.recent_obs[proc]))
                new_last_obs = hier_utils.update_theta(
                    ll_rollouts.recent_obs[proc], hl_cpu_actions[proc])
                ll_rollouts.recent_obs[proc].copy_(new_last_obs)
            assert (ll_rollouts.observations.max().item() < float('inf')
                    and ll_rollouts.recent_obs.max().item() < float('inf'))

            # Given high level action, step through the low level actions
            death_step_mask = np.ones([alg_opt['num_processes'],
                                       1])  # 1 means still alive, 0 means dead
            hl_reward = torch.zeros([alg_opt['num_processes'], 1])
            hl_obs = [None for i in range(alg_opt['num_processes'])]
            for ll_step in range(optim_opt['num_ll_steps']):
                # Sample actions
                start_time = time.time()
                with torch.no_grad():
                    ll_value, ll_action, ll_action_log_prob, ll_states = ll_policy.act(
                        ll_rollouts.recent_obs,
                        ll_rollouts.recent_s,
                        ll_rollouts.recent_masks,
                        deterministic=ll_deterministic)
                ll_cpu_actions = ll_action.squeeze(1).cpu().numpy()
                if args.profile:
                    print('ll act %f' % (time.time() - start_time))

                # Observe reward and next obs
                raw_obs, ll_reward, done, info = envs.step(
                    ll_cpu_actions, death_step_mask)
                raw_hl_obs, raw_ll_obs, step_counts = hier_utils.seperate_obs(
                    raw_obs)
                ll_obs = []
                for proc in range(alg_opt['num_processes']):
                    if (ll_step
                            == optim_opt['num_ll_steps'] - 1) or done[proc]:
                        ll_obs.append(
                            hier_utils.placeholder_theta(
                                np.array([raw_ll_obs[proc]]),
                                np.array([step_counts[proc]])))
                    else:
                        ll_obs.append(
                            hier_utils.append_theta(
                                np.array([raw_ll_obs[proc]]),
                                np.array([hl_cpu_actions[proc]]),
                                np.array([step_counts[proc]])))
                ll_obs = np.concatenate(ll_obs, 0)
                ll_reward = torch.from_numpy(
                    np.expand_dims(np.stack(ll_reward), 1)).float()
                episode_rewards += ll_reward
                hl_reward += ll_reward

                # Update values for Q learning and update replay memory
                time.time()
                hl_next_state_dqn = torch.from_numpy(raw_hl_obs)
                hl_reward_dqn = ll_reward
                hl_isdone_dqn = done
                if args.algo == 'dqn':
                    hl_agent.update_memory(hl_state_dqn, hl_action_dqn,
                                           hl_next_state_dqn, hl_reward_dqn,
                                           hl_isdone_dqn, death_step_mask)
                hl_state_dqn = hl_next_state_dqn
                if args.profile:
                    print('dqn memory %f' % (time.time() - start_time))

                # Update high level observations (only take most recent obs if we haven't see a done before now and thus the value is valid)
                for proc, raw_hl in enumerate(raw_hl_obs):
                    if death_step_mask[proc].item() > 0:
                        hl_obs[proc] = np.array([raw_hl])

                # If done then clean the history of observations
                masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                           for done_ in done])
                final_rewards *= masks
                final_rewards += (
                    1 - masks
                ) * episode_rewards  # TODO - actually not sure if I broke this logic, but this value is not used anywhere
                episode_rewards *= masks

                # TODO - I commented this out, which possibly breaks things if num_stack > 1. Fix later if necessary
                #if args.cuda:
                #    masks = masks.cuda()
                #if current_obs.dim() == 4:
                #    current_obs *= masks.unsqueeze(2).unsqueeze(2)
                #else:
                #    current_obs *= masks

                # Update low level observations
                update_ll_current_obs(ll_obs)

                # Update low level rollouts
                ll_rollouts.insert(ll_current_obs, ll_states, ll_action,
                                   ll_action_log_prob, ll_value, ll_reward,
                                   masks, death_step_mask)

                # Update which ones have stepped to the end and shouldn't be updated next time in the loop
                death_step_mask *= masks

            # Update high level rollouts
            hl_obs = np.concatenate(hl_obs, 0)
            update_hl_current_obs(hl_obs)
            hl_rollouts.insert(hl_current_obs, hl_states, hl_action,
                               hl_action_log_prob, hl_value, hl_reward, masks)

            # Check if we want to update lowlevel policy
            if ll_rollouts.isfull and all([
                    not hier_utils.has_placeholder(
                        ll_rollouts.observations[ll_rollouts.steps[proc],
                                                 proc])
                    for proc in range(alg_opt['num_processes'])
            ]):
                # Update low level policy
                assert (ll_rollouts.observations.max().item() < float('inf'))
                if optim_opt['hierarchical_mode'] == 'train_both':
                    with torch.no_grad():
                        ll_next_value = ll_policy.get_value(
                            ll_rollouts.observations[-1],
                            ll_rollouts.states[-1],
                            ll_rollouts.masks[-1]).detach()
                    ll_rollouts.compute_returns(ll_next_value,
                                                alg_opt['use_gae'],
                                                env_opt['gamma'],
                                                alg_opt['gae_tau'])
                    ll_value_loss, ll_action_loss, ll_dist_entropy = ll_agent.update(
                        ll_rollouts)
                else:
                    ll_value_loss = 0
                    ll_action_loss = 0
                    ll_dist_entropy = 0
                ll_rollouts.after_update()

                # Update logger
                alg_info = {}
                alg_info['value_loss'] = ll_value_loss
                alg_info['action_loss'] = ll_action_loss
                alg_info['dist_entropy'] = ll_dist_entropy
                ll_alg_logger.writerow(alg_info)
                ll_alg_f.flush()

        # Update high level policy
        start_time = time.time()
        assert (hl_rollouts.observations.max().item() < float('inf'))
        if args.algo == 'dqn':
            hl_value_loss, hl_action_loss, hl_dist_entropy = hl_agent.update(
                alg_opt['updates_per_step']
            )  # TODO - maybe log this loss properly
        else:
            with torch.no_grad():
                hl_next_value = hl_policy.get_value(
                    hl_rollouts.observations[-1], hl_rollouts.states[-1],
                    hl_rollouts.masks[-1]).detach()
            hl_rollouts.compute_returns(hl_next_value, alg_opt['use_gae'],
                                        env_opt['gamma'], alg_opt['gae_tau'])
            hl_value_loss, hl_action_loss, hl_dist_entropy = hl_agent.update(
                hl_rollouts)
        hl_rollouts.after_update()
        if args.profile:
            print('hl update %f' % (time.time() - start_time))

        # Update alg monitor for high level
        alg_info = {}
        alg_info['value_loss'] = hl_value_loss
        alg_info['action_loss'] = hl_action_loss
        alg_info['dist_entropy'] = hl_dist_entropy
        alg_logger.writerow(alg_info)
        alg_f.flush()

        # Save checkpoints
        total_num_steps = (j + 1) * alg_opt['num_processes'] * alg_opt[
            'num_steps'] * optim_opt['num_ll_steps']
        if 'save_interval' in alg_opt:
            save_interval = alg_opt['save_interval']
        else:
            save_interval = 100
        if j % save_interval == 0:
            # Save all of our important information
            start_time = time.time()
            save_checkpoint(logpath, ll_agent, hl_agent, envs, j,
                            total_num_steps)
            if args.profile:
                print('save checkpoint %f' % (time.time() - start_time))

        # Print log
        log_interval = log_opt['log_interval'] * alg_opt['log_mult']
        if j % log_interval == 0:
            end = time.time()
            print(
                "{}: Updates {}, num timesteps {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}"
                .format(options['logs']['exp_name'], j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        final_rewards.mean(), final_rewards.median(),
                        final_rewards.min(), final_rewards.max(),
                        hl_dist_entropy, hl_value_loss, hl_action_loss))

        # Do dashboard logging
        vis_interval = log_opt['vis_interval'] * alg_opt['log_mult']
        if args.vis and j % vis_interval == 0:
            try:
                # Sometimes monitor doesn't properly flush the outputs
                dashboard.visdom_plot()
            except IOError:
                pass

    # Save final checkpoint
    save_checkpoint(logpath, ll_agent, hl_agent, envs, j, total_num_steps)

    # Close logging file
    alg_f.close()
    ll_alg_f.close()
Exemple #7
0
def main():
    print("#######")
    print("WARNING: All rewards are clipped or normalized so you need to use a monitor (see envs.py) or visdom plot to get true rewards")
    print("#######")

    os.environ['OMP_NUM_THREADS'] = '1'
    #os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    #os.environ['CUDA_VISIBLE_DEVICES'] = "9"
    if args.vis:
        from visdom import Visdom
        viz = Visdom(port=args.port)
        win = None

    envs = [make_env(args.env_name, args.seed, i, args.log_dir, args.add_timestep)
                for i in range(args.num_processes)]

    if args.num_processes > 1:
        envs = SubprocVecEnv(envs)
    else:
        envs = DummyVecEnv(envs)

    if len(envs.observation_space.shape) == 1:
        envs = VecNormalize(envs)

    obs_shape = envs.observation_space.shape
    obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])

    if len(envs.observation_space.shape) == 3:
        actor_critic = CNNPolicy(obs_shape[0], envs.action_space,args.hid_size, args.feat_size,args.recurrent_policy)
    else:
        assert not args.recurrent_policy, \
            "Recurrent policy is not implemented for the MLP controller"
        actor_critic = MLPPolicy(obs_shape[0], envs.action_space)

    if envs.action_space.__class__.__name__ == "Discrete":
        action_shape = 1
    else:
        action_shape = envs.action_space.shape[0]
    if args.use_cell:
        hs = HistoryCell(obs_shape[0], actor_critic.feat_size, 2*actor_critic.hidden_size, 1)
        ft = FutureCell(obs_shape[0], actor_critic.feat_size, 2 * actor_critic.hidden_size, 1)
    else:
        hs = History(obs_shape[0], actor_critic.feat_size, actor_critic.hidden_size, 2, 1)
        ft = Future(obs_shape[0], actor_critic.feat_size, actor_critic.hidden_size, 2, 1)

    if args.cuda:
        actor_critic=actor_critic.cuda()
        hs = hs.cuda()
        ft = ft.cuda()
    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic, args.value_loss_coef,
                               args.entropy_coef, lr=args.lr,
                               eps=args.eps, alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic, hs,ft,args.clip_param, args.ppo_epoch, args.num_mini_batch,
                         args.value_loss_coef, args.entropy_coef, args.hf_loss_coef,ac_lr=args.lr,hs_lr=args.lr,ft_lr=args.lr,
                                eps=args.eps,
                                max_grad_norm=args.max_grad_norm,
                                num_processes=args.num_processes,
                                num_steps=args.num_steps,
                                use_cell=args.use_cell,
                                lenhs=args.lenhs,lenft=args.lenft,
                                plan=args.plan,
                                ac_intv=args.ac_interval,
                                hs_intv=args.hs_interval,
                                ft_intv=args.ft_interval
                                )
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic, args.value_loss_coef,
                               args.entropy_coef, acktr=True)

    rollouts = RolloutStorage(args.num_steps, args.num_processes, obs_shape, envs.action_space, actor_critic.state_size,
                              feat_size=512)
    current_obs = torch.zeros(args.num_processes, *obs_shape)

    def update_current_obs(obs):
        shape_dim0 = envs.observation_space.shape[0]
        obs = torch.from_numpy(obs).float()
        if args.num_stack > 1:
            current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
        current_obs[:, -shape_dim0:] = obs

    obs = envs.reset()
    update_current_obs(obs)

    rollouts.observations[0].copy_(current_obs)


    if args.cuda:
        current_obs = current_obs.cuda()
        rollouts.cuda()

    rec_x = []
    rec_y = []
    file = open('./rec/' + args.env_name + '_' + args.method_name + '.txt', 'w')

    hs_info = torch.zeros(args.num_processes, 2 * actor_critic.hidden_size).cuda()
    hs_ind = torch.IntTensor(args.num_processes, 1).zero_()

    epinfobuf = deque(maxlen=100)
    start_time = time.time()
    for j in range(num_updates):
        print('begin sample, time  {}'.format(time.strftime("%Hh %Mm %Ss",
                                                                time.gmtime(time.time() - start_time))))
        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                rollouts.feat[step]=actor_critic.get_feat(rollouts.observations[step])

                if args.use_cell:
                    for i in range(args.num_processes):
                        h = torch.zeros(1, 2 * actor_critic.hid_size).cuda()
                        c = torch.zeros(1, 2 * actor_critic.hid_size).cuda()
                        start_ind = max(hs_ind[i],step+1-args.lenhs)
                        for ind in range(start_ind,step+1):
                            h,c=hs(rollouts.feat[ind,i].unsqueeze(0),h,c)
                        hs_info[i,:]=h.view(1,2*actor_critic.hid_size)
                        del h,c
                        gc.collect()
                else:
                    for i in range(args.num_processes):
                        start_ind = max(hs_ind[i], step + 1 - args.lenhs)
                        hs_info[i,:]=hs(rollouts.feat[start_ind:step+1,i])

                hidden_feat=actor_critic.cat(rollouts.feat[step],hs_info)
                value, action, action_log_prob, states = actor_critic.act(
                        hidden_feat,
                        rollouts.states[step])
            cpu_actions = action.data.squeeze(1).cpu().numpy()

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(cpu_actions)
            for info in infos:
                maybeepinfo = info.get('episode')
                if maybeepinfo:
                    epinfobuf.extend([maybeepinfo['r']])
            reward = torch.from_numpy(np.expand_dims(np.stack(reward), 1)).float()
            masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])
            hs_ind = ((1-masks)*(step+1)+masks*hs_ind.float()).int()

            if args.cuda:
                masks = masks.cuda()

            if current_obs.dim() == 4:
                current_obs *= masks.unsqueeze(2).unsqueeze(2)
            else:
                current_obs *= masks

            update_current_obs(obs)
            rollouts.insert(current_obs, hs_ind,states.data, action.data, action_log_prob.data, value.data, reward, masks)
        with torch.no_grad():
            rollouts.feat[-1] = actor_critic.get_feat(rollouts.observations[-1])
            if args.use_cell:
                for i in range(args.num_processes):
                    h = torch.zeros(1, 2 * actor_critic.hid_size).cuda()
                    c = torch.zeros(1, 2 * actor_critic.hid_size).cuda()
                    start = max(hs_ind[i], step + 1 - args.lenhs)
                    for ind in range(start, step + 1):
                        h, c = hs(rollouts.feat[ind, i].unsqueeze(0), h, c)
                    hs_info[i, :] = h.view(1, 2 * actor_critic.hid_size)
                    del h,c
            else:
                for i in range(args.num_processes):
                    start_ind = max(hs_ind[i], step + 1 - args.lenhs)
                    hs_info[i, :] = hs(rollouts.feat[start_ind:step + 1, i])
            hidden_feat = actor_critic.cat(rollouts.feat[-1],hs_info)
            next_value = actor_critic.get_value(hidden_feat).detach()
        rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.tau)
        rollouts.compute_ft_ind()

        print('begin update, time  {}'.format(time.strftime("%Hh %Mm %Ss",
                                     time.gmtime(time.time() - start_time))))
        value_loss, action_loss, dist_entropy = agent.update(rollouts)
        print('end update, time  {}'.format(time.strftime("%Hh %Mm %Ss",
                                                            time.gmtime(time.time() - start_time))))
        rollouts.after_update()

        if j % args.save_interval == 0 and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            # A really ugly way to save a model to CPU
            save_model = actor_critic
            if args.cuda:
                save_model = copy.deepcopy(actor_critic).cpu()

            save_model = [save_model,
                            hasattr(envs, 'ob_rms') and envs.ob_rms or None]

            torch.save(save_model, os.path.join(save_path, args.env_name + ".pt"))

        if j % args.log_interval == 0:
            end = time.time()
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            v_mean,v_median,v_min,v_max = safe(epinfobuf)
            print("Updates {}, num timesteps {},time {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}".
                format(j, total_num_steps,
                       time.strftime("%Hh %Mm %Ss",
                                     time.gmtime(time.time() - start_time)),
                       int(total_num_steps / (end - start_time)),
                       v_mean, v_median, v_min, v_max,
                       dist_entropy,
                       value_loss, action_loss))

            if not (v_mean==np.nan):
                rec_x.append(total_num_steps)
                rec_y.append(v_mean)
                file.write(str(total_num_steps))
                file.write(' ')
                file.writelines(str(v_mean))
                file.write('\n')

        if args.vis and j % args.vis_interval == 0:
            try:
                # Sometimes monitor doesn't properly flush the outputs
                win = visdom_plot(viz, win, args.log_dir, args.env_name,
                                  args.algo, args.num_frames)
            except IOError:
                pass
    plot_line(rec_x, rec_y, './imgs/' + args.env_name + '_' + args.method_name + '.png', args.method_name,
              args.env_name, args.num_frames)
    file.close()
def main():
    print("#######")
    print(
        "WARNING: All rewards are clipped or normalized so you need to use a monitor (see envs.py) or visdom plot to get true rewards"
    )
    print("#######")

    if args.render_game:
        mp.set_start_method('spawn')

    torch.set_num_threads(1)

    try:
        os.makedirs(args.log_dir)
    except OSError:
        files = glob.glob(os.path.join(args.log_dir, '*.monitor.csv'))
        for f in files:
            if os.path.isfile(f):
                os.remove(f)

    if 'MiniPacman' in args.env_name:
        from environment_model.mini_pacman.builder import MiniPacmanEnvironmentBuilder
        builder = MiniPacmanEnvironmentBuilder(args)
    else:
        from environment_model.latent_space.builder import LatentSpaceEnvironmentBuilder
        builder = LatentSpaceEnvironmentBuilder(args)

    if args.vis:
        from visdom import Visdom
        viz = Visdom(port=args.port)
        win = None
        visdom_plotter = VisdomPlotterA2C(viz, args.algo == 'i2a')

    if 'MiniPacman' in args.env_name:
        from gym_envs.envs_mini_pacman import make_custom_env
        envs = [
            make_custom_env(args.env_name,
                            args.seed,
                            i,
                            args.log_dir,
                            grey_scale=args.grey_scale)
            for i in range(args.num_processes)
        ]
    elif args.algo == 'i2a' or args.train_on_200x160_pixel:
        from gym_envs.envs_ms_pacman import make_env_ms_pacman
        envs = [
            make_env_ms_pacman(env_id=args.env_name,
                               seed=args.seed,
                               rank=i,
                               log_dir=args.log_dir,
                               grey_scale=False,
                               stack_frames=1,
                               skip_frames=4)
            for i in range(args.num_processes)
        ]
    else:
        from envs import make_env
        envs = [
            make_env(args.env_name, args.seed, i, args.log_dir,
                     args.add_timestep) for i in range(args.num_processes)
        ]

    if args.num_processes > 1:
        envs = SubprocVecEnv(envs)
    else:
        envs = DummyVecEnv(envs)

    if len(envs.observation_space.shape) == 1:
        envs = VecNormalize(envs)

    obs_shape = envs.observation_space.shape
    obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])

    if args.algo == 'i2a' and 'MiniPacman' in args.env_name:
        actor_critic = builder.build_i2a_model(envs, args)
    elif args.algo == 'i2a':
        actor_critic = builder.build_i2a_model(envs, args)
    elif 'MiniPacman' in args.env_name:
        actor_critic = builder.build_a2c_model(envs)
    elif args.train_on_200x160_pixel:
        from a2c_models.atari_model import AtariModel
        actor_critic = A2C_PolicyWrapper(
            AtariModel(obs_shape=obs_shape,
                       action_space=envs.action_space.n,
                       use_cuda=args.cuda))
    else:
        actor_critic = Policy(obs_shape, envs.action_space,
                              args.recurrent_policy)

    if envs.action_space.__class__.__name__ == "Discrete":
        action_shape = 1
    else:
        action_shape = envs.action_space.shape[0]

    if args.load_model:
        load_path = os.path.join(args.save_dir, args.algo)
        load_path = os.path.join(load_path, args.env_name + ".pt")
        if os.path.isfile(load_path):
            # if args.cuda:
            saved_state = torch.load(load_path,
                                     map_location=lambda storage, loc: storage)
            actor_critic.load_state_dict(saved_state)
        else:
            print("Can not load model ", load_path, ". File does not exists")
            return

    log_file = os.path.join(os.path.join(args.save_dir, args.algo),
                            args.env_name + ".log")
    if not os.path.exists(log_file) or not args.load_model:
        print("Log file: ", log_file)
        with open(log_file, 'w') as the_file:
            the_file.write('command line args: ' + " ".join(sys.argv) + '\n')

    if args.cuda:
        actor_critic.cuda()

    if args.render_game:
        load_path = os.path.join(args.save_dir, args.algo)
        test_process = TestPolicy(model=copy.deepcopy(actor_critic),
                                  load_path=load_path,
                                  args=args)

    if args.algo == 'i2a':
        agent = I2A_ALGO(actor_critic=actor_critic,
                         obs_shape=obs_shape,
                         action_shape=action_shape,
                         args=args)
    elif args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    if args.algo == 'i2a':
        rollouts = I2A_RolloutStorage(args.num_steps, args.num_processes,
                                      obs_shape, envs.action_space,
                                      actor_critic.state_size)
    else:
        rollouts = RolloutStorage(args.num_steps, args.num_processes,
                                  obs_shape, envs.action_space,
                                  actor_critic.state_size)

    current_obs = torch.zeros(args.num_processes, *obs_shape)

    def update_current_obs(obs):
        shape_dim0 = envs.observation_space.shape[0]
        obs = torch.from_numpy(obs).float()
        if args.num_stack > 1:
            current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
        current_obs[:, -shape_dim0:] = obs

    obs = envs.reset()
    update_current_obs(obs)

    rollouts.observations[0].copy_(current_obs)

    # These variables are used to compute average rewards for all processes.
    episode_rewards = torch.zeros([args.num_processes, 1])
    final_rewards = torch.zeros([args.num_processes, 1])

    if args.cuda:
        current_obs = current_obs.cuda()
        rollouts.cuda()

    start = time.time()
    for j in range(num_updates):
        for step in range(args.num_steps):
            if args.algo == 'i2a':
                # Sample actions
                value, action, action_log_prob, states, policy_action_prob, rollout_action_prob = actor_critic.act(
                    rollouts.observations[step].clone(), rollouts.states[step],
                    rollouts.masks[step])
            else:
                # Sample actions
                with torch.no_grad():
                    value, action, action_log_prob, states = actor_critic.act(
                        rollouts.observations[step], rollouts.states[step],
                        rollouts.masks[step])
            cpu_actions = action.squeeze(1).cpu().numpy()

            # Obser reward and next obs
            obs, reward, done, info = envs.step(cpu_actions)
            reward = torch.from_numpy(np.expand_dims(np.stack(reward),
                                                     1)).float()
            episode_rewards += reward

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            final_rewards *= masks
            final_rewards += (1 - masks) * episode_rewards
            episode_rewards *= masks

            if args.cuda:
                masks = masks.cuda()

            if current_obs.dim() == 4:
                current_obs *= masks.unsqueeze(2).unsqueeze(2)
            else:
                current_obs *= masks

            update_current_obs(obs)
            if args.algo == "i2a":
                rollouts.insert(current_obs, states, action, action_log_prob,
                                value, reward, masks, policy_action_prob,
                                rollout_action_prob)
            else:
                rollouts.insert(current_obs, states, action, action_log_prob,
                                value, reward, masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.observations[-1],
                                                rollouts.states[-1],
                                                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.tau)

        if args.algo == 'i2a':
            value_loss, action_loss, dist_entropy, distill_loss = agent.update(
                rollouts=rollouts)
        else:
            value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        if args.vis:
            distill_loss_data = distill_loss if args.algo == 'i2a' else None
            visdom_plotter.append(dist_entropy,
                                  final_rewards.numpy().flatten(), value_loss,
                                  action_loss, distill_loss_data)

        if j % args.save_interval == 0 and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            # A really ugly way to save a model to CPU
            save_model = actor_critic
            torch.save(save_model.state_dict(),
                       os.path.join(save_path, args.env_name + ".pt"))

        if j % args.log_interval == 0:
            end = time.time()
            total_num_steps = (j + 1) * args.num_processes * args.num_steps

            reward_info = "mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}"\
                .format(final_rewards.mean(), final_rewards.median(), final_rewards.min(), final_rewards.max())

            distill_loss = ", distill_loss {:.5f}".format(
                distill_loss) if args.algo == 'i2a' else ""
            loss_info = "value loss {:.5f}, policy loss {:.5f}{}"\
                .format(value_loss, action_loss, distill_loss)

            entropy_info = "entropy {:.5f}".format(dist_entropy)

            info = "Updates {}, num timesteps {}, FPS {}, {}, {}, {}, time {:.5f} min"\
                    .format(j, total_num_steps, int(total_num_steps / (end - start)),
                            reward_info, entropy_info, loss_info, (end - start) / 60.)

            with open(log_file, 'a') as the_file:
                the_file.write(info + '\n')

            print(info)
        if args.vis and j % args.vis_interval == 0:
            frames = j * args.num_processes * args.num_steps
            visdom_plotter.plot(frames)
Exemple #9
0
def main():
    import random
    import gym_micropolis
    import game_of_life

    args = get_args()
    args.log_dir = args.save_dir + '/logs'
    assert args.algo in ['a2c', 'ppo', 'acktr']
    if args.recurrent_policy:
        assert args.algo in ['a2c', 'ppo'], \
            'Recurrent policy is not implemented for ACKTR'

    num_updates = int(args.num_frames) // args.num_steps // args.num_processes

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    graph_name = args.save_dir.replace('trained_models/', '').replace('/', ' ')

    actor_critic = False
    agent = False
    past_steps = 0
    try:
        os.makedirs(args.log_dir)
    except OSError:
        files = glob.glob(os.path.join(args.log_dir, '*.monitor.csv'))
        for f in files:
            if args.overwrite:
                os.remove(f)
            else:
                pass
    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    if args.vis:
        from visdom import Visdom
        viz = Visdom(port=args.port)
        win = None
        win_eval = None
    if 'GameOfLife' in args.env_name:
        print('env name: {}'.format(args.env_name))
        num_actions = 1
    envs = make_vec_envs(args.env_name,
                         args.seed,
                         args.num_processes,
                         args.gamma,
                         args.log_dir,
                         args.add_timestep,
                         device,
                         False,
                         None,
                         args=args)

    if isinstance(envs.observation_space, gym.spaces.Discrete):
        num_inputs = envs.observation_space.n
    elif isinstance(envs.observation_space, gym.spaces.Box):
        if len(envs.observation_space.shape) == 3:
            in_w = envs.observation_space.shape[1]
            in_h = envs.observation_space.shape[2]
        else:
            in_w = 1
            in_h = 1
        num_inputs = envs.observation_space.shape[0]
    if isinstance(envs.action_space, gym.spaces.Discrete):
        out_w = 1
        out_h = 1
        if 'Micropolis' in args.env_name:  #otherwise it's set
            if args.power_puzzle:
                num_actions = 1
            else:
                num_actions = 19  # TODO: have this already from env
        elif 'GameOfLife' in args.env_name:
            num_actions = 1
        else:
            num_actions = envs.action_space.n
    elif isinstance(envs.action_space, gym.spaces.Box):
        if len(envs.action_space.shape) == 3:
            out_w = envs.action_space.shape[1]
            out_h = envs.action_space.shape[2]
        elif len(envs.action_space.shape) == 1:
            out_w = 1
            out_h = 1
        num_actions = envs.action_space.shape[-1]

    if args.auto_expand:
        args.n_recs -= 1
    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={
                              'map_width': args.map_width,
                              'num_actions': num_actions,
                              'recurrent': args.recurrent_policy,
                              'in_w': in_w,
                              'in_h': in_h,
                              'num_inputs': num_inputs,
                              'out_w': out_w,
                              'out_h': out_h
                          },
                          curiosity=args.curiosity,
                          algo=args.algo,
                          model=args.model,
                          args=args)
    if args.auto_expand:
        args.n_recs += 1

    evaluator = None

    if not agent:
        if args.algo == 'a2c':
            agent = algo.A2C_ACKTR_NOREWARD(actor_critic,
                                            args.value_loss_coef,
                                            args.entropy_coef,
                                            lr=args.lr,
                                            eps=args.eps,
                                            alpha=args.alpha,
                                            max_grad_norm=args.max_grad_norm,
                                            curiosity=args.curiosity,
                                            args=args)
        elif args.algo == 'ppo':
            agent = algo.PPO(actor_critic,
                             args.clip_param,
                             args.ppo_epoch,
                             args.num_mini_batch,
                             args.value_loss_coef,
                             args.entropy_coef,
                             lr=args.lr,
                             eps=args.eps,
                             max_grad_norm=args.max_grad_norm)
        elif args.algo == 'acktr':
            agent = algo.A2C_ACKTR_NOREWARD(actor_critic,
                                            args.value_loss_coef,
                                            args.entropy_coef,
                                            lr=args.lr,
                                            eps=args.eps,
                                            alpha=args.alpha,
                                            max_grad_norm=args.max_grad_norm,
                                            acktr=True,
                                            curiosity=args.curiosity,
                                            args=args)

#saved_model = os.path.join(args.save_dir, args.env_name + '.pt')
    if args.load_dir:
        saved_model = os.path.join(args.load_dir, args.env_name + '.tar')
    else:
        saved_model = os.path.join(args.save_dir, args.env_name + '.tar')
    vec_norm = get_vec_normalize(envs)
    if os.path.exists(saved_model) and not args.overwrite:
        checkpoint = torch.load(saved_model)
        saved_args = checkpoint['args']
        actor_critic.load_state_dict(checkpoint['model_state_dict'])
        actor_critic.to(device)
        actor_critic.cuda()
        agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if args.auto_expand:
            if not args.n_recs - saved_args.n_recs == 1:
                print('can expand by 1 rec only from saved model')
                raise Exception
            actor_critic.base.auto_expand()
            print('expanded net: \n{}'.format(actor_critic.base))
        past_steps = checkpoint['past_steps']
        ob_rms = checkpoint['ob_rms']

        past_steps = next(iter(
            agent.optimizer.state_dict()['state'].values()))['step']
        print('Resuming from step {}'.format(past_steps))

        #print(type(next(iter((torch.load(saved_model))))))
        #actor_critic, ob_rms = \
        #        torch.load(saved_model)
        #agent = \
        #    torch.load(os.path.join(args.save_dir, args.env_name + '_agent.pt'))
        #if not agent.optimizer.state_dict()['state'].values():
        #    past_steps = 0
        #else:

        #    raise Exception

        if vec_norm is not None:
            vec_norm.eval()
            vec_norm.ob_rms = ob_rms
    actor_critic.to(device)

    if 'LSTM' in args.model:
        recurrent_hidden_state_size = actor_critic.base.get_recurrent_state_size(
        )
    else:
        recurrent_hidden_state_size = actor_critic.recurrent_hidden_state_size
    if args.curiosity:
        rollouts = CuriosityRolloutStorage(
            args.num_steps,
            args.num_processes,
            envs.observation_space.shape,
            envs.action_space,
            recurrent_hidden_state_size,
            actor_critic.base.feature_state_size(),
            args=args)
    else:
        rollouts = RolloutStorage(args.num_steps,
                                  args.num_processes,
                                  envs.observation_space.shape,
                                  envs.action_space,
                                  recurrent_hidden_state_size,
                                  args=args)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    model = actor_critic.base
    reset_eval = False
    plotter = None
    for j in range(past_steps, num_updates):
        if reset_eval:
            print('post eval reset')
            obs = envs.reset()
            rollouts.obs[0].copy_(obs)
            rollouts.to(device)
            reset_eval = False
        if args.model == 'FractalNet' and args.drop_path:
            model.set_drop_path()
        if args.model == 'fixed' and model.RAND:
            model.num_recursions = random.randint(1, model.map_width * 2)
        if args.model == 'FractalNet':
            n_cols = model.n_cols
            if args.rule == 'wide1' and args.n_recs > 3:
                col_step = 3
            else:
                col_step = 1
        else:
            n_cols = 0
            col_step = 1
        player_act = None
        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                if args.render:
                    if args.num_processes == 1:
                        if not ('Micropolis' in args.env_name
                                or 'GameOfLife' in args.env_name):
                            envs.venv.venv.render()
                        else:
                            pass
                    else:
                        if not ('Micropolis' in args.env_name
                                or 'GameOfLife' in args.env_name):
                            envs.render()
                            envs.venv.venv.render()
                        else:
                            pass
                        #envs.venv.venv.remotes[0].send(('render', None))
                        #envs.venv.venv.remotes[0].recv()
                value, action, action_log_probs, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step],
                    rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step],
                    player_act=player_act,
                    icm_enabled=args.curiosity,
                    deterministic=False)

            # Observe reward and next obs
            obs, reward, done, infos = envs.step(action)

            player_act = None
            if args.render:
                if infos[0]:
                    if 'player_move' in infos[0].keys():
                        player_act = infos[0]['player_move']
            if args.curiosity:
                # run icm
                with torch.no_grad():

                    feature_state, feature_state_pred, action_dist_pred = actor_critic.icm_act(
                        (rollouts.obs[step], obs, action_bin))

                intrinsic_reward = args.eta * (
                    (feature_state - feature_state_pred).pow(2)).sum() / 2.
                if args.no_reward:
                    reward = 0
                reward += intrinsic_reward.cpu()

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            if args.curiosity:
                rollouts.insert(obs, recurrent_hidden_states, action,
                                action_log_probs, value, reward, masks,
                                feature_state, feature_state_pred, action_bin,
                                action_dist_pred)
            else:
                rollouts.insert(obs, recurrent_hidden_states, action,
                                action_log_probs, value, reward, masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.tau)
        if args.curiosity:
            value_loss, action_loss, dist_entropy, fwd_loss, inv_loss = agent.update(
                rollouts)
        else:
            value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        total_num_steps = (j + 1) * args.num_processes * args.num_steps

        if not dist_entropy:
            dist_entropy = 0
        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n \
dist entropy {:.1f}, val/act loss {:.1f}/{:.1f},".format(
                    j, total_num_steps,
                    int((total_num_steps -
                         past_steps * args.num_processes * args.num_steps) /
                        (end - start)), len(episode_rewards),
                    np.mean(episode_rewards), np.median(episode_rewards),
                    np.min(episode_rewards), np.max(episode_rewards),
                    dist_entropy, value_loss, action_loss))
            if args.curiosity:
                print("fwd/inv icm loss {:.1f}/{:.1f}\n".format(
                    fwd_loss, inv_loss))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            if evaluator is None:
                evaluator = Evaluator(args,
                                      actor_critic,
                                      device,
                                      envs=envs,
                                      vec_norm=vec_norm)

            model = evaluator.actor_critic.base

            col_idx = [-1, *range(0, n_cols, col_step)]
            for i in col_idx:
                evaluator.evaluate(column=i)
        #num_eval_frames = (args.num_frames // (args.num_steps * args.eval_interval * args.num_processes)) * args.num_processes *  args.max_step
        # making sure the evaluator plots the '-1'st column (the overall net)

            if args.vis and j % args.vis_interval == 0:
                try:
                    # Sometimes monitor doesn't properly flush the outputs
                    win_eval = evaluator.plotter.visdom_plot(
                        viz,
                        win_eval,
                        evaluator.eval_log_dir,
                        graph_name,
                        args.algo,
                        args.num_frames,
                        n_graphs=col_idx)
                except IOError:
                    pass
        #elif args.model == 'fixed' and model.RAND:
        #    for i in model.eval_recs:
        #        evaluator.evaluate(num_recursions=i)
        #    win_eval = visdom_plot(viz, win_eval, evaluator.eval_log_dir, graph_name,
        #                           args.algo, args.num_frames, n_graphs=model.eval_recs)
        #else:
        #    evaluator.evaluate(column=-1)
        #    win_eval = visdom_plot(viz, win_eval, evaluator.eval_log_dir, graph_name,
        #                  args.algo, args.num_frames)
            reset_eval = True

        if j % args.save_interval == 0 and args.save_dir != "":
            save_path = os.path.join(args.save_dir)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            # A really ugly way to save a model to CPU
            save_model = actor_critic
            ob_rms = getattr(get_vec_normalize(envs), 'ob_rms', None)
            save_model = copy.deepcopy(actor_critic)
            save_agent = copy.deepcopy(agent)
            if args.cuda:
                save_model.cpu()
            optim_save = save_agent.optimizer.state_dict()

            # experimental:
            torch.save(
                {
                    'past_steps':
                    next(iter(agent.optimizer.state_dict()['state'].values()))
                    ['step'],
                    'model_state_dict':
                    save_model.state_dict(),
                    'optimizer_state_dict':
                    optim_save,
                    'ob_rms':
                    ob_rms,
                    'args':
                    args
                }, os.path.join(save_path, args.env_name + ".tar"))

        #save_model = [save_model,
        #              getattr(get_vec_normalize(envs), 'ob_rms', None)]

        #torch.save(save_model, os.path.join(save_path, args.env_name + ".pt"))
        #save_agent = copy.deepcopy(agent)

        #torch.save(save_agent, os.path.join(save_path, args.env_name + '_agent.pt'))
        #torch.save(actor_critic.state_dict(), os.path.join(save_path, args.env_name + "_weights.pt"))

        if args.vis and j % args.vis_interval == 0:
            if plotter is None:
                plotter = Plotter(n_cols, args.log_dir, args.num_processes)
            try:
                # Sometimes monitor doesn't properly flush the outputs
                win = plotter.visdom_plot(viz, win, args.log_dir, graph_name,
                                          args.algo, args.num_frames)
            except IOError:
                pass
def main():
    args = get_args()

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    # Setup Logging
    log_dir = "{}/models/{}/".format(args.dump_location, args.exp_name)
    dump_dir = "{}/dump/{}/".format(args.dump_location, args.exp_name)

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    if not os.path.exists(dump_dir):
        os.makedirs(dump_dir)

    logging.basicConfig(
        filename=log_dir + 'train.log',
        level=logging.INFO)
    print("Dumping at {}".format(log_dir))
    print(args)
    logging.info(args)

    # Logging and loss variables
    num_scenes = args.num_processes
    num_episodes = int(args.num_eval_episodes)
    device = args.device = torch.device("cuda:0" if args.cuda else "cpu")

    g_masks = torch.ones(num_scenes).float().to(device)

    best_g_reward = -np.inf

    if args.eval:
        episode_success = []
        episode_spl = []
        episode_dist = []
        for _ in range(args.num_processes):
            episode_success.append(deque(maxlen=num_episodes))
            episode_spl.append(deque(maxlen=num_episodes))
            episode_dist.append(deque(maxlen=num_episodes))

    else:
        episode_success = deque(maxlen=1000)
        episode_spl = deque(maxlen=1000)
        episode_dist = deque(maxlen=1000)

    finished = np.zeros((args.num_processes))
    wait_env = np.zeros((args.num_processes))

    g_episode_rewards = deque(maxlen=1000)

    g_value_losses = deque(maxlen=1000)
    g_action_losses = deque(maxlen=1000)
    g_dist_entropies = deque(maxlen=1000)

    per_step_g_rewards = deque(maxlen=1000)

    g_process_rewards = np.zeros((num_scenes))

    # Starting environments
    torch.set_num_threads(1)
    envs = make_vec_envs(args)
    obs, infos = envs.reset()

    torch.set_grad_enabled(False)

    # Initialize map variables:
    # Full map consists of multiple channels containing the following:
    # 1. Obstacle Map
    # 2. Exploread Area
    # 3. Current Agent Location
    # 4. Past Agent Locations
    # 5,6,7,.. : Semantic Categories
    nc = args.num_sem_categories + 4  # num channels

    # Calculating full and local map sizes
    map_size = args.map_size_cm // args.map_resolution
    full_w, full_h = map_size, map_size
    local_w = int(full_w / args.global_downscaling)
    local_h = int(full_h / args.global_downscaling)

    # Initializing full and local map
    full_map = torch.zeros(num_scenes, nc, full_w, full_h).float().to(device)
    local_map = torch.zeros(num_scenes, nc, local_w,
                            local_h).float().to(device)

    # Initial full and local pose
    full_pose = torch.zeros(num_scenes, 3).float().to(device)
    local_pose = torch.zeros(num_scenes, 3).float().to(device)

    # Origin of local map
    origins = np.zeros((num_scenes, 3))

    # Local Map Boundaries
    lmb = np.zeros((num_scenes, 4)).astype(int)

    # Planner pose inputs has 7 dimensions
    # 1-3 store continuous global agent location
    # 4-7 store local map boundaries
    planner_pose_inputs = np.zeros((num_scenes, 7))

    def get_local_map_boundaries(agent_loc, local_sizes, full_sizes):
        loc_r, loc_c = agent_loc
        local_w, local_h = local_sizes
        full_w, full_h = full_sizes

        if args.global_downscaling > 1:
            gx1, gy1 = loc_r - local_w // 2, loc_c - local_h // 2
            gx2, gy2 = gx1 + local_w, gy1 + local_h
            if gx1 < 0:
                gx1, gx2 = 0, local_w
            if gx2 > full_w:
                gx1, gx2 = full_w - local_w, full_w

            if gy1 < 0:
                gy1, gy2 = 0, local_h
            if gy2 > full_h:
                gy1, gy2 = full_h - local_h, full_h
        else:
            gx1, gx2, gy1, gy2 = 0, full_w, 0, full_h

        return [gx1, gx2, gy1, gy2]

    def init_map_and_pose():
        full_map.fill_(0.)
        full_pose.fill_(0.)
        full_pose[:, :2] = args.map_size_cm / 100.0 / 2.0

        locs = full_pose.cpu().numpy()
        planner_pose_inputs[:, :3] = locs
        for e in range(num_scenes):
            r, c = locs[e, 1], locs[e, 0]
            loc_r, loc_c = [int(r * 100.0 / args.map_resolution),
                            int(c * 100.0 / args.map_resolution)]

            full_map[e, 2:4, loc_r - 1:loc_r + 2, loc_c - 1:loc_c + 2] = 1.0

            lmb[e] = get_local_map_boundaries((loc_r, loc_c),
                                              (local_w, local_h),
                                              (full_w, full_h))

            planner_pose_inputs[e, 3:] = lmb[e]
            origins[e] = [lmb[e][2] * args.map_resolution / 100.0,
                          lmb[e][0] * args.map_resolution / 100.0, 0.]

        for e in range(num_scenes):
            local_map[e] = full_map[e, :,
                                    lmb[e, 0]:lmb[e, 1],
                                    lmb[e, 2]:lmb[e, 3]]
            local_pose[e] = full_pose[e] - \
                torch.from_numpy(origins[e]).to(device).float()

    def init_map_and_pose_for_env(e):
        full_map[e].fill_(0.)
        full_pose[e].fill_(0.)
        full_pose[e, :2] = args.map_size_cm / 100.0 / 2.0

        locs = full_pose[e].cpu().numpy()
        planner_pose_inputs[e, :3] = locs
        r, c = locs[1], locs[0]
        loc_r, loc_c = [int(r * 100.0 / args.map_resolution),
                        int(c * 100.0 / args.map_resolution)]

        full_map[e, 2:4, loc_r - 1:loc_r + 2, loc_c - 1:loc_c + 2] = 1.0

        lmb[e] = get_local_map_boundaries((loc_r, loc_c),
                                          (local_w, local_h),
                                          (full_w, full_h))

        planner_pose_inputs[e, 3:] = lmb[e]
        origins[e] = [lmb[e][2] * args.map_resolution / 100.0,
                      lmb[e][0] * args.map_resolution / 100.0, 0.]

        local_map[e] = full_map[e, :, lmb[e, 0]:lmb[e, 1], lmb[e, 2]:lmb[e, 3]]
        local_pose[e] = full_pose[e] - \
            torch.from_numpy(origins[e]).to(device).float()

    def update_intrinsic_rew(e):
        prev_explored_area = full_map[e, 1].sum(1).sum(0)
        full_map[e, :, lmb[e, 0]:lmb[e, 1], lmb[e, 2]:lmb[e, 3]] = \
            local_map[e]
        curr_explored_area = full_map[e, 1].sum(1).sum(0)
        intrinsic_rews[e] = curr_explored_area - prev_explored_area
        intrinsic_rews[e] *= (args.map_resolution / 100.)**2  # to m^2

    init_map_and_pose()

    # Global policy observation space
    ngc = 8 + args.num_sem_categories
    es = 2
    g_observation_space = gym.spaces.Box(0, 1,
                                         (ngc,
                                          local_w,
                                          local_h), dtype='uint8')

    # Global policy action space
    g_action_space = gym.spaces.Box(low=0.0, high=0.99,
                                    shape=(2,), dtype=np.float32)

    # Global policy recurrent layer size
    g_hidden_size = args.global_hidden_size

    # Semantic Mapping
    sem_map_module = Semantic_Mapping(args).to(device)
    sem_map_module.eval()

    # Global policy
    g_policy = RL_Policy(g_observation_space.shape, g_action_space,
                         model_type=1,
                         base_kwargs={'recurrent': args.use_recurrent_global,
                                      'hidden_size': g_hidden_size,
                                      'num_sem_categories': ngc - 8
                                      }).to(device)
    g_agent = algo.PPO(g_policy, args.clip_param, args.ppo_epoch,
                       args.num_mini_batch, args.value_loss_coef,
                       args.entropy_coef, lr=args.lr, eps=args.eps,
                       max_grad_norm=args.max_grad_norm)

    global_input = torch.zeros(num_scenes, ngc, local_w, local_h)
    global_orientation = torch.zeros(num_scenes, 1).long()
    intrinsic_rews = torch.zeros(num_scenes).to(device)
    extras = torch.zeros(num_scenes, 2)

    # Storage
    g_rollouts = GlobalRolloutStorage(args.num_global_steps,
                                      num_scenes, g_observation_space.shape,
                                      g_action_space, g_policy.rec_state_size,
                                      es).to(device)

    if args.load != "0":
        print("Loading model {}".format(args.load))
        state_dict = torch.load(args.load,
                                map_location=lambda storage, loc: storage)
        g_policy.load_state_dict(state_dict)

    if args.eval:
        g_policy.eval()

    # Predict semantic map from frame 1
    poses = torch.from_numpy(np.asarray(
        [infos[env_idx]['sensor_pose'] for env_idx in range(num_scenes)])
    ).float().to(device)

    _, local_map, _, local_pose = \
        sem_map_module(obs, poses, local_map, local_pose)

    # Compute Global policy input
    locs = local_pose.cpu().numpy()
    global_input = torch.zeros(num_scenes, ngc, local_w, local_h)
    global_orientation = torch.zeros(num_scenes, 1).long()

    for e in range(num_scenes):
        r, c = locs[e, 1], locs[e, 0]
        loc_r, loc_c = [int(r * 100.0 / args.map_resolution),
                        int(c * 100.0 / args.map_resolution)]

        local_map[e, 2:4, loc_r - 1:loc_r + 2, loc_c - 1:loc_c + 2] = 1.
        global_orientation[e] = int((locs[e, 2] + 180.0) / 5.)

    global_input[:, 0:4, :, :] = local_map[:, 0:4, :, :].detach()
    global_input[:, 4:8, :, :] = nn.MaxPool2d(args.global_downscaling)(
        full_map[:, 0:4, :, :])
    global_input[:, 8:, :, :] = local_map[:, 4:, :, :].detach()
    goal_cat_id = torch.from_numpy(np.asarray(
        [infos[env_idx]['goal_cat_id'] for env_idx
         in range(num_scenes)]))

    extras = torch.zeros(num_scenes, 2)
    extras[:, 0] = global_orientation[:, 0]
    extras[:, 1] = goal_cat_id

    g_rollouts.obs[0].copy_(global_input)
    g_rollouts.extras[0].copy_(extras)

    # Run Global Policy (global_goals = Long-Term Goal)
    g_value, g_action, g_action_log_prob, g_rec_states = \
        g_policy.act(
            g_rollouts.obs[0],
            g_rollouts.rec_states[0],
            g_rollouts.masks[0],
            extras=g_rollouts.extras[0],
            deterministic=False
        )

    cpu_actions = nn.Sigmoid()(g_action).cpu().numpy()
    global_goals = [[int(action[0] * local_w), int(action[1] * local_h)]
                    for action in cpu_actions]
    global_goals = [[min(x, int(local_w - 1)), min(y, int(local_h - 1))]
                    for x, y in global_goals]

    goal_maps = [np.zeros((local_w, local_h)) for _ in range(num_scenes)]

    for e in range(num_scenes):
        goal_maps[e][global_goals[e][0], global_goals[e][1]] = 1

    planner_inputs = [{} for e in range(num_scenes)]
    for e, p_input in enumerate(planner_inputs):
        p_input['map_pred'] = local_map[e, 0, :, :].cpu().numpy()
        p_input['exp_pred'] = local_map[e, 1, :, :].cpu().numpy()
        p_input['pose_pred'] = planner_pose_inputs[e]
        p_input['goal'] = goal_maps[e]  # global_goals[e]
        p_input['new_goal'] = 1
        p_input['found_goal'] = 0
        p_input['wait'] = wait_env[e] or finished[e]
        if args.visualize or args.print_images:
            local_map[e, -1, :, :] = 1e-5
            p_input['sem_map_pred'] = local_map[e, 4:, :, :
                                                ].argmax(0).cpu().numpy()

    obs, _, done, infos = envs.plan_act_and_preprocess(planner_inputs)

    start = time.time()
    g_reward = 0

    torch.set_grad_enabled(False)
    spl_per_category = defaultdict(list)
    success_per_category = defaultdict(list)

    for step in range(args.num_training_frames // args.num_processes + 1):
        if finished.sum() == args.num_processes:
            break

        g_step = (step // args.num_local_steps) % args.num_global_steps
        l_step = step % args.num_local_steps

        # ------------------------------------------------------------------
        # Reinitialize variables when episode ends
        l_masks = torch.FloatTensor([0 if x else 1
                                     for x in done]).to(device)
        g_masks *= l_masks

        for e, x in enumerate(done):
            if x:
                spl = infos[e]['spl']
                success = infos[e]['success']
                dist = infos[e]['distance_to_goal']
                spl_per_category[infos[e]['goal_name']].append(spl)
                success_per_category[infos[e]['goal_name']].append(success)
                if args.eval:
                    episode_success[e].append(success)
                    episode_spl[e].append(spl)
                    episode_dist[e].append(dist)
                    if len(episode_success[e]) == num_episodes:
                        finished[e] = 1
                else:
                    episode_success.append(success)
                    episode_spl.append(spl)
                    episode_dist.append(dist)
                wait_env[e] = 1.
                update_intrinsic_rew(e)
                init_map_and_pose_for_env(e)
        # ------------------------------------------------------------------

        # ------------------------------------------------------------------
        # Semantic Mapping Module
        poses = torch.from_numpy(np.asarray(
            [infos[env_idx]['sensor_pose'] for env_idx
             in range(num_scenes)])
        ).float().to(device)

        _, local_map, _, local_pose = \
            sem_map_module(obs, poses, local_map, local_pose)

        locs = local_pose.cpu().numpy()
        planner_pose_inputs[:, :3] = locs + origins
        local_map[:, 2, :, :].fill_(0.)  # Resetting current location channel
        for e in range(num_scenes):
            r, c = locs[e, 1], locs[e, 0]
            loc_r, loc_c = [int(r * 100.0 / args.map_resolution),
                            int(c * 100.0 / args.map_resolution)]
            local_map[e, 2:4, loc_r - 2:loc_r + 3, loc_c - 2:loc_c + 3] = 1.

        # ------------------------------------------------------------------

        # ------------------------------------------------------------------
        # Global Policy
        if l_step == args.num_local_steps - 1:
            # For every global step, update the full and local maps
            for e in range(num_scenes):
                if wait_env[e] == 1:  # New episode
                    wait_env[e] = 0.
                else:
                    update_intrinsic_rew(e)

                full_map[e, :, lmb[e, 0]:lmb[e, 1], lmb[e, 2]:lmb[e, 3]] = \
                    local_map[e]
                full_pose[e] = local_pose[e] + \
                    torch.from_numpy(origins[e]).to(device).float()

                locs = full_pose[e].cpu().numpy()
                r, c = locs[1], locs[0]
                loc_r, loc_c = [int(r * 100.0 / args.map_resolution),
                                int(c * 100.0 / args.map_resolution)]

                lmb[e] = get_local_map_boundaries((loc_r, loc_c),
                                                  (local_w, local_h),
                                                  (full_w, full_h))

                planner_pose_inputs[e, 3:] = lmb[e]
                origins[e] = [lmb[e][2] * args.map_resolution / 100.0,
                              lmb[e][0] * args.map_resolution / 100.0, 0.]

                local_map[e] = full_map[e, :,
                                        lmb[e, 0]:lmb[e, 1],
                                        lmb[e, 2]:lmb[e, 3]]
                local_pose[e] = full_pose[e] - \
                    torch.from_numpy(origins[e]).to(device).float()

            locs = local_pose.cpu().numpy()
            for e in range(num_scenes):
                global_orientation[e] = int((locs[e, 2] + 180.0) / 5.)
            global_input[:, 0:4, :, :] = local_map[:, 0:4, :, :]
            global_input[:, 4:8, :, :] = \
                nn.MaxPool2d(args.global_downscaling)(
                    full_map[:, 0:4, :, :])
            global_input[:, 8:, :, :] = local_map[:, 4:, :, :].detach()
            goal_cat_id = torch.from_numpy(np.asarray(
                [infos[env_idx]['goal_cat_id'] for env_idx
                 in range(num_scenes)]))
            extras[:, 0] = global_orientation[:, 0]
            extras[:, 1] = goal_cat_id

            # Get exploration reward and metrics
            g_reward = torch.from_numpy(np.asarray(
                [infos[env_idx]['g_reward'] for env_idx in range(num_scenes)])
            ).float().to(device)
            g_reward += args.intrinsic_rew_coeff * intrinsic_rews.detach()

            g_process_rewards += g_reward.cpu().numpy()
            g_total_rewards = g_process_rewards * \
                (1 - g_masks.cpu().numpy())
            g_process_rewards *= g_masks.cpu().numpy()
            per_step_g_rewards.append(np.mean(g_reward.cpu().numpy()))

            if np.sum(g_total_rewards) != 0:
                for total_rew in g_total_rewards:
                    if total_rew != 0:
                        g_episode_rewards.append(total_rew)

            # Add samples to global policy storage
            if step == 0:
                g_rollouts.obs[0].copy_(global_input)
                g_rollouts.extras[0].copy_(extras)
            else:
                g_rollouts.insert(
                    global_input, g_rec_states,
                    g_action, g_action_log_prob, g_value,
                    g_reward, g_masks, extras
                )

            # Sample long-term goal from global policy
            g_value, g_action, g_action_log_prob, g_rec_states = \
                g_policy.act(
                    g_rollouts.obs[g_step + 1],
                    g_rollouts.rec_states[g_step + 1],
                    g_rollouts.masks[g_step + 1],
                    extras=g_rollouts.extras[g_step + 1],
                    deterministic=False
                )
            cpu_actions = nn.Sigmoid()(g_action).cpu().numpy()
            global_goals = [[int(action[0] * local_w),
                             int(action[1] * local_h)]
                            for action in cpu_actions]
            global_goals = [[min(x, int(local_w - 1)),
                             min(y, int(local_h - 1))]
                            for x, y in global_goals]

            g_reward = 0
            g_masks = torch.ones(num_scenes).float().to(device)

        # ------------------------------------------------------------------

        # ------------------------------------------------------------------
        # Update long-term goal if target object is found
        found_goal = [0 for _ in range(num_scenes)]
        goal_maps = [np.zeros((local_w, local_h)) for _ in range(num_scenes)]

        for e in range(num_scenes):
            goal_maps[e][global_goals[e][0], global_goals[e][1]] = 1

        for e in range(num_scenes):
            cn = infos[e]['goal_cat_id'] + 4
            if local_map[e, cn, :, :].sum() != 0.:
                cat_semantic_map = local_map[e, cn, :, :].cpu().numpy()
                cat_semantic_scores = cat_semantic_map
                cat_semantic_scores[cat_semantic_scores > 0] = 1.
                goal_maps[e] = cat_semantic_scores
                found_goal[e] = 1
        # ------------------------------------------------------------------

        # ------------------------------------------------------------------
        # Take action and get next observation
        planner_inputs = [{} for e in range(num_scenes)]
        for e, p_input in enumerate(planner_inputs):
            p_input['map_pred'] = local_map[e, 0, :, :].cpu().numpy()
            p_input['exp_pred'] = local_map[e, 1, :, :].cpu().numpy()
            p_input['pose_pred'] = planner_pose_inputs[e]
            p_input['goal'] = goal_maps[e]  # global_goals[e]
            p_input['new_goal'] = l_step == args.num_local_steps - 1
            p_input['found_goal'] = found_goal[e]
            p_input['wait'] = wait_env[e] or finished[e]
            if args.visualize or args.print_images:
                local_map[e, -1, :, :] = 1e-5
                p_input['sem_map_pred'] = local_map[e, 4:, :,
                                                    :].argmax(0).cpu().numpy()

        obs, _, done, infos = envs.plan_act_and_preprocess(planner_inputs)
        # ------------------------------------------------------------------

        # ------------------------------------------------------------------
        # Training
        torch.set_grad_enabled(True)
        if g_step % args.num_global_steps == args.num_global_steps - 1 \
                and l_step == args.num_local_steps - 1:
            if not args.eval:
                g_next_value = g_policy.get_value(
                    g_rollouts.obs[-1],
                    g_rollouts.rec_states[-1],
                    g_rollouts.masks[-1],
                    extras=g_rollouts.extras[-1]
                ).detach()

                g_rollouts.compute_returns(g_next_value, args.use_gae,
                                           args.gamma, args.tau)
                g_value_loss, g_action_loss, g_dist_entropy = \
                    g_agent.update(g_rollouts)
                g_value_losses.append(g_value_loss)
                g_action_losses.append(g_action_loss)
                g_dist_entropies.append(g_dist_entropy)
            g_rollouts.after_update()

        torch.set_grad_enabled(False)
        # ------------------------------------------------------------------

        # ------------------------------------------------------------------
        # Logging
        if step % args.log_interval == 0:
            end = time.time()
            time_elapsed = time.gmtime(end - start)
            log = " ".join([
                "Time: {0:0=2d}d".format(time_elapsed.tm_mday - 1),
                "{},".format(time.strftime("%Hh %Mm %Ss", time_elapsed)),
                "num timesteps {},".format(step * num_scenes),
                "FPS {},".format(int(step * num_scenes / (end - start)))
            ])

            log += "\n\tRewards:"

            if len(g_episode_rewards) > 0:
                log += " ".join([
                    " Global step mean/med rew:",
                    "{:.4f}/{:.4f},".format(
                        np.mean(per_step_g_rewards),
                        np.median(per_step_g_rewards)),
                    " Global eps mean/med/min/max eps rew:",
                    "{:.3f}/{:.3f}/{:.3f}/{:.3f},".format(
                        np.mean(g_episode_rewards),
                        np.median(g_episode_rewards),
                        np.min(g_episode_rewards),
                        np.max(g_episode_rewards))
                ])

            if args.eval:
                total_success = []
                total_spl = []
                total_dist = []
                for e in range(args.num_processes):
                    for acc in episode_success[e]:
                        total_success.append(acc)
                    for dist in episode_dist[e]:
                        total_dist.append(dist)
                    for spl in episode_spl[e]:
                        total_spl.append(spl)

                if len(total_spl) > 0:
                    log += " ObjectNav succ/spl/dtg:"
                    log += " {:.3f}/{:.3f}/{:.3f}({:.0f}),".format(
                        np.mean(total_success),
                        np.mean(total_spl),
                        np.mean(total_dist),
                        len(total_spl))
            else:
                if len(episode_success) > 100:
                    log += " ObjectNav succ/spl/dtg:"
                    log += " {:.3f}/{:.3f}/{:.3f}({:.0f}),".format(
                        np.mean(episode_success),
                        np.mean(episode_spl),
                        np.mean(episode_dist),
                        len(episode_spl))

            log += "\n\tLosses:"
            if len(g_value_losses) > 0 and not args.eval:
                log += " ".join([
                    " Policy Loss value/action/dist:",
                    "{:.3f}/{:.3f}/{:.3f},".format(
                        np.mean(g_value_losses),
                        np.mean(g_action_losses),
                        np.mean(g_dist_entropies))
                ])

            print(log)
            logging.info(log)
        # ------------------------------------------------------------------

        # ------------------------------------------------------------------
        # Save best models
        if (step * num_scenes) % args.save_interval < \
                num_scenes:
            if len(g_episode_rewards) >= 1000 and \
                    (np.mean(g_episode_rewards) >= best_g_reward) \
                    and not args.eval:
                torch.save(g_policy.state_dict(),
                           os.path.join(log_dir, "model_best.pth"))
                best_g_reward = np.mean(g_episode_rewards)

        # Save periodic models
        if (step * num_scenes) % args.save_periodic < \
                num_scenes:
            total_steps = step * num_scenes
            if not args.eval:
                torch.save(g_policy.state_dict(),
                           os.path.join(dump_dir,
                                        "periodic_{}.pth".format(total_steps)))
        # ------------------------------------------------------------------

    # Print and save model performance numbers during evaluation
    if args.eval:
        print("Dumping eval details...")
        
        total_success = []
        total_spl = []
        total_dist = []
        for e in range(args.num_processes):
            for acc in episode_success[e]:
                total_success.append(acc)
            for dist in episode_dist[e]:
                total_dist.append(dist)
            for spl in episode_spl[e]:
                total_spl.append(spl)

        if len(total_spl) > 0:
            log = "Final ObjectNav succ/spl/dtg:"
            log += " {:.3f}/{:.3f}/{:.3f}({:.0f}),".format(
                np.mean(total_success),
                np.mean(total_spl),
                np.mean(total_dist),
                len(total_spl))

        print(log)
        logging.info(log)
            
        # Save the spl per category
        log = "Success | SPL per category\n"
        for key in success_per_category:
            log += "{}: {} | {}\n".format(key,
                                          sum(success_per_category[key]) /
                                          len(success_per_category[key]),
                                          sum(spl_per_category[key]) /
                                          len(spl_per_category[key]))

        print(log)
        logging.info(log)

        with open('{}/{}_spl_per_cat_pred_thr.json'.format(
                dump_dir, args.split), 'w') as f:
            json.dump(spl_per_category, f)

        with open('{}/{}_success_per_cat_pred_thr.json'.format(
                dump_dir, args.split), 'w') as f:
            json.dump(success_per_category, f)
Exemple #11
0
def main():
    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    run_id = "alpha{}".format(args.gcn_alpha)
    if args.use_logger:
        from utils import Logger
        folder = "{}/{}".format(args.folder, run_id)
        logger = Logger(algo_name=args.algo,
                        environment_name=args.env_name,
                        folder=folder,
                        seed=args.seed)
        logger.save_args(args)

        print("---------------------------------------")
        print('Saving to', logger.save_folder)
        print("---------------------------------------")

    else:
        print("---------------------------------------")
        print('NOTE : NOT SAVING RESULTS')
        print("---------------------------------------")
    all_rewards = []

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, args.add_timestep, device,
                         False)

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          args.env_name,
                          base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size,
                              actor_critic.base.output_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    ############################
    # GCN Model and optimizer
    from pygcn.train import update_graph
    from pygcn.models import GCN, GAT, SAGE
    assert args.gnn in ['gcn', 'gat', 'sage']

    if args.gnn == 'gat':
        gcn_model = GAT(nfeat=actor_critic.base.output_size,
                        nhid=args.gcn_hidden)
    elif args.gnn == 'sage':
        gcn_model = SAGE(nfeat=actor_critic.base.output_size,
                         nhid=args.gcn_hidden)
    elif args.gnn == 'gcn':
        gcn_model = GCN(nfeat=actor_critic.base.output_size,
                        nhid=args.gcn_hidden)

    gcn_model.to(device)
    gcn_optimizer = optim.Adam(gcn_model.parameters(),
                               lr=args.gcn_lr,
                               weight_decay=args.gcn_weight_decay)
    gcn_loss = nn.NLLLoss()
    gcn_states = [[] for _ in range(args.num_processes)]
    Gs = [nx.Graph() for _ in range(args.num_processes)]
    node_ptrs = [0 for _ in range(args.num_processes)]
    rew_states = [[] for _ in range(args.num_processes)]
    ############################

    episode_rewards = deque(maxlen=100)
    avg_fwdloss = deque(maxlen=100)
    rew_rms = RunningMeanStd(shape=())
    delay_rew = torch.zeros([args.num_processes, 1])
    delay_step = torch.zeros([args.num_processes])

    start = time.time()
    for j in range(num_updates):

        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr)

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob,\
                 recurrent_hidden_states, hidden_states = actor_critic.act(
                        rollouts.obs[step],
                        rollouts.recurrent_hidden_states[step],
                        rollouts.masks[step])

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)
            delay_rew += reward
            delay_step += 1

            for idx, (info, hid,
                      eps_done) in enumerate(zip(infos, hidden_states, done)):

                if eps_done or delay_step[idx] == args.reward_freq:
                    reward[idx] = delay_rew[idx]
                    delay_rew[idx] = delay_step[idx] = 0
                else:
                    reward[idx] = 0

                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

                if args.gcn_alpha < 1.0:
                    gcn_states[idx].append(hid)
                    node_ptrs[idx] += 1
                    if not eps_done:
                        Gs[idx].add_edge(node_ptrs[idx] - 1, node_ptrs[idx])
                    if reward[idx] != 0. or eps_done:
                        rew_states[idx].append(
                            [node_ptrs[idx] - 1, reward[idx]])
                    if eps_done:
                        adj = nx.adjacency_matrix(Gs[idx]) if len(Gs[idx].nodes)\
                                        else sp.csr_matrix(np.eye(1,dtype='int64'))
                        update_graph(gcn_model, gcn_optimizer,
                                     torch.stack(gcn_states[idx]), adj,
                                     rew_states[idx], gcn_loss, args, envs)
                        gcn_states[idx] = []
                        Gs[idx] = nx.Graph()
                        node_ptrs[idx] = 0
                        rew_states[idx] = []

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks,
                            hidden_states)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.tau, gcn_model, args.gcn_alpha)
        agent.update(rollouts)
        rollouts.after_update()

        ####################### Saving and book-keeping #######################
        if (j % int(num_updates / 5.) == 0
                or j == num_updates - 1) and args.save_dir != "":
            print('Saving model')
            print()

            save_dir = "{}/{}/{}".format(args.save_dir, args.folder, run_id)
            save_path = os.path.join(save_dir, args.algo, 'seed' +
                                     str(args.seed)) + '_iter' + str(j)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            # A really ugly way to save a model to CPU
            save_model = actor_critic
            save_gcn = gcn_model
            if args.cuda:
                save_model = copy.deepcopy(actor_critic).cpu()
                save_gcn = copy.deepcopy(gcn_model).cpu()

            save_model = [
                save_gcn, save_model,
                hasattr(envs.venv, 'ob_rms') and envs.venv.ob_rms or None
            ]

            torch.save(save_model,
                       os.path.join(save_path, args.env_name + "ac.pt"))

        total_num_steps = (j + 1) * args.num_processes * args.num_steps

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            end = time.time()
            print("Updates {}, num timesteps {}, FPS {} \n Last {}\
             training episodes: mean/median reward {:.2f}/{:.2f},\
              min/max reward {:.2f}/{:.2f}, success rate {:.2f}\n".format(
                j,
                total_num_steps,
                int(total_num_steps / (end - start)),
                len(episode_rewards),
                np.mean(episode_rewards),
                np.median(episode_rewards),
                np.min(episode_rewards),
                np.max(episode_rewards),
                np.count_nonzero(np.greater(episode_rewards, 0)) /
                len(episode_rewards),
            ))

            all_rewards.append(np.mean(episode_rewards))
            if args.use_logger:
                logger.save_task_results(all_rewards)
        ####################### Saving and book-keeping #######################

    envs.close()
Exemple #12
0
def main():
    writer = SummaryWriter()
    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")
    best_score = 0

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, args.add_timestep, device,
                         False, 4, args.carl_wrapper)

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          args.activation,
                          base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    assert (args.algo == 'a2c')
    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    beta_device = (torch.ones(args.num_processes, 1)).to(device)
    masks_device = torch.ones(args.num_processes, 1).to(device)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    obs = obs / 255
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)
    g_step = 0
    for j in range(num_updates):
        for step in range(args.num_steps):
            # sample actions
            g_step += 1
            eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(
                -1. * g_step / EPS_DECAY)
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states, ori_dist_entropy = actor_critic.act(
                    rollouts.obs[step],
                    rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step],
                    deterministic=True)
            ori_dist_entropy = ori_dist_entropy.cpu().unsqueeze(1)

            # select action based on epsilon greedy
            rand_val = torch.rand(action.shape).to(device)
            eps_mask = (rand_val >= eps_threshold).type(torch.int64)
            rand_action = torch.LongTensor([
                envs.action_space.sample() for i in range(args.num_processes)
            ]).unsqueeze(1).to(device)
            action = eps_mask * action + (1 - eps_mask) * rand_action
            obs, reward, done, infos = envs.step(action)
            obs = obs / 255

            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])

            if args.log_evaluation:
                writer.add_scalar('analysis/reward', reward[0], g_step)
                writer.add_scalar('analysis/entropy',
                                  ori_dist_entropy[0].item(), g_step)
                writer.add_scalar('analysis/eps', eps_threshold, g_step)
                if done[0]:
                    writer.add_scalar('analysis/done', 1, g_step)

            # save model
            for idx in range(len(infos)):
                info = infos[idx]
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])
                    steps_done = g_step * args.num_processes + idx
                    writer.add_scalar('data/reward', info['episode']['r'],
                                      steps_done)
                    mean_rewards = np.mean(episode_rewards)
                    writer.add_scalar('data/avg_reward', mean_rewards,
                                      steps_done)
                    if mean_rewards > best_score:
                        best_score = mean_rewards
                        save_model = actor_critic
                        if args.cuda:
                            save_model = copy.deepcopy(actor_critic).cpu()
                        torch.save(
                            save_model,
                            os.path.join(save_path, args.env_name + ".pt"))

            # update storage
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, beta_device)

        with torch.no_grad():
            masks_device.copy_(masks)
            next_value = actor_critic.get_value(obs, recurrent_hidden_states,
                                                masks_device)

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.tau)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

    writer.export_scalars_to_json("./all_scalars.json")
    writer.close()
Exemple #13
0
def main():
    print("#######")
    print("WARNING: All rewards are clipped or normalized so you need to use a monitor (see envs.py) or visdom plot to get true rewards")
    print("#######")

    os.environ['OMP_NUM_THREADS'] = '1'

    if args.vis:
        from visdom import Visdom
        viz = Visdom(port=args.port)
        win = None

    names = getListOfGames("train")

    envs = [make_env_train(names[i], args.seed, i, args.log_dir)
                for i in range(len(names))]
                
    # TODO TODO TODO TODO TODO TODO TODO TODO TODO TODO TODO TODO TODO TODO TODO TODO TODO TODO
    args.num_processes = len(envs)
    # REMEMBER YOU CHENGED IT

    if len(envs) > 1:
        envs = SubprocVecEnv(envs)
    else:
        envs = DummyVecEnv(envs)

    if len(envs.observation_space.shape) == 1:
        envs = VecNormalize(envs)

    obs_shape = envs.observation_space.shape
    #print(obs_shape)
    obs_shape = (obs_shape[0], *obs_shape[1:])
    #print(obs_shape)

    if len(envs.observation_space.shape) == 3:
        actor_critic = CNNPolicy(obs_shape[0], envs.action_space, args.recurrent_policy)
    else:
        assert not args.recurrent_policy, \
            "Recurrent policy is not implemented for the MLP controller"
        actor_critic = MLPPolicy(obs_shape[0], envs.action_space)

    # Making it paralel
    actor_critic = torch.nn.parallel.DataParallel(actor_critic).module

    if envs.action_space.__class__.__name__ == "Discrete":
        action_shape = 1
    else:
        action_shape = envs.action_space.shape[0]

    if args.cuda:
       actor_critic.cuda()

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic, args.value_loss_coef,
                               args.entropy_coef, lr=args.lr,
                               eps=args.eps, alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch,
                         args.value_loss_coef, args.entropy_coef, lr=args.lr,
                               eps=args.eps,
                               max_grad_norm=args.max_grad_norm)
        # Make agent DataParallel
        agent = torch.nn.parallel.DataParallel(agent).module

    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic, args.value_loss_coef,
                               args.entropy_coef, acktr=True)

    # Make rollouts DataParallel
    rollouts = torch.nn.parallel.DataParallel(RolloutStorage(args.num_steps, args.num_processes, obs_shape, envs.action_space, actor_critic.state_size)).module
    current_obs = torch.nn.parallel.DataParallel(torch.zeros(envs.nenvs, *obs_shape)).module

    def update_current_obs(obs):
        shape_dim0 = envs.observation_space.shape[0]
        obs = torch.from_numpy(obs).float()
        # if args.num_stack > 1:
        #     current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
        current_obs[:, -shape_dim0:] = obs

    obs = envs.reset()
    update_current_obs(obs)

    rollouts.observations[0].copy_(current_obs)

    # These variables are used to compute average rewards for all processes.
    episode_rewards = torch.zeros([args.num_processes, 1])
    final_rewards = torch.zeros([args.num_processes, 1])

    if args.cuda:
        current_obs = current_obs.cuda()
        rollouts.cuda()

    start = time.time()
    for j in range(num_updates):
        for step in range(args.num_steps):
            # Sample actions
            value, action, action_log_prob, states = actor_critic.act(
                    Variable(rollouts.observations[step], volatile=True),
                    Variable(rollouts.states[step], volatile=True),
                    Variable(rollouts.masks[step], volatile=True))
            cpu_actions = action.data.squeeze(1).cpu().numpy()

            # Obser reward and next obs
            obs, reward, done, info = envs.step(cpu_actions)
            reward = torch.from_numpy(np.expand_dims(np.stack(reward), 1)).float()
            episode_rewards += reward

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])
            final_rewards *= masks
            final_rewards += (1 - masks) * episode_rewards
            episode_rewards *= masks

            if args.cuda:
                masks = masks.cuda()

            if current_obs.dim() == 4:
                current_obs *= masks.unsqueeze(2).unsqueeze(2)
            else:
                current_obs *= masks

            update_current_obs(obs)
            rollouts.insert(current_obs, states.data, action.data, action_log_prob.data, value.data, reward, masks)

        next_value = actor_critic.get_value(Variable(rollouts.observations[-1], volatile=True),
                                            Variable(rollouts.states[-1], volatile=True),
                                            Variable(rollouts.masks[-1], volatile=True)).data

        rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.tau)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        if j % args.save_interval == 0 and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            # A really ugly way to save a model to CPU
            save_model = actor_critic
            if args.cuda:
                save_model = copy.deepcopy(actor_critic).cpu()

            save_model = [save_model,
                            hasattr(envs, 'ob_rms') and envs.ob_rms or None]

            torch.save(save_model, os.path.join(save_path, args.env_name + ".pt"))

        if j % args.log_interval == 0:
            end = time.time()
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            print("Updates {}, num timesteps {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}".
                format(j, total_num_steps,
                       int(total_num_steps / (end - start)),
                       final_rewards.mean(),
                       final_rewards.median(),
                       final_rewards.min(),
                       final_rewards.max(), dist_entropy.data[0],
                       value_loss.data[0], action_loss.data[0]))
        if args.vis and j % args.vis_interval == 0:
            try:
                # Sometimes monitor doesn't properly flush the outputs
                win = visdom_plot(viz, win, args.log_dir, args.env_name,
                                  args.algo, args.num_frames)
            except IOError:
                pass
Exemple #14
0
def main():

    print('Preparing parameters')

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    print('Creating envs: {}'.format(args.env_name))

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, args.add_timestep, device,
                         False)

    # input(envs)
    print('Creating network')
    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    print('Initializing PPO')
    agent = algo.PPO(actor_critic,
                     args.clip_param,
                     args.ppo_epoch,
                     args.num_mini_batch,
                     args.value_loss_coef,
                     args.entropy_coef,
                     lr=args.lr,
                     eps=args.eps,
                     max_grad_norm=args.max_grad_norm)

    print('Memory')
    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    num_episodes = [0 for _ in range(args.num_processes)]

    if args.run_id == "debug":
        try:
            shutil.rmtree('./runs/debug')
        except:
            pass

    writer = SummaryWriter("./runs/{}".format(args.run_id))
    with open('./runs/{}/recap.txt'.format(args.run_id), 'w') as file:
        file.write(str(actor_critic))

    last_index = 0

    print('Starting ! ')

    start = time.time()
    for j in tqdm(range(num_updates)):
        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob = actor_critic.act(
                    rollouts.obs[step], rollouts.masks[step])

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)

            for info_num, info in enumerate(infos):
                if (info_num == 0):
                    if 'episode' in info.keys():
                        episode_rewards.append(info['episode']['r'])
                        end_episode_to_viz(writer, info, info_num,
                                           num_episodes[info_num])
                        num_episodes[info_num] += 1

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            rollouts.insert(obs, action, action_log_prob, value, reward, masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.obs[-1],
                                                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.tau)
        losses = agent.update(rollouts)
        rollouts.after_update()

        losses_to_viz(writer, losses, j)
        create_checkpoint(actor_critic, envs, args)
        last_index = global_rew_to_viz(writer, last_index)
def main():
    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    if args.vis:
        from visdom import Visdom
        viz = Visdom(port=args.port)
        win = None

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, args.add_timestep, device,
                         False)

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    print('args.lr')
    print(args.lr)

    #     print('args.stat_decay')
    #     print(args.stat_decay)

    #     sys.exit()

    if args.algo == 'a2c':

        #         print('args.eps')
        #         print(args.eps)

        #         sys.exit()

        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo in ['acktr']:
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               acktr=True,
                               stat_decay=args.stat_decay)
    elif args.algo in ['acktr-h**o']:
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               acktr=True,
                               if_homo=True,
                               stat_decay=args.stat_decay)
    elif args.algo in ['acktr-h**o-noEigen']:
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               acktr=True,
                               if_homo=True,
                               stat_decay=args.stat_decay,
                               if_eigen=False)
    elif args.algo in ['kbfgs']:

        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               kbfgs=True,
                               stat_decay=args.stat_decay)
    elif args.algo in ['kbfgs-h**o']:

        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               kbfgs=True,
                               if_homo=True,
                               stat_decay=args.stat_decay)
    elif args.algo in ['kbfgs-h**o-invertA']:

        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               kbfgs=True,
                               if_homo=True,
                               stat_decay=args.stat_decay,
                               if_invert_A=True)

    elif args.algo in ['kbfgs-h**o-invertA-decoupledDecay']:

        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               kbfgs=True,
                               if_homo=True,
                               stat_decay_A=args.stat_decay_A,
                               stat_decay_G=args.stat_decay_G,
                               if_invert_A=True,
                               if_decoupled_decay=True)
    elif args.algo in ['kbfgs-h**o-momentumGrad']:

        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               kbfgs=True,
                               if_homo=True,
                               if_momentumGrad=True,
                               stat_decay=args.stat_decay)
    elif args.algo in ['kbfgs-h**o-noClip']:

        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               kbfgs=True,
                               if_homo=True,
                               if_clip=False,
                               stat_decay=args.stat_decay)
    else:
        print('unknown args.algo for ' + args.algo)
        sys.exit()

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    record_rewards = []

    record_num_steps = []

    print('num_updates')
    print(num_updates)

    total_num_steps = 0

    start = time.time()
    for j in range(num_updates):

        print('j')
        print(j)

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)

            for info in infos:

                #                 print('info.keys()')
                #                 print(info.keys())

                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

                    print('info[episode][r]')
                    print(info['episode']['r'])

                    record_rewards.append(info['episode']['r'])

                    #                     print('total_num_steps')
                    #                     print(total_num_steps)

                    #                     print('total_num_steps + (step + 1) * args.num_processes')
                    #                     print(total_num_steps + (step + 1) * args.num_processes)

                    record_num_steps.append(total_num_steps +
                                            (step + 1) * args.num_processes)

#                     sys.exit()

# If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.tau)

        value_loss, action_loss, dist_entropy, update_signal = agent.update(
            rollouts)

        if update_signal == -1:
            #             sys.exit()
            break

        rollouts.after_update()

        if j % args.save_interval == 0 and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            # A really ugly way to save a model to CPU
            save_model = actor_critic
            if args.cuda:
                save_model = copy.deepcopy(actor_critic).cpu()

            save_model = [
                save_model,
                getattr(get_vec_normalize(envs), 'ob_rms', None)
            ]

            torch.save(save_model,
                       os.path.join(save_path, args.env_name + ".pt"))

        total_num_steps = (j + 1) * args.num_processes * args.num_steps

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            eval_envs = make_vec_envs(args.env_name,
                                      args.seed + args.num_processes,
                                      args.num_processes, args.gamma,
                                      eval_log_dir, args.add_timestep, device,
                                      True)

            vec_norm = get_vec_normalize(eval_envs)
            if vec_norm is not None:
                vec_norm.eval()
                vec_norm.ob_rms = get_vec_normalize(envs).ob_rms

            eval_episode_rewards = []

            obs = eval_envs.reset()
            eval_recurrent_hidden_states = torch.zeros(
                args.num_processes,
                actor_critic.recurrent_hidden_state_size,
                device=device)
            eval_masks = torch.zeros(args.num_processes, 1, device=device)

            while len(eval_episode_rewards) < 10:
                with torch.no_grad():
                    _, action, _, eval_recurrent_hidden_states = actor_critic.act(
                        obs,
                        eval_recurrent_hidden_states,
                        eval_masks,
                        deterministic=True)

                # Obser reward and next obs
                obs, reward, done, infos = eval_envs.step(action)

                eval_masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                                for done_ in done])
                for info in infos:
                    if 'episode' in info.keys():
                        eval_episode_rewards.append(info['episode']['r'])

            eval_envs.close()

            print(" Evaluation using {} episodes: mean reward {:.5f}\n".format(
                len(eval_episode_rewards), np.mean(eval_episode_rewards)))

        if args.vis and j % args.vis_interval == 0:
            try:
                # Sometimes monitor doesn't properly flush the outputs
                win = visdom_plot(viz, win, args.log_dir, args.env_name,
                                  args.algo, args.num_frames)
            except IOError:
                pass

    print('record_rewards')
    print(record_rewards)

    dir_with_params = args.env_name + '/' +\
    args.algo + '/' +\
    'eps_' + str(args.eps) + '/' +\
    'lr_' + str(args.lr) + '/' +\
    'stat_decay_' + str(args.stat_decay) + '/'

    #     saving_dir = './result/' + args.env_name + '/' + args.algo + '/'
    saving_dir = './result/' + dir_with_params

    if not os.path.isdir(saving_dir):
        os.makedirs(saving_dir)

    import pickle

    with open(saving_dir + 'result.pkl', 'wb') as handle:
        pickle.dump(
            {
                'record_rewards': record_rewards,
                'record_num_steps': record_num_steps
            }, handle)

    print('args.log_dir')
    print(args.log_dir)

    print('os.listdir(args.log_dir)')
    print(os.listdir(args.log_dir))

    #     saving_dir_monitor = './result_monitor/' + args.env_name + '/' + args.algo + '/'

    saving_dir_monitor = './result_monitor/' + dir_with_params

    if os.path.isdir(saving_dir_monitor):
        import shutil

        shutil.rmtree(saving_dir_monitor)

    if not os.path.isdir(saving_dir_monitor):
        os.makedirs(saving_dir_monitor)

    print('saving_dir_monitor')
    print(saving_dir_monitor)

    import shutil

    for file_name in os.listdir(args.log_dir):

        full_file_name = os.path.join(args.log_dir, file_name)

        print('full_file_name')
        print(full_file_name)

        print('os.path.isfile(full_file_name)')
        print(os.path.isfile(full_file_name))

        if os.path.isfile(full_file_name):
            shutil.copy(full_file_name, saving_dir_monitor)

#     print('os.listdir(saving_dir_monitor)')
#     print(os.listdir(saving_dir_monitor))

#     print('len(os.listdir(saving_dir_monitor))')
#     print(len(os.listdir(saving_dir_monitor)))

#     print('args.num_processes')
#     print(args.num_processes)

    assert len(os.listdir(saving_dir_monitor)) == args.num_processes
Exemple #16
0
def train_a_gym_model(env, config):
    """We train gym-type RL problem using ppo given environment and configuration"""
    torch.set_num_threads(1)

    seed = config.get('seed', 1)
    log_dir = config.get('log_dir', '/tmp/gym')
    log_interval = config.get('log_interval', 10)
    save_interval = config.get('save_interval', 100)
    save_dir = config.get('save_dir', 'trained_models/ppo')
    add_timestep = config.get('add_timestep', False)
    num_processes = config.get('num_processes', 4)
    gamma = config.get('gamma', 0.99)
    num_stack = config.get('num_stack', 1)
    recurrent_policy = config.get('recurrent_policy', False)
    cuda = config.get('cuda', True)
    vis = config.get('vis', True)
    vis_interval = config.get('vis_interval', 100)
    env_name = config['env_name']
    save_step = config.get('save_step', None)
    if save_step is not None:
        next_save_step = save_step

    # clean the log folder, if necessary
    try:
        os.makedirs(log_dir)
    except OSError:
        files = glob.glob(os.path.join(log_dir, '*.monitor.csv'))
        for f in files:
            os.remove(f)

    torch.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed(seed)

    if vis:
        from visdom import Visdom
        port = config.get('port', 8097)
        viz = Visdom(port=port)
        win = None

    envs = [make_env(env, seed, i, log_dir, add_timestep)
            for i in range(num_processes)]

    if num_processes > 1:
        envs = SubprocVecEnv(envs)
    else:
        envs = DummyVecEnv(envs)

    if len(envs.observation_space.shape) == 1:
        envs = VecNormalize(envs, gamma=gamma)

    obs_shape = envs.observation_space.shape
    obs_shape = (obs_shape[0] * num_stack, *obs_shape[1:])

    actor_critic = Policy(obs_shape, envs.action_space, recurrent_policy)

    if envs.action_space.__class__.__name__ == "Discrete":
        action_shape = 1
    else:
        action_shape = envs.action_space.shape[0]

    if cuda:
        actor_critic.cuda()

    clip_param = config.get('clip_param', 0.2)
    ppo_epoch = config.get('ppo_epoch', 4)
    num_mini_batch = config.get('num_mini_batch', 32)
    value_loss_coef = config.get('value_loss_coef', 0.5)
    entropy_coef = config.get('entropy_coef', 0.01)
    lr = config.get('lr', 1e-3)
    eps = config.get('eps', 1e-5)
    max_grad_norm = config.get('max_grad_norm', 0.5)
    use_gae = config.get('use_gae', False)
    tau = config.get('tau', 0.95)
    num_steps = config.get('num_steps', 100)
    num_frames = config.get('num_frames', 1e6)

    num_updates = int(num_frames) // num_steps // num_processes

    agent = algo.PPO(actor_critic, clip_param, ppo_epoch, num_mini_batch,
                     value_loss_coef, entropy_coef, lr=lr,
                     eps=eps,
                     max_grad_norm=max_grad_norm)

    rollouts = RolloutStorage(num_steps, num_processes, obs_shape, envs.action_space, actor_critic.state_size)
    current_obs = torch.zeros(num_processes, *obs_shape)

    obs = envs.reset()
    update_current_obs(obs, current_obs, obs_shape, num_stack)

    rollouts.observations[0].copy_(current_obs)

    # These variables are used to compute average rewards for all processes.
    episode_rewards = torch.zeros([num_processes, 1])
    final_rewards = torch.zeros([num_processes, 1])

    if cuda:
        current_obs = current_obs.cuda()
        rollouts.cuda()

    def save_the_model(num=None):
        """num is additional information"""
        # save it after training
        save_path = save_dir
        try:
            os.makedirs(save_path)
        except OSError:
            pass
        # A really ugly way to save a model to CPU
        save_model = actor_critic
        if cuda:
            save_model = copy.deepcopy(actor_critic).cpu()
        save_model = [save_model,
                      hasattr(envs, 'ob_rms') and envs.ob_rms or None]
        if num is None:
            save_name = '%s.pt' % env_name
        else:
            save_name = '%s_at_%d.pt' % (env_name, int(num))
        torch.save(save_model, os.path.join(save_path, save_name))

    start = time.time()
    for j in range(1, 1 + num_updates):
        for step in range(num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, states = actor_critic.act(
                    rollouts.observations[step],
                    rollouts.states[step],
                    rollouts.masks[step])
            cpu_actions = action.squeeze(1).cpu().numpy()

            # Obser reward and next obs
            obs, reward, done, info = envs.step(cpu_actions)
            reward = torch.from_numpy(np.expand_dims(np.stack(reward), 1)).float()
            episode_rewards += reward

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])
            final_rewards *= masks
            final_rewards += (1 - masks) * episode_rewards
            episode_rewards *= masks

            if cuda:
                masks = masks.cuda()

            if current_obs.dim() == 4:
                current_obs *= masks.unsqueeze(2).unsqueeze(2)
            else:
                current_obs *= masks

            update_current_obs(obs, current_obs, obs_shape, num_stack)
            rollouts.insert(current_obs, states, action, action_log_prob, value, reward, masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.observations[-1],
                                                rollouts.states[-1],
                                                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, use_gae, gamma, tau)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        if j % save_interval == 0 and save_dir != "":
            save_the_model()
            if save_step is not None:
                total_num_steps = j * num_processes * num_steps
                if total_num_steps > next_save_step:
                    save_the_model(total_num_steps)
                    next_save_step += save_step

        if j % log_interval == 0:
            end = time.time()
            total_num_steps = j * num_processes * num_steps
            print("Updates {}, num timesteps {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}".
                  format(j, total_num_steps,
                         int(total_num_steps / (end - start)),
                         final_rewards.mean(),
                         final_rewards.median(),
                         final_rewards.min(),
                         final_rewards.max(), dist_entropy,
                         value_loss, action_loss))
        if vis and j % vis_interval == 0:
            try:
                # Sometimes monitor doesn't properly flush the outputs
                win = visdom_plot(viz, win, log_dir, env_name,
                                  'ppo', num_frames)
            except IOError:
                pass
    # finally save model again
    save_the_model()
Exemple #17
0
        def run(self, time, S_time_interval, S_send_data_size, S_chunk_len, S_rebuf, S_buffer_size, S_play_time_len,
                S_end_delay, S_decision_flag, S_buffer_flag, S_cdn_flag, end_of_video, cdn_newest_id, download_id,
                cdn_has_frame, IntialVars):
            torch.set_num_threads(1)
            device = torch.device("cuda:0" if args.cuda else "cpu")

            if args.vis:
                from visdom import Visdom
                viz = Visdom(port=args.port)
                win = None

            # The online env in AItrans, it should have the observation space, action space and so on
            # We should step into the depth of envs.py in the github doc, and extract the format of observation
            # and action space
            envs =

            actor_critic = Policy(envs.observation_space.shape, envs.action_space,
                                  base_kwargs={'recurrent': args.recurrent_policy})
            actor_critic.to(device)

            # choose the algorithm, now we only have a2c
            if args.algo == 'a2c':
                agent = algo.A2C_ACKTR(actor_critic, args.value_loss_coef,
                                       args.entropy_coef, lr=args.lr,
                                       eps=args.eps, alpha=args.alpha,
                                       max_grad_norm=args.max_grad_norm)
            elif args.algo == 'ppo':
                agent = algo.PPO(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch,
                                 args.value_loss_coef, args.entropy_coef, lr=args.lr,
                                 eps=args.eps,
                                 max_grad_norm=args.max_grad_norm)
            elif args.algo == 'acktr':
                agent = algo.A2C_ACKTR(actor_critic, args.value_loss_coef,
                                       args.entropy_coef, acktr=True)

            rollouts = RolloutStorage(args.num_steps, args.num_processes,
                                      envs.observation_space.shape, envs.action_space,
                                      actor_critic.recurrent_hidden_state_size)

            # the initial observation
            obs =
            rollouts.obs[0].copy_(obs)
            rollouts.to(device)

            episode_reward = deque(maxlen=10)
            start = time.time()
            for j in range(num_updates):

                if args.use_linear_lr_decay:
                    # decrease learning rate linearly
                    if args.algo == "acktr":
                        # use optimizer's learning rate since it's hard-coded in kfac.py
                        update_linear_schedule(agent.optimizer, j, num_updates, agent.optimizer.lr)
                    else:
                        update_linear_schedule(agent.optimizer, j, num_updates, args.lr)

                if args.algo == 'ppo' and args.use_linear_lr_decay:
                    agent.clip_param = args.clip_param * (1 - j / float(num_updates))

                for step in range(args.num_steps):
                    # Sample actions
                    with torch.no_grad():
                        value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                            rollouts.obs[step],
                            rollouts.recurrent_hidden_states[step],
                            rollouts.masks[step])
def main():
    writer = SummaryWriter()
    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")
    best_score = 0

    if args.vis:
        from visdom import Visdom
        viz = Visdom(port=args.port)
        win = None

    if args.reward_mode == 0:
        clip_rewards = True
    else:
        clip_rewards = False
    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                        args.gamma, args.log_dir, args.add_timestep, device, False, 4, args.carl_wrapper, clip_rewards, args.track_primitive_reward)

    actor_critic = Policy(envs.observation_space.shape, envs.action_space, args.activation, args.complex_model,
        base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic, args.value_loss_coef,
                               args.entropy_coef, lr=args.lr,
                               eps=args.eps, alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch,
                         args.value_loss_coef, args.entropy_coef, lr=args.lr,
                               eps=args.eps,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic, args.value_loss_coef,
                               args.entropy_coef, acktr=True)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                        envs.observation_space.shape, envs.action_space,
                        actor_critic.recurrent_hidden_state_size)

    # initiate env and storage rollout
    obs = envs.reset()
    obs = obs/255
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    # necessary variabels
    episode_rewards = deque(maxlen=10) # store last 10 episode rewards
    g_step = 0 # global step
    reward_history = set() # record reward history (after reward rescaling)
    primitive_reward_history = set() # record original history (before reward rescaling)
    min_abs_reward = float('inf') # used in reward rescaling mode 2, work as a base
    masks_device = torch.ones(args.num_processes, 1).to(device)  # mask on gpu
    reward_count = 0 # for reward density calculation
    reward_start_step = 0 # for reward density calculation
    insert_entropy = torch.ones(args.num_processes, 1)  # entropys inserte into rollout
    avg_entropy = 0  
    have_done = 0.0

    num_feature_neurons = args.num_processes * 512
    for j in range(num_updates):
        if j == int((num_updates-1)*have_done):
            if args.save_intermediate_model:
                save_model = actor_critic
                if args.cuda:
                    save_model = copy.deepcopy(actor_critic).cpu()
                torch.save(save_model, os.path.join(save_path, args.env_name + str(have_done)+".pt")) 
            print("have done: ", have_done)
            have_done += 0.1

        for step in range(args.num_steps):
            # Sample actions
            g_step += 1
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states, entropy, f_a = actor_critic.act(
                        rollouts.obs[step],
                        rollouts.recurrent_hidden_states[step],
                        rollouts.masks[step])

            if args.track_hidden_stats:
                # analyze the stats of f_a 
                mean_fa = torch.mean(f_a)
                num_nonzero = f_a.nonzero().size(0)
                mean_pos = mean_fa * num_feature_neurons / num_nonzero
                activation_ratio = f_a / mean_pos
                num_bigger_mean_fa = torch.sum(activation_ratio > 1).item()
                num_bigger_half_fa = torch.sum(activation_ratio > 0.5).item()
                writer.add_scalar('analysis/fa_mean_ratio', (num_nonzero - num_bigger_mean_fa)/num_nonzero, g_step)
                writer.add_scalar('analysis/fa_0.5_ratio', (num_nonzero - num_bigger_half_fa)/num_nonzero, g_step)
                writer.add_scalar('analysis/fa_active', num_nonzero/num_feature_neurons, g_step)

                # analyze the stats of entropy
                avg_entropy = 0.999*avg_entropy + 0.001*torch.mean(entropy).item()
                num_all = len(entropy.view(-1))
                entropy_ratio = entropy/avg_entropy
                num_larger_mean = sum(entropy_ratio > 1).item()
                num_larger_onehalf = sum(entropy_ratio > 1.5).item()
                num_larger_double = sum(entropy_ratio > 2).item()
                writer.add_scalar('analysis/entropy_mean_ratio', num_larger_mean/num_all, g_step)
                writer.add_scalar('analysis/entropy_1.5_ratio', num_larger_onehalf/num_all, g_step)
                writer.add_scalar('analysis/entropy_2_ratio', num_larger_double/num_all, g_step)

            # update entropy inserted into rollout when appropriate 
            if args.modulation and j > args.start_modulate * num_updates:
                insert_entropy = entropy.unsqueeze(1)

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)
            obs = obs/255

            # reward rescaling
            if args.reward_mode == 1:
                reward = reward * args.reward_scale
            elif args.reward_mode == 2:
                if j < args.change_base_reward * num_updates:
                    non_zeros = abs(reward[reward != 0])
                    if len(non_zeros) > 0:
                        min_abs_reward_step = torch.min(non_zeros).item()
                        if min_abs_reward > min_abs_reward_step:
                            min_abs_reward = min_abs_reward_step
                            print('new min abs reward: ', min_abs_reward, ' time: ', g_step)
                if min_abs_reward != float('inf'):
                    reward = reward/min_abs_reward

            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])

            if args.log_evaluation:
                writer.add_scalar('analysis/entropy', entropy.mean().item(), g_step)
                if args.track_reward_density:   # track reward density, based on 0th process
                    reward_count += (reward[0] != 0)
                    if 'episode' in infos[0].keys():
                        writer.add_scalar('analysis/reward_density', reward_count/(g_step - reward_start_step), g_step)
                        reward_count = 0
                        reward_start_step = g_step
                if args.track_primitive_reward:   # track primitive reward (before rescaling)
                    for info in infos:
                        if 'new_reward' in info:
                            new_rewards  = info['new_reward'] - primitive_reward_history
                            if len(new_rewards) > 0:
                                print('new primitive rewards: ', new_rewards, ' time: ', g_step)
                                primitive_reward_history =  primitive_reward_history.union(info['new_reward'])
                if args.track_scaled_reward:  # track rewards after rescaling
                    for r in reward:
                        r = r.item()
                        if r not in reward_history:
                            print('new step rewards: ', r, g_step)
                            reward_history.add(r)


            for idx in range(len(infos)):
                info = infos[idx]
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])
                    steps_done = g_step*args.num_processes + idx
                    writer.add_scalar('data/reward', info['episode']['r'], steps_done)
                    mean_rewards = np.mean(episode_rewards)
                    writer.add_scalar('data/avg_reward', mean_rewards, steps_done)
                    if mean_rewards > best_score:
                        best_score = mean_rewards
                        save_model = actor_critic
                        if args.cuda:
                            save_model = copy.deepcopy(actor_critic).cpu()
                        torch.save(save_model, os.path.join(save_path, args.env_name + ".pt"))                        
            rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob, value, reward, masks, insert_entropy)

        with torch.no_grad():
            masks_device.copy_(masks)
            next_value = actor_critic.get_value(obs, recurrent_hidden_states, masks_device)

        rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.tau)

        value_loss, action_loss, dist_entropy, value = agent.update(rollouts, args.modulation)

        if args.track_value_loss:
            writer.add_scalar('analysis/value_loss', value_loss, j)
            writer.add_scalar('analysis/value', value, j)
            writer.add_scalar('analysis/loss_ratio', value_loss/value, j)

        if args.modulation and  args.track_lr and args.log_evaluation:
            writer.add_scalar('analysis/min_lr', torch.min(rollouts.lr).item(), j)
            writer.add_scalar('analysis/max_lr', torch.max(rollouts.lr).item(), j)
            writer.add_scalar('analysis/std_lr', torch.std(rollouts.lr).item(), j)
            writer.add_scalar('analysis/avg_lr', torch.mean(rollouts.lr).item(), j)

        rollouts.after_update()

    writer.export_scalars_to_json("./all_scalars.json")
    writer.close()
Exemple #19
0
def main():
    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")
    """
    if args.vis:
        from visdom import Visdom
        viz = Visdom(port=args.port)
        win = None
    """

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, args.add_timestep, device,
                         False)

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=100)

    start = time.time()
    for j in range(num_updates):
        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)
            """
            for info in infos:
                if 'episode' in info.keys():
                    print(reward)
                    episode_rewards.append(info['episode']['r'])
            """

            # FIXME: works only for environments with sparse rewards
            for idx, eps_done in enumerate(done):
                if eps_done:
                    episode_rewards.append(reward[idx])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.tau)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        if j % args.save_interval == 0 and args.save_dir != "":
            print('Saving model')
            print()

            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            # A really ugly way to save a model to CPU
            save_model = actor_critic
            if args.cuda:
                save_model = copy.deepcopy(actor_critic).cpu()

            save_model = [
                save_model,
                hasattr(envs.venv, 'ob_rms') and envs.venv.ob_rms or None
            ]

            torch.save(save_model,
                       os.path.join(save_path, args.env_name + ".pt"))

        total_num_steps = (j + 1) * args.num_processes * args.num_steps

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.2f}/{:.2f}, min/max reward {:.2f}/{:.2f}, success rate {:.2f}\n"
                .format(
                    j, total_num_steps, int(total_num_steps / (end - start)),
                    len(episode_rewards), np.mean(episode_rewards),
                    np.median(episode_rewards), np.min(episode_rewards),
                    np.max(episode_rewards),
                    np.count_nonzero(np.greater(episode_rewards, 0)) /
                    len(episode_rewards)))

        if args.eval_interval is not None and len(
                episode_rewards) > 1 and j % args.eval_interval == 0:
            eval_envs = make_vec_envs(args.env_name,
                                      args.seed + args.num_processes,
                                      args.num_processes, args.gamma,
                                      eval_log_dir, args.add_timestep, device,
                                      True)

            if eval_envs.venv.__class__.__name__ == "VecNormalize":
                eval_envs.venv.ob_rms = envs.venv.ob_rms

                # An ugly hack to remove updates
                def _obfilt(self, obs):
                    if self.ob_rms:
                        obs = np.clip((obs - self.ob_rms.mean) /
                                      np.sqrt(self.ob_rms.var + self.epsilon),
                                      -self.clipob, self.clipob)
                        return obs
                    else:
                        return obs

                eval_envs.venv._obfilt = types.MethodType(_obfilt, envs.venv)

            eval_episode_rewards = []

            obs = eval_envs.reset()
            eval_recurrent_hidden_states = torch.zeros(
                args.num_processes,
                actor_critic.recurrent_hidden_state_size,
                device=device)
            eval_masks = torch.zeros(args.num_processes, 1, device=device)

            while len(eval_episode_rewards) < 10:
                with torch.no_grad():
                    _, action, _, eval_recurrent_hidden_states = actor_critic.act(
                        obs,
                        eval_recurrent_hidden_states,
                        eval_masks,
                        deterministic=True)

                # Obser reward and next obs
                obs, reward, done, infos = eval_envs.step(action)
                eval_masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                                for done_ in done])
                for info in infos:
                    if 'episode' in info.keys():
                        eval_episode_rewards.append(info['episode']['r'])

            eval_envs.close()

            print(" Evaluation using {} episodes: mean reward {:.5f}\n".format(
                len(eval_episode_rewards), np.mean(eval_episode_rewards)))
        """
        if args.vis and j % args.vis_interval == 0:
            try:
                # Sometimes monitor doesn't properly flush the outputs
                win = visdom_plot(viz, win, args.log_dir, args.env_name,
                                  args.algo, args.num_frames)
            except IOError:
                pass
        """

    envs.close()
def main():
    import copy
    import glob
    import os
    import time
    import matplotlib.pyplot as plt

    import gym
    import numpy as np
    import torch
    torch.multiprocessing.set_start_method('spawn')

    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    from gym.spaces import Discrete

    from arguments import get_args
    from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
    from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
    from baselines.common.vec_env.vec_normalize import VecNormalize
    from envs import make_env
    from img_env import ImgEnv, IMG_ENVS
    from model import Policy
    from storage import RolloutStorage
    from utils import update_current_obs, agent1_eval_episode, agent2_eval_episode
    from torchvision import transforms
    from visdom import Visdom

    import algo

    # viz = Visdom(port=8097)

    print("#######")
    print(
        "WARNING: All rewards are clipped or normalized so you need to use a monitor (see envs.py) or visdom plot to get true rewards"
    )
    print("#######")

    plot_rewards = []
    plot_policy_loss = []
    plot_value_loss = []
    # x = np.array([0])
    # y = np.array([0])
    # counter = 0
    # win = viz.line(
    #     X=x,
    #     Y=y,
    #     win="test1",
    #     name='Line1',
    #     opts=dict(
    #         title='Reward',
    #     )
    #     )
    # win2 = viz.line(
    #     X=x,
    #     Y=y,
    #     win="test2",
    #     name='Line2',
    #     opts=dict(
    #         title='Policy Loss',
    #     )
    #     )
    # win3 = viz.line(
    #     X=x,
    #     Y=y,
    #     win="test3",
    #     name='Line3',
    #     opts=dict(
    #         title='Value Loss',
    #     )
    #     )

    args = get_args()
    if args.no_cuda:
        args.cuda = False
    print(args)
    assert args.algo in ['a2c', 'ppo', 'acktr']
    if args.recurrent_policy:
        assert args.algo in ['a2c', 'ppo'], \
            'Recurrent policy is not implemented for ACKTR'

    num_updates = int(args.num_frames) // args.num_steps // args.num_processes

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    toprint = ['seed', 'lr', 'nat', 'resnet']
    if args.env_name in IMG_ENVS:
        toprint += ['window', 'max_steps']
    toprint.sort()
    name = args.tag
    args_param = vars(args)
    os.makedirs(os.path.join(args.out_dir, args.env_name), exist_ok=True)
    for arg in toprint:
        if arg in args_param and (args_param[arg] or arg in ['gamma', 'seed']):
            if args_param[arg] is True:
                name += '{}_'.format(arg)
            else:
                name += '{}{}_'.format(arg, args_param[arg])
    model_dir = os.path.join(args.out_dir, args.env_name, args.algo)
    os.makedirs(model_dir, exist_ok=True)

    results_dict = {'episodes': [], 'rewards': [], 'args': args}
    torch.set_num_threads(1)
    eval_env = make_env(args,
                        'cifar10',
                        args.seed,
                        1,
                        None,
                        args.add_timestep,
                        natural=args.nat,
                        train=False)
    envs = make_env(args,
                    'cifar10',
                    args.seed,
                    1,
                    None,
                    args.add_timestep,
                    natural=args.nat,
                    train=True)

    #print(envs)
    # envs = envs[0]

    # if args.num_processes > 1:
    #     envs = SubprocVecEnv(envs)
    # else:
    #     envs = DummyVecEnv(envs)
    # eval_env = DummyVecEnv(eval_env)
    # if len(envs.observation_space.shape) == 1:
    #     envs = VecNormalize(envs, gamma=args.gamma)

    obs_shape = envs.observation_space.shape
    obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])

    actor_critic1 = Policy(obs_shape,
                           envs.action_space,
                           args.recurrent_policy,
                           dataset=args.env_name,
                           resnet=args.resnet,
                           pretrained=args.pretrained)

    actor_critic2 = Policy(obs_shape,
                           envs.action_space,
                           args.recurrent_policy,
                           dataset=args.env_name,
                           resnet=args.resnet,
                           pretrained=args.pretrained)

    if envs.action_space.__class__.__name__ == "Discrete":
        action_shape = 1
    else:
        action_shape = envs.action_space.shape[0]

    if args.cuda:
        actor_critic1.cuda()
        actor_critic2.cuda()

    if args.algo == 'a2c':
        agent1 = algo.A2C_ACKTR(actor_critic1,
                                args.value_loss_coef,
                                args.entropy_coef,
                                lr=args.lr,
                                eps=args.eps,
                                alpha=args.alpha,
                                max_grad_norm=args.max_grad_norm)
        agent2 = algo.A2C_ACKTR(actor_critic2,
                                args.value_loss_coef,
                                args.entropy_coef,
                                lr=args.lr,
                                eps=args.eps,
                                alpha=args.alpha,
                                max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent1 = algo.PPO(actor_critic1,
                          args.clip_param,
                          args.ppo_epoch,
                          args.num_mini_batch,
                          args.value_loss_coef,
                          args.entropy_coef,
                          lr=args.lr,
                          eps=args.eps,
                          max_grad_norm=args.max_grad_norm)
        agent2 = algo.PPO(actor_critic2,
                          args.clip_param,
                          args.ppo_epoch,
                          args.num_mini_batch,
                          args.value_loss_coef,
                          args.entropy_coef,
                          lr=args.lr,
                          eps=args.eps,
                          max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent1 = algo.A2C_ACKTR(actor_critic1,
                                args.value_loss_coef,
                                args.entropy_coef,
                                acktr=True)
        agent2 = algo.A2C_ACKTR(actor_critic2,
                                args.value_loss_coef,
                                args.entropy_coef,
                                acktr=True)

    action_space = envs.action_space
    if args.env_name in IMG_ENVS:
        action_space = np.zeros(2)
    # obs_shape = envs.observation_space.shape
    agent1_rollouts = RolloutStorage(args.num_steps, args.num_processes,
                                     obs_shape, action_space,
                                     actor_critic1.state_size)
    agent1_current_obs = torch.zeros(args.num_processes, *obs_shape)

    agent1_obs = envs.agent1_reset()
    update_current_obs(agent1_obs, agent1_current_obs, obs_shape,
                       args.num_stack)
    agent1_rollouts.observations[0].copy_(agent1_current_obs)

    agent2_rollouts = RolloutStorage(args.num_steps, args.num_processes,
                                     obs_shape, action_space,
                                     actor_critic2.state_size)
    agent2_current_obs = torch.zeros(args.num_processes, *obs_shape)

    agent2_obs = envs.agent2_reset()
    update_current_obs(agent2_obs, agent2_current_obs, obs_shape,
                       args.num_stack)
    agent2_rollouts.observations[0].copy_(agent2_current_obs)

    # These variables are used to compute average rewards for all processes.
    episode_rewards = torch.zeros([args.num_processes, 1])
    final_rewards = torch.zeros([args.num_processes, 1])

    if args.cuda:
        current_obs = current_obs.cuda()
        rollouts.cuda()

    start = time.time()
    for j in range(num_updates):
        #print("NEW J ITERATION: ", j)
        # envs.display_original(j)
        for step in range(args.num_steps):
            #uprint("STEP", step)
            # Sample actions
            with torch.no_grad():
                value1, action1, action_log_prob1, states1 = actor_critic1.act(
                    agent1_rollouts.observations[step],
                    agent1_rollouts.states[step], agent1_rollouts.masks[step])
                value2, action2, action_log_prob2, states2 = actor_critic2.act(
                    agent2_rollouts.observations[step],
                    agent2_rollouts.states[step], agent2_rollouts.masks[step])

            cpu_actions1 = action1.squeeze(1).cpu().numpy()
            cpu_actions2 = action2.squeeze(1).cpu().numpy()

            # Obser reward and next obs
            obs1, reward1, done1, info1 = envs.agent1_step(cpu_actions1)
            obs2, reward2, done2, info2 = envs.agent2_step(cpu_actions2)

            # SIMPLE HEURISTIC 1
            # If either agent gets it correct, they are done.

            if done1 == True or done2 == True:
                done1 = True
                done2 = True
                done = True
            else:
                done = False

            # envs.display_step(step, j)

            # print("OBS", obs)

            # print("REWARD", reward)
            # print("DONE", done)
            # print("INFO", info)

            reward1 = torch.from_numpy(np.expand_dims(np.stack([reward1]),
                                                      1)).float()
            reward2 = torch.from_numpy(np.expand_dims(np.stack([reward2]),
                                                      1)).float()
            reward = (reward1 + reward2)
            episode_rewards += reward

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in [done]])
            final_rewards *= masks
            final_rewards += (1 - masks) * episode_rewards
            episode_rewards *= masks

            if args.cuda:
                masks = masks.cuda()

            if agent1_current_obs.dim() == 4:
                agent1_current_obs *= masks.unsqueeze(2).unsqueeze(2)
                agent2_current_obs *= masks.unsqueeze(2).unsqueeze(2)
            else:
                agent1_current_obs *= masks
                agent2_current_obs *= masks

            update_current_obs(agent1_obs, agent1_current_obs, obs_shape,
                               args.num_stack)
            agent1_rollouts.insert(agent1_current_obs, states1, action1,
                                   action_log_prob1, value1, reward, masks)

            update_current_obs(agent2_obs, agent2_current_obs, obs_shape,
                               args.num_stack)
            agent2_rollouts.insert(agent2_current_obs, states2, action2,
                                   action_log_prob2, value2, reward, masks)

            # print("envs.curr_img SHAPE: ", envs.curr_img.shape)
            #display_state = envs.curr_img
            # display_state[:, envs.pos[0]:envs.pos[0]+envs.window, envs.pos[1]:envs.pos[1]+envs.window] = 5
            # display_state = custom_replace(display_state, 1, 0)
            # display_state[:, envs.pos[0]:envs.pos[0]+envs.window, envs.pos[1]:envs.pos[1]+envs.window] = \
            #     envs.curr_img[:, envs.pos[0]:envs.pos[0]+envs.window, envs.pos[1]:envs.pos[1]+envs.window]
            # img = transforms.ToPILImage()(display_state)
            # img.save("state_cifar/"+"state"+str(j)+"_"+str(step)+".png")

        with torch.no_grad():
            next_value1 = actor_critic1.get_value(
                agent1_rollouts.observations[-1], agent1_rollouts.states[-1],
                agent1_rollouts.masks[-1]).detach()
            next_value2 = actor_critic2.get_value(
                agent2_rollouts.observations[-1], agent2_rollouts.states[-1],
                agent2_rollouts.masks[-1]).detach()

        #print("GOT HERE")
        agent1_rollouts.compute_returns(next_value1, args.use_gae, args.gamma,
                                        args.tau)
        value_loss1, action_loss1, dist_entropy1 = agent1.update(
            agent1_rollouts)
        agent1_rollouts.after_update()

        agent2_rollouts.compute_returns(next_value2, args.use_gae, args.gamma,
                                        args.tau)
        value_loss2, action_loss2, dist_entropy2 = agent2.update(
            agent2_rollouts)
        agent2_rollouts.after_update()
        #print("GOT HERE2")

        if j % args.save_interval == 0:
            # print("SAVING")
            torch.save((actor_critic1.state_dict(), results_dict),
                       os.path.join(model_dir,
                                    name + 'cifar_model_ppo_ex2_agent1.pt'))
            torch.save((actor_critic2.state_dict(), results_dict),
                       os.path.join(model_dir,
                                    name + 'cifar_model_ppo_ex2_agent2.pt'))
            # print("SAVED")

        if j % args.log_interval == 0:
            # print("EVALUATING EPISODE")
            end = time.time()
            total_reward1 = agent1_eval_episode(eval_env, actor_critic1, args)
            total_reward2 = agent2_eval_episode(eval_env, actor_critic2, args)

            # print("EVALUATED EPISODE")

            total_reward = (total_reward1 + total_reward2)
            value_loss = (value_loss1 + value_loss2)
            action_loss = (action_loss1 + action_loss2)
            dist_entropy = (dist_entropy1 + dist_entropy2)

            results_dict['rewards'].append(total_reward)
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            print(
                "Updates {}, num timesteps {}, FPS {}, reward {:.1f} entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        np.mean(results_dict['rewards'][-10:]), dist_entropy,
                        value_loss, action_loss))

            plot_rewards.append(np.mean(results_dict['rewards'][-10:]))
            plot_policy_loss.append(action_loss)
            plot_value_loss.append(value_loss)

    plt.plot(range(len(plot_rewards)), plot_rewards)
    plt.savefig("rewards_multi_1.png")
    plt.close()

    plt.plot(range(len(plot_policy_loss)), plot_policy_loss)
    plt.savefig("policyloss_multi_1.png")
    plt.close()

    plt.plot(range(len(plot_value_loss)), plot_value_loss)
    plt.savefig("valueloss_multi_1.png")
    plt.close()
Exemple #21
0
def main():
    saved_model = os.path.join(args.save_dir, args.env_name + '.pt')
    if os.path.exists(saved_model) and not args.overwrite:
        actor_critic, ob_rms = \
                torch.load(saved_model)
        agent = \
            torch.load(os.path.join(args.save_dir, args.env_name + '_agent.pt'))
        for i in agent.optimizer.state_dict():
            print(dir(agent.optimizer))
            print(getattr(agent.optimizer, 'steps'))
            print(agent.optimizer.state_dict()[i])
        past_steps = agent.optimizer.steps
    else: 
        actor_critic = False
        agent = False
        past_steps = 0
        try:
            os.makedirs(args.log_dir)
        except OSError:
            files = glob.glob(os.path.join(args.log_dir, '*.monitor.csv'))
            for f in files:
                os.remove(f)
    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    if args.vis:
        from visdom import Visdom
        viz = Visdom(port=args.port)
        win = None
        win_eval = None

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                        args.gamma, args.log_dir, args.add_timestep, device, False, None,

                        args=args)

    if actor_critic:
        pass
      # vec_norm = get_vec_normalize(envs)
      # if vec_norm is not None:
      #     vec_norm.eval()
      #     vec_norm.ob_rms = ob_rms
        
    else:
        actor_critic = Policy(envs.observation_space.shape, envs.action_space,
            base_kwargs={'map_width': args.map_width, 'num_actions': 18, 'recurrent': args.recurrent_policy},
            curiosity=args.curiosity, algo=args.algo, model=args.model, args=args)
    actor_critic.to(device)

    evaluator = None

    if not agent:
        if args.algo == 'a2c':
            agent = algo.A2C_ACKTR_NOREWARD(actor_critic, args.value_loss_coef,
                                   args.entropy_coef, lr=args.lr,
                                   eps=args.eps, alpha=args.alpha,
                                   max_grad_norm=args.max_grad_norm,
                                   curiosity=args.curiosity, args=args)
        elif args.algo == 'ppo':
            agent = algo.PPO(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch,
                             args.value_loss_coef, args.entropy_coef, lr=args.lr,
                                   eps=args.eps,
                                   max_grad_norm=args.max_grad_norm)
        elif args.algo == 'acktr':
            agent = algo.A2C_ACKTR_NOREWARD(actor_critic, args.value_loss_coef,
                                   args.entropy_coef, lr=args.lr,
                                   eps=args.eps, alpha=args.alpha,
                                   max_grad_norm=args.max_grad_norm,
                                   acktr=True,
                                   curiosity=args.curiosity, args=args)

    if args.curiosity:
        rollouts = CuriosityRolloutStorage(args.num_steps, args.num_processes,
                            envs.observation_space.shape, envs.action_space,
                            actor_critic.recurrent_hidden_state_size, actor_critic.base.feature_state_size(), args=args)
    else:
        rollouts = RolloutStorage(args.num_steps, args.num_processes,
                            envs.observation_space.shape, envs.action_space,
                            actor_critic.recurrent_hidden_state_size, args=args)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    for j in range(num_updates - past_steps):
        if args.drop_path:
            actor_critic.base.get_drop_path()
        player_act = None
        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():

                value, action, action_log_probs, recurrent_hidden_states = actor_critic.act(
                        rollouts.obs[step],
                        rollouts.recurrent_hidden_states[step],
                        rollouts.masks[step],
                        player_act=player_act,
                        icm_enabled=args.curiosity)

            # Observe reward and next obs
            obs, reward, done, infos = envs.step(action)

            player_act = None
            if args.render:

                if infos[0]:
                    if 'player_move' in infos[0].keys():
                        player_act = infos[0]['player_move']
            

            if args.curiosity:
                # run icm
                with torch.no_grad():


                    feature_state, feature_state_pred, action_dist_pred = actor_critic.icm_act(
                            (rollouts.obs[step], obs, action_bin)
                            )

                intrinsic_reward = args.eta * ((feature_state - feature_state_pred).pow(2)).sum() / 2.
                if args.no_reward:
                    reward = 0
                reward += intrinsic_reward.cpu()

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            if args.curiosity:
                rollouts.insert(obs, recurrent_hidden_states, action, action_log_probs, value, reward, masks,
                                feature_state, feature_state_pred, action_bin, action_dist_pred)
            else:
                rollouts.insert(obs, recurrent_hidden_states, action, action_log_probs, value, reward, masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.obs[-1],
                                                rollouts.recurrent_hidden_states[-1],
                                                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.tau)
        
        if args.curiosity:
            value_loss, action_loss, dist_entropy, fwd_loss, inv_loss = agent.update(rollouts)
        else:
            value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        if j % args.save_interval == 0 and args.save_dir != "":
            save_path = os.path.join(args.save_dir)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            # A really ugly way to save a model to CPU
            save_model = actor_critic
            if args.cuda:
                save_model = copy.deepcopy(actor_critic).cpu()

            save_model = [save_model,
                          getattr(get_vec_normalize(envs), 'ob_rms', None)]

            torch.save(save_model, os.path.join(save_path, args.env_name + ".pt"))
            save_agent = copy.deepcopy(agent)

            torch.save(save_agent, os.path.join(save_path, args.env_name + '_agent.pt'))
           #torch.save(actor_critic.state_dict(), os.path.join(save_path, args.env_name + "_weights.pt"))

        total_num_steps = (j + 1) * args.num_processes * args.num_steps

        if not dist_entropy:
            dist_entropy = 0
        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            end = time.time()
            print("Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n \
dist entropy {:.1f}, val/act loss {:.1f}/{:.1f},".
                format(j, total_num_steps,
                       int(total_num_steps / (end - start)),
                       len(episode_rewards),
                       np.mean(episode_rewards),
                       np.median(episode_rewards),
                       np.min(episode_rewards),
                       np.max(episode_rewards), dist_entropy,
                       value_loss, action_loss))
            if args.curiosity:
                print("fwd/inv icm loss {:.1f}/{:.1f}\n".
                format(
                       fwd_loss, inv_loss))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            if evaluator is None:
                evaluator = Evaluator(args, actor_critic, device)


            if args.model == 'fractal':
                n_cols = evaluator.actor_critic.base.n_cols
                for i in range(-1, n_cols):
                    evaluator.evaluate(column=i)
               #num_eval_frames = (args.num_frames // (args.num_steps * args.eval_interval * args.num_processes)) * args.num_processes *  args.max_step
                win_eval = visdom_plot(viz, win_eval, evaluator.eval_log_dir, args.env_name,
                              args.algo, args.num_frames, n_graphs=args.n_recs)
            else:
                evaluator.evaluate(column=None)



        if args.vis and j % args.vis_interval == 0:
            try:
                # Sometimes monitor doesn't properly flush the outputs
                win = visdom_plot(viz, win, args.log_dir, args.env_name,
                                  args.algo, args.num_frames)
            except IOError:
                pass
def run(number_of_workers, log_dir, vis_title):
    print("#######")
    print(
        "WARNING: All rewards are clipped or normalized so you need to use a monitor (see envs.py) or visdom plot to get true rewards"
    )
    print("#######")

    print("#######")
    print("num_updates: {}".format(num_updates))
    print("#######")

    try:
        os.makedirs(log_dir)
    except OSError:
        files = glob.glob(os.path.join(log_dir, '*.monitor.csv'))
        for f in files:
            os.remove(f)

    torch.set_num_threads(1)

    if args.vis:
        from visdom import Visdom
        viz = Visdom(port=args.port)
        win = None

    # Done: change make_env behaviour such that simple env is created; see custom_envs.py
    # args.env_name has to start with ng_ currently only WorkerMaintenanceEnv is working
    env_config = ENV_CONFIG.copy()
    # env_config['path_to_keras_expert_model'] = args.path_to_keras_expert_model
    env_config['number_of_workers'] = number_of_workers
    env_config['enable_0action_boost'] = args.enable_0action_boost
    envs = [
        make_env(args.env_name, args.seed, i, log_dir, args.add_timestep,
                 env_config) for i in range(args.num_processes)
    ]

    if args.num_processes > 1:
        envs = SubprocVecEnv(envs)
    else:
        envs = DummyVecEnv(envs)

    if len(envs.observation_space.shape) == 1:
        envs = VecNormalize(envs,
                            ob=not args.disable_env_normalize_ob,
                            ret=not args.disable_env_normalize_rw,
                            gamma=args.gamma)

    obs_shape = envs.observation_space.shape
    obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])

    # Done: 2018/06/24. change Model in Policy to LSTM/GRU model (ref. CNN with gru); see model.py

    print("#######")
    print("action space.n : {}".format(envs.action_space.n))
    print("#######")
    actor_critic = Policy(obs_shape, envs.action_space, args.recurrent_policy)

    if envs.action_space.__class__.__name__ == "Discrete":
        action_shape = 1
    else:
        action_shape = envs.action_space.shape[0]

    if args.cuda:
        actor_critic.cuda()

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    rollouts = RolloutStorage(args.num_steps, args.num_processes, obs_shape,
                              envs.action_space, actor_critic.state_size)
    current_obs = torch.zeros(args.num_processes, *obs_shape)

    def update_current_obs(obs):
        shape_dim0 = envs.observation_space.shape[0]
        obs = torch.from_numpy(obs).float()
        if args.num_stack > 1:
            current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
        current_obs[:, -shape_dim0:] = obs

    obs = envs.reset()
    update_current_obs(obs)

    rollouts.observations[0].copy_(current_obs)

    # These variables are used to compute average rewards for all processes.
    episode_rewards = torch.zeros([args.num_processes, 1])
    final_rewards = torch.zeros([args.num_processes, 1])

    if args.cuda:
        current_obs = current_obs.cuda()
        rollouts.cuda()

    start = time.time()
    for j in range(num_updates):
        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, states = actor_critic.act(
                    rollouts.observations[step], rollouts.states[step],
                    rollouts.masks[step])
            cpu_actions = action.squeeze(1).cpu().numpy()
            # Obser reward and next obs
            obs, reward, done, info = envs.step(cpu_actions)
            reward = torch.from_numpy(np.expand_dims(np.stack(reward),
                                                     1)).float()
            episode_rewards += reward

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            final_rewards *= masks
            final_rewards += (1 - masks) * episode_rewards
            episode_rewards *= masks

            if args.cuda:
                masks = masks.cuda()

            if current_obs.dim() == 4:
                current_obs *= masks.unsqueeze(2).unsqueeze(2)
            else:
                current_obs *= masks

            update_current_obs(obs)
            rollouts.insert(current_obs, states, action, action_log_prob,
                            value, reward, masks)

            if args.enable_debug_info_print:
                print("#####")
                print("cpu_action: {}".format(cpu_actions))
                print("envs reward: {}".format(reward))
                print("info stats reward: {}".format(
                    info[0]["stats_relative_reward_regret"] +
                    info[0]["stats_relative_reward_penalty"]))
                print("final_rewards after masks: {}".format(final_rewards))
                print(
                    "episode_rewards after masks: {}".format(episode_rewards))

        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.observations[-1],
                                                rollouts.states[-1],
                                                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.tau)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        if j % args.save_interval == 0 and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            # A really ugly way to save a model to CPU
            save_model = actor_critic
            if args.cuda:
                save_model = copy.deepcopy(actor_critic).cpu()

            save_model = [
                save_model,
                hasattr(envs, 'ob_rms') and envs.ob_rms or None
            ]
            model_name = "{}-{}-{}_w{}-{}.pt".format(args.env_name, args.algo,
                                                     args.save_model_postfix,
                                                     number_of_workers, j)
            torch.save(save_model, os.path.join(save_path, model_name))

        if j % args.log_interval == 0:
            end = time.time()
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            print(
                "Updates {}, num timesteps {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        final_rewards.mean(), final_rewards.median(),
                        final_rewards.min(), final_rewards.max(), dist_entropy,
                        value_loss, action_loss))
        if args.vis and j % args.vis_interval == 0:
            try:
                # Sometimes monitor doesn't properly flush the outputs
                win = visdom_plot(viz, win, log_dir, vis_title, args.algo,
                                  args.num_frames)
            except IOError:
                pass
    # save final policy
    save_path = os.path.join(args.save_dir, args.algo)
    try:
        os.makedirs(save_path)
    except OSError:
        pass

    # A really ugly way to save a model to CPU
    save_model = actor_critic
    if args.cuda:
        save_model = copy.deepcopy(actor_critic).cpu()

    save_model = [save_model, hasattr(envs, 'ob_rms') and envs.ob_rms or None]
    model_name = "{}-{}-{}_w{}-final.pt".format(args.env_name, args.algo,
                                                args.save_model_postfix,
                                                number_of_workers)
    torch.save(save_model, os.path.join(save_path, model_name))
    return True
Exemple #23
0
def main():
    # Setup Logging
    log_dir = "{}/models/{}/".format(args.dump_location, args.exp_name)
    dump_dir = "{}/dump/{}/".format(args.dump_location, args.exp_name)

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists("{}/images/".format(dump_dir)):
        os.makedirs("{}/images/".format(dump_dir))

    logging.basicConfig(filename=log_dir + 'train.log', level=logging.INFO)
    print("Dumping at {}".format(log_dir))
    print(args)
    logging.info(args)

    # Logging and loss variables
    num_scenes = args.num_processes
    num_episodes = int(args.num_episodes)
    device = args.device = torch.device("cuda:0" if args.cuda else "cpu")
    policy_loss = 0

    best_cost = 100000
    costs = deque(maxlen=1000)
    exp_costs = deque(maxlen=1000)
    pose_costs = deque(maxlen=1000)

    g_masks = torch.ones(num_scenes).float().to(device)
    l_masks = torch.zeros(num_scenes).float().to(device)

    best_local_loss = np.inf
    best_g_reward = -np.inf

    if args.eval:
        traj_lengths = args.max_episode_length // args.num_local_steps
        explored_area_log = np.zeros((num_scenes, num_episodes, traj_lengths))
        explored_ratio_log = np.zeros((num_scenes, num_episodes, traj_lengths))

    g_episode_rewards = deque(maxlen=1000)

    l_action_losses = deque(maxlen=1000)

    g_value_losses = deque(maxlen=1000)
    g_action_losses = deque(maxlen=1000)
    g_dist_entropies = deque(maxlen=1000)

    per_step_g_rewards = deque(maxlen=1000)

    g_process_rewards = np.zeros((num_scenes))

    # Starting environments
    torch.set_num_threads(1)
    envs = make_vec_envs(args)
    obs, infos = envs.reset()

    # Initialize map variables
    ### Full map consists of 4 channels containing the following:
    ### 1. Obstacle Map
    ### 2. Exploread Area
    ### 3. Current Agent Location
    ### 4. Past Agent Locations

    torch.set_grad_enabled(False)

    # Calculating full and local map sizes
    map_size = args.map_size_cm // args.map_resolution
    full_w, full_h = map_size, map_size
    local_w, local_h = int(full_w / args.global_downscaling), \
                       int(full_h / args.global_downscaling)

    # Initializing full and local map
    full_map = torch.zeros(num_scenes, 4, full_w, full_h).float().to(device)
    local_map = torch.zeros(num_scenes, 4, local_w, local_h).float().to(device)

    # Initial full and local pose
    full_pose = torch.zeros(num_scenes, 3).float().to(device)
    local_pose = torch.zeros(num_scenes, 3).float().to(device)

    # Origin of local map
    origins = np.zeros((num_scenes, 3))

    # Local Map Boundaries
    lmb = np.zeros((num_scenes, 4)).astype(int)

    ### Planner pose inputs has 7 dimensions
    ### 1-3 store continuous global agent location
    ### 4-7 store local map boundaries
    planner_pose_inputs = np.zeros((num_scenes, 7))

    def init_map_and_pose():
        full_map.fill_(0.)
        full_pose.fill_(0.)
        full_pose[:, :2] = args.map_size_cm / 100.0 / 2.0

        locs = full_pose.cpu().numpy()
        planner_pose_inputs[:, :3] = locs
        for e in range(num_scenes):
            r, c = locs[e, 1], locs[e, 0]
            loc_r, loc_c = [
                int(r * 100.0 / args.map_resolution),
                int(c * 100.0 / args.map_resolution)
            ]

            full_map[e, 2:, loc_r - 1:loc_r + 2, loc_c - 1:loc_c + 2] = 1.0

            lmb[e] = get_local_map_boundaries(
                (loc_r, loc_c), (local_w, local_h), (full_w, full_h))

            planner_pose_inputs[e, 3:] = lmb[e]
            origins[e] = [
                lmb[e][2] * args.map_resolution / 100.0,
                lmb[e][0] * args.map_resolution / 100.0, 0.
            ]

        for e in range(num_scenes):
            local_map[e] = full_map[e, :, lmb[e, 0]:lmb[e, 1],
                                    lmb[e, 2]:lmb[e, 3]]
            local_pose[e] = full_pose[e] - \
                            torch.from_numpy(origins[e]).to(device).float()

    init_map_and_pose()

    # Global policy observation space
    g_observation_space = gym.spaces.Box(0,
                                         1, (8, local_w, local_h),
                                         dtype='uint8')

    # Global policy action space
    g_action_space = gym.spaces.Box(low=0.0,
                                    high=1.0,
                                    shape=(2, ),
                                    dtype=np.float32)

    # Local policy observation space
    l_observation_space = gym.spaces.Box(
        0, 255, (3, args.frame_width, args.frame_width), dtype='uint8')

    # Local and Global policy recurrent layer sizes
    l_hidden_size = args.local_hidden_size
    g_hidden_size = args.global_hidden_size

    # slam
    nslam_module = Neural_SLAM_Module(args).to(device)
    slam_optimizer = get_optimizer(nslam_module.parameters(),
                                   args.slam_optimizer)

    # Global policy
    g_policy = RL_Policy(g_observation_space.shape,
                         g_action_space,
                         base_kwargs={
                             'recurrent': args.use_recurrent_global,
                             'hidden_size': g_hidden_size,
                             'downscaling': args.global_downscaling
                         }).to(device)
    g_agent = algo.PPO(g_policy,
                       args.clip_param,
                       args.ppo_epoch,
                       args.num_mini_batch,
                       args.value_loss_coef,
                       args.entropy_coef,
                       lr=args.global_lr,
                       eps=args.eps,
                       max_grad_norm=args.max_grad_norm)

    # Local policy
    l_policy = Local_IL_Policy(
        l_observation_space.shape,
        envs.action_space.n,
        recurrent=args.use_recurrent_local,
        hidden_size=l_hidden_size,
        deterministic=args.use_deterministic_local).to(device)
    local_optimizer = get_optimizer(l_policy.parameters(),
                                    args.local_optimizer)

    # Storage
    g_rollouts = GlobalRolloutStorage(args.num_global_steps, num_scenes,
                                      g_observation_space.shape,
                                      g_action_space, g_policy.rec_state_size,
                                      1).to(device)

    slam_memory = FIFOMemory(args.slam_memory_size)

    # Loading model
    if args.load_slam != "0":
        print("Loading slam {}".format(args.load_slam))
        state_dict = torch.load(args.load_slam,
                                map_location=lambda storage, loc: storage)
        nslam_module.load_state_dict(state_dict)

    if not args.train_slam:
        nslam_module.eval()

    if args.load_global != "0":
        print("Loading global {}".format(args.load_global))
        state_dict = torch.load(args.load_global,
                                map_location=lambda storage, loc: storage)
        g_policy.load_state_dict(state_dict)

    if not args.train_global:
        g_policy.eval()

    if args.load_local != "0":
        print("Loading local {}".format(args.load_local))
        state_dict = torch.load(args.load_local,
                                map_location=lambda storage, loc: storage)
        l_policy.load_state_dict(state_dict)

    if not args.train_local:
        l_policy.eval()

    # Predict map from frame 1:
    poses = torch.from_numpy(
        np.asarray([
            infos[env_idx]['sensor_pose'] for env_idx in range(num_scenes)
        ])).float().to(device)

    _, _, local_map[:, 0, :, :], local_map[:, 1, :, :], _, local_pose = \
        nslam_module(obs, obs, poses, local_map[:, 0, :, :],
                     local_map[:, 1, :, :], local_pose)

    # Compute Global policy input
    locs = local_pose.cpu().numpy()
    global_input = torch.zeros(num_scenes, 8, local_w, local_h)
    global_orientation = torch.zeros(num_scenes, 1).long()

    for e in range(num_scenes):
        r, c = locs[e, 1], locs[e, 0]
        loc_r, loc_c = [
            int(r * 100.0 / args.map_resolution),
            int(c * 100.0 / args.map_resolution)
        ]

        local_map[e, 2:, loc_r - 1:loc_r + 2, loc_c - 1:loc_c + 2] = 1.
        global_orientation[e] = int((locs[e, 2] + 180.0) / 5.)

    global_input[:, 0:4, :, :] = local_map.detach()
    global_input[:, 4:, :, :] = nn.MaxPool2d(args.global_downscaling)(full_map)

    g_rollouts.obs[0].copy_(global_input)
    g_rollouts.extras[0].copy_(global_orientation)

    # Run Global Policy (global_goals = Long-Term Goal)
    g_value, g_action, g_action_log_prob, g_rec_states = \
        g_policy.act(
            g_rollouts.obs[0],
            g_rollouts.rec_states[0],
            g_rollouts.masks[0],
            extras=g_rollouts.extras[0],
            deterministic=False
        )

    cpu_actions = nn.Sigmoid()(g_action).cpu().numpy()
    global_goals = [[int(action[0] * local_w),
                     int(action[1] * local_h)] for action in cpu_actions]

    # Compute planner inputs
    planner_inputs = [{} for e in range(num_scenes)]
    for e, p_input in enumerate(planner_inputs):
        p_input['goal'] = global_goals[e]
        p_input['map_pred'] = global_input[e, 0, :, :].detach().cpu().numpy()
        p_input['exp_pred'] = global_input[e, 1, :, :].detach().cpu().numpy()
        p_input['pose_pred'] = planner_pose_inputs[e]

    # Output stores local goals as well as the the ground-truth action
    output = envs.get_short_term_goal(planner_inputs)

    last_obs = obs.detach()
    local_rec_states = torch.zeros(num_scenes, l_hidden_size).to(device)
    start = time.time()

    total_num_steps = -1
    g_reward = 0

    torch.set_grad_enabled(False)

    for ep_num in range(num_episodes):
        for step in range(args.max_episode_length):
            total_num_steps += 1

            g_step = (step // args.num_local_steps) % args.num_global_steps
            eval_g_step = step // args.num_local_steps + 1
            l_step = step % args.num_local_steps

            # ------------------------------------------------------------------
            # Local Policy
            del last_obs
            last_obs = obs.detach()
            local_masks = l_masks
            local_goals = output[:, :-1].to(device).long()

            if args.train_local:
                torch.set_grad_enabled(True)

            action, action_prob, local_rec_states = l_policy(
                obs,
                local_rec_states,
                local_masks,
                extras=local_goals,
            )

            if args.train_local:
                action_target = output[:, -1].long().to(device)
                policy_loss += nn.CrossEntropyLoss()(action_prob,
                                                     action_target)
                torch.set_grad_enabled(False)
            l_action = action.cpu()
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Env step
            obs, rew, done, infos = envs.step(l_action)

            l_masks = torch.FloatTensor([0 if x else 1
                                         for x in done]).to(device)
            g_masks *= l_masks
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Reinitialize variables when episode ends
            if step == args.max_episode_length - 1:  # Last episode step
                init_map_and_pose()
                del last_obs
                last_obs = obs.detach()
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Neural SLAM Module
            if args.train_slam:
                # Add frames to memory
                for env_idx in range(num_scenes):
                    env_obs = obs[env_idx].to("cpu")
                    env_poses = torch.from_numpy(
                        np.asarray(
                            infos[env_idx]['sensor_pose'])).float().to("cpu")
                    env_gt_fp_projs = torch.from_numpy(
                        np.asarray(infos[env_idx]['fp_proj'])).unsqueeze(
                            0).float().to("cpu")
                    env_gt_fp_explored = torch.from_numpy(
                        np.asarray(infos[env_idx]['fp_explored'])).unsqueeze(
                            0).float().to("cpu")
                    env_gt_pose_err = torch.from_numpy(
                        np.asarray(
                            infos[env_idx]['pose_err'])).float().to("cpu")
                    slam_memory.push(
                        (last_obs[env_idx].cpu(), env_obs, env_poses),
                        (env_gt_fp_projs, env_gt_fp_explored, env_gt_pose_err))

            poses = torch.from_numpy(
                np.asarray([
                    infos[env_idx]['sensor_pose']
                    for env_idx in range(num_scenes)
                ])).float().to(device)

            _, _, local_map[:, 0, :, :], local_map[:, 1, :, :], _, local_pose = \
                nslam_module(last_obs, obs, poses, local_map[:, 0, :, :],
                             local_map[:, 1, :, :], local_pose, build_maps=True)

            locs = local_pose.cpu().numpy()
            planner_pose_inputs[:, :3] = locs + origins
            local_map[:,
                      2, :, :].fill_(0.)  # Resetting current location channel
            for e in range(num_scenes):
                r, c = locs[e, 1], locs[e, 0]
                loc_r, loc_c = [
                    int(r * 100.0 / args.map_resolution),
                    int(c * 100.0 / args.map_resolution)
                ]

                local_map[e, 2:, loc_r - 2:loc_r + 3, loc_c - 2:loc_c + 3] = 1.
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Global Policy
            if l_step == args.num_local_steps - 1:
                # For every global step, update the full and local maps
                for e in range(num_scenes):
                    full_map[e, :, lmb[e, 0]:lmb[e, 1], lmb[e, 2]:lmb[e, 3]] = \
                        local_map[e]
                    full_pose[e] = local_pose[e] + \
                                   torch.from_numpy(origins[e]).to(device).float()

                    locs = full_pose[e].cpu().numpy()
                    r, c = locs[1], locs[0]
                    loc_r, loc_c = [
                        int(r * 100.0 / args.map_resolution),
                        int(c * 100.0 / args.map_resolution)
                    ]

                    lmb[e] = get_local_map_boundaries(
                        (loc_r, loc_c), (local_w, local_h), (full_w, full_h))

                    planner_pose_inputs[e, 3:] = lmb[e]
                    origins[e] = [
                        lmb[e][2] * args.map_resolution / 100.0,
                        lmb[e][0] * args.map_resolution / 100.0, 0.
                    ]

                    local_map[e] = full_map[e, :, lmb[e, 0]:lmb[e, 1],
                                            lmb[e, 2]:lmb[e, 3]]
                    local_pose[e] = full_pose[e] - \
                                    torch.from_numpy(origins[e]).to(device).float()

                locs = local_pose.cpu().numpy()
                for e in range(num_scenes):
                    global_orientation[e] = int((locs[e, 2] + 180.0) / 5.)
                global_input[:, 0:4, :, :] = local_map
                global_input[:, 4:, :, :] = \
                    nn.MaxPool2d(args.global_downscaling)(full_map)

                if False:
                    for i in range(4):
                        ax[i].clear()
                        ax[i].set_yticks([])
                        ax[i].set_xticks([])
                        ax[i].set_yticklabels([])
                        ax[i].set_xticklabels([])
                        ax[i].imshow(global_input.cpu().numpy()[0, 4 + i])
                    plt.gcf().canvas.flush_events()
                    # plt.pause(0.1)
                    fig.canvas.start_event_loop(0.001)
                    plt.gcf().canvas.flush_events()

                # Get exploration reward and metrics
                g_reward = torch.from_numpy(
                    np.asarray([
                        infos[env_idx]['exp_reward']
                        for env_idx in range(num_scenes)
                    ])).float().to(device)

                if args.eval:
                    g_reward = g_reward * 50.0  # Convert reward to area in m2

                g_process_rewards += g_reward.cpu().numpy()
                g_total_rewards = g_process_rewards * \
                                  (1 - g_masks.cpu().numpy())
                g_process_rewards *= g_masks.cpu().numpy()
                per_step_g_rewards.append(np.mean(g_reward.cpu().numpy()))

                if np.sum(g_total_rewards) != 0:
                    for tr in g_total_rewards:
                        g_episode_rewards.append(tr) if tr != 0 else None

                if args.eval:
                    exp_ratio = torch.from_numpy(
                        np.asarray([
                            infos[env_idx]['exp_ratio']
                            for env_idx in range(num_scenes)
                        ])).float()

                    for e in range(num_scenes):
                        explored_area_log[e, ep_num, eval_g_step - 1] = \
                            explored_area_log[e, ep_num, eval_g_step - 2] + \
                            g_reward[e].cpu().numpy()
                        explored_ratio_log[e, ep_num, eval_g_step - 1] = \
                            explored_ratio_log[e, ep_num, eval_g_step - 2] + \
                            exp_ratio[e].cpu().numpy()

                # Add samples to global policy storage
                g_rollouts.insert(global_input, g_rec_states, g_action,
                                  g_action_log_prob, g_value, g_reward,
                                  g_masks, global_orientation)

                # Sample long-term goal from global policy
                g_value, g_action, g_action_log_prob, g_rec_states = \
                    g_policy.act(
                        g_rollouts.obs[g_step + 1],
                        g_rollouts.rec_states[g_step + 1],
                        g_rollouts.masks[g_step + 1],
                        extras=g_rollouts.extras[g_step + 1],
                        deterministic=False
                    )
                cpu_actions = nn.Sigmoid()(g_action).cpu().numpy()
                global_goals = [[
                    int(action[0] * local_w),
                    int(action[1] * local_h)
                ] for action in cpu_actions]

                g_reward = 0
                g_masks = torch.ones(num_scenes).float().to(device)
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Get short term goal
            planner_inputs = [{} for e in range(num_scenes)]
            for e, p_input in enumerate(planner_inputs):
                p_input['map_pred'] = local_map[e, 0, :, :].cpu().numpy()
                p_input['exp_pred'] = local_map[e, 1, :, :].cpu().numpy()
                p_input['pose_pred'] = planner_pose_inputs[e]
                p_input['goal'] = global_goals[e]

            output = envs.get_short_term_goal(planner_inputs)
            # ------------------------------------------------------------------

            ### TRAINING
            torch.set_grad_enabled(True)
            # ------------------------------------------------------------------
            # Train Neural SLAM Module
            if args.train_slam and len(slam_memory) > args.slam_batch_size:
                for _ in range(args.slam_iterations):
                    inputs, outputs = slam_memory.sample(args.slam_batch_size)
                    b_obs_last, b_obs, b_poses = inputs
                    gt_fp_projs, gt_fp_explored, gt_pose_err = outputs

                    b_obs = b_obs.to(device)
                    b_obs_last = b_obs_last.to(device)
                    b_poses = b_poses.to(device)

                    gt_fp_projs = gt_fp_projs.to(device)
                    gt_fp_explored = gt_fp_explored.to(device)
                    gt_pose_err = gt_pose_err.to(device)

                    b_proj_pred, b_fp_exp_pred, _, _, b_pose_err_pred, _ = \
                        nslam_module(b_obs_last, b_obs, b_poses,
                                     None, None, None,
                                     build_maps=False)
                    loss = 0
                    if args.proj_loss_coeff > 0:
                        proj_loss = F.binary_cross_entropy(
                            b_proj_pred, gt_fp_projs)
                        costs.append(proj_loss.item())
                        loss += args.proj_loss_coeff * proj_loss

                    if args.exp_loss_coeff > 0:
                        exp_loss = F.binary_cross_entropy(
                            b_fp_exp_pred, gt_fp_explored)
                        exp_costs.append(exp_loss.item())
                        loss += args.exp_loss_coeff * exp_loss

                    if args.pose_loss_coeff > 0:
                        pose_loss = torch.nn.MSELoss()(b_pose_err_pred,
                                                       gt_pose_err)
                        pose_costs.append(args.pose_loss_coeff *
                                          pose_loss.item())
                        loss += args.pose_loss_coeff * pose_loss

                    if args.train_slam:
                        slam_optimizer.zero_grad()
                        loss.backward()
                        slam_optimizer.step()

                    del b_obs_last, b_obs, b_poses
                    del gt_fp_projs, gt_fp_explored, gt_pose_err
                    del b_proj_pred, b_fp_exp_pred, b_pose_err_pred

            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Train Local Policy
            if (l_step + 1) % args.local_policy_update_freq == 0 \
                    and args.train_local:
                local_optimizer.zero_grad()
                policy_loss.backward()
                local_optimizer.step()
                l_action_losses.append(policy_loss.item())
                policy_loss = 0
                local_rec_states = local_rec_states.detach_()
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Train Global Policy
            if g_step % args.num_global_steps == args.num_global_steps - 1 \
                    and l_step == args.num_local_steps - 1:
                if args.train_global:
                    g_next_value = g_policy.get_value(
                        g_rollouts.obs[-1],
                        g_rollouts.rec_states[-1],
                        g_rollouts.masks[-1],
                        extras=g_rollouts.extras[-1]).detach()

                    g_rollouts.compute_returns(g_next_value, args.use_gae,
                                               args.gamma, args.tau)
                    g_value_loss, g_action_loss, g_dist_entropy = \
                        g_agent.update(g_rollouts)
                    g_value_losses.append(g_value_loss)
                    g_action_losses.append(g_action_loss)
                    g_dist_entropies.append(g_dist_entropy)
                g_rollouts.after_update()
            # ------------------------------------------------------------------

            # Finish Training
            torch.set_grad_enabled(False)
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Logging
            if total_num_steps % args.log_interval == 0:
                end = time.time()
                time_elapsed = time.gmtime(end - start)
                log = " ".join([
                    "Time: {0:0=2d}d".format(time_elapsed.tm_mday - 1),
                    "{},".format(time.strftime("%Hh %Mm %Ss", time_elapsed)),
                    "num timesteps {},".format(total_num_steps *
                                               num_scenes),
                    "FPS {},".format(int(total_num_steps * num_scenes \
                                         / (end - start)))
                ])

                log += "\n\tRewards:"

                if len(g_episode_rewards) > 0:
                    log += " ".join([
                        " Global step mean/med rew:",
                        "{:.4f}/{:.4f},".format(np.mean(per_step_g_rewards),
                                                np.median(per_step_g_rewards)),
                        " Global eps mean/med/min/max eps rew:",
                        "{:.3f}/{:.3f}/{:.3f}/{:.3f},".format(
                            np.mean(g_episode_rewards),
                            np.median(g_episode_rewards),
                            np.min(g_episode_rewards),
                            np.max(g_episode_rewards))
                    ])

                log += "\n\tLosses:"

                if args.train_local and len(l_action_losses) > 0:
                    log += " ".join([
                        " Local Loss:",
                        "{:.3f},".format(np.mean(l_action_losses))
                    ])

                if args.train_global and len(g_value_losses) > 0:
                    log += " ".join([
                        " Global Loss value/action/dist:",
                        "{:.3f}/{:.3f}/{:.3f},".format(
                            np.mean(g_value_losses), np.mean(g_action_losses),
                            np.mean(g_dist_entropies))
                    ])

                if args.train_slam and len(costs) > 0:
                    log += " ".join([
                        " SLAM Loss proj/exp/pose:"
                        "{:.4f}/{:.4f}/{:.4f}".format(np.mean(costs),
                                                      np.mean(exp_costs),
                                                      np.mean(pose_costs))
                    ])

                print(log)
                logging.info(log)
            # ------------------------------------------------------------------

            # ------------------------------------------------------------------
            # Save best models
            if (total_num_steps * num_scenes) % args.save_interval < \
                    num_scenes:

                # Save Neural SLAM Model
                if len(costs) >= 1000 and np.mean(costs) < best_cost \
                        and not args.eval:
                    best_cost = np.mean(costs)
                    torch.save(nslam_module.state_dict(),
                               os.path.join(log_dir, "model_best.slam"))

                # Save Local Policy Model
                if len(l_action_losses) >= 100 and \
                        (np.mean(l_action_losses) <= best_local_loss) \
                        and not args.eval:
                    torch.save(l_policy.state_dict(),
                               os.path.join(log_dir, "model_best.local"))

                    best_local_loss = np.mean(l_action_losses)

                # Save Global Policy Model
                if len(g_episode_rewards) >= 100 and \
                        (np.mean(g_episode_rewards) >= best_g_reward) \
                        and not args.eval:
                    torch.save(g_policy.state_dict(),
                               os.path.join(log_dir, "model_best.global"))
                    best_g_reward = np.mean(g_episode_rewards)

            # Save periodic models
            if (total_num_steps * num_scenes) % args.save_periodic < \
                    num_scenes:
                step = total_num_steps * num_scenes
                if args.train_slam:
                    torch.save(
                        nslam_module.state_dict(),
                        os.path.join(dump_dir,
                                     "periodic_{}.slam".format(step)))
                if args.train_local:
                    torch.save(
                        l_policy.state_dict(),
                        os.path.join(dump_dir,
                                     "periodic_{}.local".format(step)))
                if args.train_global:
                    torch.save(
                        g_policy.state_dict(),
                        os.path.join(dump_dir,
                                     "periodic_{}.global".format(step)))
            # ------------------------------------------------------------------

    # Print and save model performance numbers during evaluation
    if args.eval:
        logfile = open("{}/explored_area.txt".format(dump_dir), "w+")
        for e in range(num_scenes):
            for i in range(explored_area_log[e].shape[0]):
                logfile.write(str(explored_area_log[e, i]) + "\n")
                logfile.flush()

        logfile.close()

        logfile = open("{}/explored_ratio.txt".format(dump_dir), "w+")
        for e in range(num_scenes):
            for i in range(explored_ratio_log[e].shape[0]):
                logfile.write(str(explored_ratio_log[e, i]) + "\n")
                logfile.flush()

        logfile.close()

        log = "Final Exp Area: \n"
        for i in range(explored_area_log.shape[2]):
            log += "{:.5f}, ".format(np.mean(explored_area_log[:, :, i]))

        log += "\nFinal Exp Ratio: \n"
        for i in range(explored_ratio_log.shape[2]):
            log += "{:.5f}, ".format(np.mean(explored_ratio_log[:, :, i]))

        print(log)
        logging.info(log)
def main():
    print("#######")
    print(
        "WARNING: All rewards are clipped or normalized so you need to use a monitor (see envs.py) or visdom plot to get true rewards"
    )
    print("#######")

    torch.set_num_threads(1)

    with open(args.eval_env_seeds_file, 'r') as f:
        eval_env_seeds = json.load(f)

    if args.vis:
        from visdom import Visdom
        viz = Visdom(port=args.port)
        win = None

    envs = [
        make_env(args.env_name, args.seed, i, args.log_dir, args.add_timestep)
        for i in range(args.num_processes)
    ]

    eval_dir = os.path.join(args.log_dir, "eval/")
    if not os.path.exists(eval_dir):
        os.makedirs(eval_dir)
    eval_env = [
        make_env(args.env_name,
                 args.seed,
                 0,
                 eval_dir,
                 args.add_timestep,
                 early_resets=True)
    ]
    eval_env = DummyVecEnv(eval_env)

    if args.num_processes > 1:
        envs = SubprocVecEnv(envs)
    else:
        envs = DummyVecEnv(envs)

    if len(envs.observation_space.shape) == 1:
        envs = VecNormalize(envs, gamma=args.gamma)

    if len(envs.observation_space.shape) == 1:
        # Don't touch rewards for evaluation
        eval_env = VecNormalize(eval_env, ret=False)
        # set running filter to be the same
        eval_env.ob_rms = envs.ob_rms

    obs_shape = envs.observation_space.shape
    obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])

    actor_critic = Policy(obs_shape, envs.action_space, args.recurrent_policy)

    if envs.action_space.__class__.__name__ == "Discrete":
        action_shape = 1
    else:
        action_shape = envs.action_space.shape[0]

    if args.cuda:
        actor_critic.cuda()

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    rollouts = RolloutStorage(args.num_steps, args.num_processes, obs_shape,
                              envs.action_space, actor_critic.state_size)
    current_obs = torch.zeros(args.num_processes, *obs_shape)

    def update_current_obs(obs):
        shape_dim0 = envs.observation_space.shape[0]
        obs = torch.from_numpy(obs).float()
        if args.num_stack > 1:
            current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
        current_obs[:, -shape_dim0:] = obs

    obs = envs.reset()
    update_current_obs(obs)

    rollouts.observations[0].copy_(current_obs)

    # These variables are used to compute average rewards for all processes.
    episode_rewards = torch.zeros([args.num_processes, 1])
    final_rewards = torch.zeros([args.num_processes, 1])

    if args.cuda:
        current_obs = current_obs.cuda()
        rollouts.cuda()

    start = time.time()
    for j in range(num_updates):
        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, states = actor_critic.act(
                    rollouts.observations[step], rollouts.states[step],
                    rollouts.masks[step])
            cpu_actions = action.squeeze(1).cpu().numpy()

            # Obser reward and next obs
            obs, reward, done, info = envs.step(cpu_actions)
            reward = torch.from_numpy(np.expand_dims(np.stack(reward),
                                                     1)).float()
            episode_rewards += reward

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            final_rewards *= masks
            final_rewards += (1 - masks) * episode_rewards
            episode_rewards *= masks

            if args.cuda:
                masks = masks.cuda()

            if current_obs.dim() == 4:
                current_obs *= masks.unsqueeze(2).unsqueeze(2)
            else:
                current_obs *= masks

            update_current_obs(obs)
            rollouts.insert(current_obs, states, action, action_log_prob,
                            value, reward, masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.observations[-1],
                                                rollouts.states[-1],
                                                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.tau)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        if j % args.save_interval == 0 and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            # A really ugly way to save a model to CPU
            save_model = actor_critic
            if args.cuda:
                save_model = copy.deepcopy(actor_critic).cpu()

            save_model = [
                save_model,
                hasattr(envs, 'ob_rms') and envs.ob_rms or None
            ]

            torch.save(save_model,
                       os.path.join(save_path, args.env_name + ".pt"))

        if j % args.log_interval == 0:
            end = time.time()
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            print(
                "Updates {}, num timesteps {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        final_rewards.mean(), final_rewards.median(),
                        final_rewards.min(), final_rewards.max(), dist_entropy,
                        value_loss, action_loss))

        if args.vis and j % args.vis_interval == 0:
            try:
                # Sometimes monitor doesn't properly flush the outputs
                win = visdom_plot(viz, win, args.log_dir, args.env_name,
                                  args.algo, args.num_frames)
            except IOError:
                pass

    validation_returns = evaluate_with_seeds(eval_env, actor_critic, args.cuda,
                                             eval_env_seeds)

    report_results([
        dict(name='validation_return',
             type='objective',
             value=np.mean(validation_returns))
    ])
Exemple #25
0
def main():
    global args
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.vis = not args.no_vis

    # Set options
    if args.path_opt is not None:
        with open(args.path_opt, 'r') as handle:
            options = yaml.load(handle)
    if args.vis_path_opt is not None:
        with open(args.vis_path_opt, 'r') as handle:
            vis_options = yaml.load(handle)
    print('## args')
    pprint(vars(args))
    print('## options')
    pprint(options)

    # Put alg_%s and optim_%s to alg and optim depending on commandline
    options['use_cuda'] = args.cuda
    options['trial'] = args.trial
    options['alg'] = options['alg_%s' % args.algo]
    options['optim'] = options['optim_%s' % args.algo]
    alg_opt = options['alg']
    alg_opt['algo'] = args.algo
    model_opt = options['model']
    env_opt = options['env']
    env_opt['env-name'] = args.env_name
    log_opt = options['logs']
    optim_opt = options['optim']
    model_opt['time_scale'] = env_opt['time_scale']
    if model_opt['mode'] in ['baselinewtheta', 'phasewtheta']:
        model_opt['theta_space_mode'] = env_opt['theta_space_mode']
        model_opt['theta_sz'] = env_opt['theta_sz']
    elif model_opt['mode'] in ['baseline_lowlevel', 'phase_lowlevel']:
        model_opt['theta_space_mode'] = env_opt['theta_space_mode']

    # Check asserts
    assert (model_opt['mode'] in [
        'baseline', 'baseline_reverse', 'phasesimple', 'phasewstate',
        'baselinewtheta', 'phasewtheta', 'baseline_lowlevel', 'phase_lowlevel',
        'interpolate', 'cyclic', 'maze_baseline', 'maze_baseline_wphase'
    ])
    assert (args.algo in ['a2c', 'ppo', 'acktr'])
    if model_opt['recurrent_policy']:
        assert args.algo in ['a2c', 'ppo'
                             ], 'Recurrent policy is not implemented for ACKTR'

    # Set seed - just make the seed the trial number
    seed = args.trial
    torch.manual_seed(seed)
    if args.cuda:
        torch.cuda.manual_seed(seed)

    # Initialization
    num_updates = int(optim_opt['num_frames']
                      ) // alg_opt['num_steps'] // alg_opt['num_processes']
    torch.set_num_threads(1)

    # Print warning
    print("#######")
    print(
        "WARNING: All rewards are clipped or normalized so you need to use a monitor (see envs.py) or visdom plot to get true rewards"
    )
    print("#######")

    # Set logging / load previous checkpoint
    logpath = os.path.join(log_opt['log_base'], model_opt['mode'],
                           log_opt['exp_name'], args.algo, args.env_name,
                           'trial%d' % args.trial)
    if len(args.resume) > 0:
        assert (os.path.isfile(os.path.join(logpath, args.resume)))
        ckpt = torch.load(os.path.join(logpath, 'ckpt.pth.tar'))
        start_update = ckpt['update_count']
    else:
        # Make directory, check before overwriting
        if os.path.isdir(logpath):
            if click.confirm(
                    'Logs directory already exists in {}. Erase?'.format(
                        logpath, default=False)):
                os.system('rm -rf ' + logpath)
            else:
                return
        os.system('mkdir -p ' + logpath)
        start_update = 0

        # Save options and args
        with open(os.path.join(logpath, os.path.basename(args.path_opt)),
                  'w') as f:
            yaml.dump(options, f, default_flow_style=False)
        with open(os.path.join(logpath, 'args.yaml'), 'w') as f:
            yaml.dump(vars(args), f, default_flow_style=False)

        # Save git info as well
        os.system('git status > %s' % os.path.join(logpath, 'git_status.txt'))
        os.system('git diff > %s' % os.path.join(logpath, 'git_diff.txt'))
        os.system('git show > %s' % os.path.join(logpath, 'git_show.txt'))

    # Set up plotting dashboard
    dashboard = Dashboard(options,
                          vis_options,
                          logpath,
                          vis=args.vis,
                          port=args.port)

    # If interpolate mode, choose states
    if options['model']['mode'] == 'phase_lowlevel' and options['env'][
            'theta_space_mode'] == 'pretrain_interp':
        all_states = torch.load(env_opt['saved_state_file'])
        s1 = random.choice(all_states)
        s2 = random.choice(all_states)
        fixed_states = [s1, s2]
    elif model_opt['mode'] == 'interpolate':
        all_states = torch.load(env_opt['saved_state_file'])
        s1 = all_states[env_opt['s1_ind']]
        s2 = all_states[env_opt['s2_ind']]
        fixed_states = [s1, s2]
    else:
        fixed_states = None

    # Create environments
    dummy_env = make_env(args.env_name, seed, 0, logpath, options,
                         args.verbose)
    dummy_env = dummy_env()
    envs = [
        make_env(args.env_name, seed, i, logpath, options, args.verbose,
                 fixed_states) for i in range(alg_opt['num_processes'])
    ]
    if alg_opt['num_processes'] > 1:
        envs = SubprocVecEnv(envs)
    else:
        envs = DummyVecEnv(envs)

    # Get theta_sz for models (if applicable)
    dummy_env.reset()
    if model_opt['mode'] == 'baseline_lowlevel':
        model_opt['theta_sz'] = dummy_env.env.theta_sz
    elif model_opt['mode'] == 'phase_lowlevel':
        model_opt['theta_sz'] = dummy_env.env.env.theta_sz
    if 'theta_sz' in model_opt:
        env_opt['theta_sz'] = model_opt['theta_sz']

    # Get observation shape
    obs_shape = envs.observation_space.shape
    obs_shape = (obs_shape[0] * env_opt['num_stack'], *obs_shape[1:])

    # Do vec normalize, but mask out what we don't want altered
    if len(envs.observation_space.shape) == 1:
        ignore_mask = np.zeros(envs.observation_space.shape)
        if env_opt['add_timestep']:
            ignore_mask[-1] = 1
        if model_opt['mode'] in [
                'baselinewtheta', 'phasewtheta', 'baseline_lowlevel',
                'phase_lowlevel'
        ]:
            theta_sz = env_opt['theta_sz']
            if env_opt['add_timestep']:
                ignore_mask[-(theta_sz + 1):] = 1
            else:
                ignore_mask[-theta_sz:] = 1
        if args.finetune_baseline:
            ignore_mask = dummy_env.unwrapped._get_obs_mask()
            freeze_mask, _ = dummy_env.unwrapped._get_pro_ext_mask()
            if env_opt['add_timestep']:
                ignore_mask = np.concatenate([ignore_mask, [1]])
                freeze_mask = np.concatenate([freeze_mask, [0]])
            ignore_mask = (ignore_mask + freeze_mask > 0).astype(float)
            envs = ObservationFilter(envs,
                                     ret=alg_opt['norm_ret'],
                                     has_timestep=True,
                                     noclip=env_opt['step_plus_noclip'],
                                     ignore_mask=ignore_mask,
                                     freeze_mask=freeze_mask,
                                     time_scale=env_opt['time_scale'],
                                     gamma=env_opt['gamma'])
        else:
            envs = ObservationFilter(envs,
                                     ret=alg_opt['norm_ret'],
                                     has_timestep=env_opt['add_timestep'],
                                     noclip=env_opt['step_plus_noclip'],
                                     ignore_mask=ignore_mask,
                                     time_scale=env_opt['time_scale'],
                                     gamma=env_opt['gamma'])

    # Set up algo monitoring
    alg_filename = os.path.join(logpath, 'Alg.Monitor.csv')
    alg_f = open(alg_filename, "wt")
    alg_f.write('# Alg Logging %s\n' %
                json.dumps({
                    "t_start": time.time(),
                    'env_id': dummy_env.spec and dummy_env.spec.id,
                    'mode': options['model']['mode'],
                    'name': options['logs']['exp_name']
                }))
    alg_fields = ['value_loss', 'action_loss', 'dist_entropy']
    alg_logger = csv.DictWriter(alg_f, fieldnames=alg_fields)
    alg_logger.writeheader()
    alg_f.flush()

    # Create the policy network
    actor_critic = Policy(obs_shape, envs.action_space, model_opt)
    if args.cuda:
        actor_critic.cuda()

    # Create the agent
    if envs.action_space.__class__.__name__ == "Discrete":
        action_shape = 1
    else:
        action_shape = envs.action_space.shape[0]

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               alg_opt['value_loss_coef'],
                               alg_opt['entropy_coef'],
                               lr=optim_opt['lr'],
                               eps=optim_opt['eps'],
                               alpha=optim_opt['alpha'],
                               max_grad_norm=optim_opt['max_grad_norm'])
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         alg_opt['clip_param'],
                         alg_opt['ppo_epoch'],
                         alg_opt['num_mini_batch'],
                         alg_opt['value_loss_coef'],
                         alg_opt['entropy_coef'],
                         lr=optim_opt['lr'],
                         eps=optim_opt['eps'],
                         max_grad_norm=optim_opt['max_grad_norm'])
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               alg_opt['value_loss_coef'],
                               alg_opt['entropy_coef'],
                               acktr=True)
    rollouts = RolloutStorage(alg_opt['num_steps'], alg_opt['num_processes'],
                              obs_shape, envs.action_space,
                              actor_critic.state_size)
    current_obs = torch.zeros(alg_opt['num_processes'], *obs_shape)

    # Update agent with loaded checkpoint
    if len(args.resume) > 0:
        # This should update both the policy network and the optimizer
        agent.load_state_dict(ckpt['agent'])

        # Set ob_rms
        envs.ob_rms = ckpt['ob_rms']
    elif len(args.other_resume) > 0:
        ckpt = torch.load(args.other_resume)

        # This should update both the policy network
        agent.actor_critic.load_state_dict(ckpt['agent']['model'])

        # Set ob_rms
        envs.ob_rms = ckpt['ob_rms']
    elif args.finetune_baseline:
        # Load the model based on the trial number
        ckpt_base = options['lowlevel']['ckpt']
        ckpt_file = ckpt_base + '/trial%d/ckpt.pth.tar' % args.trial
        ckpt = torch.load(ckpt_file)

        # Make "input mask" that tells the model which inputs were the same from before and should be copied
        oldinput_mask, _ = dummy_env.unwrapped._get_pro_ext_mask()

        # This should update both the policy network
        agent.actor_critic.load_state_dict_special(ckpt['agent']['model'],
                                                   oldinput_mask)

        # Set ob_rms
        old_rms = ckpt['ob_rms']
        old_size = old_rms.mean.size
        if env_opt['add_timestep']:
            old_size -= 1

        # Only copy the pro state part of it
        envs.ob_rms.mean[:old_size] = old_rms.mean[:old_size]
        envs.ob_rms.var[:old_size] = old_rms.var[:old_size]

    # Inline define our helper function for updating obs
    def update_current_obs(obs):
        shape_dim0 = envs.observation_space.shape[0]
        obs = torch.from_numpy(obs).float()
        if env_opt['num_stack'] > 1:
            current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
        current_obs[:, -shape_dim0:] = obs

    # Reset our env and rollouts
    obs = envs.reset()
    update_current_obs(obs)
    rollouts.observations[0].copy_(current_obs)
    if args.cuda:
        current_obs = current_obs.cuda()
        rollouts.cuda()

    # These variables are used to compute average rewards for all processes.
    episode_rewards = torch.zeros([alg_opt['num_processes'], 1])
    final_rewards = torch.zeros([alg_opt['num_processes'], 1])

    # Update loop
    start = time.time()
    for j in range(start_update, num_updates):
        for step in range(alg_opt['num_steps']):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, states = actor_critic.act(
                    rollouts.observations[step], rollouts.states[step],
                    rollouts.masks[step])
            cpu_actions = action.squeeze(1).cpu().numpy()

            # Observe reward and next obs
            obs, reward, done, info = envs.step(cpu_actions)
            #pdb.set_trace()
            reward = torch.from_numpy(np.expand_dims(np.stack(reward),
                                                     1)).float()
            episode_rewards += reward

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            final_rewards *= masks
            final_rewards += (1 - masks) * episode_rewards
            episode_rewards *= masks

            if args.cuda:
                masks = masks.cuda()

            if current_obs.dim() == 4:
                current_obs *= masks.unsqueeze(2).unsqueeze(2)
            else:
                current_obs *= masks

            update_current_obs(obs)
            rollouts.insert(current_obs, states, action, action_log_prob,
                            value, reward, masks)

        # Update model and rollouts
        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.observations[-1],
                                                rollouts.states[-1],
                                                rollouts.masks[-1]).detach()
        rollouts.compute_returns(next_value, alg_opt['use_gae'],
                                 env_opt['gamma'], alg_opt['gae_tau'])
        value_loss, action_loss, dist_entropy = agent.update(rollouts)
        rollouts.after_update()

        # Add algo updates here
        alg_info = {}
        alg_info['value_loss'] = value_loss
        alg_info['action_loss'] = action_loss
        alg_info['dist_entropy'] = dist_entropy
        alg_logger.writerow(alg_info)
        alg_f.flush()

        # Save checkpoints
        total_num_steps = (j +
                           1) * alg_opt['num_processes'] * alg_opt['num_steps']
        #save_interval = log_opt['save_interval'] * alg_opt['log_mult']
        save_interval = 100
        if j % save_interval == 0:
            # Save all of our important information
            save_checkpoint(logpath,
                            agent,
                            envs,
                            j,
                            total_num_steps,
                            args.save_every,
                            final=False)

        # Print log
        log_interval = log_opt['log_interval'] * alg_opt['log_mult']
        if j % log_interval == 0:
            end = time.time()
            print(
                "{}: Updates {}, num timesteps {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}"
                .format(options['logs']['exp_name'], j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        final_rewards.mean(), final_rewards.median(),
                        final_rewards.min(), final_rewards.max(), dist_entropy,
                        value_loss, action_loss))

        # Do dashboard logging
        vis_interval = log_opt['vis_interval'] * alg_opt['log_mult']
        if args.vis and j % vis_interval == 0:
            try:
                # Sometimes monitor doesn't properly flush the outputs
                dashboard.visdom_plot()
            except IOError:
                pass

    # Save final checkpoint
    save_checkpoint(logpath,
                    agent,
                    envs,
                    j,
                    total_num_steps,
                    args.save_every,
                    final=False)

    # Close logging file
    alg_f.close()
Exemple #26
0
def main():
    print("#######")
    print(
        "WARNING: All rewards are clipped or normalized so you need to use a monitor (see envs.py) or visdom plot to get true rewards"
    )
    print("#######")

    torch.set_num_threads(1)

    if args.vis:
        from visdom import Visdom
        viz = Visdom(port=args.port)
        win = None

    envs = [
        make_env(args.env_name, args.seed, i, args.log_dir, args.add_timestep)
        for i in range(args.num_processes)
    ]

    if args.num_processes > 1:
        envs = SubprocVecEnv(envs)
    else:
        envs = DummyVecEnv(envs)

    if len(envs.observation_space.shape) == 1:
        envs = VecNormalize(envs, gamma=args.gamma)

    obs_shape = envs.observation_space.shape
    obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])

    if args.load_model is not None:
        actor_critic = torch.load(args.load_model)[0]
    else:
        actor_critic = Policy(obs_shape, envs.action_space,
                              args.recurrent_policy, args.hidden_size, args)

    if envs.action_space.__class__.__name__ == "Discrete":
        action_shape = 1
    else:
        action_shape = envs.action_space.shape[0]

    if args.cuda:
        actor_critic.cuda()

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm,
                               pop_art=args.pop_art)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    rollouts = RolloutStorage(args.num_steps, args.num_processes, obs_shape,
                              envs.action_space, actor_critic.state_size)
    current_obs = torch.zeros(args.num_processes, *obs_shape)

    obs = envs.reset()
    update_current_obs(obs, current_obs, obs_shape, args.num_stack)

    rollouts.observations[0].copy_(current_obs)

    # These variables are used to compute average rewards for all processes.
    episode_rewards = torch.zeros([args.num_processes, 1])
    final_rewards = torch.zeros([args.num_processes, 1])

    if args.cuda:
        current_obs = current_obs.cuda()
        rollouts.cuda()

    start = time.time()
    scale = 1.
    current_pdrr = [0., 0.]
    last_update = 0

    ### parameters for adaptive reward scaling ###
    t_stop = 0
    beta = .99
    R_prev = -1e9
    m_max = -1e9
    m_t = 0
    reverse = False

    last_scale_t = -1e9
    ###

    for j in range(num_updates):
        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, states = actor_critic.act(
                    rollouts.observations[step], rollouts.states[step],
                    rollouts.masks[step])
            cpu_actions = action.squeeze(1).cpu().numpy()

            # Obser reward and next obs
            obs, reward, done, info = envs.step(cpu_actions)

            # reward *= args.reward_scaling

            reward = torch.from_numpy(np.expand_dims(np.stack(reward),
                                                     1)).float()
            episode_rewards += reward

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            final_rewards *= masks
            final_rewards += (1 - masks) * episode_rewards
            episode_rewards *= masks

            if args.cuda:
                masks = masks.cuda()

            if current_obs.dim() == 4:
                current_obs *= masks.unsqueeze(2).unsqueeze(2)
            else:
                current_obs *= masks

            update_current_obs(obs, current_obs, obs_shape, args.num_stack)
            rollouts.insert(current_obs, states, action, action_log_prob,
                            value, reward, masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.observations[-1],
                                                rollouts.states[-1],
                                                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.tau)

        t = j // args.adaptive_interval
        if args.pop_art:
            value_loss, action_loss, dist_entropy = agent.pop_art_update(
                rollouts)
        else:
            if t - last_scale_t > 100:
                value_loss, action_loss, dist_entropy = agent.update(
                    rollouts, update_actor=True)
            else:
                value_loss, action_loss, dist_entropy = agent.update(
                    rollouts, update_actor=False)

        if agent.max_grad_norm < .5 and t - last_scale_t < 100:
            agent.max_grad_norm += 0.00001

        if j % args.adaptive_interval == 0 and j and t - last_scale_t > 100:
            t = j // args.adaptive_interval

            R_t = float('{}'.format(final_rewards.mean()))
            R_ts.append(R_t)
            assert type(R_t) == float
            t_stop += 1
            m_t = beta * m_t + (1 - beta) * R_t
            m_hat = m_t / (1 - beta**t)
            print('m_hat :{}, t_stop: {}'.format(m_hat, t_stop))
            print('agent.max_grad_norm, ', agent.max_grad_norm)
            if m_hat > m_max:
                m_max = m_hat
                t_stop = 0
            if t_stop > args.tolerance:
                if reverse and m_max <= R_prev:
                    break
                elif reverse and m_max > R_prev:
                    agent.max_grad_norm = args.max_grad_norm_after
                    actor_critic.rescale(args.cdec)
                    scale *= args.cdec
                    agent.reinitialize()
                    last_scale_t = t
                elif not reverse and m_max <= R_prev:
                    agent.max_grad_norm = args.max_grad_norm_after
                    actor_critic.rescale(args.cdec)
                    scale *= args.cdec
                    agent.reinitialize()
                    reverse = True
                    last_scale_t = t
                else:
                    agent.max_grad_norm = args.max_grad_norm_after
                    actor_critic.rescale(args.cinc)
                    scale *= args.cinc
                    agent.reinitialize()
                    last_scale_t = t

                R_prev = m_max
                j = t_stop = m_t = 0
                m_max = -1e9

        # if j % args.log_interval == 0:
        # this is used for testing saturation
        # relus = actor_critic.base_forward(
        # rollouts.observations[:-1].view(-1, *rollouts.observations.size()[2:]))

        rollouts.after_update()

        if j % args.log_interval == 0:
            end = time.time()
            total_num_steps = (j + 1) * args.num_processes * args.num_steps

            # relus = log_saturation(fname=args.saturation_log,
            # first=(j==0),
            # relus=[relu.cpu().detach().numpy() for relu in relus])

            # print("saturation", relus)
            # if j > 0:
            # current_pdrr = incremental_update(current_pdrr, relus)

            print(
                "Updates {}, num timesteps {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}, scale {:.5f}"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        final_rewards.mean(), final_rewards.median(),
                        final_rewards.min(), final_rewards.max(), dist_entropy,
                        value_loss, action_loss, scale))

        if args.vis and j % args.vis_interval == 0:
            try:
                # Sometimes monitor doesn't properly flush the outputs
                win = visdom_plot(viz, win, args.log_dir, args.plot_title,
                                  args.algo, args.num_frames)
            except IOError:
                pass
Exemple #27
0
def main():

    torch.set_num_threads(1)

    if args.vis:
        summary_writer = tf.summary.FileWriter(args.save_dir)

    envs = [make_env(i, args=args) for i in range(args.num_processes)]

    if args.num_processes > 1:
        envs = SubprocVecEnv(envs)
    else:
        envs = DummyVecEnv(envs)

    if len(envs.observation_space.shape) == 1 and args.env_name not in [
            'OverCooked'
    ]:
        envs = VecNormalize(envs, gamma=args.gamma)

    obs_shape = envs.observation_space.shape
    obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])

    def update_current_obs(obs):
        shape_dim0 = envs.observation_space.shape[0]
        obs = torch.from_numpy(obs).float()
        if args.num_stack > 1:
            current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
        current_obs[:, -shape_dim0:] = obs

    def get_onehot(num_class, action):
        one_hot = np.zeros(num_class)
        one_hot[action] = 1
        one_hot = torch.from_numpy(one_hot).float()

        return one_hot

    if args.policy_type == 'shared_policy':

        actor_critic = Policy(obs_shape, envs.action_space,
                              args.recurrent_policy)

        if envs.action_space.__class__.__name__ == "Discrete":
            action_shape = 1
        else:
            action_shape = envs.action_space.shape[0]

        if args.cuda:
            actor_critic.cuda()

        if args.algo == 'a2c':
            agent = algo.A2C_ACKTR(
                actor_critic,
                args.value_loss_coef,
                args.entropy_coef,
                lr=args.lr,
                eps=args.eps,
                alpha=args.alpha,
                max_grad_norm=args.max_grad_norm,
            )
        elif args.algo == 'ppo':
            agent = algo.PPO(
                actor_critic,
                args.clip_param,
                args.ppo_epoch,
                args.num_mini_batch,
                args.value_loss_coef,
                args.entropy_coef,
                lr=args.lr,
                eps=args.eps,
                max_grad_norm=args.max_grad_norm,
            )
        elif args.algo == 'acktr':
            agent = algo.A2C_ACKTR(
                actor_critic,
                args.value_loss_coef,
                args.entropy_coef,
                acktr=True,
            )

        rollouts = RolloutStorage(args.num_steps, args.num_processes,
                                  obs_shape, envs.action_space,
                                  actor_critic.state_size)
        current_obs = torch.zeros(args.num_processes, *obs_shape)

        obs = envs.reset()

        update_current_obs(obs)

        rollouts.observations[0].copy_(current_obs)

        episode_reward_raw = 0.0
        final_reward_raw = 0.0

        if args.cuda:
            current_obs = current_obs.cuda()
            rollouts.cuda()

        # try to load checkpoint
        try:
            num_trained_frames = np.load(args.save_dir +
                                         '/num_trained_frames.npy')[0]
            try:
                actor_critic.load_state_dict(
                    torch.load(args.save_dir + '/trained_learner.pth'))
                print('Load learner previous point: Successed')
            except Exception as e:
                print('Load learner previous point: Failed')
        except Exception as e:
            num_trained_frames = 0
        print('Learner has been trained to step: ' + str(num_trained_frames))

        start = time.time()
        j = 0
        while True:
            if num_trained_frames > args.num_frames:
                break

            for step in range(args.num_steps):
                # Sample actions
                with torch.no_grad():
                    value, action, action_log_prob, states = actor_critic.act(
                        rollouts.observations[step],
                        rollouts.states[step],
                        rollouts.masks[step],
                    )
                cpu_actions = action.squeeze(1).cpu().numpy()

                # Obser reward and next obs
                obs, reward_raw, done, info = envs.step(cpu_actions)

                episode_reward_raw += reward_raw[0]
                if done[0]:
                    final_reward_raw = episode_reward_raw
                    episode_reward_raw = 0.0
                reward = np.sign(reward_raw)
                reward = torch.from_numpy(np.expand_dims(np.stack(reward),
                                                         1)).float()

                # If done then clean the history of observations.
                masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                           for done_ in done])

                if args.cuda:
                    masks = masks.cuda()

                if current_obs.dim() == 4:
                    current_obs *= masks.unsqueeze(2).unsqueeze(2)
                else:
                    current_obs *= masks

                update_current_obs(obs)
                rollouts.insert(current_obs, states, action, action_log_prob,
                                value, reward, masks)

            with torch.no_grad():
                next_value = actor_critic.get_value(
                    rollouts.observations[-1],
                    rollouts.states[-1],
                    rollouts.masks[-1],
                ).detach()

            rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                     args.tau)

            value_loss, action_loss, dist_entropy = agent.update(rollouts)

            rollouts.after_update()

            num_trained_frames += (args.num_steps * args.num_processes)
            j += 1

            # save checkpoint
            if j % args.save_interval == 0 and args.save_dir != "":
                try:
                    np.save(
                        args.save_dir + '/num_trained_frames.npy',
                        np.array([num_trained_frames]),
                    )
                    actor_critic.save_model(save_path=args.save_dir)
                except Exception as e:
                    print("Save checkpoint failed")

            # print info
            if j % args.log_interval == 0:
                end = time.time()
                total_num_steps = (j + 1) * args.num_processes * args.num_steps
                print(
                    "[{}/{}], FPS {}, final_reward_raw {:.2f}, remaining {} hours"
                    .format(
                        num_trained_frames, args.num_frames,
                        int(num_trained_frames / (end - start)),
                        final_reward_raw, (end - start) / num_trained_frames *
                        (args.num_frames - num_trained_frames) / 60.0 / 60.0))

            # visualize results
            if args.vis and j % args.vis_interval == 0:
                '''we use tensorboard since its better when comparing plots'''
                summary = tf.Summary()
                summary.value.add(
                    tag='final_reward_raw',
                    simple_value=final_reward_raw,
                )
                summary.value.add(
                    tag='value_loss',
                    simple_value=value_loss,
                )
                summary.value.add(
                    tag='action_loss',
                    simple_value=action_loss,
                )
                summary.value.add(
                    tag='dist_entropy',
                    simple_value=dist_entropy,
                )
                summary_writer.add_summary(summary, num_trained_frames)
                summary_writer.flush()

    elif args.policy_type == 'hierarchical_policy':
        num_subpolicy = args.num_subpolicy
        update_interval = args.hierarchy_interval

        while len(num_subpolicy) < args.num_hierarchy - 1:
            num_subpolicy.append(num_subpolicy[-1])
        while len(update_interval) < args.num_hierarchy - 1:
            update_interval.append(update_interval[-1])

        if args.num_hierarchy == 1:
            update_interval = [1]
            num_subpolicy = [envs.action_space.n]
            # print(envs.action_space.n)
            # print(stop)

        actor_critic = {}
        rollouts = {}
        actor_critic['top'] = EHRL_Policy(obs_shape,
                                          space.Discrete(num_subpolicy[-1]),
                                          np.zeros(1), 128,
                                          args.recurrent_policy, 'top')
        rollouts['top'] = EHRL_RolloutStorage(
            int(args.num_steps / update_interval[-1]),
            args.num_processes, obs_shape, space.Discrete(num_subpolicy[-1]),
            np.zeros(1), actor_critic['top'].state_size)

        for hie_id in range(args.num_hierarchy - 1):
            if hie_id > 0:
                actor_critic[str(hie_id)] = EHRL_Policy(
                    obs_shape, space.Discrete(num_subpolicy[hie_id - 1]),
                    np.zeros(num_subpolicy[hie_id]), 128,
                    args.recurrent_policy, str(hie_id))
                rollouts[str(hie_id)] = EHRL_RolloutStorage(
                    int(args.num_steps / update_interval[hie_id - 1]),
                    args.num_processes, obs_shape,
                    space.Discrete(num_subpolicy[hie_id - 1]),
                    np.zeros(num_subpolicy[hie_id]),
                    actor_critic[str(hie_id)].state_size)
            else:
                actor_critic[str(hie_id)] = EHRL_Policy(
                    obs_shape, envs.action_space,
                    np.zeros(num_subpolicy[hie_id]), 128,
                    args.recurrent_policy, str(hie_id))
                rollouts[str(hie_id)] = EHRL_RolloutStorage(
                    args.num_steps, args.num_processes, obs_shape,
                    envs.action_space, np.zeros(num_subpolicy[hie_id]),
                    actor_critic[str(hie_id)].state_size)

        if envs.action_space.__class__.__name__ == "Discrete":
            action_shape = 1
        else:
            action_shape = envs.action_space.shape[0]

        if args.cuda:
            for key in actor_critic:
                actor_critic[key].cuda()

        agent = {}
        for ac_key in actor_critic:
            if args.algo == 'a2c':
                agent[ac_key] = algo.A2C_ACKTR(
                    actor_critic[ac_key],
                    args.value_loss_coef,
                    args.entropy_coef,
                    lr=args.lr,
                    eps=args.eps,
                    alpha=args.alpha,
                    max_grad_norm=args.max_grad_norm,
                )
            elif args.algo == 'ppo':
                agent[ac_key] = algo.PPO(
                    actor_critic[ac_key],
                    args.clip_param,
                    args.ppo_epoch,
                    args.num_mini_batch,
                    args.value_loss_coef,
                    args.entropy_coef,
                    lr=args.lr,
                    eps=args.eps,
                    max_grad_norm=args.max_grad_norm,
                )
            elif args.algo == 'acktr':
                agent[ac_key] = algo.A2C_ACKTR(
                    actor_critic[ac_key],
                    args.value_loss_coef,
                    args.entropy_coef,
                    acktr=True,
                )

        current_obs = torch.zeros(args.num_processes, *obs_shape)

        obs = envs.reset()
        update_current_obs(obs)

        for obs_key in rollouts:
            rollouts[obs_key].observations[0].copy_(current_obs)

        episode_reward_raw = 0.0
        final_reward_raw = 0.0

        if args.cuda:
            current_obs = current_obs.cuda()
            for rol_key in rollouts:
                rollouts[rol_key].cuda()

        # try to load checkpoint
        try:
            num_trained_frames = np.load(args.save_dir +
                                         '/num_trained_frames.npy')[0]
            try:
                for save_key in actor_critic:
                    actor_critic[save_key].load_state_dict(
                        torch.load(args.save_dir + '/trained_learner_' +
                                   save_key + '.pth'))
                print('Load learner previous point: Successed')
            except Exception as e:
                print('Load learner previous point: Failed')
        except Exception as e:
            num_trained_frames = 0
        print('Learner has been trained to step: ' + str(num_trained_frames))

        start = time.time()
        j = 0
        onehot_mem = {}
        reward_mem = {}
        if args.num_hierarchy > 1:
            update_flag = np.zeros(args.num_hierarchy - 1, dtype=np.uint8)
        else:
            update_flag = np.zeros(1, dtype=np.uint8)
        step_count = 0

        value = {}
        next_value = {}
        action = {}
        action_log_prob = {}
        states = {}
        while True:
            if num_trained_frames > args.num_frames:
                break
            step_count = 0

            for step in range(args.num_steps):
                if step_count % update_interval[-1] == 0:
                    with torch.no_grad():
                        value['top'], action['top'], action_log_prob[
                            'top'], states['top'] = actor_critic['top'].act(
                                rollouts['top'].observations[update_flag[-1]],
                                rollouts['top'].one_hot[update_flag[-1]],
                                rollouts['top'].states[update_flag[-1]],
                                rollouts['top'].masks[update_flag[-1]],
                            )
                    update_flag[-1] += 1
                    onehot_mem[str(args.num_hierarchy - 1)] = get_onehot(
                        num_subpolicy[-1], action['top'])
                    onehot_mem[str(args.num_hierarchy)] = get_onehot(1, 0)
                if len(update_interval) > 1:
                    for interval_id in range(len(update_interval) - 1):
                        if step_count % update_interval[interval_id] == 0:
                            with torch.no_grad():
                                value[str(interval_id+1)], action[str(interval_id+1)], action_log_prob[str(interval_id+1)], states[str(interval_id+1)] = \
                                actor_critic[str(interval_id+1)].act(
                                    rollouts[str(interval_id+1)].observations[update_flag[interval_id]],
                                    rollouts[str(interval_id+1)].one_hot[update_flag[-1]],
                                    rollouts[str(interval_id+1)].states[update_flag[interval_id]],
                                    rollouts[str(interval_id+1)].masks[update_flag[interval_id]],
                                )
                            update_flag[interval_id] += 1
                            onehot_mem[str(interval_id + 1)] = get_onehot(
                                num_subpolicy[interval_id],
                                action[str(interval_id + 1)])
                # Sample actions
                if args.num_hierarchy > 1:
                    with torch.no_grad():
                        value['0'], action['0'], action_log_prob['0'], states[
                            '0'] = actor_critic['0'].act(
                                rollouts['0'].observations[step],
                                rollouts['0'].one_hot[step],
                                rollouts['0'].states[step],
                                rollouts['0'].masks[step],
                            )
                    cpu_actions = action['0'].squeeze(1).cpu().numpy()
                else:
                    cpu_actions = action['top'].squeeze(1).cpu().numpy()

                # Obser reward and next obs
                obs, reward_raw, done, info = envs.step(cpu_actions)

                for reward_id in range(args.num_hierarchy - 1):
                    try:
                        reward_mem[str(reward_id)] += [reward_raw[0]]
                    except Exception as e:
                        reward_mem[str(reward_id)] = reward_raw[0]

                episode_reward_raw += reward_raw[0]

                if done[0]:
                    final_reward_raw = episode_reward_raw
                    episode_reward_raw = 0.0

                reward = np.sign(reward_raw)
                reward = torch.from_numpy(np.expand_dims(np.stack(reward),
                                                         1)).float()

                # If done then clean the history of observations.
                masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                           for done_ in done])

                if args.cuda:
                    masks = masks.cuda()

                if current_obs.dim() == 4:
                    current_obs *= masks.unsqueeze(2).unsqueeze(2)
                else:
                    current_obs *= masks

                update_current_obs(obs)
                if args.num_hierarchy > 1:
                    rollouts['0'].insert(current_obs, states['0'], action['0'],
                                         onehot_mem['1'], action_log_prob['0'],
                                         value['0'], reward, masks)
                if step_count % update_interval[-1] == 0:
                    if args.num_hierarchy > 1:
                        reward_mean = np.mean(
                            np.array(reward_mem[str(args.num_hierarchy - 2)]))
                        reward_mean = torch.from_numpy(
                            np.ones(1) * reward_mean).float()
                        rollouts['top'].insert(
                            current_obs, states['top'], action['top'],
                            onehot_mem[str(args.num_hierarchy)],
                            action_log_prob['top'], value['top'], reward_mean,
                            masks)
                        reward_mem[str(args.num_hierarchy - 2)] = []
                    else:
                        rollouts['top'].insert(
                            current_obs, states['top'], action['top'],
                            onehot_mem[str(args.num_hierarchy)],
                            action_log_prob['top'], value['top'], reward,
                            masks)
                if len(update_interval) > 1:
                    for interval_id in range(len(update_interval) - 1):
                        if step_count % update_interval[
                                interval_id] == 0 or done[0]:
                            reward_mean = np.mean(
                                np.array(reward_mem[str(interval_id)]))
                            reward_mean = torch.from_numpy(
                                np.ones(1) * reward_mean).float()
                            rollouts[str(interval_id + 1)].insert(
                                current_obs, states[str(interval_id + 1)],
                                action[str(interval_id + 1)],
                                onehot_mem[str(interval_id + 2)],
                                action_log_prob[str(interval_id + 1)],
                                value[str(interval_id + 1)], reward_mean,
                                masks)
                            reward_mem[str(interval_id)] = []
                step_count += 1

            if args.num_hierarchy > 1:
                with torch.no_grad():
                    next_value['0'] = actor_critic['0'].get_value(
                        rollouts['0'].observations[-1],
                        rollouts['0'].one_hot[-1],
                        rollouts['0'].states[-1],
                        rollouts['0'].masks[-1],
                    ).detach()

                rollouts['0'].compute_returns(next_value['0'], args.use_gae,
                                              args.gamma, args.tau)

                value_loss, action_loss, dist_entropy = agent['0'].update(
                    rollouts['0'], add_onehot=True)

                rollouts['0'].after_update()

            with torch.no_grad():
                next_value['top'] = actor_critic['top'].get_value(
                    rollouts['top'].observations[-1],
                    rollouts['top'].one_hot[-1],
                    rollouts['top'].states[-1],
                    rollouts['top'].masks[-1],
                ).detach()

            rollouts['top'].compute_returns(next_value['top'], args.use_gae,
                                            args.gamma, args.tau)
            if args.num_hierarchy > 1:
                _, _, _ = agent['top'].update(rollouts['top'], add_onehot=True)
            else:
                value_loss, action_loss, dist_entropy = agent['top'].update(
                    rollouts['top'], add_onehot=True)
            rollouts['top'].after_update()
            update_flag[-1] = 0

            if len(update_interval) > 1:
                for interval_id in range(len(update_interval) - 1):
                    with torch.no_grad():
                        next_value[str(interval_id + 1)] = actor_critic[str(
                            interval_id + 1)].get_value(
                                rollouts[str(interval_id +
                                             1)].observations[-1],
                                rollouts[str(interval_id + 1)].one_hot[-1],
                                rollouts[str(interval_id + 1)].states[-1],
                                rollouts[str(interval_id + 1)].masks[-1],
                            ).detach()

                    rollouts[str(interval_id + 1)].compute_returns(
                        next_value[str(interval_id + 1)], args.use_gae,
                        args.gamma, args.tau)
                    _, _, _ = agent[str(interval_id + 1)].update(
                        rollouts[str(interval_id + 1)], add_onehot=True)
                    rollouts[str(interval_id + 1)].after_update()
                    update_flag[interval_id] = 0

            num_trained_frames += (args.num_steps * args.num_processes)
            j += 1

            # save checkpoint
            if j % args.save_interval == 0 and args.save_dir != "":
                try:
                    np.save(
                        args.save_dir + '/num_trained_frames.npy',
                        np.array([num_trained_frames]),
                    )
                    for key_store in actor_critic:
                        actor_critic[key].save_model(save_path=args.save_dir)
                except Exception as e:
                    print("Save checkpoint failed")

            # print info
            if j % args.log_interval == 0:
                end = time.time()
                total_num_steps = (j + 1) * args.num_processes * args.num_steps
                print(
                    "[{}/{}], FPS {}, final_reward_raw {:.2f}, remaining {} hours"
                    .format(
                        num_trained_frames, args.num_frames,
                        int(num_trained_frames / (end - start)),
                        final_reward_raw, (end - start) / num_trained_frames *
                        (args.num_frames - num_trained_frames) / 60.0 / 60.0))

            # visualize results
            if args.vis and j % args.vis_interval == 0:
                '''we use tensorboard since its better when comparing plots'''
                summary = tf.Summary()
                summary.value.add(
                    tag='final_reward_raw',
                    simple_value=final_reward_raw,
                )
                summary.value.add(
                    tag='value_loss',
                    simple_value=value_loss,
                )
                summary.value.add(
                    tag='action_loss',
                    simple_value=action_loss,
                )
                summary.value.add(
                    tag='dist_entropy',
                    simple_value=dist_entropy,
                )
                summary_writer.add_summary(summary, num_trained_frames)
                summary_writer.flush()
Exemple #28
0
def main():

    print('Preparing parameters')

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    # print('Initializing visdom')
    # if args.vis:
    #     from visdom import Visdom
    #     viz = Visdom(port=args.port)
    #     win = None

    print('Creating envs')
    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, args.add_timestep, device,
                         False)

    print('Creating network')
    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    print('Initializing PPO')
    agent = algo.PPO(actor_critic,
                     args.clip_param,
                     args.ppo_epoch,
                     args.num_mini_batch,
                     args.value_loss_coef,
                     args.entropy_coef,
                     lr=args.lr,
                     eps=args.eps,
                     max_grad_norm=args.max_grad_norm)
    print('Memory')
    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    # ===================== TB visualisation =================

    writer = SummaryWriter()
    last_index = 0

    print('Starting ! ')

    start = time.time()
    for j in range(num_updates):
        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob = actor_critic.act(
                    rollouts.obs[step], rollouts.masks[step])

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            rollouts.insert(obs, action, action_log_prob, value, reward, masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.obs[-1],
                                                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.tau)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)
        writer.add_scalar('Agents metrics/Policy loss', action_loss, j)
        writer.add_scalar('Agents metrics/Value loss', value_loss, j)
        writer.add_scalar('Agents metrics/Entropy loss', dist_entropy, j)

        rollouts.after_update()

        if j % args.save_interval == 0 and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            # A really ugly way to save a model to CPU
            save_model = actor_critic
            if args.cuda:
                save_model = copy.deepcopy(actor_critic).cpu()

            save_model = [
                save_model,
                hasattr(envs.venv, 'ob_rms') and envs.venv.ob_rms or None
            ]

            torch.save(save_model,
                       os.path.join(save_path, args.env_name + ".pt"))

        total_num_steps = (j + 1) * args.num_processes * args.num_steps

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))

        if j % args.vis_interval == 0:
            try:

                # Sometimes monitor doesn't properly flush the outputs
                # win, tx, ty = visdom_plot(viz, win, args.log_dir, args.env_name, args.algo, args.num_frames)
                tx, ty = get_reward_log(args.log_dir)
                if tx != None and ty != None:
                    max_index = len(tx)
                    for ind_iter in range(last_index, max_index):
                        writer.add_scalar('Reward', ty[ind_iter], tx[ind_iter])
                    last_index = max_index

                # tx, ty = get_reward_log(viz, win, args.log_dir, args.env_name,
                #                   args.algo, args.num_frames)

                # if tx != None and ty != None:
                #     plt.cla()
                #     plt.plot(tx,ty)
                #     plt.pause(0.1)

                #     plt.show()

                # if(ty != None and tx != None):

                #     input(ty)
                #     writer.add_scalar('Reward', ty[-1], tx[-1])
                # if(tx != None and ty != None):
                #     plt.cla()
                #     plt.plot(tx, ty)
                #     plt.pause(0.1)
            except IOError:
                pass
def main():
    print("#######")
    print(
        "WARNING: All rewards are clipped or normalized so you need to use a monitor (see envs.py) or visdom plot to get true rewards"
    )
    print("#######")

    torch.set_num_threads(1)

    if args.vis:
        from visdom import Visdom
        viz = Visdom(port=args.port)
        win = None

    envs = [
        make_env(args.env_name, args.seed, i, args.log_dir, args.add_timestep)
        for i in range(args.num_processes)
    ]

    if args.num_processes > 1:
        envs = SubprocVecEnv(envs)
    else:
        envs = DummyVecEnv(envs)

    if len(envs.observation_space.shape) == 1:
        envs = VecNormalize(envs)

    obs_shape = envs.observation_space.shape
    obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])

    actor_critic = Policy(obs_shape, envs.action_space, args.recurrent_policy)

    if envs.action_space.__class__.__name__ == "Discrete":
        action_shape = 1
    else:
        action_shape = envs.action_space.shape[0]

    if args.cuda:
        actor_critic.cuda()

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    rollouts = RolloutStorage(args.num_steps, args.num_processes, obs_shape,
                              envs.action_space, actor_critic.state_size)
    current_obs = torch.zeros(args.num_processes, *obs_shape)

    def update_current_obs(obs):
        shape_dim0 = envs.observation_space.shape[0]
        obs = torch.from_numpy(obs).float()
        if args.num_stack > 1:
            current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
        current_obs[:, -shape_dim0:] = obs

    obs = envs.reset()
    update_current_obs(obs)

    rollouts.observations[0].copy_(current_obs)

    # These variables are used to compute average rewards for all processes.
    episode_rewards = torch.zeros([args.num_processes, 1])
    final_rewards = torch.zeros([args.num_processes, 1])

    if args.cuda:
        current_obs = current_obs.cuda()
        rollouts.cuda()

    start = time.time()

    lmdb_idx = 0
    try:
        os.makedirs(os.path.join(args.lmdb_path, args.env_name))
        os.makedirs(os.path.join(args.lmdb_path, args.env_name, 'test'))
    except:
        print('Directory already exists.')

    for j in range(num_updates):
        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, states = actor_critic.act(
                    rollouts.observations[step], rollouts.states[step],
                    rollouts.masks[step])
            cpu_actions = action.squeeze(1).cpu().numpy()

            # Observe reward and next obs
            # obs, reward, done, info = envs.step(cpu_actions)
            '''unwrapped obs, reward'''
            obs, reward, done, info, wr_obs, wr_reward = envs.step(cpu_actions)
            # sample images
            # img = np.squeeze(np.transpose(obs[3], (1, 2, 0)), 2)
            for img, rwd in zip(wr_obs, wr_reward):
                if rwd > 0:
                    lmdb_idx += 1
                    convert_to_lmdb(
                        img, rwd, os.path.join(args.lmdb_path, args.env_name),
                        lmdb_idx)

            # Evaluate unwrapped rewards
            # model = Model()
            # model.load(args.digit_checkpoint)
            # model.cuda()
            # accuracy = digit_eval(image, length_labels, digits_labels, model)
            # img.show()

            reward = torch.from_numpy(np.expand_dims(np.stack(reward),
                                                     1)).float()
            episode_rewards += reward

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            final_rewards *= masks
            final_rewards += (1 - masks) * episode_rewards
            episode_rewards *= masks

            if args.cuda:
                masks = masks.cuda()

            if current_obs.dim() == 4:
                current_obs *= masks.unsqueeze(2).unsqueeze(2)
            else:
                current_obs *= masks

            update_current_obs(obs)
            rollouts.insert(current_obs, states, action, action_log_prob,
                            value, reward, masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.observations[-1],
                                                rollouts.states[-1],
                                                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.tau)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        if j % args.save_interval == 0 and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            # A really ugly way to save a model to CPU
            save_model = actor_critic
            if args.cuda:
                save_model = copy.deepcopy(actor_critic).cpu()

            save_model = [
                save_model,
                hasattr(envs, 'ob_rms') and envs.ob_rms or None
            ]

            torch.save(save_model,
                       os.path.join(save_path, args.env_name + ".pt"))

        if j % args.log_interval == 0:
            end = time.time()
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            print(
                "Updates {}, num timesteps {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        final_rewards.mean(), final_rewards.median(),
                        final_rewards.min(), final_rewards.max(), dist_entropy,
                        value_loss, action_loss))
        if args.vis and j % args.vis_interval == 0:
            try:
                # Sometimes monitor doesn't properly flush the outputs
                win = visdom_plot(viz, win, args.log_dir, args.env_name,
                                  args.algo, args.num_frames)
            except IOError:
                pass
Exemple #30
0
def main():
    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    experiment_name = args.env_name + '-' + args.algo + '-' + datetime.datetime.now(
    ).strftime("%Y-%m-%d-%H-%M-%S-%f")
    log_dir, eval_log_dir, save_dir = setup_dirs(experiment_name, args.log_dir,
                                                 args.save_dir)

    if args.vis:
        from visdom import Visdom
        viz = Visdom(port=args.port)
        win = None

    envs = make_vec_envs(args.env_name,
                         args.seed,
                         args.num_processes,
                         args.gamma,
                         log_dir,
                         args.add_timestep,
                         device,
                         False,
                         frame_skip=args.frame_skip)

    if args.load_path:
        actor_critic, _ob_rms = torch.load(args.load_path)
        vec_norm = get_vec_normalize(envs)
        if vec_norm is not None:
            vec_norm.train()
            vec_norm.ob_rms = _ob_rms
        actor_critic.train()
    else:
        actor_critic = Policy(envs.observation_space.shape,
                              envs.action_space,
                              beta=args.beta_dist,
                              base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    if args.algo.startswith('a2c'):
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               lr_schedule=args.lr_schedule,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo.startswith('ppo'):
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         lr_schedule=args.lr_schedule,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    if args.algo.endswith('sil'):
        agent = algo.SIL(agent,
                         update_ratio=args.sil_update_ratio,
                         epochs=args.sil_epochs,
                         batch_size=args.sil_batch_size,
                         beta=args.sil_beta,
                         value_loss_coef=args.sil_value_loss_coef,
                         entropy_coef=args.sil_entropy_coef)
        replay = ReplayStorage(10000,
                               num_processes=args.num_processes,
                               gamma=args.gamma,
                               prio_alpha=args.sil_alpha,
                               obs_shape=envs.observation_space.shape,
                               action_space=envs.action_space,
                               recurrent_hidden_state_size=actor_critic.
                               recurrent_hidden_state_size,
                               device=device)
    else:
        replay = None

    action_high = torch.from_numpy(envs.action_space.high).to(device)
    action_low = torch.from_numpy(envs.action_space.low).to(device)
    action_mid = 0.5 * (action_high + action_low)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)
    benchmark_rewards = deque(maxlen=10)

    start = time.time()
    for j in range(num_updates):
        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                # sample actions
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            if args.clip_action and isinstance(envs.action_space,
                                               gym.spaces.Box):
                clipped_action = action.clone()
                if args.shift_action:
                    # FIXME experimenting with this, so far resulting in
                    # faster learning when clipping guassian continuous
                    # output (vs leaving centred at 0 and unscaled)
                    clipped_action = 0.5 * clipped_action + action_mid
                clipped_action = torch.max(
                    torch.min(clipped_action, action_high), action_low)
            else:
                clipped_action = action

            # act in environment and observe
            obs, reward, done, infos = envs.step(clipped_action)

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])
                    if 'rb' in info['episode']:
                        benchmark_rewards.append(info['episode']['rb'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks)
            if replay is not None:
                replay.insert(rollouts.obs[step],
                              rollouts.recurrent_hidden_states[step], action,
                              reward, done)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.tau)

        value_loss, action_loss, dist_entropy = agent.update(
            rollouts, j, replay)

        rollouts.after_update()

        total_num_steps = (j + 1) * args.num_processes * args.num_steps

        train_eprew = np.mean(episode_rewards)
        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} episodes: mean/med {:.1f}/{:.1f}, min/max reward {:.2f}/{:.2f}"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), train_eprew,
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss),
                end='')
            if len(benchmark_rewards):
                print(", benchmark {:.1f}/{:.1f}, {:.1f}/{:.1f}".format(
                    np.mean(benchmark_rewards), np.median(benchmark_rewards),
                    np.min(benchmark_rewards), np.max(benchmark_rewards)),
                      end='')
            print()

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            eval_envs = make_vec_envs(args.env_name,
                                      args.seed + args.num_processes,
                                      args.num_processes, args.gamma,
                                      eval_log_dir, args.add_timestep, device,
                                      True)

            vec_norm = get_vec_normalize(eval_envs)
            if vec_norm is not None:
                vec_norm.eval()
                vec_norm.ob_rms = get_vec_normalize(envs).ob_rms

            eval_episode_rewards = []

            obs = eval_envs.reset()
            eval_recurrent_hidden_states = torch.zeros(
                args.num_processes,
                actor_critic.recurrent_hidden_state_size,
                device=device)
            eval_masks = torch.zeros(args.num_processes, 1, device=device)

            while len(eval_episode_rewards) < 10:
                with torch.no_grad():
                    _, action, _, eval_recurrent_hidden_states = actor_critic.act(
                        obs,
                        eval_recurrent_hidden_states,
                        eval_masks,
                        deterministic=True)

                clipped_action = action
                if args.clip_action and isinstance(envs.action_space,
                                                   gym.spaces.Box):
                    if args.shift_action:
                        clipped_action = 0.5 * clipped_action + action_mid
                    clipped_action = torch.max(
                        torch.min(clipped_action, action_high), action_low)

                obs, reward, done, infos = eval_envs.step(clipped_action)

                eval_masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                                for done_ in done])
                for info in infos:
                    if 'episode' in info.keys():
                        eval_episode_rewards.append(info['episode']['r'])

            eval_envs.close()

            eval_eprew = np.mean(eval_episode_rewards)
            print(" Evaluation using {} episodes: mean reward {:.5f}\n".format(
                len(eval_episode_rewards), eval_eprew))

        if len(episode_rewards
               ) and j % args.save_interval == 0 and save_dir != "":
            # A really ugly way to save a model to CPU
            save_model = actor_critic
            if args.cuda:
                save_model = copy.deepcopy(actor_critic).cpu()

            save_model = [
                save_model,
                getattr(get_vec_normalize(envs), 'ob_rms', None)
            ]

            ep_rewstr = ("%d" % train_eprew).replace("-", "n")
            save_filename = os.path.join(
                save_dir, './checkpoint-%d-%s.pt' % (j, ep_rewstr))

            torch.save(save_model, save_filename)

        if args.vis and j % args.vis_interval == 0:
            try:
                # Sometimes monitor doesn't properly flush the outputs
                win = visdom_plot(viz, win, log_dir, args.env_name, args.algo,
                                  args.num_frames)
            except IOError:
                pass