예제 #1
0
def initialise(variant):
    setup_logger("name-of-experiment", variant=variant)
    ptu.set_gpu_mode(True)
    log_dir = os.path.expanduser(variant["log_dir"])
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)
예제 #2
0
def experiment(variant):
    setup_logger("name-of-experiment", variant=variant)
    ptu.set_gpu_mode(True)
    log_dir = os.path.expanduser(variant["log_dir"])
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    # missing - set torch seed and num threads=1

    # expl_env = gym.make(variant["env_name"])
    expl_envs = make_vec_envs(
        variant["env_name"],
        variant["seed"],
        variant["num_processes"],
        variant["gamma"],
        variant["log_dir"],  # probably change this?
        ptu.device,
        False,
        pytorch=False,
    )
    # eval_env = gym.make(variant["env_name"])
    eval_envs = make_vec_envs(
        variant["env_name"],
        variant["seed"],
        variant["num_processes"],
        1,
        variant["log_dir"],
        ptu.device,
        False,
        pytorch=False,
    )
    obs_shape = expl_envs.observation_space.image.shape
    # if len(obs_shape) == 3 and obs_shape[2] in [1, 3]:  # convert WxHxC into CxWxH
    #     expl_env = TransposeImage(expl_env, op=[2, 0, 1])
    #     eval_env = TransposeImage(eval_env, op=[2, 0, 1])
    # obs_shape = expl_env.observation_space.shape

    channels, obs_width, obs_height = obs_shape
    action_space = expl_envs.action_space

    base_kwargs = {"num_inputs": channels, "recurrent": variant["recurrent_policy"]}

    base = CNNBase(**base_kwargs)

    dist = create_output_distribution(action_space, base.output_size)

    eval_policy = LearnPlanPolicy(
        WrappedPolicy(
            obs_shape,
            action_space,
            ptu.device,
            base=base,
            deterministic=True,
            dist=dist,
            num_processes=variant["num_processes"],
        ),
        num_processes=variant["num_processes"],
        vectorised=True,
    )
    expl_policy = LearnPlanPolicy(
        WrappedPolicy(
            obs_shape,
            action_space,
            ptu.device,
            base=base,
            deterministic=False,
            dist=dist,
            num_processes=variant["num_processes"],
        ),
        num_processes=variant["num_processes"],
        vectorised=True,
    )

    # missing: at this stage, policy hasn't been sent to device, but happens later
    eval_path_collector = HierarchicalStepCollector(
        eval_envs,
        eval_policy,
        ptu.device,
        max_num_epoch_paths_saved=variant["algorithm_kwargs"][
            "num_eval_steps_per_epoch"
        ],
        num_processes=variant["num_processes"],
        render=variant["render"],
    )
    expl_path_collector = HierarchicalStepCollector(
        expl_envs,
        expl_policy,
        ptu.device,
        max_num_epoch_paths_saved=variant["num_steps"],
        num_processes=variant["num_processes"],
        render=variant["render"],
    )
    # added: created rollout(5,1,(4,84,84),Discrete(6),1), reset env and added obs to rollout[step]

    trainer = A2CTrainer(actor_critic=expl_policy.learner, **variant["trainer_kwargs"])
    # missing: by this point, rollout back in sync.
    replay_buffer = EnvReplayBuffer(variant["replay_buffer_size"], expl_envs)
    # added: replay buffer is new
    algorithm = TorchIkostrikovRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_envs,
        evaluation_env=eval_envs,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant["algorithm_kwargs"],
        # batch_size,
        # max_path_length,
        # num_epochs,
        # num_eval_steps_per_epoch,
        # num_expl_steps_per_train_loop,
        # num_trains_per_train_loop,
        # num_train_loops_per_epoch=1,
        # min_num_steps_before_training=0,
    )

    algorithm.to(ptu.device)
    # missing: device back in sync
    algorithm.train()
예제 #3
0
def train_ppo_from_scratch(args):

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(2)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, device, True)

    actor_critic = Policy(  # 2-layer fully connected network
        envs.observation_space.shape,
        envs.action_space,
        base_kwargs={
            'recurrent': False,
            'hidden_size': 32
        })
    actor_critic.to(device)

    agent = algo.PPO(actor_critic,
                     args.clip_param,
                     args.ppo_epoch,
                     args.num_mini_batch,
                     args.value_loss_coef,
                     args.entropy_coef,
                     lr=args.lr,
                     eps=args.eps,
                     max_grad_norm=args.max_grad_norm)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes

    episode_reward_means = []
    episode_reward_times = []

    for j in range(num_updates):

        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(agent.optimizer, j, num_updates,
                                         args.lr)

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], os.path.join(save_path, args.env_name + ".pt"))

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))

            episode_reward_means.append(np.mean(episode_rewards))
            episode_reward_times.append(total_num_steps)

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(actor_critic, ob_rms, args.env_name, args.seed,
                     args.num_processes, eval_log_dir, device)

    print(episode_reward_means, episode_reward_times)
    return episode_reward_means, episode_reward_times
예제 #4
0
def main():
    args = get_args()

    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    if config.cuda and torch.cuda.is_available() and config.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    logger, final_output_dir, tb_log_dir = create_logger(config,
                                                         args.cfg,
                                                         'train',
                                                         seed=config.seed)

    eval_log_dir = final_output_dir + "_eval"

    utils.cleanup_log_dir(final_output_dir)
    utils.cleanup_log_dir(eval_log_dir)

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    writer = SummaryWriter(tb_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:" + config.GPUS if config.cuda else "cpu")

    width = height = 84
    envs = make_vec_envs(config.env_name,
                         config.seed,
                         config.num_processes,
                         config.gamma,
                         final_output_dir,
                         device,
                         False,
                         width=width,
                         height=height,
                         ram_wrapper=False)
    # create agent
    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={
                              'recurrent':
                              config.recurrent_policy,
                              'hidden_size':
                              config.hidden_size,
                              'feat_from_selfsup_attention':
                              config.feat_from_selfsup_attention,
                              'feat_add_selfsup_attention':
                              config.feat_add_selfsup_attention,
                              'feat_mul_selfsup_attention_mask':
                              config.feat_mul_selfsup_attention_mask,
                              'selfsup_attention_num_keypoints':
                              config.SELFSUP_ATTENTION.NUM_KEYPOINTS,
                              'selfsup_attention_gauss_std':
                              config.SELFSUP_ATTENTION.GAUSS_STD,
                              'selfsup_attention_fix':
                              config.selfsup_attention_fix,
                              'selfsup_attention_fix_keypointer':
                              config.selfsup_attention_fix_keypointer,
                              'selfsup_attention_pretrain':
                              config.selfsup_attention_pretrain,
                              'selfsup_attention_keyp_maps_pool':
                              config.selfsup_attention_keyp_maps_pool,
                              'selfsup_attention_image_feat_only':
                              config.selfsup_attention_image_feat_only,
                              'selfsup_attention_feat_masked':
                              config.selfsup_attention_feat_masked,
                              'selfsup_attention_feat_masked_residual':
                              config.selfsup_attention_feat_masked_residual,
                              'selfsup_attention_feat_load_pretrained':
                              config.selfsup_attention_feat_load_pretrained,
                              'use_layer_norm':
                              config.use_layer_norm,
                              'selfsup_attention_keyp_cls_agnostic':
                              config.SELFSUP_ATTENTION.KEYPOINTER_CLS_AGNOSTIC,
                              'selfsup_attention_feat_use_ln':
                              config.SELFSUP_ATTENTION.USE_LAYER_NORM,
                              'selfsup_attention_use_instance_norm':
                              config.SELFSUP_ATTENTION.USE_INSTANCE_NORM,
                              'feat_mul_selfsup_attention_mask_residual':
                              config.feat_mul_selfsup_attention_mask_residual,
                              'bottom_up_form_objects':
                              config.bottom_up_form_objects,
                              'bottom_up_form_num_of_objects':
                              config.bottom_up_form_num_of_objects,
                              'gaussian_std':
                              config.gaussian_std,
                              'train_selfsup_attention':
                              config.train_selfsup_attention,
                              'block_selfsup_attention_grad':
                              config.block_selfsup_attention_grad,
                              'sep_bg_fg_feat':
                              config.sep_bg_fg_feat,
                              'mask_threshold':
                              config.mask_threshold,
                              'fix_feature':
                              config.fix_feature
                          })

    # init / load parameter
    if config.MODEL_FILE:
        logger.info('=> loading model from {}'.format(config.MODEL_FILE))
        state_dict = torch.load(config.MODEL_FILE)

        state_dict = OrderedDict(
            (_k, _v) for _k, _v in state_dict.items() if 'dist' not in _k)

        actor_critic.load_state_dict(state_dict, strict=False)
    elif config.RESUME:
        checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')
        if os.path.exists(checkpoint_file):
            logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
            checkpoint = torch.load(checkpoint_file)
            actor_critic.load_state_dict(checkpoint['state_dict'])

            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                checkpoint_file, checkpoint['epoch']))

    actor_critic.to(device)

    if config.algo == 'a2c':
        agent = algo.A2C_ACKTR(
            actor_critic,
            config.value_loss_coef,
            config.entropy_coef,
            lr=config.lr,
            eps=config.eps,
            alpha=config.alpha,
            max_grad_norm=config.max_grad_norm,
            train_selfsup_attention=config.train_selfsup_attention)
    elif config.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         config.clip_param,
                         config.ppo_epoch,
                         config.num_mini_batch,
                         config.value_loss_coef,
                         config.entropy_coef,
                         lr=config.lr,
                         eps=config.eps,
                         max_grad_norm=config.max_grad_norm)
    elif config.algo == 'acktr':
        agent = algo.A2C_ACKTR(
            actor_critic,
            config.value_loss_coef,
            config.entropy_coef,
            acktr=True,
            train_selfsup_attention=config.train_selfsup_attention,
            max_grad_norm=config.max_grad_norm)

    # rollouts: environment
    rollouts = RolloutStorage(
        config.num_steps,
        config.num_processes,
        envs.observation_space.shape,
        envs.action_space,
        actor_critic.recurrent_hidden_state_size,
        keep_buffer=config.train_selfsup_attention,
        buffer_size=config.train_selfsup_attention_buffer_size)

    if config.RESUME:
        if os.path.exists(checkpoint_file):
            agent.optimizer.load_state_dict(checkpoint['optimizer'])
    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        config.num_env_steps) // config.num_steps // config.num_processes
    best_perf = 0.0
    best_model = False
    print('num updates', num_updates, 'num steps', config.num_steps)

    for j in range(num_updates):

        if config.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if config.algo == "acktr" else config.lr)

        for step in range(config.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            recurrent_hidden_states, meta = recurrent_hidden_states

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)

            objects_locs = []
            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            if objects_locs:
                objects_locs = torch.FloatTensor(objects_locs)
                objects_locs = objects_locs * 2 - 1  # -1, 1
            else:
                objects_locs = None
            rollouts.insert(obs,
                            recurrent_hidden_states,
                            action,
                            action_log_prob,
                            value,
                            reward,
                            masks,
                            bad_masks,
                            objects_loc=objects_locs)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1],
                rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1],
            ).detach()

        rollouts.compute_returns(next_value, config.use_gae, config.gamma,
                                 config.gae_lambda,
                                 config.use_proper_time_limits)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        if config.train_selfsup_attention and j > 15:
            for _iter in range(config.num_steps // 5):
                frame_x, frame_y = rollouts.generate_pair_image()
                selfsup_attention_loss, selfsup_attention_output, image_b_keypoints_maps = \
                    agent.update_selfsup_attention(frame_x, frame_y, config.SELFSUP_ATTENTION)

        if j % config.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * config.num_processes * config.num_steps
            end = time.time()
            msg = 'Updates {}, num timesteps {}, FPS {} \n' \
                  'Last {} training episodes: mean/median reward {:.1f}/{:.1f} ' \
                  'min/max reward {:.1f}/{:.1f} ' \
                  'dist entropy {:.1f}, value loss {:.1f}, action loss {:.1f}\n'. \
                format(j, total_num_steps,
                       int(total_num_steps / (end - start)),
                       len(episode_rewards), np.mean(episode_rewards),
                       np.median(episode_rewards), np.min(episode_rewards),
                       np.max(episode_rewards), dist_entropy, value_loss,
                       action_loss)
            if config.train_selfsup_attention and j > 15:
                msg = msg + 'selfsup attention loss {:.5f}\n'.format(
                    selfsup_attention_loss)
            logger.info(msg)

        if (config.eval_interval is not None and len(episode_rewards) > 1
                and j % config.eval_interval == 0):
            total_num_steps = (j + 1) * config.num_processes * config.num_steps
            ob_rms = getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            eval_mean_score, eval_max_score, eval_scores = evaluate(
                actor_critic,
                ob_rms,
                config.env_name,
                config.seed,
                config.num_processes,
                eval_log_dir,
                device,
                width=width,
                height=height)
            perf_indicator = eval_mean_score
            if perf_indicator > best_perf:
                best_perf = perf_indicator
                best_model = True
            else:
                best_model = False

            # record test scores
            with open(os.path.join(final_output_dir, 'test_scores'),
                      'a+') as f:
                out_s = "TEST: {}, {}, {}, {}\n".format(
                    str(total_num_steps), str(eval_mean_score),
                    str(eval_max_score),
                    [str(_eval_scores) for _eval_scores in eval_scores])
                print(out_s, end="", file=f)
                logger.info(out_s)
            writer.add_scalar('data/mean_score', eval_mean_score,
                              total_num_steps)
            writer.add_scalar('data/max_score', eval_max_score,
                              total_num_steps)

            writer.add_scalars('test', {'mean_score': eval_mean_score},
                               total_num_steps)

            # save for every interval-th episode or for the last epoch
            if (j % config.save_interval == 0
                    or j == num_updates - 1) and config.save_dir != "":

                logger.info(
                    "=> saving checkpoint to {}".format(final_output_dir))
                epoch = j / config.save_interval
                save_checkpoint(
                    {
                        'epoch':
                        epoch + 1,
                        'model':
                        get_model_name(config),
                        'state_dict':
                        actor_critic.state_dict(),
                        'perf':
                        perf_indicator,
                        'optimizer':
                        agent.optimizer.state_dict(),
                        'ob_rms':
                        getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
                    }, best_model, final_output_dir)

    final_model_state_file = os.path.join(final_output_dir, 'final_state.pth')
    logger.info(
        '=> saving final model state to {}'.format(final_model_state_file))
    torch.save(actor_critic.state_dict(), final_model_state_file)

    # export_scalars_to_json needs results from add scalars
    writer.export_scalars_to_json(os.path.join(tb_log_dir, 'all_scalars.json'))
    writer.close()
예제 #5
0
def main():
    all_episode_rewards = []  ### 记录 6/29
    all_temp_rewards = []  ### 记录 6/29
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, device, False)

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    if args.gail:
        assert len(envs.observation_space.shape) == 1
        discr = gail.Discriminator(
            envs.observation_space.shape[0] + envs.action_space.shape[0], 100,
            device)
        file_name = os.path.join(
            args.gail_experts_dir,
            "trajs_{}.pt".format(args.env_name.split('-')[0].lower()))

        gail_train_loader = torch.utils.data.DataLoader(
            gail.ExpertDataset(file_name,
                               num_trajectories=4,
                               subsample_frequency=20),
            batch_size=args.gail_batch_size,
            shuffle=True,
            drop_last=True)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes
    print('num_updates ', num_updates)
    print('num_steps ', args.num_steps)
    count = 0
    h5_path = './data/' + args.env_name
    if not os.path.exists(h5_path):
        os.makedirs(h5_path)
    h5_filename = h5_path + '/trajs_' + args.env_name + '_%05d.h5' % (count)
    data = {}
    data['states'] = []
    data['actions'] = []
    data['rewards'] = []
    data['done'] = []
    data['lengths'] = []

    episode_step = 0

    for j in range(num_updates):  ### num-steps

        temp_states = []
        temp_actions = []
        temp_rewards = []
        temp_done = []
        temp_lenthgs = []

        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr)

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            if j == 0 and step == 0:
                print('obs ', type(rollouts.obs[step]),
                      rollouts.obs[step].shape)
                print('hidden_states ',
                      type(rollouts.recurrent_hidden_states[step]),
                      rollouts.recurrent_hidden_states[step].shape)
                print('action ', type(action), action.shape)
                print('action prob ', type(action_log_prob),
                      action_log_prob.shape)
                print('-' * 20)

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)

            #print(infos)
            #print(reward)
            temp_states += [np.array(rollouts.obs[step].cpu())]
            temp_actions += [np.array(action.cpu())]
            #temp_rewards += [np.array(reward.cpu())]
            temp_rewards += [np.array([infos[0]['myrewards']])
                             ]  ### for halfcheetah不能直接用 reward !! 6/29
            temp_done += [np.array(done)]

            if j == 0 and step == 0:
                print('obs ', type(obs), obs.shape)
                print('reward ', type(reward), reward.shape)
                print('done ', type(done), done.shape)
                print('infos ', len(infos))
                for k, v in infos[0].items():
                    print(k, v.shape)
                print()

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])
                    all_episode_rewards += [info['episode']['r']]  ### 记录 6/29

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        temp_lengths = len(temp_states)
        temp_states = np.concatenate(temp_states)
        temp_actions = np.concatenate(temp_actions)
        temp_rewards = np.concatenate(temp_rewards)
        temp_done = np.concatenate(temp_done)
        #print('temp_lengths',temp_lengths)
        #print('temp_states', temp_states.shape)
        #print('temp_actions', temp_actions.shape)
        #print('temp_rewards', temp_rewards.shape)
        if j > int(0.4 * num_updates):
            data['states'] += [temp_states]
            data['actions'] += [temp_actions]
            data['rewards'] += [temp_rewards]
            data['lengths'] += [temp_lengths]
            data['done'] += [temp_done]
            #print('temp_lengths',data['lengths'].shape)
            #print('temp_states', data['states'].shape)
            #print('temp_actions', data['actions'].shape)
            #print('temp_rewards', data['rewards'].shape)

            if args.save_expert and len(data['states']) >= 100:
                with h5py.File(h5_filename, 'w') as f:
                    f['states'] = np.array(data['states'])
                    f['actions'] = np.array(data['actions'])
                    f['rewards'] = np.array(data['rewards'])
                    f['done'] = np.array(data['done'])
                    f['lengths'] = np.array(data['lengths'])
                    #print('f_lengths',f['lengths'].shape)
                    #print('f_states', f['states'].shape)
                    #print('f_actions', f['actions'].shape)
                    #print('f_rewards', f['rewards'].shape)

                count += 1
                h5_filename = h5_path + '/trajs_' + args.env_name + '_%05d.h5' % (
                    count)
                data['states'] = []
                data['actions'] = []
                data['rewards'] = []
                data['done'] = []
                data['lengths'] = []

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        if args.gail:
            if j >= 10:
                envs.venv.eval()

            gail_epoch = args.gail_epoch
            if j < 10:
                gail_epoch = 100  # Warm up
            for _ in range(gail_epoch):
                discr.update(gail_train_loader, rollouts,
                             utils.get_vec_normalize(envs)._obfilt)

            for step in range(args.num_steps):
                rollouts.rewards[step] = discr.predict_reward(
                    rollouts.obs[step], rollouts.actions[step], args.gamma,
                    rollouts.masks[step])

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], os.path.join(save_path, args.env_name + "_%d.pt" % (args.seed)))

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))
            #np.save(os.path.join(save_path, args.env_name+"_%d"%(args.seed)), all_episode_rewards)  ### 保存记录 6/29
            #print(temp_rewards)
            print("temp rewards size", temp_rewards.shape, "mean",
                  np.mean(temp_rewards), "min", np.min(temp_rewards), "max",
                  np.max(temp_rewards))
            all_temp_rewards += [temp_rewards]
            np.savez(os.path.join(save_path,
                                  args.env_name + "_%d" % (args.seed)),
                     episode=all_episode_rewards,
                     timestep=all_temp_rewards)

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(actor_critic, ob_rms, args.env_name, args.seed,
                     args.num_processes, eval_log_dir, device)
    '''data['states'] = np.array(data['states'])
예제 #6
0
def record_trajectories():
    args = get_args()
    print(args)

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    # Append the model name
    log_dir = os.path.expanduser(args.log_dir)
    log_dir = os.path.join(log_dir, args.model_name, str(args.seed))

    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name,
                         args.seed,
                         1,
                         args.gamma,
                         log_dir,
                         device,
                         True,
                         training=False)

    # Take activation for carracing
    print("Loaded env...")
    activation = None
    if args.env_name == 'CarRacing-v0' and args.use_activation:
        activation = torch.tanh
    print(activation)

    actor_critic = Policy(
        envs.observation_space.shape,
        envs.action_space,
        base_kwargs={
            'recurrent': args.recurrent_policy,
            'env': args.env_name
        },
        activation=activation,
    )
    actor_critic.to(device)

    # Load from previous model
    if args.load_model_name:
        loaddata = torch.load(
            os.path.join(args.save_dir, args.load_model_name,
                         args.load_model_name + '_{}.pt'.format(args.seed)))
        state = loaddata[0]
        try:
            obs_rms, ret_rms = loaddata[1:]
            # Feed it into the env
            envs.obs_rms = None
            envs.ret_rms = None
        except:
            print("Couldnt load obsrms")
            obs_rms = ret_rms = None
        try:
            actor_critic.load_state_dict(state)
        except:
            actor_critic = state
    else:
        raise NotImplementedError

    # Record trajectories
    actions = []
    rewards = []
    observations = []
    episode_starts = []

    for eps in range(args.num_episodes):
        obs = envs.reset()
        # Init variables for storing
        episode_starts.append(True)
        reward = 0
        while True:
            # Take action
            act = actor_critic.act(obs, None, None, None)[1]
            next_state, rew, done, info = envs.step(act)
            #print(obs.shape, act.shape, rew.shape, done)
            reward += rew
            # Add the current observation and act
            observations.append(obs.data.cpu().numpy()[0])  # [C, H, W]
            actions.append(act.data.cpu().numpy()[0])  # [A]
            rewards.append(rew[0, 0].data.cpu().numpy())
            if done[0]:
                break
            episode_starts.append(False)
            obs = next_state + 0
        print("Total reward: {}".format(reward[0, 0].data.cpu().numpy()))

    # Save these values
    save_trajectories_images(observations, actions, rewards, episode_starts)
예제 #7
0
파일: main.py 프로젝트: wisdomdeng/tppo
def main():
    args = get_args()
    use_ppo = args.algo == 'ppo'
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, device, False)

    actor_critic = Policy(
        envs.observation_space.shape,
        envs.action_space,
        base_kwargs={'recurrent': args.recurrent_policy,
                     'share_parameter': args.share_parameter})
    actor_critic.to(device)

    return_distributions = False
    if args.algo == 'ppo':
        agent = algo.PPO(
            actor_critic,
            args.clip_param,
            args.ppo_epoch,
            args.num_mini_batch,
            args.value_loss_coef,
            args.entropy_coef,
            lr=args.lr,
            eps=args.eps,
            max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo_rb':
        agent = algo.PPO_RB(
            actor_critic,
            args.clip_param,
            args.ppo_epoch,
            args.num_mini_batch,
            args.value_loss_coef,
            args.entropy_coef,
            args.rb_alpha,
            lr=args.lr,
            eps=args.eps,
            max_grad_norm=args.max_grad_norm
        )
    elif args.algo == 'tr_ppo':
        agent = algo.TR_PPO(
            actor_critic,
            args.clip_param,
            args.ppo_epoch,
            args.num_mini_batch,
            args.value_loss_coef,
            args.entropy_coef,
            lr=args.lr,
            eps=args.eps,
            max_grad_norm=args.max_grad_norm,
            ppo_clip_param=args.ppo_clip_param
        )
        return_distributions = True
    elif args.algo == 'tr_ppo_rb':
        agent = algo.TR_PPO_RB(
            actor_critic,
            args.clip_param,
            args.ppo_epoch,
            args.num_mini_batch,
            args.value_loss_coef,
            args.entropy_coef,
            args.rb_alpha,
            lr=args.lr,
            eps=args.eps,
            max_grad_norm=args.max_grad_norm,
            ppo_clip_param=args.ppo_clip_param
        )
        return_distributions = True

    if not return_distributions:
        rollouts = RolloutStorage(args.num_steps, args.num_processes,
                                  envs.observation_space.shape, envs.action_space,
                                  actor_critic.recurrent_hidden_state_size)
    else:
        if actor_critic.dist_name == 'DiagGaussian':
            rollouts = RolloutStorage(args.num_steps, args.num_processes,
                                      envs.observation_space.shape, envs.action_space,
                                      actor_critic.recurrent_hidden_state_size,
                                      distribution_param_dim=envs.action_space.shape[0]*2
                                      )
        elif actor_critic.dist_name == 'Bernoulli' or actor_critic.dist_name == 'Categorical':
            rollouts = RolloutStorage(args.num_steps, args.num_processes,
                                      envs.observation_space.shape, envs.action_space,
                                      actor_critic.recurrent_hidden_state_size,
                                      distribution_param_dim=envs.action_space.n
                                      )
    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes
    prev_mean_reward = None
    for j in range(num_updates):

        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr)

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states, parameters = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step], return_distribution=True)


            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor(
                [[0.0] if done_ else [1.0] for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])

            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks, parameters)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        value_loss, action_loss, dist_entropy = agent.update(rollouts, use_ppo)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], os.path.join(save_path, args.env_name + ".pt"))

        mean_rewards = np.mean(episode_rewards)
        if (prev_mean_reward is not None) and (mean_rewards < prev_mean_reward) and \
           (use_ppo == False) and args.revert_to_ppo and j > 3:
            use_ppo = True
            print('Revert Back to PPO Training')
            # args.lr = 3e-4
        prev_mean_reward = mean_rewards
        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(actor_critic, ob_rms, args.env_name, args.seed,
                     args.num_processes, eval_log_dir, device)
예제 #8
0
def main():
    #wandb.run = config.tensorboard.run
    wandb.init(settings=wandb.Settings(start_method="fork"),
               project='growspaceenv_baselines',
               entity='growspace')
    #torch.manual_seed(config.seed)
    #torch.cuda.manual_seed_all(config.seed)

    if config.cuda and torch.cuda.is_available() and config.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(config.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if config.cuda else "cpu")

    envs = make_vec_envs(config.env_name, config.seed, config.num_processes,
                         config.gamma, config.log_dir, device, False,
                         config.custom_gym)

    if "Mnist" in config.env_name:
        base = 'Mnist'
    else:
        base = None

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base,
                          base_kwargs={'recurrent': config.recurrent_policy})
    actor_critic.to(device)

    if config.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               config.value_loss_coef,
                               config.entropy_coef,
                               lr=config.lr,
                               eps=config.eps,
                               alpha=config.alpha,
                               max_grad_norm=config.max_grad_norm)
    elif config.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         config.clip_param,
                         config.ppo_epoch,
                         config.num_mini_batch,
                         config.value_loss_coef,
                         config.entropy_coef,
                         lr=config.lr,
                         eps=config.eps,
                         max_grad_norm=config.max_grad_norm,
                         optimizer=config.optimizer,
                         momentum=config.momentum)
    elif config.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               config.value_loss_coef,
                               config.entropy_coef,
                               acktr=True)

    if config.gail:
        assert len(envs.observation_space.shape) == 1
        discr = gail.Discriminator(
            envs.observation_space.shape[0] + envs.action_space.shape[0], 100,
            device)
        file_name = os.path.join(
            config.gail_experts_dir,
            "trajs_{}.pt".format(config.env_name.split('-')[0].lower()))

        expert_dataset = gail.ExpertDataset(file_name,
                                            num_trajectories=4,
                                            subsample_frequency=20)
        drop_last = len(expert_dataset) > config.gail_batch_size
        gail_train_loader = torch.utils.data.DataLoader(
            dataset=expert_dataset,
            batch_size=config.gail_batch_size,
            shuffle=True,
            drop_last=drop_last)

    rollouts = RolloutStorage(config.num_steps, config.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = []
    episode_length = []
    episode_branches = []
    episode_branch1 = []
    episode_branch2 = []
    episode_light_width = []
    episode_light_move = []
    episode_success = []
    episode_plantpixel = []

    start = time.time()
    num_updates = int(
        config.num_env_steps) // config.num_steps // config.num_processes
    x = 0
    action_space_type = envs.action_space

    for j in range(num_updates):

        if isinstance(action_space_type, Discrete):
            action_dist = np.zeros(envs.action_space.n)

        total_num_steps = (j + 1) * config.num_processes * config.num_steps

        if config.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if config.algo == "acktr" else config.lr)
        #new_branches = []
        for step in range(config.num_steps):
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)

            if isinstance(action_space_type, Discrete):
                action_dist[action] += 1

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])
                    episode_length.append(info['episode']['l'])
                    wandb.log({"Episode_Reward": info['episode']['r']},
                              step=total_num_steps)

                if 'new_branches' in info.keys():
                    episode_branches.append(info['new_branches'])

                if 'new_b1' in info.keys():
                    episode_branch1.append(info['new_b1'])

                if 'new_b2' in info.keys():
                    episode_branch2.append(info['new_b2'])

                if 'light_width' in info.keys():
                    episode_light_width.append(info['light_width'])

                if 'light_move' in info.keys():
                    episode_light_move.append(info['light_move'])

                if 'success' in info.keys():
                    episode_success.append(info['success'])

                if 'plant_pixel' in info.keys():
                    episode_plantpixel.append(info['plant_pixel'])

                if j == x:
                    if 'img' in info.keys():
                        img = info['img']
                        path = './hittiyas/growspaceenv_braselines/scripts/imgs/'
                        cv2.imwrite(
                            os.path.join(path, 'step' + str(step) + '.png'),
                            img)
                    x += 1000

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        if config.gail:
            if j >= 10:
                envs.venv.eval()

            gail_epoch = config.gail_epoch
            if j < 10:
                gail_epoch = 100  # Warm up
            for _ in range(gail_epoch):
                discr.update(gail_train_loader, rollouts,
                             utils.get_vec_normalize(envs)._obfilt)

            for step in range(config.num_steps):
                rollouts.rewards[step] = discr.predict_reward(
                    rollouts.obs[step], rollouts.actions[step], config.gamma,
                    rollouts.masks[step])

        rollouts.compute_returns(next_value, config.use_gae, config.gamma,
                                 config.gae_lambda,
                                 config.use_proper_time_limits)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % config.save_interval == 0
                or j == num_updates - 1) and config.save_dir != "":
            save_path = os.path.join(config.save_dir, config.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], os.path.join(save_path, config.env_name + ".pt"))

        if j % config.log_interval == 0 and len(episode_rewards) > 1:

            if isinstance(action_space_type, Discrete):
                np_hist = np.histogram(np.arange(action_dist.shape[0]),
                                       weights=action_dist)
                wandb.log(
                    {
                        "Discrete Actions":
                        wandb.Histogram(np_histogram=np_hist)
                    },
                    step=total_num_steps)
            wandb.log({"Reward Min": np.min(episode_rewards)},
                      step=total_num_steps)
            wandb.log({"Summed Reward": np.sum(episode_rewards)},
                      step=total_num_steps)
            wandb.log({"Reward Mean": np.mean(episode_rewards)},
                      step=total_num_steps)
            wandb.log({"Reward Max": np.max(episode_rewards)},
                      step=total_num_steps)
            wandb.log(
                {"Number of Mean New Branches": np.mean(episode_branches)},
                step=total_num_steps)
            wandb.log({"Number of Max New Branches": np.max(episode_branches)},
                      step=total_num_steps)
            wandb.log({"Number of Min New Branches": np.min(episode_branches)},
                      step=total_num_steps)
            wandb.log(
                {
                    "Number of Mean New Branches of Plant 1":
                    np.mean(episode_branch1)
                },
                step=total_num_steps)
            wandb.log(
                {
                    "Number of Mean New Branches of Plant 2":
                    np.mean(episode_branch2)
                },
                step=total_num_steps)
            wandb.log(
                {
                    "Number of Total Displacement of Light":
                    np.sum(episode_light_move)
                },
                step=total_num_steps)
            wandb.log({"Mean Light Displacement": episode_light_move},
                      step=total_num_steps)
            wandb.log({"Mean Light Width": episode_light_width},
                      step=total_num_steps)
            wandb.log(
                {
                    "Number of Steps in Episode with Tree is as close as possible":
                    np.sum(episode_success)
                },
                step=total_num_steps)
            wandb.log({"Entropy": dist_entropy}, step=total_num_steps)
            wandb.log(
                {
                    "Displacement of Light Position":
                    wandb.Histogram(episode_light_move)
                },
                step=total_num_steps)
            wandb.log(
                {
                    "Displacement of Beam Width":
                    wandb.Histogram(episode_light_width)
                },
                step=total_num_steps)
            wandb.log({"Mean Plant Pixel": np.mean(episode_plantpixel)},
                      step=total_num_steps)
            wandb.log({"Summed Plant Pixel": np.sum(episode_plantpixel)},
                      step=total_num_steps)
            wandb.log(
                {"Plant Pixel Histogram": wandb.Histogram(episode_plantpixel)},
                step=total_num_steps)

            episode_rewards.clear()
            episode_length.clear()
            episode_branches.clear()
            episode_branch2.clear()
            episode_branch1.clear()
            episode_light_move.clear()
            episode_light_width.clear()
            episode_success.clear()
            episode_plantpixel.clear()

        if (config.eval_interval is not None and len(episode_rewards) > 1
                and j % config.eval_interval == 0):
            ob_rms = getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            evaluate(actor_critic, ob_rms, config.env_name, config.seed,
                     config.num_processes, eval_log_dir, device,
                     config.custom_gym)

    ob_rms = getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
    evaluate(actor_critic,
             ob_rms,
             config.env_name,
             config.seed,
             config.num_processes,
             eval_log_dir,
             device,
             config.custom_gym,
             gif=True)
예제 #9
0
def main():
    args = get_args()
    import random
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    logdir = args.env_name + '_' + args.algo + '_num_arms_' + str(
        args.num_processes) + '_' + time.strftime("%d-%m-%Y_%H-%M-%S")
    if args.use_privacy:
        logdir = logdir + '_privacy'
    elif args.use_noisygrad:
        logdir = logdir + '_noisygrad'
    elif args.use_pcgrad:
        logdir = logdir + '_pcgrad'
    elif args.use_testgrad:
        logdir = logdir + '_testgrad'
    elif args.use_median_grad:
        logdir = logdir + '_mediangrad'
    logdir = os.path.join('runs', logdir)
    logdir = os.path.join(os.path.expanduser(args.log_dir), logdir)
    utils.cleanup_log_dir(logdir)

    # Ugly but simple logging
    log_dict = {
        'task_steps': args.task_steps,
        'grad_noise_ratio': args.grad_noise_ratio,
        'max_task_grad_norm': args.max_task_grad_norm,
        'use_noisygrad': args.use_noisygrad,
        'use_pcgrad': args.use_pcgrad,
        'use_testgrad': args.use_testgrad,
        'use_testgrad_median': args.use_testgrad_median,
        'testgrad_quantile': args.testgrad_quantile,
        'median_grad': args.use_median_grad,
        'use_meanvargrad': args.use_meanvargrad,
        'meanvar_beta': args.meanvar_beta,
        'no_special_grad_for_critic': args.no_special_grad_for_critic,
        'use_privacy': args.use_privacy,
        'seed': args.seed,
        'recurrent': args.recurrent_policy,
        'obs_recurrent': args.obs_recurrent,
        'cmd': ' '.join(sys.argv[1:])
    }
    for eval_disp_name, eval_env_name in EVAL_ENVS.items():
        log_dict[eval_disp_name] = []

    summary_writer = SummaryWriter()
    summary_writer.add_hparams(
        {
            'task_steps': args.task_steps,
            'grad_noise_ratio': args.grad_noise_ratio,
            'max_task_grad_norm': args.max_task_grad_norm,
            'use_noisygrad': args.use_noisygrad,
            'use_pcgrad': args.use_pcgrad,
            'use_testgrad': args.use_testgrad,
            'use_testgrad_median': args.use_testgrad_median,
            'testgrad_quantile': args.testgrad_quantile,
            'median_grad': args.use_median_grad,
            'use_meanvargrad': args.use_meanvargrad,
            'meanvar_beta': args.meanvar_beta,
            'no_special_grad_for_critic': args.no_special_grad_for_critic,
            'use_privacy': args.use_privacy,
            'seed': args.seed,
            'recurrent': args.recurrent_policy,
            'obs_recurrent': args.obs_recurrent,
            'cmd': ' '.join(sys.argv[1:])
        }, {})

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    print('making envs...')
    envs = make_vec_envs(args.env_name,
                         args.seed,
                         args.num_processes,
                         args.gamma,
                         args.log_dir,
                         device,
                         False,
                         steps=args.task_steps,
                         free_exploration=args.free_exploration,
                         recurrent=args.recurrent_policy,
                         obs_recurrent=args.obs_recurrent,
                         multi_task=True)

    val_envs = make_vec_envs(args.val_env_name,
                             args.seed,
                             args.num_processes,
                             args.gamma,
                             args.log_dir,
                             device,
                             False,
                             steps=args.task_steps,
                             free_exploration=args.free_exploration,
                             recurrent=args.recurrent_policy,
                             obs_recurrent=args.obs_recurrent,
                             multi_task=True)

    eval_envs_dic = {}
    for eval_disp_name, eval_env_name in EVAL_ENVS.items():
        eval_envs_dic[eval_disp_name] = make_vec_envs(
            eval_env_name[0],
            args.seed,
            args.num_processes,
            None,
            logdir,
            device,
            True,
            steps=args.task_steps,
            recurrent=args.recurrent_policy,
            obs_recurrent=args.obs_recurrent,
            multi_task=True,
            free_exploration=args.free_exploration)
    prev_eval_r = {}
    print('done')
    if args.hard_attn:
        actor_critic = Policy(envs.observation_space.shape,
                              envs.action_space,
                              base=MLPHardAttnBase,
                              base_kwargs={
                                  'recurrent':
                                  args.recurrent_policy or args.obs_recurrent
                              })
    else:
        actor_critic = Policy(envs.observation_space.shape,
                              envs.action_space,
                              base=MLPAttnBase,
                              base_kwargs={
                                  'recurrent':
                                  args.recurrent_policy or args.obs_recurrent
                              })
    actor_critic.to(device)

    if (args.continue_from_epoch > 0) and args.save_dir != "":
        save_path = os.path.join(args.save_dir, args.algo)
        actor_critic_, loaded_obs_rms_ = torch.load(
            os.path.join(
                save_path, args.env_name +
                "-epoch-{}.pt".format(args.continue_from_epoch)))
        actor_critic.load_state_dict(actor_critic_.state_dict())

    if args.algo != 'ppo':
        raise "only PPO is supported"
    agent = algo.PPO(actor_critic,
                     args.clip_param,
                     args.ppo_epoch,
                     args.num_mini_batch,
                     args.value_loss_coef,
                     args.entropy_coef,
                     lr=args.lr,
                     eps=args.eps,
                     num_tasks=args.num_processes,
                     attention_policy=False,
                     max_grad_norm=args.max_grad_norm,
                     weight_decay=args.weight_decay)
    val_agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.val_lr,
                         eps=args.eps,
                         num_tasks=args.num_processes,
                         attention_policy=True,
                         max_grad_norm=args.max_grad_norm,
                         weight_decay=args.weight_decay)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    val_rollouts = RolloutStorage(args.num_steps, args.num_processes,
                                  val_envs.observation_space.shape,
                                  val_envs.action_space,
                                  actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    val_obs = val_envs.reset()
    val_rollouts.obs[0].copy_(val_obs)
    val_rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes

    save_copy = True
    for j in range(args.continue_from_epoch,
                   args.continue_from_epoch + num_updates):

        # policy rollouts
        for step in range(args.num_steps):
            # Sample actions
            actor_critic.eval()
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])
            actor_critic.train()

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])
                    for k, v in info['episode'].items():
                        summary_writer.add_scalar(
                            f'training/{k}', v,
                            j * args.num_processes * args.num_steps +
                            args.num_processes * step)

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        actor_critic.eval()
        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()
        actor_critic.train()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        if save_copy:
            prev_weights = copy.deepcopy(actor_critic.state_dict())
            prev_opt_state = copy.deepcopy(agent.optimizer.state_dict())
            prev_val_opt_state = copy.deepcopy(
                val_agent.optimizer.state_dict())
            save_copy = False

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        # validation rollouts
        for val_iter in range(args.val_agent_steps):
            for step in range(args.num_steps):
                # Sample actions
                actor_critic.eval()
                with torch.no_grad():
                    value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                        val_rollouts.obs[step],
                        val_rollouts.recurrent_hidden_states[step],
                        val_rollouts.masks[step])
                actor_critic.train()

                # Obser reward and next obs
                obs, reward, done, infos = val_envs.step(action)

                # If done then clean the history of observations.
                masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                           for done_ in done])
                bad_masks = torch.FloatTensor(
                    [[0.0] if 'bad_transition' in info.keys() else [1.0]
                     for info in infos])
                val_rollouts.insert(obs, recurrent_hidden_states, action,
                                    action_log_prob, value, reward, masks,
                                    bad_masks)

            actor_critic.eval()
            with torch.no_grad():
                next_value = actor_critic.get_value(
                    val_rollouts.obs[-1],
                    val_rollouts.recurrent_hidden_states[-1],
                    val_rollouts.masks[-1]).detach()
            actor_critic.train()

            val_rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                         args.gae_lambda,
                                         args.use_proper_time_limits)

            val_value_loss, val_action_loss, val_dist_entropy = val_agent.update(
                val_rollouts)
            val_rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'obs_rms', None)
            ], os.path.join(save_path,
                            args.env_name + "-epoch-{}.pt".format(j)))

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))
        revert = False
        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            actor_critic.eval()
            obs_rms = utils.get_vec_normalize(envs).obs_rms
            eval_r = {}
            printout = f'Seed {args.seed} Iter {j} '
            for eval_disp_name, eval_env_name in EVAL_ENVS.items():
                eval_r[eval_disp_name] = evaluate(
                    actor_critic,
                    obs_rms,
                    eval_envs_dic,
                    eval_disp_name,
                    args.seed,
                    args.num_processes,
                    eval_env_name[1],
                    logdir,
                    device,
                    steps=args.task_steps,
                    recurrent=args.recurrent_policy,
                    obs_recurrent=args.obs_recurrent,
                    multi_task=True,
                    free_exploration=args.free_exploration)
                if eval_disp_name in prev_eval_r:
                    diff = np.array(eval_r[eval_disp_name]) - np.array(
                        prev_eval_r[eval_disp_name])
                    if eval_disp_name == 'many_arms':
                        if np.sum(diff > 0) - np.sum(
                                diff < 0) < args.val_improvement_threshold:
                            print('no update')
                            revert = True

                summary_writer.add_scalar(f'eval/{eval_disp_name}',
                                          np.mean(eval_r[eval_disp_name]),
                                          (j + 1) * args.num_processes *
                                          args.num_steps)
                log_dict[eval_disp_name].append([
                    (j + 1) * args.num_processes * args.num_steps,
                    eval_r[eval_disp_name]
                ])
                printout += eval_disp_name + ' ' + str(
                    np.mean(eval_r[eval_disp_name])) + ' '
            # summary_writer.add_scalars('eval_combined', eval_r, (j+1) * args.num_processes * args.num_steps)
            if revert:
                actor_critic.load_state_dict(prev_weights)
                agent.optimizer.load_state_dict(prev_opt_state)
                val_agent.optimizer.load_state_dict(prev_val_opt_state)
            else:
                print(printout)
                prev_eval_r = eval_r.copy()
            save_copy = True
            actor_critic.train()

    save_obj(log_dict, os.path.join(logdir, 'log_dict.pkl'))
    envs.close()
    val_envs.close()
    for eval_disp_name, eval_env_name in EVAL_ENVS.items():
        eval_envs_dic[eval_disp_name].close()
예제 #10
0
def main():
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, device, False)

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={'recurrent': args.recurrent_policy},
                          dimh=args.dimh)
    actor_critic.to(device)

    exp_name = "%s_%s_seed%d_dimh%d_" % (args.env_name, args.algo, args.seed,
                                         args.dimh)
    if args.gail:
        exp_name += '_gail_'

    if args.split:
        exp_name += 'splitevery' + str(args.split_every)
        if args.random_split:
            exp_name += '_rsplit'
    else:
        exp_name += 'baseline'

    writer = SummaryWriter('./runs/' + exp_name)

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    if args.gail:
        assert len(envs.observation_space.shape) == 1
        discr = gail.Discriminator(
            envs.observation_space.shape[0] + envs.action_space.shape[0], 100,
            device)
        file_name = os.path.join(
            args.gail_experts_dir,
            "trajs_{}.pt".format(args.env_name.split('-')[0].lower()))

        gail_train_loader = torch.utils.data.DataLoader(
            gail.ExpertDataset(file_name,
                               num_trajectories=4,
                               subsample_frequency=20),
            batch_size=args.gail_batch_size,
            shuffle=True,
            drop_last=True)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes
    print(num_updates)
    stats = {
        'seed': args.seed,
        'experiment': exp_name,
        'env': args.env_name,
        'dimh': args.dimh,
        'split every': args.split_every,
        'random split': args.random_split,
        'steps': [],
        'mean reward': [],
        'actor neurons': [],
        'critic neurons': [],
    }
    save_dir = './experiment_results/%s/' % args.env_name
    stats_save_path = save_dir + exp_name
    check_path(save_dir)
    print('start')
    count = -1
    num_updates = 488 * 2
    meanreward = []
    for j in range(num_updates):
        #if j % 50 == 0:
        #    print('STEP', j)
        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            count += 1
            if j % 488 == 0:
                count = 0
                total = 488 * 2
            else:
                total = 488 * 2
            if args.split:
                utils.update_linear_schedule(
                    agent.optimizer, count, total,
                    agent.optimizer.lr if args.algo == "acktr" else args.lr)
            else:
                utils.update_linear_schedule(
                    agent.optimizer, j, num_updates,
                    agent.optimizer.lr if args.algo == "acktr" else args.lr)
        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        if args.gail:
            if j >= 10:
                envs.venv.eval()

            gail_epoch = args.gail_epoch
            if j < 10:
                gail_epoch = 100  # Warm up
            for _ in range(gail_epoch):
                discr.update(gail_train_loader, rollouts,
                             utils.get_vec_normalize(envs)._obfilt)

            for step in range(args.num_steps):
                rollouts.rewards[step] = discr.predict_reward(
                    rollouts.obs[step], rollouts.actions[step], args.gamma,
                    rollouts.masks[step])

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        # splitting
        if args.split and (j + 1) % args.split_every == 0 and j < 200:
            print("[INFO] split on iteration %d..." % j)
            agent.split(rollouts, args.random_split)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], os.path.join(save_path, args.env_name + ".pt"))
        meanreward.append(np.mean(episode_rewards))
        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()

            if True:
                print(
                    "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                    .format(j, total_num_steps,
                            int(total_num_steps / (end - start)),
                            len(episode_rewards), np.mean(episode_rewards),
                            np.median(episode_rewards),
                            np.min(episode_rewards), np.max(episode_rewards),
                            dist_entropy, value_loss, action_loss))
            stats['mean reward'].append(np.mean(episode_rewards))
            stats['steps'].append(j)
            if args.split:
                a, c = agent.actor_critic.get_num_params()
                stats['actor neurons'].append(a)
                stats['critic neurons'].append(c)

        if j % 10 == 0:
            print("[INFO] saving to ", stats_save_path)
            np.save(stats_save_path, stats)

        if j % 5 == 0:
            s = (j + 1) * args.num_processes * args.num_steps
            if args.split:
                a, c = agent.actor_critic.get_num_params()
                writer.add_scalar('A neurons', a, s)
                writer.add_scalar('C neurons', c, s)
            writer.add_scalar('mean reward', np.mean(episode_rewards), s)
            writer.add_scalar('entropy loss', dist_entropy, s)
            writer.add_scalar('value loss', value_loss, s)
            writer.add_scalar('action loss', action_loss, s)

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(actor_critic, ob_rms, args.env_name, args.seed,
                     args.num_processes, eval_log_dir, device)

    writer.close()
    import pickle
    pickle.dump(meanreward, open(stats_save_path + '.pkl', 'wb'))
예제 #11
0
def main():
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    save_name = '%s_%s' % (args.env_name, args.algo)
    if args.postfix != '':
        save_name += ('_' + args.postfix)

    logger_filename = os.path.join(log_dir, save_name)
    logger = utils.create_logger(logger_filename)

    torch.set_num_threads(1)
    device = torch.device("cuda:%d" % args.gpu if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name,
                         args.seed,
                         args.num_processes,
                         args.gamma,
                         args.log_dir,
                         device,
                         False,
                         4,
                         obs_type="grid" if args.grid else "image",
                         skip_frames=args.num_skip_frames)

    if args.load_dir != None:
        actor_critic, ob_rms = \
                torch.load(os.path.join(args.load_dir), map_location=lambda storage, loc: storage)
        vec_norm = utils.get_vec_normalize(envs)
        if vec_norm is not None:
            vec_norm.ob_rms = ob_rms
        print("load pretrained...")
    else:
        actor_critic = Policy(envs.observation_space.shape,
                              envs.action_space,
                              base="grid" if args.grid else None,
                              base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    if args.gail:
        assert len(envs.observation_space.shape) == 1
        discr = gail.Discriminator(
            envs.observation_space.shape[0] + envs.action_space.shape[0], 100,
            device)
        file_name = os.path.join(
            args.gail_experts_dir,
            "trajs_{}.pt".format(args.env_name.split('-')[0].lower()))

        gail_train_loader = torch.utils.data.DataLoader(
            gail.ExpertDataset(file_name,
                               num_trajectories=4,
                               subsample_frequency=20),
            batch_size=args.gail_batch_size,
            shuffle=True,
            drop_last=True)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)
    lines = deque(maxlen=10)
    start = time.time()
    kk = 0
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes
    # learning_start = 1000
    learning_start = 0
    best_reward = -100
    for j in range(num_updates):

        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr)

        explore = exploration_rate(j - learning_start, 'exp')
        # print(j)
        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # if j < learning_start:
            #     action[0, 0] = random.randint(0, envs.action_space.n - 1)
            # elif random.uniform(0, 1) < explore:
            #     action[0, 0] = random.randint(0, envs.action_space.n - 1)
            # else:
            #     pass

            # Obser reward and next obs
            # action[0, 0] = 1
            # envs.take_turns()
            obs, reward, done, infos = envs.step(action)
            # print(obs)

            # im = Image.fromarray(obs[0].reshape(224 * 4, -1).cpu().numpy().astype(np.uint8))
            # im.save("samples/%d.png" % kk)
            # kk += 1
            # info = infos[0]
            # if len(info) > 0:
            #     print(info)
            # print(done)
            # print(infos)
            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])
                if 'sent' in info.keys():
                    lines.append(info['sent'])

            # kk += 1
            # print(action.shape)
            # print(obs.shape)
            # print(done.shape)
            # if done[0]:
            #     print(time.time() - start)
            #     print(kk)
            #     exit()

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        if args.gail:
            if j >= 10:
                envs.venv.eval()

            gail_epoch = args.gail_epoch
            if j < 10:
                gail_epoch = 100  # Warm up
            for _ in range(gail_epoch):
                discr.update(gail_train_loader, rollouts,
                             utils.get_vec_normalize(envs)._obfilt)

            for step in range(args.num_steps):
                rollouts.rewards[step] = discr.predict_reward(
                    rollouts.obs[step], rollouts.actions[step], args.gamma,
                    rollouts.masks[step])

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "" \
                and np.mean(episode_rewards) > best_reward:
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass
            best_reward = np.mean(episode_rewards)
            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], os.path.join(save_path, save_name + ".pt"))

        # print(episode_rewards)

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            if j < learning_start:
                logger.info("random action")
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            logger.info(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))

            logger.info(
                ' lines sent: mean/median lines {:.1f}/{:.1f}, min/max lines {:.1f}/{:.1f}\n'
                .format(np.mean(lines), np.median(lines), np.min(lines),
                        np.max(lines)))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(actor_critic, ob_rms, args.env_name, args.seed,
                     args.num_processes, eval_log_dir, device)
예제 #12
0
def main():
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    if not args.debug:
        wandb.init(project="ops")

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, device, False)

    # render_func = envs.venv.venv.get_images_single
    flip, flip1 = False, False

    def make_agent(is_leaf=True):
        ## AGENT CONSTRUCTION:
        ## Modularize this and allow for cascading (obs dim for child policy should be cat of obs and parents output)
        actor_critic = OpsPolicy(
            envs.observation_space.shape,
            envs.action_space if is_leaf else gym.spaces.Discrete(2),
            is_leaf=is_leaf,
            base_kwargs=dict(recurrent=True,
                             partial_obs=args.partial_obs,
                             gate_input=args.gate_input))

        actor_critic.to(device)
        # wandb.watch(actor_critic.base)

        if args.algo == 'a2c':
            agent = algo.A2C_ACKTR(actor_critic,
                                   args.value_loss_coef,
                                   args.entropy_coef,
                                   args.pred_loss_coef,
                                   lr=args.lr,
                                   eps=args.eps,
                                   alpha=args.alpha,
                                   max_grad_norm=args.max_grad_norm)
        elif args.algo == 'ppo':
            agent = algo.PPO(actor_critic,
                             args.clip_param,
                             args.ppo_epoch,
                             args.num_mini_batch,
                             args.value_loss_coef,
                             args.entropy_coef,
                             args.pred_loss_coef,
                             lr=args.lr,
                             eps=args.eps,
                             max_grad_norm=args.max_grad_norm)
        elif args.algo == 'acktr':
            agent = algo.A2C_ACKTR(actor_critic,
                                   args.value_loss_coef,
                                   args.entropy_coef,
                                   acktr=True)

        rollouts = RolloutStorage(args.num_steps,
                                  args.num_processes,
                                  envs.observation_space.shape,
                                  envs.action_space,
                                  actor_critic.recurrent_hidden_state_size,
                                  info_size=2 if is_leaf else 0)

        return actor_critic, agent, rollouts

    root = make_agent(is_leaf=False)
    leaf = make_agent(is_leaf=True)
    actor_critic, agent, rollouts = list(zip(root, leaf))

    obs = envs.reset()
    for r in rollouts:
        r.obs[0].copy_(obs)
        r.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes

    img_trajs = [[]]

    def act(i, step, **kwargs):
        with torch.no_grad():
            value, action, action_log_prob, recurrent_hidden_states = actor_critic[
                i].act(rollouts[i].obs[step],
                       rollouts[i].recurrent_hidden_states[step],
                       rollouts[i].masks[step], **kwargs)

            return value, action, action_log_prob, recurrent_hidden_states

    for j in range(num_updates):

        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            (utils.update_linear_schedule(
                _agent.optimizer, j, num_updates,
                _agent.optimizer.lr if args.algo == "acktr" else args.lr)
             for _agent in agent)

        diags = dict()
        for step in range(args.num_steps):
            # Sample actions
            value1, action1, action_log_prob1, recurrent_hidden_states1 = act(
                0, step)

            if np.random.random() > 0.9:
                print(action1.numpy().tolist())

            # TODO make sure the last index of actions is the right hting to do
            last_action = 1 + rollouts[1].actions[step - 1]

            # import pdb; pdb.set_trace()
            value2, action2, action_log_prob2, recurrent_hidden_states2 = act(
                1, step, info=torch.cat([action1, last_action], dim=1))

            action = action2
            recurrent_hidden_states = recurrent_hidden_states2

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(
                torch.cat([action1, action2], dim=-1))

            for k in infos[0]:
                if k != 'episode':
                    diags[k] = diags.get(k, [])
                    diags[k].append(np.array([info[k] for info in infos]))

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])

            bonus = action1.float() * 0
            bonus[action1 > 0] = args.bonus1
            bonus[action1 == 0] = args.bonus1 * 4
            int_rew = bonus * -1

            # import pdb; pdb.set_trace()

            # print(reward, int_rew)
            # reward

            rollouts[0].insert(obs,
                               recurrent_hidden_states,
                               action1,
                               action_log_prob1,
                               value1,
                               reward + int_rew,
                               masks,
                               bad_masks,
                               infos=None)

            rollouts[1].insert(obs,
                               recurrent_hidden_states,
                               action2,
                               action_log_prob2,
                               value2,
                               reward,
                               masks,
                               bad_masks,
                               infos=torch.cat([action1, last_action], dim=1))

            # if j % 100 == 0 or flip:
            #     flip = True

            #     if masks[0].item() < 1 or flip1:
            #         flip1 = True

            #         if len(img_trajs) > 5:
            #             mean = lambda x: sum(x)/(len(x)+1)
            #             flat = lambda x: [xxx for xx in x for xxx in xx]
            #             norm = lambda x: (x-x.min()) / (x-x.min()).max()

            #             i1 = [[ii[0] for ii in traj if ii[1] == 0] for traj in img_trajs]
            #             i2 = [[ii[0] for ii in traj if ii[1] == 1] for traj in img_trajs]

            #             i1, i2 = flat(i1), flat(i2)
            #             i1, i2 = mean(i1) * 255.0, mean(i2) * 255.0

            #             # i1 = mean([mean(l) for l in i1])*255.0
            #             # i2 = mean([mean(l) for l in i2])*255.0
            #             # import pdb; pdb.set_trace()

            #             img = np.concatenate([norm(ii) for ii in (i1, i2) if not isinstance(ii, float)], axis=1)

            #             wandb.log({"%s" % j: [wandb.Image(img, caption="capture - predict")]})

            #             img_trajs = [[]]
            #             flip, flip1 = False, False

            #         if masks[0].item() < 1 and len(img_trajs[-1]) > 0:
            #             img_trajs.append([])

            #         imgs = render_func('rgb_array')
            #         img_trajs[-1].append((imgs, action1[0].item()))

        def update(i, info=None):
            with torch.no_grad():
                next_value = actor_critic[i].get_value(
                    rollouts[i].obs[-1],
                    rollouts[i].recurrent_hidden_states[-1],
                    rollouts[i].masks[-1],
                    info=info).detach()

            rollouts[i].compute_returns(next_value, args.use_gae, args.gamma,
                                        args.gae_lambda,
                                        args.use_proper_time_limits)

            value_loss, action_loss, dist_entropy, pred_err = agent[i].update(
                rollouts[i], pred_loss=i != 0)

            rollouts[i].after_update()

            return value_loss, action_loss, dist_entropy, pred_err

        if j % 2 == 0 or True:
            value_loss1, action_loss1, dist_entropy1, pred_err1 = update(0)
        if (j % 2) == 1 or True:
            _, action1, _, _ = actor_critic[0].act(
                rollouts[0].obs[-1], rollouts[0].recurrent_hidden_states[-1],
                rollouts[0].masks[-1])

            value_loss2, action_loss2, dist_entropy2, pred_err2 = update(
                1,
                info=torch.cat([action1, rollouts[1].actions[-1] + 1], dim=-1))

        # value_loss, action_loss, dist_entropy = list(zip((update(i) for i in range(len(agent)))))

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], os.path.join(save_path, args.env_name + ".pt"))

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()

            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}"
                .format(
                    j,
                    total_num_steps,
                    int(total_num_steps / (end - start)),
                    len(episode_rewards),
                    np.mean(episode_rewards),
                    np.median(episode_rewards),
                    np.min(episode_rewards),
                    np.max(episode_rewards),
                ))

            if not args.debug:
                wandb.log(
                    dict(
                        median_reward=np.median(episode_rewards),
                        mean_reward=np.mean(episode_rewards),
                        min_reward=np.min(episode_rewards),
                        max_reward=np.max(episode_rewards),
                    ))

            if j % 5 == 0:

                data = rollouts[0].obs[:-1]
                gate = rollouts[0].actions.byte().squeeze()

                capt = data[1 - gate]
                # x1, v1, ang1, angv1 = make_histograms(capt.numpy())
                pred = data[gate]
                # x2, v2, ang2, angv2 = make_histograms(pred.numpy())

                #########################
                # if j%50 == 0 and not args.debug:
                #     from sklearn.manifold import TSNE
                #     from matplotlib import pyplot as plt
                #     from matplotlib import cm

                #     comb = torch.cat([capt, pred], dim=0)

                #     xx = TSNE(n_components=2).fit_transform(comb)
                #     cc = np.array([0]*capt.shape[0] + [9]*pred.shape[0])

                #     plt.scatter(xx[:, 0], xx[:, 1],
                #         c=cc, cmap=plt.cm.get_cmap("jet", 10),
                #         alpha=0.9, s=50)
                #     wandb.log({
                #         "tsne %s" % j: plt,
                #     })
                # plt.colorbar(ticks=range(2))
                # plt.clim(-0.5, 9.5)
                #########################
                # import pdb; pdb.set_trace()

                if not args.debug:
                    # wandb_lunarlander(capt, pred)
                    logging.wandb_minigrid(capt, pred, gate, diags)

            if not args.debug:
                if (j % 2) == 0 or True:
                    wandb.log(
                        dict(
                            ent1=dist_entropy1,
                            val1=value_loss1,
                            aloss1=action_loss1,
                        ))
                    print("ent1 {:.4f}, val1 {:.4f}, loss1 {:.4f}\n".format(
                        dist_entropy1, value_loss1, action_loss1))
                if (j % 2) == 1 or True:
                    wandb.log(
                        dict(ent2=dist_entropy2,
                             val2=value_loss2,
                             aloss2=action_loss2,
                             prederr2=pred_err2))
                    print(
                        "ent2 {:.4f}, val2 {:.4f}, loss2 {:.4f}, prederr2 {:.4f}\n"
                        .format(dist_entropy2, value_loss2, action_loss2,
                                pred_err2))

                wandb.log(
                    dict(mean_gt=rollouts[0].actions.float().mean().item()))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(actor_critic, ob_rms, args.env_name, args.seed,
                     args.num_processes, eval_log_dir, device)
예제 #13
0
파일: main.py 프로젝트: voot-t/ril_co
def main():
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, device, False)

    actor_critic = Policy(
        envs.observation_space.shape,
        envs.action_space,
        base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    print("Env obs space shape")
    print(envs.observation_space.shape)
    clip_action = False 
    a_high = None 
    a_low = None 

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(
            actor_critic,
            args.value_loss_coef,
            args.entropy_coef,
            lr=args.lr,
            eps=args.eps,
            alpha=args.alpha,
            max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(
            actor_critic,
            args.clip_param,
            args.ppo_epoch,
            args.num_mini_batch,
            args.value_loss_coef,
            args.entropy_coef,
            lr=args.lr,
            eps=args.eps,
            max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(
            actor_critic, args.value_loss_coef, args.entropy_coef, acktr=True)

    method_type = "RL"
    method_name = args.algo.upper()
    hypers = "ec%0.5f" % args.entropy_coef
    exp_name = "%s-%s_s%d" % (method_name, hypers, args.seed)

    if args.gail:
        if args.is_atari:   # Do not have GAIL implementation for Atari games
            raise NotImplementedError   

        from a2c_ppo_acktr.algo import ail, ril, vild
        if args.il_algo.upper() == "AIL":
            discr = ail.AIL(
                envs.observation_space, envs.action_space, device, args)
        elif args.il_algo.upper() == "AIRL":
            discr = ail.AIRL(
                envs.observation_space, envs.action_space, device, args)
        elif args.il_algo.upper() == "FAIRL":
            discr = ail.FAIRL(
                envs.observation_space, envs.action_space, device, args)
        elif args.il_algo.upper() == "VILD":
            discr = vild.VILD(
                envs.observation_space, envs.action_space, device, args)
        elif args.il_algo.upper() == "RIL_CO":
            discr = ril.RIL_CO(
                envs.observation_space, envs.action_space, device, args)
        elif args.il_algo.upper() == "RIL":
            discr = ril.RIL(
                envs.observation_space, envs.action_space, device, args)

        # method name. 
        method_type = "IL"
        method_name = args.algo.upper() + "_" + args.il_algo.upper() 

        # demonstrations specification name. 
        if args.noise_type == "policy":     # noisy policy (non-expert policy snapshots)
            traj_name = "np%0.1f" % args.noise_prior 
        elif args.noise_type == "action":   # noisy action (add Gaussian noise to actions)
            traj_name = "na%0.1f" % args.noise_prior             
        if args.traj_deterministic:
            traj_name += "_det"
        else:
            traj_name += "_sto" 

        # hyper-parameters in file name. 
        hypers += "_gp%0.3f" % args.gp_lambda    
        
        if "AIL" in args.il_algo.upper():
            hypers += "_%s_sat%d" % (args.ail_loss_type, args.ail_saturate)
            
        if "VILD" in args.il_algo.upper():
            hypers += "_%s" % (args.ail_loss_type)
            if args.ail_saturate != 1:
                hypers += "_sat%d" % (args.ail_saturate)
                            
        if "RIL" in args.il_algo.upper() :
            hypers += "_%s_sat%d" % (args.ail_loss_type, args.ail_saturate) 

        if args.reward_std: hypers += "_rs"

        exp_name = "%s-%s-%s_s%d" % (traj_name, method_name, hypers, args.seed)

    # set directory path for result text files. 
    result_path = "./results_%s/%s/%s/%s-%s" % (method_type, method_name, args.env_name, args.env_name, exp_name)
    pathlib.Path("./results_%s/%s/%s" % (method_type, method_name, args.env_name)).mkdir(parents=True, exist_ok=True) 
    print("Running %s" % (colored(method_name, p_color)))
    print("%s result will be saved at %s" % (colored(method_name, p_color), colored(result_path, p_color)))

    # set directory path for policy model files. 
    model_name = "%s-%s" % (args.env_name, exp_name)
    save_path = os.path.join(args.save_dir, method_name, args.env_name)
    pathlib.Path(save_path).mkdir(parents=True, exist_ok=True) 

    if args.eval_interval is not None:     
        eval_envs = make_vec_envs(args.env_name, args.seed, 1,
                              None, eval_log_dir, device, True)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    tt_g, tt_d = 0, 0
    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes
    print("Update iterations: %d" % num_updates)  #  ~~ 15000
    for j in range(num_updates):

        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr)

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Observe reward and next obs
            obs, reward, done, infos = envs.step(action)

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor(
                [[0.0] if done_ else [1.0] for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        if args.gail:
            if j >= 10:
                envs.venv.eval()

            gail_epoch = 1 #args.gail_epoch
            # if j < 10:
            #     gail_epoch = 100  # Warm up
            
            obfilt = utils.get_vec_normalize(envs)._obfilt    
            t0 = time.time()

            for _ in range(gail_epoch):
                discr.update(rollouts, obfilt)

            tt_d += time.time() - t0

            for step in range(args.num_steps):
                rollouts.rewards[step] = discr.predict_reward(
                    rollouts.obs[step], rollouts.actions[step], args.gamma,
                    rollouts.masks[step])

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        t0 = time.time()
        value_loss, action_loss, dist_entropy = agent.update(rollouts)
        tt_g += time.time() - t0

        rollouts.after_update()

        total_num_steps = (j + 1) * args.num_processes * args.num_steps
        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], os.path.join(save_path, model_name + ("T%d.pt" % total_num_steps)))

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            end = time.time()
            # print(
            #     "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
            #     .format(j, total_num_steps,
            #             int(total_num_steps / (end - start)),
            #             len(episode_rewards), np.mean(episode_rewards),
            #             np.median(episode_rewards), np.min(episode_rewards),
            #             np.max(episode_rewards), dist_entropy, value_loss,
            #             action_loss))

        # if (args.eval_interval is not None and len(episode_rewards) > 1
        #         and j % args.eval_interval == 0):
        if args.eval_interval is not None and j % args.eval_interval == 0:
            ob_rms = None if args.is_atari else utils.get_vec_normalize(envs).ob_rms  
            eval_episode_rewards = evaluate(actor_critic, ob_rms, args.env_name, args.seed,
                                        args.num_processes, eval_log_dir, device, eval_envs, clip_action, a_low, a_high)

            ## Save results as text file            
            result_text = t_format("Step %8d " % (total_num_steps), 0) 
            result_text += t_format("(g%2.1f+d%2.1f)s" % (tt_g, tt_d), 1) 

            c_reward_list = rollouts.rewards.to(device_cpu).detach().numpy()

            ## Statistics of discriminator rewards during training
            result_text += " | [D] " + t_format("min: %.2f" % np.amin(c_reward_list), 0.5) + t_format(" max: %.2f" % np.amax(c_reward_list), 0.5)

            ## Environment reward in test trajectories from learned policies. 
            result_text += " | [R_te] "
            result_text += t_format("min: %.2f" % np.min(eval_episode_rewards) , 1) + t_format("max: %.2f" % np.max(eval_episode_rewards), 1) \
                + t_format("Avg: %.2f (%.2f)" % (np.mean(eval_episode_rewards), np.std(eval_episode_rewards)/np.sqrt(len(eval_episode_rewards))), 2)
    
            if args.il_algo.upper() == "VILD":
                ## check estimated noise
                estimated_worker_noise = discr.worker_net.get_worker_cov().to(device_cpu).detach().numpy().squeeze()
                if envs.action_space.shape[0] > 1:
                    estimated_worker_noise = estimated_worker_noise.mean(axis=0)  #average across action dim
                result_text += " | w_noise: %s" % (np.array2string(estimated_worker_noise, formatter={'float_kind':lambda x: "%.5f" % x}).replace('\n', '') )
                    
            print(result_text)
            with open(result_path + ".txt", 'a') as f:
                print(result_text, file=f) 
            tt_g, tt_d = 0, 0
예제 #14
0
def main():
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, device, False)

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)
    elif args.algo == 'random':
        agent = RandomAgent(envs.action_space, args.env_name, device)
    else:
        raise NotImplementedError

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes
    for j in range(num_updates):

        if args.use_linear_lr_decay and args.algo != 'random':
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr)

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])
                if args.algo == 'random':
                    value, recurrent_hidden_states = torch.rand(
                        [1]).to(device), torch.rand([1]).to(device)
                    action, action_log_prob, dist_entropy = agent.act()

            obs, reward, done, infos = envs.step(action)

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            #if np.isnan(obs.cpu().numpy()).any():
            #    obs = np.nan_to_num(obs.cpu().numpy())
            #    obs = torch.from_numpy(obs).float().to(device)
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        if args.algo != 'random':
            value_loss, action_loss, dist_entropy = agent.update(rollouts)
        else:
            value_loss, action_loss = 0.0, 0.0

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], os.path.join(save_path, args.env_name + "_{}_.pt".format(j)))

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "{} Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(args.env_name, j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            #if len(envs.observation_space.shape) == 1:
            #    ob_rms = utils.get_vec_normalize(envs).ob_rms
            #    print('ob_rms', ob_rms)
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(actor_critic, ob_rms, args.env_name, args.seed,
                     args.num_processes, eval_log_dir, device)
예제 #15
0
파일: main.py 프로젝트: azarafrooz/corgail
def main():
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    # coinrun environments need to be treated differently.
    coinrun_envs = {
        'CoinRun': 'standard',
        'CoinRun-Platforms': 'platform',
        'Random-Mazes': 'maze'
    }

    envs = make_vec_envs(args.env_name,
                         args.seed,
                         args.num_processes,
                         args.gamma,
                         args.log_dir,
                         device,
                         False,
                         coin_run_level=args.num_levels,
                         difficulty=args.high_difficulty,
                         coin_run_seed=args.seed)
    if args.env_name in coinrun_envs.keys():
        observation_space_shape = (3, 64, 64)
        args.save_dir = args.save_dir + "/NUM_LEVELS_{}".format(
            args.num_levels)  # Save the level info in the

    else:
        observation_space_shape = envs.observation_space.shape

    # trained model name
    if args.continue_ppo_training:
        actor_critic, _ = torch.load(os.path.join(args.check_point,
                                                  args.env_name + ".pt"),
                                     map_location=torch.device(device))
    elif args.cor_gail:
        embed_size = args.embed_size
        actor_critic = Policy(observation_space_shape,
                              envs.action_space,
                              hidden_size=args.hidden_size,
                              embed_size=embed_size,
                              base_kwargs={'recurrent': args.recurrent_policy})
        actor_critic.to(device)
        correlator = Correlator(observation_space_shape,
                                envs.action_space,
                                hidden_dim=args.hidden_size,
                                embed_dim=embed_size,
                                lr=args.lr,
                                device=device)

        correlator.to(device)
        embeds = torch.zeros(1, embed_size)
    else:
        embed_size = 0
        actor_critic = Policy(observation_space_shape,
                              envs.action_space,
                              hidden_size=args.hidden_size,
                              base_kwargs={'recurrent': args.recurrent_policy})
        actor_critic.to(device)
        embeds = None

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm,
                         use_clipped_value_loss=True,
                         ftrl_mode=args.cor_gail or args.no_regret_gail,
                         correlated_mode=args.cor_gail)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    if args.gail or args.no_regret_gail or args.cor_gail:
        file_name = os.path.join(
            args.gail_experts_dir,
            "trajs_{}.pt".format(args.env_name.split('-')[0].lower()))

        expert_dataset = gail.ExpertDataset(
            file_name, num_trajectories=50,
            subsample_frequency=1)  #if subsample set to a different number,
        # grad_pen might need adjustment
        drop_last = len(expert_dataset) > args.gail_batch_size
        gail_train_loader = torch.utils.data.DataLoader(
            dataset=expert_dataset,
            batch_size=args.gail_batch_size,
            shuffle=True,
            drop_last=drop_last)
        if args.gail:
            discr = gail.Discriminator(observation_space_shape,
                                       envs.action_space,
                                       device=device)
        if args.no_regret_gail or args.cor_gail:
            queue = deque(
                maxlen=args.queue_size
            )  # Strategy Queues: Each element of a queue is a dicr strategy
            agent_queue = deque(
                maxlen=args.queue_size
            )  # Strategy Queues: Each element of a queue is an agent strategy
            pruning_frequency = 1
        if args.no_regret_gail:
            discr = regret_gail.NoRegretDiscriminator(observation_space_shape,
                                                      envs.action_space,
                                                      device=device)
        if args.cor_gail:
            discr = cor_gail.CorDiscriminator(observation_space_shape,
                                              envs.action_space,
                                              hidden_size=args.hidden_size,
                                              embed_size=embed_size,
                                              device=device)
        discr.to(device)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              observation_space_shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size,
                              embed_size)

    obs = envs.reset()

    rollouts.obs[0].copy_(obs)
    if args.cor_gail:
        rollouts.embeds[0].copy_(embeds)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes
    for j in range(num_updates):
        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr)

        for step in range(args.num_steps):
            # Sample actions # Roll-out
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step], rollouts.embeds[step])

            obs, reward, done, infos = envs.step(action.to('cpu'))
            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])

            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)
            # Sample mediating/correlating actions # Correlated Roll-out
            if args.cor_gail:
                embeds, embeds_log_prob, mean = correlator.act(
                    rollouts.obs[step], rollouts.actions[step])
                rollouts.insert_embedding(embeds, embeds_log_prob)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1], rollouts.embeds[-1]).detach()

        if args.gail or args.no_regret_gail or args.cor_gail:
            if args.env_name not in {'CoinRun', 'Random-Mazes'}:
                if j >= 10:
                    envs.venv.eval()

            gail_epoch = args.gail_epoch
            if args.gail:
                if j < 10:
                    gail_epoch = 100  # Warm up

                # no need for gail epoch or warm up in the no-regret case and cor_gail.
            for _ in range(gail_epoch):
                if utils.get_vec_normalize(envs):
                    obfilt = utils.get_vec_normalize(envs)._obfilt
                else:
                    obfilt = None

                if args.gail:
                    discr.update(gail_train_loader, rollouts, obfilt)

                if args.no_regret_gail or args.cor_gail:
                    last_strategy = discr.update(gail_train_loader, rollouts,
                                                 queue, args.max_grad_norm,
                                                 obfilt, j)

            for step in range(args.num_steps):
                if args.gail:
                    rollouts.rewards[step] = discr.predict_reward(
                        rollouts.obs[step], rollouts.actions[step], args.gamma,
                        rollouts.masks[step])
                if args.no_regret_gail:
                    rollouts.rewards[step] = discr.predict_reward(
                        rollouts.obs[step], rollouts.actions[step], args.gamma,
                        rollouts.masks[step], queue)
                if args.cor_gail:
                    rollouts.rewards[
                        step], correlator_reward = discr.predict_reward(
                            rollouts.obs[step], rollouts.actions[step],
                            rollouts.embeds[step], args.gamma,
                            rollouts.masks[step], queue)

                    rollouts.correlated_reward[step] = correlator_reward

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        if args.gail:
            value_loss, action_loss, dist_entropy = agent.update(rollouts, j)

        elif args.no_regret_gail or args.cor_gail:
            value_loss, action_loss, dist_entropy, agent_gains, agent_strategy = \
                agent.mixed_update(rollouts, agent_queue, j)

        if args.cor_gail:
            correlator.update(rollouts, agent_gains, args.max_grad_norm)

        if args.no_regret_gail or args.cor_gail:
            queue, _ = utils.queue_update(queue, pruning_frequency,
                                          args.queue_size, j, last_strategy)
            agent_queue, pruning_frequency = utils.queue_update(
                agent_queue, pruning_frequency, args.queue_size, j,
                agent_strategy)

        rollouts.after_update()
        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            if not args.cor_gail:
                torch.save([
                    actor_critic,
                    getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
                ], os.path.join(save_path, args.env_name + ".pt"))

            else:
                print("saving models in {}".format(
                    os.path.join(save_path, args.env_name)))
                torch.save(
                    correlator.state_dict(),
                    os.path.join(save_path, args.env_name + "correlator.pt"))
                torch.save([
                    actor_critic.state_dict(),
                    getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
                ], os.path.join(save_path, args.env_name + "actor.pt"))

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f},"
                " value loss/action loss {:.1f}/{}".format(
                    j, total_num_steps, int(total_num_steps / (end - start)),
                    len(episode_rewards), np.mean(episode_rewards),
                    np.median(episode_rewards), value_loss, action_loss))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(actor_critic, ob_rms, args.env_name, args.seed,
                     args.num_processes, eval_log_dir, device)
예제 #16
0
def main():
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")
    receipts = StorageReceipt()
    make_env = lambda tasks: MiniWoBGraphEnvironment(
        base_url=os.environ.get("BASE_URL", f"file://{MINIWOB_HTML}/"),
        levels=tasks,
        level_tracker=LevelTracker(tasks),
        wait_ms=500,
    )

    task = args.env_name
    if args.env_name == "PongNoFrameskip-v4":
        args.env_name = "clickbutton"
        task = "miniwob/click-button.html"
    if task == "levels":
        tasks = MINIWOB_CHALLENGES
    else:
        tasks = [[task]]
    print("Selected tasks:", tasks)
    NUM_ACTIONS = 1
    envs = make_vec_envs(
        [make_env(tasks[i % len(tasks)]) for i in range(args.num_processes)],
        receipts)

    if os.path.exists("./datadir/autoencoder.pt"):
        dom_autoencoder = torch.load("./datadir/autoencoder.pt")
        dom_encoder = dom_autoencoder.encoder
        for param in dom_encoder.parameters():
            param.requires_grad = False
    else:
        print("No dom encoder")
        dom_encoder = None
    actor_critic = Policy(
        envs.observation_space.shape,
        gym.spaces.Discrete(NUM_ACTIONS),  # envs.action_space,
        base=GNNBase,
        base_kwargs={
            "dom_encoder": dom_encoder,
            "recurrent": args.recurrent_policy
        },
    )
    actor_critic.dist = NodeObjective()
    actor_critic.to(device)

    if args.algo == "a2c":
        agent = algo.A2C_ACKTR(
            actor_critic,
            args.value_loss_coef,
            args.entropy_coef,
            lr=args.lr,
            eps=args.eps,
            alpha=args.alpha,
            max_grad_norm=args.max_grad_norm,
        )
    elif args.algo == "ppo":
        agent = algo.PPO(
            actor_critic,
            args.clip_param,
            args.ppo_epoch,
            args.num_mini_batch,
            args.value_loss_coef,
            args.entropy_coef,
            lr=args.lr,
            eps=args.eps,
            max_grad_norm=args.max_grad_norm,
        )
    elif args.algo == "acktr":
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    if args.gail:
        assert len(envs.observation_space.shape) == 1
        discr = gail.Discriminator(envs.observation_space.shape[0], 100,
                                   device)

        rr = ReplayRepository("/code/miniwob-plusplus-demos/*turk/*")
        ds = rr.get_dataset()
        print("GAIL Replay Dataset", ds)
        gail_train_loader = torch_geometric.data.DataLoader(
            ds, batch_size=args.gail_batch_size, shuffle=True, drop_last=True)

    from tensorboardX import SummaryWriter
    import datetime

    ts_str = datetime.datetime.fromtimestamp(
        time.time()).strftime("%Y-%m-%d_%H-%M-%S")
    tensorboard_writer = SummaryWriter(
        log_dir=os.path.join("/tmp/log", ts_str))

    rollouts = ReceiptRolloutStorage(
        args.num_steps,
        args.num_processes,
        (1, ),  # envs.observation_space.shape,
        envs.action_space,
        actor_critic.recurrent_hidden_state_size,
        receipts,
    )

    # resume from last save
    if args.save_dir != "":
        save_path = os.path.join(args.save_dir, args.algo)
        try:
            os.makedirs(save_path)
        except OSError:
            pass

        model_path = os.path.join(save_path, args.env_name + ".pt")
        if False and os.path.exists(model_path):
            print("Loadng previous model:", model_path)
            actor_critic = torch.load(model_path)
            actor_critic.train()

    obs = envs.reset()
    rollouts.obs[0].copy_(torch.tensor(obs))
    rollouts.to(device)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes
    print("Iterations:", num_updates, args.num_steps)
    for j in range(num_updates):
        episode_rewards = deque(maxlen=args.num_steps * args.num_processes)
        if j and last_action_time + 5 < time.time():
            # task likely timed out
            print("Reseting tasks")
            obs = envs.reset()
            rollouts.obs[0].copy_(torch.tensor(obs))
            rollouts.recurrent_hidden_states[0].copy_(
                torch.zeros_like(rollouts.recurrent_hidden_states[0]))
            rollouts.masks[0].copy_(torch.zeros_like(rollouts.masks[0]))

        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer,
                j,
                num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr,
            )

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    receipts.redeem(rollouts.obs[step]),
                    rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step],
                )

            # Obser reward and next obs
            last_action_time = time.time()
            obs, reward, done, infos = envs.step(action)

            for e, i in enumerate(infos):
                if i.get("real_action") is not None:
                    action[e] = i["real_action"]
                if i.get("bad_transition"):
                    action[e] = torch.zeros_like(action[e])

            for info in infos:
                if "episode" in info.keys():
                    episode_rewards.append(info["episode"]["r"])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if "bad_transition" in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(
                torch.tensor(obs),
                recurrent_hidden_states,
                action,
                action_log_prob,
                value,
                torch.tensor(reward).unsqueeze(1),
                masks,
                bad_masks,
            )

        with torch.no_grad():
            next_value = actor_critic.get_value(
                receipts.redeem(rollouts.obs[-1]),
                rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1],
            ).detach()

        if args.gail:
            # if j >= 10:
            #    envs.venv.eval()

            gail_epoch = args.gail_epoch
            if j < 10:
                gail_epoch = 100  # Warm up
            for _ in range(gail_epoch):
                obsfilt = lambda x, update: x  # utils.get_vec_normalize(envs)._obfilt
                gl = discr.update(gail_train_loader, rollouts, obsfilt)
            print("Gail loss:", gl)

            for step in range(args.num_steps):
                rollouts.rewards[step] = discr.predict_reward(
                    receipts.redeem(rollouts.obs[step]),
                    rollouts.actions[step],
                    args.gamma,
                    rollouts.masks[step],
                )

        rollouts.compute_returns(
            next_value,
            args.use_gae,
            args.gamma,
            args.gae_lambda,
            args.use_proper_time_limits,
        )

        value_loss, action_loss, dist_entropy = agent.update(rollouts)
        obs_shape = rollouts.obs.size()[2:]
        obs = rollouts.obs[:-1].view(-1, *obs_shape)
        obs = obs[torch.randint(0, obs.size(0), (1, 32))]

        rollouts.after_update()

        receipts.prune(rollouts.obs)

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            model_path = os.path.join(save_path, args.env_name + ".pt")
            torch.save(actor_critic, model_path)
            print("Saved model:", model_path)

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(
                    j,
                    total_num_steps,
                    int(total_num_steps / (end - start)),
                    len(episode_rewards),
                    np.mean(episode_rewards),
                    np.median(episode_rewards),
                    np.min(episode_rewards),
                    np.max(episode_rewards),
                    dist_entropy,
                    value_loss,
                    action_loss,
                ))

            from pprint import pprint

            pprint(LevelTracker.global_scoreboard)

            # tensorboard_writer.add_histogram(
            #    "task_ranks", torch.tensor(predictor._difficulty_rank), total_num_steps
            # )
            tensorboard_writer.add_histogram("value", value, total_num_steps)
            tensorboard_writer.add_histogram("x", actor_critic.base.last_x,
                                             total_num_steps)
            tensorboard_writer.add_histogram("query",
                                             actor_critic.base.last_query,
                                             total_num_steps)
            tensorboard_writer.add_histogram("inputs_at",
                                             actor_critic.base.last_inputs_at,
                                             total_num_steps)

            tensorboard_writer.add_scalar("mean_reward",
                                          np.mean(episode_rewards),
                                          total_num_steps)
            tensorboard_writer.add_scalar("median_reward",
                                          np.median(episode_rewards),
                                          total_num_steps)
            tensorboard_writer.add_scalar("min_reward",
                                          np.min(episode_rewards),
                                          total_num_steps)
            tensorboard_writer.add_scalar("max_reward",
                                          np.max(episode_rewards),
                                          total_num_steps)
            tensorboard_writer.add_scalar("dist_entropy", dist_entropy,
                                          total_num_steps)
            tensorboard_writer.add_scalar("value_loss", value_loss,
                                          total_num_steps)
            tensorboard_writer.add_scalar("action_loss", action_loss,
                                          total_num_steps)

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(
                actor_critic,
                ob_rms,
                args.env_name,
                args.seed,
                args.num_processes,
                eval_log_dir,
                device,
            )
예제 #17
0
    def __init__(self,
                 env_def,
                 processes=1,
                 dir='.',
                 version=0,
                 lr=2e-4,
                 architecture='base',
                 dropout=0,
                 reconstruct=None,
                 r_weight=.05):
        self.env_def = env_def
        self.num_processes = processes  #cpu processes
        self.lr = lr
        self.version = version
        self.save_dir = dir + '/trained_models/'

        #Setup
        pathlib.Path(self.save_dir).mkdir(parents=True, exist_ok=True)
        if (self.num_mini_batch > processes):
            self.num_mini_batch = processes

        self.writer = SummaryWriter()
        self.total_steps = 0

        #State
        torch.manual_seed(self.seed)
        torch.cuda.manual_seed_all(self.seed)

        if not self.no_cuda and torch.cuda.is_available(
        ) and self.cuda_deterministic:
            torch.backends.cudnn.benchmark = False
            torch.backends.cudnn.deterministic = True

        utils.cleanup_log_dir(self.log_dir)
        utils.cleanup_log_dir(self.eval_log_dir)

        torch.set_num_threads(1)

        self.level_path = None
        self.envs = None
        self.num_envs = -1
        self.set_envs(num_envs=1)

        if (version > 0):
            self.actor_critic = self.load(path, version)
        else:
            self.actor_critic = Policy(
                self.envs.observation_space.shape,
                self.envs.action_space,
                base_kwargs={
                    'recurrent': self.recurrent_policy,
                    'shapes': list(reversed(self.env_def.model_shape)),
                    'dropout': dropout
                },
                model=architecture)
        self.actor_critic.to(self.device)

        #Reconstruction
        self.reconstruct = reconstruct is not None
        if (self.reconstruct):
            #layers = self.envs.observation_space.shape[0]
            #shapes = list(self.env_def.model_shape)
            #self.r_model = Decoder(layers, shapes=shapes).to(self.device)
            reconstruct.to(self.device)
            self.r_model = lambda x: reconstruct.adapter(reconstruct(x))
            #self.r_model = lambda x: reconstruct.adapter(reconstruct(x)).clamp(min=1e-6).log()
            #self.r_loss = nn.L1Loss() #nn.NLLLoss() #nn.MSELoss()
            self.r_loss = lambda pred, true: -r_weight * (true * torch.log(
                pred.clamp(min=1e-7, max=1 - 1e-7))).sum(dim=1).mean()
            self.r_optimizer = reconstruct.optimizer  #optim.Adam(reconstruct.parameters(), lr = .0001)

        if self.algo == 'a2c':
            self.agent = A2C_ACKTR(self.actor_critic,
                                   self.value_loss_coef,
                                   self.entropy_coef,
                                   lr=self.lr,
                                   eps=self.eps,
                                   alpha=self.alpha,
                                   max_grad_norm=self.max_grad_norm)
        elif self.algo == 'ppo':
            self.agent = PPO(self.actor_critic,
                             self.clip_param,
                             self.ppo_epoch,
                             self.num_mini_batch,
                             self.value_loss_coef,
                             self.entropy_coef,
                             lr=self.lr,
                             eps=self.eps,
                             max_grad_norm=self.max_grad_norm,
                             use_clipped_value_loss=False)
        elif self.algo == 'acktr':
            self.agent = algo.A2C_ACKTR(self.actor_critic,
                                        self.value_loss_coef,
                                        self.entropy_coef,
                                        acktr=True)

        self.gail = False
        self.gail_experts_dir = './gail_experts'
        if self.gail:
            assert len(self.envs.observation_space.shape) == 1
            self.gail_discr = gail.Discriminator(
                self.envs.observation_space.shape[0] +
                self.envs.action_space.shape[0], 100, self.device)
            file_name = os.path.join(
                self.gail_experts_dir,
                "trajs_{}.pt".format(env_name.split('-')[0].lower()))

            self.gail_train_loader = torch.utils.data.DataLoader(
                gail.ExpertDataset(file_name,
                                   num_trajectories=4,
                                   subsample_frequency=20),
                batch_size=self.gail_batch_size,
                shuffle=True,
                drop_last=True)

        self.rollouts = RolloutStorage(
            self.num_steps, self.num_processes,
            self.envs.observation_space.shape, self.envs.action_space,
            self.actor_critic.recurrent_hidden_state_size)
예제 #18
0
    def run(self):
        args = self.args
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        print("CUDA is available: ", torch.cuda.is_available())
        if args.cuda:
            print("CUDA enabled")
            torch.backends.cudnn.benchmark = False
            torch.backends.cudnn.deterministic = True
        else:
            if args.cuda_deterministic:
                print("Warning CUDA is requested but is not available")
            else:
                print("CUDA disabled")

        log_dir = os.path.expanduser(args.log_dir)
        eval_log_dir = log_dir + "_eval"
        utils.cleanup_log_dir(log_dir)
        utils.cleanup_log_dir(eval_log_dir)
        print("get_num_thread", torch.get_num_threads())

        device = torch.device("cuda:0" if args.cuda else "cpu")

        envs = make_vec_envs(args.env_name, self.config_parameters, args.seed,
                             args.num_processes, args.gamma, args.log_dir,
                             device, False)

        actor_critic = create_IAM_model(envs, args, self.config_parameters)
        actor_critic.to(device)

        if args.algo == 'a2c':
            agent = algo.A2C_ACKTR(actor_critic,
                                   args.value_loss_coef,
                                   args.entropy_coef,
                                   lr=args.lr,
                                   eps=args.eps,
                                   alpha=args.alpha,
                                   max_grad_norm=args.max_grad_norm)
        # This algorithm should be used for the reproduction project.
        elif args.algo == 'ppo':
            agent = algo.PPO(actor_critic,
                             args.clip_param,
                             args.ppo_epoch,
                             args.num_mini_batch,
                             args.value_loss_coef,
                             args.entropy_coef,
                             lr=args.lr,
                             eps=args.eps,
                             max_grad_norm=args.max_grad_norm)
        elif args.algo == 'acktr':
            agent = algo.A2C_ACKTR(actor_critic,
                                   args.value_loss_coef,
                                   args.entropy_coef,
                                   acktr=True)

        if args.gail:
            assert len(envs.observation_space.shape) == 1
            discr = gail.Discriminator(
                envs.observation_space.shape[0] + envs.action_space.shape[0],
                100, device)
            file_name = os.path.join(
                args.gail_experts_dir,
                "trajs_{}.pt".format(args.env_name.split('-')[0].lower()))

            expert_dataset = gail.ExpertDataset(file_name,
                                                num_trajectories=4,
                                                subsample_frequency=20)
            drop_last = len(expert_dataset) > args.gail_batch_size
            gail_train_loader = torch.utils.data.DataLoader(
                dataset=expert_dataset,
                batch_size=args.gail_batch_size,
                shuffle=True,
                drop_last=drop_last)

        rollouts = RolloutStorage(args.num_steps, args.num_processes,
                                  envs.observation_space.shape,
                                  envs.action_space,
                                  actor_critic.recurrent_hidden_state_size)

        obs = envs.reset()
        rollouts.obs[0].copy_(obs)
        rollouts.to(device)
        # Always return the average of the last 100 steps. This means the average is sampled.
        episode_rewards = deque(maxlen=100)

        start = time.time()
        num_updates = int(
            args.num_env_steps) // args.num_steps // args.num_processes
        for j in range(num_updates):

            if args.use_linear_lr_decay:
                # decrease learning rate linearly
                utils.update_linear_schedule(
                    agent.optimizer, j, num_updates,
                    agent.optimizer.lr if args.algo == "acktr" else args.lr)

            for step in range(args.num_steps):
                # Sample actions
                with torch.no_grad():
                    value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                        rollouts.obs[step],
                        rollouts.recurrent_hidden_states[step],
                        rollouts.masks[step])

                # Obser reward and next obs
                obs, reward, done, infos = envs.step(action)

                for info in infos:
                    if 'episode' in info.keys():
                        episode_rewards.append(info['episode']['r'])

                # If done then clean the history of observations.
                masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                           for done_ in done])
                bad_masks = torch.FloatTensor(
                    [[0.0] if 'bad_transition' in info.keys() else [1.0]
                     for info in infos])
                rollouts.insert(obs, recurrent_hidden_states, action,
                                action_log_prob, value, reward, masks,
                                bad_masks)

            with torch.no_grad():
                next_value = actor_critic.get_value(
                    rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                    rollouts.masks[-1]).detach()

            if args.gail:
                if j >= 10:
                    envs.venv.eval()

                gail_epoch = args.gail_epoch
                if j < 10:
                    gail_epoch = 100  # Warm up
                for _ in range(gail_epoch):
                    discr.update(gail_train_loader, rollouts,
                                 utils.get_vec_normalize(envs)._obfilt)

                for step in range(args.num_steps):
                    rollouts.rewards[step] = discr.predict_reward(
                        rollouts.obs[step], rollouts.actions[step], args.gamma,
                        rollouts.masks[step])

            rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                     args.gae_lambda,
                                     args.use_proper_time_limits)

            value_loss, action_loss, dist_entropy = agent.update(rollouts)

            rollouts.after_update()

            # save for every interval-th episode or for the last epoch
            if (j % args.save_interval == 0
                    or j == num_updates - 1) and args.save_dir != "":
                save_path = os.path.join(args.save_dir, args.algo)
                try:
                    os.makedirs(save_path)
                except OSError:
                    pass

                torch.save([
                    actor_critic,
                    getattr(utils.get_vec_normalize(envs), 'obs_rms', None)
                ], os.path.join(save_path, self.model_file_name))

            if j % args.log_interval == 0 and len(episode_rewards) > 1:
                total_num_steps = (j + 1) * args.num_processes * args.num_steps
                end = time.time()
                elapsed_time = end - start
                data = [
                    j,  # Updates
                    total_num_steps,  # timesteps
                    int(total_num_steps / elapsed_time),  # FPS
                    len(episode_rewards),  # Only useful for print statement
                    np.mean(episode_rewards),  # mean of rewards
                    np.median(episode_rewards),  # median of rewards
                    np.min(episode_rewards),  # min rewards
                    np.max(episode_rewards),  # max rewards
                    dist_entropy,
                    value_loss,
                    action_loss,
                    elapsed_time
                ]
                output = ''.join([str(x) + ',' for x in data])
                self.data_saver.append(output)
                print(
                    f"Updates {data[0]}, num timesteps {data[1]}, FPS {data[2]}, elapsed time {int(data[11])} sec. Last {data[3]} training episodes: mean/median reward {data[4]:.2f}/{data[5]:.2f}, min/max reward {data[6]:.1f}/{data[7]:.1f}",
                    end="\r")

            if (args.eval_interval is not None and len(episode_rewards) > 1
                    and j % args.eval_interval == 0):
                obs_rms = utils.get_vec_normalize(envs).obs_rms
                evaluate(actor_critic, obs_rms, args.env_name, args.seed,
                         args.num_processes, eval_log_dir, device)
예제 #19
0
                break

    wandb.config.update(args)
else:
    wandb = None

torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

log_dir = os.path.expanduser(args.log_dir)
eval_log_dir = log_dir + "_eval"
utils.cleanup_log_dir(log_dir)
utils.cleanup_log_dir(eval_log_dir)

torch.set_num_threads(1)
device = torch.device("cuda:0" if args.cuda else "cpu")

envs = make_vec_envs(
    env_name=args.env_name,
    seed=args.seed,
    num_processes=args.num_processes,
    gamma=args.gamma,
    log_dir=args.log_dir,
    device=device,
    allow_early_resets=False,
    custom_gym=args.custom_gym,
    navi=args.navi,
def main(base=IAMBase, num_frame_stack=None):

    seed = 1
    env_name = "Warehouse-v0"
    num_processes = 32
    log_dir = './logs/'
    eval_interval = None
    log_interval = 10
    use_linear_lr_decay = False
    use_proper_time_limits = False
    save_dir = './trained_models/'
    use_cuda = True

    # PPO
    gamma = 0.99  # reward discount factor
    clip_param = 0.1  #0.2
    ppo_epoch = 3  #4
    num_mini_batch = 32
    value_loss_coef = 1  #0.5
    entropy_coef = 0.01
    lr = 2.5e-4  #7e-4
    eps = 1e-5
    max_grad_norm = float('inf')
    use_gae = True
    gae_lambda = 0.95
    num_steps = 8  #5

    # Store
    num_env_steps = 4e6
    save_interval = 100

    # IAM
    dset = [
        0, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65,
        66, 67, 68, 69, 70, 71, 72
    ]

    #gym.envs.register(env_name, entry_point="environments.warehouse.warehouse:Warehouse",
    #                    kwargs={'seed': seed, 'parameters': {"num_frames": 1}})

    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    log_dir = os.path.expanduser(log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if use_cuda else "cpu")

    envs = make_vec_envs(env_name,
                         seed,
                         num_processes,
                         gamma,
                         log_dir,
                         device,
                         False,
                         num_frame_stack=num_frame_stack)

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base=base,
                          base_kwargs=({
                              'dset': dset
                          } if base == IAMBase else {}))
    actor_critic.to(device)

    agent = algo.PPO(actor_critic,
                     clip_param,
                     ppo_epoch,
                     num_mini_batch,
                     value_loss_coef,
                     entropy_coef,
                     lr=lr,
                     eps=eps,
                     max_grad_norm=max_grad_norm)

    rollouts = RolloutStorage(num_steps, num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(num_env_steps) // num_steps // num_processes
    for j in range(num_updates):

        if use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(agent.optimizer, j, num_updates, lr)

        for step in range(num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, use_gae, gamma, gae_lambda,
                                 use_proper_time_limits)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % save_interval == 0 or j == num_updates - 1) and save_dir != "":
            save_path = os.path.join(save_dir, 'PPO')
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'obs_rms', None)
            ], os.path.join(save_path, env_name + ".pt"))

        if j % log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * num_processes * num_steps
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))

        if (eval_interval is not None and len(episode_rewards) > 1
                and j % eval_interval == 0):
            obs_rms = utils.get_vec_normalize(envs).obs_rms
            evaluate(actor_critic, obs_rms, env_name, seed, num_processes,
                     eval_log_dir, device)
예제 #21
0
def main():
    args = get_args()

    # Record trajectories
    if args.record_trajectories:
        record_trajectories()
        return

    print(args)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    # Append the model name
    log_dir = os.path.expanduser(args.log_dir)
    log_dir = os.path.join(log_dir, args.model_name, str(args.seed))

    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, log_dir, device, False)

    # Take activation for carracing
    print("Loaded env...")
    activation = None
    if args.env_name == 'CarRacing-v0' and args.use_activation:
        activation = torch.tanh
    print(activation)

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={
                              'recurrent': args.recurrent_policy,
                              'env': args.env_name
                          },
                          activation=activation)
    actor_critic.to(device)
    # Load from previous model
    if args.load_model_name:
        state = torch.load(
            os.path.join(args.save_dir, args.load_model_name,
                         args.load_model_name + '_{}.pt'.format(args.seed)))[0]
        try:
            actor_critic.load_state_dict(state)
        except:
            actor_critic = state

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    if args.gail:
        if len(envs.observation_space.shape) == 1:
            discr = gail.Discriminator(
                envs.observation_space.shape[0] + envs.action_space.shape[0],
                100, device)
            file_name = os.path.join(
                args.gail_experts_dir,
                "trajs_{}.pt".format(args.env_name.split('-')[0].lower()))

            expert_dataset = gail.ExpertDataset(file_name,
                                                num_trajectories=3,
                                                subsample_frequency=1)
            expert_dataset_test = gail.ExpertDataset(file_name,
                                                     num_trajectories=1,
                                                     start=3,
                                                     subsample_frequency=1)
            drop_last = len(expert_dataset) > args.gail_batch_size
            gail_train_loader = torch.utils.data.DataLoader(
                dataset=expert_dataset,
                batch_size=args.gail_batch_size,
                shuffle=True,
                drop_last=drop_last)
            gail_test_loader = torch.utils.data.DataLoader(
                dataset=expert_dataset_test,
                batch_size=args.gail_batch_size,
                shuffle=False,
                drop_last=False)
            print(len(expert_dataset), len(expert_dataset_test))
        else:
            # env observation shape is 3 => its an image
            assert len(envs.observation_space.shape) == 3
            discr = gail.CNNDiscriminator(envs.observation_space.shape,
                                          envs.action_space, 100, device)
            file_name = os.path.join(args.gail_experts_dir, 'expert_data.pkl')

            expert_dataset = gail.ExpertImageDataset(file_name, train=True)
            test_dataset = gail.ExpertImageDataset(file_name, train=False)
            gail_train_loader = torch.utils.data.DataLoader(
                dataset=expert_dataset,
                batch_size=args.gail_batch_size,
                shuffle=True,
                drop_last=len(expert_dataset) > args.gail_batch_size,
            )
            gail_test_loader = torch.utils.data.DataLoader(
                dataset=test_dataset,
                batch_size=args.gail_batch_size,
                shuffle=False,
                drop_last=len(test_dataset) > args.gail_batch_size,
            )
            print('Dataloader size', len(gail_train_loader))

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)
    start = time.time()
    #num_updates = int(
    #args.num_env_steps) // args.num_steps // args.num_processes
    num_updates = args.num_steps
    print(num_updates)

    # count the number of times validation loss increases
    val_loss_increase = 0
    prev_val_action = np.inf
    best_val_loss = np.inf

    for j in range(num_updates):
        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr)

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Observe reward and next obs
            obs, reward, done, infos = envs.step(action)
            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        if args.gail:
            if j >= 10:
                try:
                    envs.venv.eval()
                except:
                    pass

            gail_epoch = args.gail_epoch
            #if j < 10:
            #gail_epoch = 100  # Warm up
            for _ in range(gail_epoch):
                #discr.update(gail_train_loader, rollouts,
                #None)
                pass

            for step in range(args.num_steps):
                rollouts.rewards[step] = discr.predict_reward(
                    rollouts.obs[step], rollouts.actions[step], args.gamma,
                    rollouts.masks[step])

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        #value_loss, action_loss, dist_entropy = agent.update(rollouts)
        value_loss = 0
        dist_entropy = 0
        for data in gail_train_loader:
            expert_states, expert_actions = data
            expert_states = Variable(expert_states).to(device)
            expert_actions = Variable(expert_actions).to(device)
            loss = agent.update_bc(expert_states, expert_actions)
            action_loss = loss.data.cpu().numpy()
        print("Epoch: {}, Loss: {}".format(j, action_loss))

        with torch.no_grad():
            cnt = 0
            val_action_loss = 0
            for data in gail_test_loader:
                expert_states, expert_actions = data
                expert_states = Variable(expert_states).to(device)
                expert_actions = Variable(expert_actions).to(device)
                loss = agent.get_action_loss(expert_states, expert_actions)
                val_action_loss += loss.data.cpu().numpy()
                cnt += 1
            val_action_loss /= cnt
            print("Val Loss: {}".format(val_action_loss))

        #rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":

            if val_action_loss < best_val_loss:
                val_loss_increase = 0
                best_val_loss = val_action_loss
                save_path = os.path.join(args.save_dir, args.model_name)
                try:
                    os.makedirs(save_path)
                except OSError:
                    pass

                torch.save([
                    actor_critic.state_dict(),
                    getattr(utils.get_vec_normalize(envs), 'ob_rms', None),
                    getattr(utils.get_vec_normalize(envs), 'ret_rms', None)
                ],
                           os.path.join(
                               save_path,
                               args.model_name + "_{}.pt".format(args.seed)))
            elif val_action_loss > prev_val_action:
                val_loss_increase += 1
                if val_loss_increase == 10:
                    print("Val loss increasing too much, breaking here...")
                    break
            elif val_action_loss < prev_val_action:
                val_loss_increase = 0

            # Update prev val action
            prev_val_action = val_action_loss

        # log interval
        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(actor_critic, ob_rms, args.env_name, args.seed,
                     args.num_processes, eval_log_dir, device)
예제 #22
0
def main():
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:" + str(args.cuda_id) if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, device, False)

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    if args.gail:
        assert len(envs.observation_space.shape) == 1
        discr = gail.Discriminator(
            envs.observation_space.shape[0] + envs.action_space.shape[0], 100,
            device)
        file_name = os.path.join(
            args.gail_experts_dir,
            "trajs_{}.pt".format(args.env_name.split('-')[0].lower()))

        expert_dataset = gail.ExpertDataset(file_name,
                                            num_trajectories=4,
                                            subsample_frequency=20)
        drop_last = len(expert_dataset) > args.gail_batch_size
        gail_train_loader = torch.utils.data.DataLoader(
            dataset=expert_dataset,
            batch_size=args.gail_batch_size,
            shuffle=True,
            drop_last=drop_last)

    ########## file related
    filename = args.env_name + "_" + args.algo + "_n" + str(args.max_episodes)
    if args.attack:
        filename += "_" + args.type + "_" + args.aim
        filename += "_s" + str(args.stepsize) + "_m" + str(
            args.maxiter) + "_r" + str(args.radius) + "_f" + str(args.frac)
    if args.run >= 0:
        filename += "_run" + str(args.run)

    logger = get_log(args.logdir + filename + "_" + current_time)
    logger.info(args)

    rew_file = open(args.resdir + filename + ".txt", "w")

    if args.compute:
        radius_file = open(
            args.resdir + filename + "_radius" + "_s" + str(args.stepsize) +
            "_m" + str(args.maxiter) + "_th" + str(args.dist_thres) + ".txt",
            "w")
    if args.type == "targ" or args.type == "fgsm":
        targ_file = open(args.resdir + filename + "_targ.txt", "w")

    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes

    if args.type == "wb":
        attack_net = WbAttacker(agent,
                                envs,
                                int(args.frac * num_updates),
                                num_updates,
                                args,
                                device=device)
    if args.type == "bb":
        attack_net = BbAttacker(agent,
                                envs,
                                int(args.frac * num_updates),
                                num_updates,
                                args,
                                device=device)
    elif args.type == "rand":
        attack_net = RandAttacker(envs,
                                  radius=args.radius,
                                  frac=args.frac,
                                  maxat=int(args.frac * num_updates),
                                  device=device)
    elif args.type == "semirand":
        attack_net = WbAttacker(agent,
                                envs,
                                int(args.frac * num_updates),
                                num_updates,
                                args,
                                device,
                                rand_select=True)
    elif args.type == "targ":
        if isinstance(envs.action_space, Discrete):
            action_dim = envs.action_space.n
            target_policy = action_dim - 1
        elif isinstance(envs.action_space, Box):
            action_dim = envs.action_space.shape[0]
            target_policy = torch.zeros(action_dim)
#            target_policy[-1] = 1
        print("target policy is", target_policy)
        attack_net = TargAttacker(agent,
                                  envs,
                                  int(args.frac * num_updates),
                                  num_updates,
                                  target_policy,
                                  args,
                                  device=device)
    elif args.type == "fgsm":
        if isinstance(envs.action_space, Discrete):
            action_dim = envs.action_space.n
            target_policy = action_dim - 1
        elif isinstance(envs.action_space, Box):
            action_dim = envs.action_space.shape[0]
            target_policy = torch.zeros(action_dim)

        def targ_policy(obs):
            return target_policy

        attack_net = FGSMAttacker(envs,
                                  agent,
                                  targ_policy,
                                  radius=args.radius,
                                  frac=args.frac,
                                  maxat=int(args.frac * num_updates),
                                  device=device)
#    if args.aim == "obs" or aim == "hybrid":
#        obs_space = gym.make(args.env_name).observation_space
#        attack_net.set_obs_range(obs_space.low, obs_space.high)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)
    episode = 0

    start = time.time()

    for j in range(num_updates):

        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr)

        for step in range(args.num_steps):
            # Sample actions
            if args.type == "fgsm":
                #                print("before", rollouts.obs[step])
                rollouts.obs[step] = attack_net.attack(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step]).clone()
#                print("after", rollouts.obs[step])
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])
            if args.type == "targ" or args.type == "fgsm":
                if isinstance(envs.action_space, Discrete):
                    num_target = (
                        action == target_policy).nonzero()[:, 0].size()[0]
                    targ_file.write(
                        str(num_target / args.num_processes) + "\n")
                    print("percentage of target:",
                          num_target / args.num_processes)
                elif isinstance(envs.action_space, Box):
                    target_action = target_policy.repeat(action.size()[0], 1)
                    targ_file.write(
                        str(
                            torch.norm(action - target_action).item() /
                            args.num_processes) + "\n")
                    print("percentage of target:",
                          torch.sum(action).item() / args.num_processes)
            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action.cpu())
            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])
                    #                    rew_file.write("episode: {}, total reward: {}\n".format(episode, info['episode']['r']))
                    episode += 1

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        if args.gail:
            if j >= 10:
                envs.venv.eval()

            gail_epoch = args.gail_epoch
            if j < 10:
                gail_epoch = 100  # Warm up
            for _ in range(gail_epoch):
                discr.update(gail_train_loader, rollouts,
                             utils.get_vec_normalize(envs)._obfilt)

            for step in range(args.num_steps):
                rollouts.rewards[step] = discr.predict_reward(
                    rollouts.obs[step], rollouts.actions[step], args.gamma,
                    rollouts.masks[step])

        if args.attack and args.type != "fgsm":
            if args.aim == "reward":
                logger.info(rollouts.rewards.flatten())
                rollouts.rewards = attack_net.attack_r_general(
                    rollouts, next_value).clone().detach()
                logger.info("after attack")
                logger.info(rollouts.rewards.flatten())
            elif args.aim == "obs":
                origin = rollouts.obs.clone()
                rollouts.obs = attack_net.attack_s_general(
                    rollouts, next_value).clone().detach()
                logger.info(origin)
                logger.info("after")
                logger.info(rollouts.obs)
            elif args.aim == "action":
                origin = torch.flatten(rollouts.actions).clone()
                rollouts.actions = attack_net.attack_a_general(
                    rollouts, next_value).clone().detach()
                logger.info("attack value")
                logger.info(torch.flatten(rollouts.actions) - origin)
            elif args.aim == "hybrid":
                res_aim, attack = attack_net.attack_hybrid(
                    rollouts, next_value, args.radius_s, args.radius_a,
                    args.radius_r)
                print("attack ", res_aim)
                if res_aim == "obs":
                    origin = rollouts.obs.clone()
                    rollouts.obs = attack.clone().detach()
                    logger.info(origin)
                    logger.info("attack obs")
                    logger.info(rollouts.obs)
                elif res_aim == "action":
                    origin = torch.flatten(rollouts.actions).clone()
                    rollouts.actions = attack.clone().detach()
                    logger.info("attack action")
                    logger.info(torch.flatten(rollouts.actions) - origin)
                elif res_aim == "reward":
                    logger.info(rollouts.rewards.flatten())
                    rollouts.rewards = attack.clone().detach()
                    logger.info("attack reward")
                    logger.info(rollouts.rewards.flatten())
        if args.compute:
            stable_radius = attack_net.compute_radius(rollouts, next_value)
            print("stable radius:", stable_radius)
            radius_file.write("update: {}, radius: {}\n".format(
                j, np.round(stable_radius, decimals=3)))
        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        if args.attack and args.type == "bb":
            attack_net.learning(rollouts)
        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], os.path.join(save_path, args.env_name + ".pt"))

        if j % args.log_interval == 0 and len(episode_rewards) >= 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))
            rew_file.write("updates: {}, mean reward: {}\n".format(
                j, np.mean(episode_rewards)))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(actor_critic, ob_rms, args.env_name, args.seed,
                     args.num_processes, eval_log_dir, device)


#        if episode > args.max_episodes:
#            print("reach episodes limit")
#            break

    if args.attack:
        logger.info("total attacks: {}\n".format(attack_net.attack_num))
        print("total attacks: {}\n".format(attack_net.attack_num))

    rew_file.close()
    if args.compute:
        radius_file.close()
    if args.type == "targ" or args.type == "fgsm":
        targ_file.close()
예제 #23
0
def main():
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir + args.env_name)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    log_dir2 = os.path.expanduser(args.log_dir2 + args.env_name2)
    eval_log_dir2 = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir2)
    utils.cleanup_log_dir(eval_log_dir2)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    import json
    file_path = "config.json"
    setup_json = json.load(open(file_path, 'r'))
    env_conf = setup_json["Default"]
    for i in setup_json.keys():
        if i in args.env_name:
            env_conf = setup_json[i]


# 1 game
    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, device, env_conf, False)
    # 2 game
    envs2 = make_vec_envs(args.env_name2, args.seed, args.num_processes,
                          args.gamma, args.log_dir2, device, env_conf, False)

    save_model, ob_rms = torch.load('./trained_models/PongNoFrameskip-v4.pt')

    from a2c_ppo_acktr.cnn import CNNBase

    a = CNNBase(envs.observation_space.shape[0], recurrent=False)

    actor_critic = Policy(
        envs.observation_space.shape,
        envs.action_space,
        #(obs_shape[0], ** base_kwargs)
        base=a,
        #base_kwargs={'recurrent': args.recurrent_policy}
    )
    #actor_critic.load_state_dict(save_model.state_dict())
    actor_critic.to(device)

    actor_critic2 = Policy(envs2.observation_space.shape,
                           envs2.action_space,
                           base=a)
    #base_kwargs={'recurrent': args.recurrent_policy})
    #actor_critic2.load_state_dict(save_model.state_dict())
    actor_critic2.to(device)

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               actor_critic2,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    rollouts2 = RolloutStorage(args.num_steps, args.num_processes,
                               envs2.observation_space.shape,
                               envs2.action_space,
                               actor_critic2.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    obs2 = envs2.reset()
    rollouts2.obs[0].copy_(obs2)
    rollouts2.to(device)

    episode_rewards = deque(maxlen=10)
    episode_rewards2 = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes

    for j in range(num_updates):
        # if args.use_linear_lr_decay:
        #     # decrease learning rate linearly
        #     utils.update_linear_schedule(
        #         agent.optimizer, j, num_updates,
        #         agent.optimizer.lr if args.algo == "acktr" else args.lr)

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states, _ = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])
                value2, action2, action_log_prob2, recurrent_hidden_states2, _ = actor_critic2.act(
                    rollouts2.obs[step],
                    rollouts2.recurrent_hidden_states[step],
                    rollouts2.masks[step])

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)
            obs2, reward2, done2, infos2 = envs2.step(action2)

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])
            for info2 in infos2:
                if 'episode' in info2.keys():
                    episode_rewards2.append(info2['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])

            masks2 = torch.FloatTensor([[0.0] if done_ else [1.0]
                                        for done_ in done2])
            bad_masks2 = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info2.keys() else [1.0]
                 for info2 in infos2])

            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)
            rollouts2.insert(obs2, recurrent_hidden_states2, action2,
                             action_log_prob2, value2, reward2, masks2,
                             bad_masks2)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()
            next_value2 = actor_critic2.get_value(
                rollouts2.obs[-1], rollouts2.recurrent_hidden_states[-1],
                rollouts2.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        rollouts2.compute_returns(next_value2, args.use_gae, args.gamma,
                                  args.gae_lambda, args.use_proper_time_limits)

        value_loss, action_loss, dist_entropy, value_loss2, action_loss2, dist_entropy2 = agent.update(
            rollouts, rollouts2)

        rollouts.after_update()
        rollouts2.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], os.path.join(save_path, args.env_name + ".pt"))
            torch.save([
                actor_critic2,
                getattr(utils.get_vec_normalize(envs2), 'ob_rms2', None)
            ], os.path.join(save_path, args.env_name2 + ".pt"))

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards2), np.mean(episode_rewards2),
                        np.median(episode_rewards2), np.min(episode_rewards2),
                        np.max(episode_rewards2), dist_entropy2, value_loss2,
                        action_loss2))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(actor_critic, ob_rms, args.env_name, args.seed,
                     args.num_processes, eval_log_dir, device)

            ob_rms2 = utils.get_vec_normalize(envs2).ob_rms
            evaluate(actor_critic2, ob_rms2, args.env_name2, args.seed,
                     args.num_processes, eval_log_dir2, device)
예제 #24
0
def main():
    args = get_args()
    trace_size = args.trace_size
    toke = tokenizer(args)

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, device, False)

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    tobs = torch.zeros((args.num_processes, trace_size), dtype=torch.long)
    #print (tobs.dtype)
    rollouts.obs[0].copy_(obs)
    rollouts.tobs[0].copy_(tobs)

    rollouts.to(device)

    episode_rewards = deque(maxlen=args.num_processes)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes
    save_path = os.path.join(args.save_dir, args.algo)

    if args.load:
        actor_critic.load_state_dict = (os.path.join(save_path,
                                                     args.env_name + ".pt"))
    for j in range(num_updates):

        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr)

        for step in range(args.num_steps):

            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.tobs[step],
                    rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)
            tobs = []
            envs.render()
            for info in infos:
                if 'episode' in info.keys():
                    #print ("episode ", info['episode'])
                    episode_rewards.append(info['episode']['r'])
                trace = info['trace'][0:trace_size]
                trace = [x[2] for x in trace]
                word_to_ix = toke.tokenize(trace)
                seq = prepare_sequence(trace, word_to_ix)
                if len(seq) < trace_size:
                    seq = torch.zeros((trace_size), dtype=torch.long)
                seq = seq[:trace_size]
                #print (seq.dtype)
                tobs.append(seq)
            tobs = torch.stack(tobs)
            #print (tobs)
            #print (tobs.size())
            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, tobs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.tobs[-1],
                rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        #"""
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], os.path.join(save_path, args.env_name + ".pt"))
            pickle.dump(toke.word_to_ix, open("save.p", "wb"))

        #"""

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))
            writer.add_scalar(
                'mean reward',
                np.mean(episode_rewards),
                total_num_steps,
            )
            writer.add_scalar(
                'median reward',
                np.median(episode_rewards),
                total_num_steps,
            )
            writer.add_scalar(
                'max reward',
                np.max(episode_rewards),
                total_num_steps,
            )

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(actor_critic, ob_rms, args.env_name, args.seed,
                     args.num_processes, eval_log_dir, device)
예제 #25
0
def main():
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    #envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
    #                    args.gamma, args.log_dir, device, False)

    envs = make_parallel_env(args.env_name, args.num_processes, args.seed, True)

    '''
    actor_critic = Policy(
        envs.observation_space[0].shape,
        envs.action_space[0],
        agent_num=args.agent_num, 
        base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)
    '''
    actor_critic = []
    for i in range(args.agent_num):
        ac = Policy(
            envs.observation_space[0].shape,
            envs.action_space[0],
            agent_num=args.agent_num, 
            agent_i = i,
            base_kwargs={'recurrent': args.recurrent_policy})
        ac.to(device)
        actor_critic.append(ac)

    
    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(
            actor_critic,
            args.value_loss_coef,
            args.entropy_coef,
            lr=args.lr,
            eps=args.eps,
            alpha=args.alpha,
            max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        '''
        agent = algo.PPO(
            actor_critic,
            args.clip_param,
            args.ppo_epoch,
            args.num_mini_batch,
            args.value_loss_coef,
            args.entropy_coef,
            lr=args.lr,
            eps=args.eps,
            max_grad_norm=args.max_grad_norm)
        '''
        agent = []
        for i in range(args.agent_num):
            agent.append(algo.PPO(
                actor_critic[i],
                i,
                args.clip_param,
                args.ppo_epoch,
                args.num_mini_batch,
                args.value_loss_coef,
                args.entropy_coef,
                lr=args.lr,
                eps=args.eps,
                max_grad_norm=args.max_grad_norm,
                model_dir = args.model_dir))
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(
            actor_critic, args.value_loss_coef, args.entropy_coef, acktr=True)

    if args.gail:
        assert len(envs.observation_space.shape) == 1
        discr = gail.Discriminator(
            envs.observation_space.shape[0] + envs.action_space.shape[0], 100,
            device)
        file_name = os.path.join(
            args.gail_experts_dir, "trajs_{}.pt".format(
                args.env_name.split('-')[0].lower()))
        
        expert_dataset = gail.ExpertDataset(
            file_name, num_trajectories=4, subsample_frequency=20)
        drop_last = len(expert_dataset) > args.gail_batch_size
        gail_train_loader = torch.utils.data.DataLoader(
            dataset=expert_dataset,
            batch_size=args.gail_batch_size,
            shuffle=True,
            drop_last=drop_last)
    '''   
    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space[0].shape, envs.action_space[0],
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(torch.tensor(obs[:,0,:]))
    rollouts.to(device)
    '''

    rollouts = []
    for i in range(args.agent_num):
        rollout = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space[0].shape, envs.action_space[0],
                              actor_critic[i].recurrent_hidden_state_size,
                              args.agent_num, i)
        rollouts.append(rollout)

    obs = envs.reset()
    # pdb.set_trace()
    
    for i in range(args.agent_num):
        rollouts[i].share_obs[0].copy_(torch.tensor(obs.reshape(args.num_processes, -1)))
        rollouts[i].obs[0].copy_(torch.tensor(obs[:,i,:]))
        rollouts[i].to(device)
        
    episode_rewards = deque(maxlen=10)

    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes
    print(num_updates)
    for j in range(num_updates):
        #pdb.set_trace()
        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            for i in range(args.agent_num):
                utils.update_linear_schedule(agent[i].optimizer, j, num_updates, agent[i].optimizer.lr if args.algo == "acktr" else args.lr)

        for step in range(args.num_steps):
            # Sample actions
            value_list, action_list, action_log_prob_list, recurrent_hidden_states_list = [], [], [], []
            with torch.no_grad():
                for i in range(args.agent_num):
                    #pdb.set_trace()
                    value, action, action_log_prob, recurrent_hidden_states = actor_critic[i].act(
                        rollouts[i].share_obs[step],
                        rollouts[i].obs[step], rollouts[i].recurrent_hidden_states[step],
                        rollouts[i].masks[step])
                    # import pdb; pdb.set_trace()
                    value_list.append(value)
                    action_list.append(action)
                    action_log_prob_list.append(action_log_prob)
                    recurrent_hidden_states_list.append(recurrent_hidden_states)
            # Obser reward and next obs
            action = []
            for i in range(args.num_processes):
                one_env_action = []
                for k in range(args.agent_num):
                    one_hot_action = np.zeros(envs.action_space[0].n)
                    one_hot_action[action_list[k][i]] = 1
                    one_env_action.append(one_hot_action)
                action.append(one_env_action)
            #start = time.time()
            #pdb.set_trace()            
            obs, reward, done, infos = envs.step(action)
            # print(obs[0][0])
            # pdb.set_trace()
            #end = time.time()
            #print("step time: ", end-start)
            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            '''
            masks = torch.FloatTensor(
                [[0.0] if done_ else [1.0] for done_ in done[0]])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos[0]])
            '''
            masks = torch.ones(args.num_processes, 1)
            bad_masks = torch.ones(args.num_processes, 1)
            '''
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)
            '''
            #import pdb; pdb.set_trace()
            for i in range(args.agent_num):
                rollouts[i].insert(torch.tensor(obs.reshape(args.num_processes, -1)), torch.tensor(obs[:,i,:]), 
                            recurrent_hidden_states, action_list[i],
                            action_log_prob_list[i], value_list[i], torch.tensor(reward[:, i].reshape(-1,1)), masks, bad_masks)
        #import pdb; pdb.set_trace()
        with torch.no_grad():
            next_value_list = []
            for i in range(args.agent_num):
                next_value = actor_critic[i].get_value(
                    rollouts[i].share_obs[-1],
                    rollouts[i].obs[-1], rollouts[i].recurrent_hidden_states[-1],
                    rollouts[i].masks[-1]).detach()
                next_value_list.append(next_value)

        if args.gail:
            if j >= 10:
                envs.venv.eval()

            gail_epoch = args.gail_epoch
            if j < 10:
                gail_epoch = 100  # Warm up
            for _ in range(gail_epoch):
                discr.update(gail_train_loader, rollouts,
                             utils.get_vec_normalize(envs)._obfilt)

            for step in range(args.num_steps):
                rollouts.rewards[step] = discr.predict_reward(
                    rollouts.obs[step], rollouts.actions[step], args.gamma,
                    rollouts.masks[step])
        for i in range(args.agent_num):
            rollouts[i].compute_returns(next_value_list[i], args.use_gae, args.gamma,
                                    args.gae_lambda, args.use_proper_time_limits)

        #import pdb; pdb.set_trace()
        for i in range(args.agent_num):
            value_loss, action_loss, dist_entropy = agent[i].update(rollouts[i])
            if (i == 0):
                print("value loss: " + str(value_loss))
        # print(value_loss)
            # pdb.set_trace()

        #rollouts.after_update()
        obs = envs.reset()
        # pdb.set_trace()
        for i in range(args.agent_num):
            rollouts[i].share_obs[0].copy_(torch.tensor(obs.reshape(args.num_processes, -1)))
            rollouts[i].obs[0].copy_(torch.tensor(obs[:,i,:]))
            rollouts[i].to(device)

        # save for every interval-th episode or for the last epoch     
        #pdb.set_trace()   
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            if not os.path.exists(save_path + args.model_dir):
                os.makedirs(save_path + args.model_dir)
            for i in range(args.agent_num):
                torch.save([
                    actor_critic[i],
                    getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
                ], save_path + args.model_dir + '/agent_%i' % (i+1) + ".pt")
        '''
        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))
        '''
        '''
예제 #26
0
def learn(env, max_timesteps, timesteps_per_batch, clip_param):
    ppo_epoch = 5
    num_step = timesteps_per_batch
    save_interval = 100
    seed = 1000
    batch_size = 64

    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    log_dir = os.path.expanduser('/tmp/gym/')
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda")

    envs = make_vec_envs(env, seed, 8, 0.95, log_dir, device, False)

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={'recurrent': False})
    actor_critic.to(device)

    agent = algo.PPO(actor_critic,
                     clip_param,
                     ppo_epoch,
                     batch_size,
                     0.5,
                     0.01,
                     lr=0.00025,
                     eps=1e-05,
                     max_grad_norm=0.5)

    rollouts = RolloutStorage(num_step, 8, envs.observation_space.shape,
                              envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(torch.tensor(obs))
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(max_timesteps) // num_step // 8
    for j in range(num_updates):

        # decrease learning rate linearly
        utils.update_linear_schedule(agent.optimizer, j, num_updates, 0.00025)

        for step in range(num_step):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, True, 0.99, 0.95, False)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % save_interval == 0
                or j == num_updates - 1) and "./trained_models/" != "":
            save_path = os.path.join("./trained_models/", 'ppo')
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], os.path.join(save_path, 'UniversalPolicy' + ".pt"))

        if j % 1 == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * 8 * num_step
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))
        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(actor_critic, ob_rms, args.env_name, args.seed,
                     args.num_processes, eval_log_dir, device)
    '''
예제 #27
0
def main():
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    if args.env_name.startswith("lab_"):
        gym_name, flow_json = make_lab_env(args.env_name)

        args.env_name = gym_name

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, device, False)

    actor_critic = Policy(
        envs.observation_space.shape,
        envs.action_space,
        base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(
            actor_critic,
            args.value_loss_coef,
            args.entropy_coef,
            lr=args.lr,
            eps=args.eps,
            alpha=args.alpha,
            max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(
            actor_critic,
            args.clip_param,
            args.ppo_epoch,
            args.num_mini_batch,
            args.value_loss_coef,
            args.entropy_coef,
            lr=args.lr,
            eps=args.eps,
            max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(
            actor_critic, args.value_loss_coef, args.entropy_coef, acktr=True)

    if args.gail:
        assert len(envs.observation_space.shape) == 1
        discr = gail.Discriminator(
            envs.observation_space.shape[0] + envs.action_space.shape[0], 100,
            device)
        file_name = os.path.join(
            args.gail_experts_dir, "trajs_{}.pt".format(
                args.env_name.split('-')[0].lower()))
        
        expert_dataset = gail.ExpertDataset(
            file_name, num_trajectories=4, subsample_frequency=20)
        drop_last = len(expert_dataset) > args.gail_batch_size
        gail_train_loader = torch.utils.data.DataLoader(
            dataset=expert_dataset,
            batch_size=args.gail_batch_size,
            shuffle=True,
            drop_last=drop_last)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes

    for j in range(num_updates):

        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr)

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor(
                [[0.0] if done_ else [1.0] for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        if args.gail:
            if j >= 10:
                envs.venv.eval()

            gail_epoch = args.gail_epoch
            if j < 10:
                gail_epoch = 100  # Warm up
            for _ in range(gail_epoch):
                discr.update(gail_train_loader, rollouts,
                             utils.get_vec_normalize(envs)._obfilt)

            for step in range(args.num_steps):
                rollouts.rewards[step] = discr.predict_reward(
                    rollouts.obs[step], rollouts.actions[step], args.gamma,
                    rollouts.masks[step])

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], os.path.join(save_path, args.env_name + ".pt"))

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: "
                "mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(actor_critic, ob_rms, args.env_name, args.seed,
                     args.num_processes, eval_log_dir, device)
예제 #28
0
def main():
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    args_dir, logs_dir, models_dir, samples_dir = get_all_save_paths(
        args, 'pretrain', combine_action=args.combine_action)
    eval_log_dir = logs_dir + "_eval"
    utils.cleanup_log_dir(logs_dir)
    utils.cleanup_log_dir(eval_log_dir)

    _, _, intrinsic_models_dir, _ = get_all_save_paths(args,
                                                       'learn_reward',
                                                       load_only=True)
    if args.load_iter != 'final':
        intrinsic_model_file_name = os.path.join(
            intrinsic_models_dir,
            args.env_name + '_{}.pt'.format(args.load_iter))
    else:
        intrinsic_model_file_name = os.path.join(
            intrinsic_models_dir, args.env_name + '.pt'.format(args.load_iter))
    intrinsic_arg_file_name = os.path.join(args_dir, 'command.txt')

    # save args to arg_file
    with open(intrinsic_arg_file_name, 'w') as f:
        json.dump(args.__dict__, f, indent=2)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, logs_dir, device, False)

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)
    else:
        raise NotImplementedError

    if args.use_intrinsic:
        obs_shape = envs.observation_space.shape
        if len(obs_shape) == 3:
            action_dim = envs.action_space.n
        elif len(obs_shape) == 1:
            action_dim = envs.action_space.shape[0]

        if 'NoFrameskip' in args.env_name:
            file_name = os.path.join(
                args.experts_dir, "trajs_ppo_{}.pt".format(
                    args.env_name.split('-')[0].replace('NoFrameskip',
                                                        '').lower()))
        else:
            file_name = os.path.join(
                args.experts_dir,
                "trajs_ppo_{}.pt".format(args.env_name.split('-')[0].lower()))

        rff = RewardForwardFilter(args.gamma)
        intrinsic_rms = RunningMeanStd(shape=())

        if args.intrinsic_module == 'icm':
            print('Loading pretrained intrinsic module: %s' %
                  intrinsic_model_file_name)
            inverse_model, forward_dynamics_model, encoder = torch.load(
                intrinsic_model_file_name)
            icm =  IntrinsicCuriosityModule(envs, device, inverse_model, forward_dynamics_model, \
                                            inverse_lr=args.intrinsic_lr, forward_lr=args.intrinsic_lr,\
                                            )

        if args.intrinsic_module == 'vae':
            print('Loading pretrained intrinsic module: %s' %
                  intrinsic_model_file_name)
            vae = torch.load(intrinsic_model_file_name)
            icm =  GenerativeIntrinsicRewardModule(envs, device, \
                                                   vae, lr=args.intrinsic_lr, \
                                                   )

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes
    for j in range(num_updates):

        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr)

        for step in range(args.num_steps):
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            obs, reward, done, infos = envs.step(action)
            next_obs = obs

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, next_obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        if args.use_intrinsic:
            for step in range(args.num_steps):
                state = rollouts.obs[step]
                action = rollouts.actions[step]
                next_state = rollouts.next_obs[step]
                if args.intrinsic_module == 'icm':
                    state = encoder(state)
                    next_state = encoder(next_state)
                with torch.no_grad():
                    rollouts.rewards[
                        step], pred_next_state = icm.calculate_intrinsic_reward(
                            state, action, next_state, args.lambda_true_action)
            if args.standardize == 'True':
                buf_rews = rollouts.rewards.cpu().numpy()
                intrinsic_rffs = np.array(
                    [rff.update(rew) for rew in buf_rews.T])
                rffs_mean, rffs_std, rffs_count = mpi_moments(
                    intrinsic_rffs.ravel())
                intrinsic_rms.update_from_moments(rffs_mean, rffs_std**2,
                                                  rffs_count)
                mean = intrinsic_rms.mean
                std = np.asarray(np.sqrt(intrinsic_rms.var))
                rollouts.rewards = rollouts.rewards / torch.from_numpy(std).to(
                    device)

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(models_dir, args.algo)
            policy_file_name = os.path.join(save_path, args.env_name + '.pt')

            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], policy_file_name)

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "{} Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(args.env_name, j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(actor_critic, ob_rms, args.env_name, args.seed,
                     args.num_processes, eval_log_dir, device)
예제 #29
0
def main():

    args = get_args()
    writer = SummaryWriter(os.path.join('logs', args.save_name), )
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(
        basic_env.BasicFlatDiscreteEnv,
        args.seed,
        args.num_processes,
        args.gamma,
        args.log_dir,
        device,
        False,
        task='lift',
        gripper_type='RobotiqThreeFingerDexterousGripper',
        robot='Panda',
        controller='JOINT_TORQUE' if args.vel else 'JOINT_POSITION',
        horizon=1000,
        reward_shaping=True)

    actor_critic = Policy(
        envs.observation_space.shape,
        envs.action_space,
        base=Surreal,
        # base=OpenAI,
        # base=MLP_ATTN,
        base_kwargs={
            'recurrent':
            args.recurrent_policy,
            # 'dims': basic_env.BasicFlatEnv().modality_dims
            'config':
            dict(act='relu' if args.relu else 'tanh', rec=args.rec, fc=args.fc)
        })
    print(actor_critic)
    actor_critic.to(device)

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    if args.gail:
        assert len(envs.observation_space.shape) == 1
        discr = gail.Discriminator(
            envs.observation_space.shape[0] + envs.action_space.shape[0], 100,
            device)
        file_name = os.path.join(
            args.gail_experts_dir,
            "trajs_{}.pt".format(args.env_name.split('-')[0].lower()))

        expert_dataset = gail.ExpertDataset(file_name,
                                            num_trajectories=4,
                                            subsample_frequency=20)
        drop_last = len(expert_dataset) > args.gail_batch_size
        gail_train_loader = torch.utils.data.DataLoader(
            dataset=expert_dataset,
            batch_size=args.gail_batch_size,
            shuffle=True,
            drop_last=drop_last)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=100)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes

    best_reward = 0
    for j in range(num_updates):

        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr)
            writer.add_scalar('lr', agent.optimizer.param_groups[0]['lr'])

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        if args.gail:
            if j >= 10:
                envs.venv.eval()

            gail_epoch = args.gail_epoch
            if j < 10:
                gail_epoch = 100  # Warm up
            for _ in range(gail_epoch):
                discr.update(gail_train_loader, rollouts,
                             utils.get_vec_normalize(envs)._obfilt)

            for step in range(args.num_steps):
                rollouts.rewards[step] = discr.predict_reward(
                    rollouts.obs[step], rollouts.actions[step], args.gamma,
                    rollouts.masks[step])

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)
        end = time.time()
        total_num_steps = (j + 1) * args.num_processes * args.num_steps
        if len(episode_rewards) > 1:
            writer.add_scalar('loss/value', value_loss, total_num_steps)
            writer.add_scalar('loss/policy', action_loss, total_num_steps)
            writer.add_scalar('experiment/num_updates', j, total_num_steps)
            writer.add_scalar('experiment/FPS',
                              int(total_num_steps / (end - start)),
                              total_num_steps)
            writer.add_scalar('experiment/EPISODE MEAN',
                              np.mean(episode_rewards), total_num_steps)
            writer.add_scalar('experiment/EPISODE MEDIAN',
                              np.median(episode_rewards), total_num_steps)
            writer.add_scalar('experiment/EPISODE MIN',
                              np.min(episode_rewards), total_num_steps)
            writer.add_scalar('experiment/EPSIDOE MAX',
                              np.max(episode_rewards), total_num_steps)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if len(episode_rewards) > 1 and args.save_dir != "":
            rew = np.mean(episode_rewards)
            if rew > best_reward:
                best_reward = rew
                print('saved with best reward', rew)

                save_path = os.path.join(args.save_dir, args.algo)
                try:
                    os.makedirs(save_path)
                except OSError:
                    pass

                torch.save([
                    actor_critic,
                    getattr(utils.get_vec_normalize(envs), 'obs_rms', None)
                ], os.path.join(save_path, args.save_name + ".pt"))

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            obs_rms = utils.get_vec_normalize(envs).obs_rms
            evaluate(actor_critic, obs_rms, args.env_name, args.seed,
                     args.num_processes, eval_log_dir, device)

        writer.close()
예제 #30
0
def main():
    args = get_args()

    if args.state:
        assert args.algo == "ppo"

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, device, False)

    if args.load_dir:

        actor_critic, obsrms = torch.load(args.load_dir)
        vec_norm = utils.get_vec_normalize(envs)
        if vec_norm is not None:
            vec_norm.train()
            vec_norm.ob_rms = obsrms
        actor_critic.base.deterministic = args.deterministic
        actor_critic.base.humanoid = args.env_name.startswith("SH")

    else:
        if args.state:
            actor_critic = StatePolicy(envs.observation_space.shape,
                                       envs.action_space,
                                       base_kwargs={
                                           'recurrent':
                                           args.recurrent_policy,
                                           'deterministic':
                                           args.deterministic,
                                           'hidden_size':
                                           args.code_size,
                                           'humanoid':
                                           args.env_name.startswith("SH")
                                       })
        else:
            actor_critic = Policy(
                envs.observation_space.shape,
                envs.action_space,
                base_kwargs={'recurrent': args.recurrent_policy},
            )
        actor_critic.to(device)

    if args.state:
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm,
                         beta=args.beta,
                         beta_end=args.beta_end,
                         state=True)
    elif args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    if args.gail:
        assert len(envs.observation_space.shape) == 1
        discr = gail.Discriminator(
            envs.observation_space.shape[0] + envs.action_space.shape[0], 100,
            device)
        file_name = os.path.join(
            args.gail_experts_dir,
            "trajs_{}.pt".format(args.env_name.split('-')[0].lower()))

        expert_dataset = gail.ExpertDataset(file_name,
                                            num_trajectories=4,
                                            subsample_frequency=20)
        drop_last = len(expert_dataset) > args.gail_batch_size
        gail_train_loader = torch.utils.data.DataLoader(
            dataset=expert_dataset,
            batch_size=args.gail_batch_size,
            shuffle=True,
            drop_last=drop_last)

    # A bunch of tensors; circular buffer

    if args.state:

        rollouts = RolloutStorage(args.num_steps,
                                  args.num_processes,
                                  envs.observation_space.shape,
                                  envs.action_space,
                                  actor_critic.recurrent_hidden_state_size,
                                  code_size=args.code_size)

        mis = []

    else:
        rollouts = RolloutStorage(args.num_steps, args.num_processes,
                                  envs.observation_space.shape,
                                  envs.action_space,
                                  actor_critic.recurrent_hidden_state_size)

    # Populate the first observation in rollouts
    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    # Rewards is a deque
    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes

    for j in range(num_updates):
        # print(j)
        agent.ratio = j / num_updates
        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr)

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                if args.state:
                    value, action, action_log_prob, recurrent_hidden_states, eps, code = actor_critic.act(
                        rollouts.obs[step],
                        rollouts.recurrent_hidden_states[step],
                        rollouts.masks[step])

                    # import ipdb; ipdb.set_trace()

                else:
                    value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                        rollouts.obs[step],
                        rollouts.recurrent_hidden_states[step],
                        rollouts.masks[step])

            # Obs reward and next obs
            obs, reward, done, infos = envs.step(action)
            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])

            if args.state:
                rollouts.insert(obs,
                                recurrent_hidden_states,
                                action,
                                action_log_prob,
                                value,
                                reward,
                                masks,
                                bad_masks,
                                eps=eps,
                                code=code)
            else:
                rollouts.insert(obs, recurrent_hidden_states, action,
                                action_log_prob, value, reward, masks,
                                bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        if args.gail:
            if j >= 10:
                envs.venv.eval()

            gail_epoch = args.gail_epoch
            if j < 10:
                gail_epoch = 100  # Warm up
            for _ in range(gail_epoch):
                discr.update(gail_train_loader, rollouts,
                             utils.get_vec_normalize(envs)._obfilt)

            for step in range(args.num_steps):
                rollouts.rewards[step] = discr.predict_reward(
                    rollouts.obs[step], rollouts.actions[step], args.gamma,
                    rollouts.masks[step])

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        value_loss, action_loss, dist_entropy, mi_loss = agent.update(rollouts)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], os.path.join(save_path, args.env_name + "_" + str(j) + ".pt"))

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], os.path.join(save_path, args.env_name + ".pt"))

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()

            if args.state:
                print("DKL loss " + str(mi_loss))
                mis.append(mi_loss)

            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms

            if args.env_name.startswith("SH"):
                masses = [
                    0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.15, 1.25, 1.35, 1.45, 1.55
                ]
                damps = [
                    0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.15, 1.25, 1.35, 1.45, 1.55
                ]
                means = np.zeros((len(masses), len(damps)))
                stds = np.zeros((len(masses), len(damps)))
                for m_i in range(len(masses)):
                    for d_i in range(len(damps)):
                        m = masses[m_i]
                        d = masses[d_i]
                        u, s = evaluate(
                            actor_critic, ob_rms,
                            'OracleSHTest' + str(m) + "_" + str(d) + '-v0',
                            args.seed, args.num_processes, eval_log_dir,
                            device)
                        means[m_i, d_i] = u
                        stds[m_i, d_i] = s
                a, _ = args.load_dir.split(".")
                a = a.split("_")[-1]
                with open("sh_means_" + str(a) + ".npz", "wb") as f:
                    np.save(f, means)
                with open("sh_stds_" + str(a) + ".npz", "wb") as f:
                    np.save(f, stds)
            elif args.env_name.startswith("Oracle"):
                fs = [
                    1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
                    18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
                    33, 34, 35, 36, 37, 38, 39, 40
                ]
                ls = [
                    0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6,
                    0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.00, 1.05, 1.10,
                    1.15, 1.20, 1.25, 1.30, 1.35, 1.40, 1.45, 1.50, 1.55, 1.60,
                    1.65, 1.70
                ]
                a, _ = args.load_dir.split(".")
                a = a.split("_")[-1]

                means = np.zeros((len(fs), len(ls)))
                stds = np.zeros((len(fs), len(ls)))
                for f_i in range(len(fs)):
                    for l_i in range(len(ls)):
                        f = fs[f_i]
                        l = ls[l_i]
                        u, s = evaluate(
                            actor_critic, ob_rms, 'OracleCartpoleTest' +
                            str(f) + "_" + str(l) + '-v0', args.seed,
                            args.num_processes, eval_log_dir, device)
                        means[f_i, l_i] = u
                        stds[f_i, l_i] = s
                with open("cp_means" + str(a) + ".npz", "wb") as f:
                    np.save(f, means)
                with open("cp_stds" + str(a) + ".npz", "wb") as f:
                    np.save(f, stds)
            elif args.env_name.startswith("HC"):
                ds = [
                    50, 100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600,
                    650, 700, 750, 800, 850, 900, 950, 1000, 1050, 1100, 1150,
                    1200, 1250, 1300, 1350, 1400, 1450, 1500, 1550, 1600, 1650,
                    1700, 1750, 1800, 1850, 1900, 1950, 2000
                ]
                us = np.zeros_like(ds)
                ss = np.zeros_like(ds)
                for i in range(len(ds)):
                    d = ds[i]
                    u, s = evaluate(actor_critic, ob_rms,
                                    "OracleHalfCheetahTest_" + str(d) + "-v0",
                                    args.seed, args.num_processes,
                                    eval_log_dir, device)
                    us[i] = u
                    ss[i] = s
                a, _ = args.load_dir.split(".")
                a = a.split("_")[-1]
                with open("hc_means" + str(a) + ".npz", "wb") as f:
                    np.save(f, us)
                with open("hc_stds" + str(a) + ".npz", "wb") as f:
                    np.save(f, ss)
            assert False, "Evaluation Ended"