Пример #1
0
def job(rank, args, device, shared_model):
    episode_rewards = deque(maxlen=10)
    envs = gym.make(args.env_name)
    envs.seed(args.seed + rank)

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)
    actor_critic.load_state_dict(shared_model.state_dict())

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)
    obs = envs.reset()
    obs = torch.from_numpy(obs)
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)
    acc_r = 0
    done = [False]

    for step in range(args.num_steps):
        if done[0]:
            episode_rewards.append(acc_r)
            obs = envs.reset()
            obs = torch.from_numpy(obs)
            rollouts.obs[0].copy_(obs)
            rollouts.to(device)
            acc_r = 0

        # Sample actions
        with torch.no_grad():
            value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                rollouts.masks[step])

        # Obser reward and next obs
        target_action = action.numpy()[0]
        obs, reward, done, infos = envs.step(target_action)

        acc_r += reward
        obs = torch.from_numpy(obs).float().to(device)
        reward = torch.from_numpy(np.array([reward])).unsqueeze(dim=1).float()
        done = [done]

        masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                   for done_ in done])
        bad_masks = torch.FloatTensor([[1.0] for done_ in done])
        rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob,
                        value, reward, masks, bad_masks)
    print(rank, np.mean(episode_rewards))
    return rollouts
Пример #2
0
def main():
    args = get_args()

    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    if config.cuda and torch.cuda.is_available() and config.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    logger, final_output_dir, tb_log_dir = create_logger(config,
                                                         args.cfg,
                                                         'train',
                                                         seed=config.seed)

    eval_log_dir = final_output_dir + "_eval"

    utils.cleanup_log_dir(final_output_dir)
    utils.cleanup_log_dir(eval_log_dir)

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    writer = SummaryWriter(tb_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:" + config.GPUS if config.cuda else "cpu")

    width = height = 84
    envs = make_vec_envs(config.env_name,
                         config.seed,
                         config.num_processes,
                         config.gamma,
                         final_output_dir,
                         device,
                         False,
                         width=width,
                         height=height,
                         ram_wrapper=False)
    # create agent
    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={
                              'recurrent':
                              config.recurrent_policy,
                              'hidden_size':
                              config.hidden_size,
                              'feat_from_selfsup_attention':
                              config.feat_from_selfsup_attention,
                              'feat_add_selfsup_attention':
                              config.feat_add_selfsup_attention,
                              'feat_mul_selfsup_attention_mask':
                              config.feat_mul_selfsup_attention_mask,
                              'selfsup_attention_num_keypoints':
                              config.SELFSUP_ATTENTION.NUM_KEYPOINTS,
                              'selfsup_attention_gauss_std':
                              config.SELFSUP_ATTENTION.GAUSS_STD,
                              'selfsup_attention_fix':
                              config.selfsup_attention_fix,
                              'selfsup_attention_fix_keypointer':
                              config.selfsup_attention_fix_keypointer,
                              'selfsup_attention_pretrain':
                              config.selfsup_attention_pretrain,
                              'selfsup_attention_keyp_maps_pool':
                              config.selfsup_attention_keyp_maps_pool,
                              'selfsup_attention_image_feat_only':
                              config.selfsup_attention_image_feat_only,
                              'selfsup_attention_feat_masked':
                              config.selfsup_attention_feat_masked,
                              'selfsup_attention_feat_masked_residual':
                              config.selfsup_attention_feat_masked_residual,
                              'selfsup_attention_feat_load_pretrained':
                              config.selfsup_attention_feat_load_pretrained,
                              'use_layer_norm':
                              config.use_layer_norm,
                              'selfsup_attention_keyp_cls_agnostic':
                              config.SELFSUP_ATTENTION.KEYPOINTER_CLS_AGNOSTIC,
                              'selfsup_attention_feat_use_ln':
                              config.SELFSUP_ATTENTION.USE_LAYER_NORM,
                              'selfsup_attention_use_instance_norm':
                              config.SELFSUP_ATTENTION.USE_INSTANCE_NORM,
                              'feat_mul_selfsup_attention_mask_residual':
                              config.feat_mul_selfsup_attention_mask_residual,
                              'bottom_up_form_objects':
                              config.bottom_up_form_objects,
                              'bottom_up_form_num_of_objects':
                              config.bottom_up_form_num_of_objects,
                              'gaussian_std':
                              config.gaussian_std,
                              'train_selfsup_attention':
                              config.train_selfsup_attention,
                              'block_selfsup_attention_grad':
                              config.block_selfsup_attention_grad,
                              'sep_bg_fg_feat':
                              config.sep_bg_fg_feat,
                              'mask_threshold':
                              config.mask_threshold,
                              'fix_feature':
                              config.fix_feature
                          })

    # init / load parameter
    if config.MODEL_FILE:
        logger.info('=> loading model from {}'.format(config.MODEL_FILE))
        state_dict = torch.load(config.MODEL_FILE)

        state_dict = OrderedDict(
            (_k, _v) for _k, _v in state_dict.items() if 'dist' not in _k)

        actor_critic.load_state_dict(state_dict, strict=False)
    elif config.RESUME:
        checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')
        if os.path.exists(checkpoint_file):
            logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
            checkpoint = torch.load(checkpoint_file)
            actor_critic.load_state_dict(checkpoint['state_dict'])

            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                checkpoint_file, checkpoint['epoch']))

    actor_critic.to(device)

    if config.algo == 'a2c':
        agent = algo.A2C_ACKTR(
            actor_critic,
            config.value_loss_coef,
            config.entropy_coef,
            lr=config.lr,
            eps=config.eps,
            alpha=config.alpha,
            max_grad_norm=config.max_grad_norm,
            train_selfsup_attention=config.train_selfsup_attention)
    elif config.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         config.clip_param,
                         config.ppo_epoch,
                         config.num_mini_batch,
                         config.value_loss_coef,
                         config.entropy_coef,
                         lr=config.lr,
                         eps=config.eps,
                         max_grad_norm=config.max_grad_norm)
    elif config.algo == 'acktr':
        agent = algo.A2C_ACKTR(
            actor_critic,
            config.value_loss_coef,
            config.entropy_coef,
            acktr=True,
            train_selfsup_attention=config.train_selfsup_attention,
            max_grad_norm=config.max_grad_norm)

    # rollouts: environment
    rollouts = RolloutStorage(
        config.num_steps,
        config.num_processes,
        envs.observation_space.shape,
        envs.action_space,
        actor_critic.recurrent_hidden_state_size,
        keep_buffer=config.train_selfsup_attention,
        buffer_size=config.train_selfsup_attention_buffer_size)

    if config.RESUME:
        if os.path.exists(checkpoint_file):
            agent.optimizer.load_state_dict(checkpoint['optimizer'])
    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        config.num_env_steps) // config.num_steps // config.num_processes
    best_perf = 0.0
    best_model = False
    print('num updates', num_updates, 'num steps', config.num_steps)

    for j in range(num_updates):

        if config.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if config.algo == "acktr" else config.lr)

        for step in range(config.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            recurrent_hidden_states, meta = recurrent_hidden_states

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)

            objects_locs = []
            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            if objects_locs:
                objects_locs = torch.FloatTensor(objects_locs)
                objects_locs = objects_locs * 2 - 1  # -1, 1
            else:
                objects_locs = None
            rollouts.insert(obs,
                            recurrent_hidden_states,
                            action,
                            action_log_prob,
                            value,
                            reward,
                            masks,
                            bad_masks,
                            objects_loc=objects_locs)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1],
                rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1],
            ).detach()

        rollouts.compute_returns(next_value, config.use_gae, config.gamma,
                                 config.gae_lambda,
                                 config.use_proper_time_limits)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        if config.train_selfsup_attention and j > 15:
            for _iter in range(config.num_steps // 5):
                frame_x, frame_y = rollouts.generate_pair_image()
                selfsup_attention_loss, selfsup_attention_output, image_b_keypoints_maps = \
                    agent.update_selfsup_attention(frame_x, frame_y, config.SELFSUP_ATTENTION)

        if j % config.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * config.num_processes * config.num_steps
            end = time.time()
            msg = 'Updates {}, num timesteps {}, FPS {} \n' \
                  'Last {} training episodes: mean/median reward {:.1f}/{:.1f} ' \
                  'min/max reward {:.1f}/{:.1f} ' \
                  'dist entropy {:.1f}, value loss {:.1f}, action loss {:.1f}\n'. \
                format(j, total_num_steps,
                       int(total_num_steps / (end - start)),
                       len(episode_rewards), np.mean(episode_rewards),
                       np.median(episode_rewards), np.min(episode_rewards),
                       np.max(episode_rewards), dist_entropy, value_loss,
                       action_loss)
            if config.train_selfsup_attention and j > 15:
                msg = msg + 'selfsup attention loss {:.5f}\n'.format(
                    selfsup_attention_loss)
            logger.info(msg)

        if (config.eval_interval is not None and len(episode_rewards) > 1
                and j % config.eval_interval == 0):
            total_num_steps = (j + 1) * config.num_processes * config.num_steps
            ob_rms = getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            eval_mean_score, eval_max_score, eval_scores = evaluate(
                actor_critic,
                ob_rms,
                config.env_name,
                config.seed,
                config.num_processes,
                eval_log_dir,
                device,
                width=width,
                height=height)
            perf_indicator = eval_mean_score
            if perf_indicator > best_perf:
                best_perf = perf_indicator
                best_model = True
            else:
                best_model = False

            # record test scores
            with open(os.path.join(final_output_dir, 'test_scores'),
                      'a+') as f:
                out_s = "TEST: {}, {}, {}, {}\n".format(
                    str(total_num_steps), str(eval_mean_score),
                    str(eval_max_score),
                    [str(_eval_scores) for _eval_scores in eval_scores])
                print(out_s, end="", file=f)
                logger.info(out_s)
            writer.add_scalar('data/mean_score', eval_mean_score,
                              total_num_steps)
            writer.add_scalar('data/max_score', eval_max_score,
                              total_num_steps)

            writer.add_scalars('test', {'mean_score': eval_mean_score},
                               total_num_steps)

            # save for every interval-th episode or for the last epoch
            if (j % config.save_interval == 0
                    or j == num_updates - 1) and config.save_dir != "":

                logger.info(
                    "=> saving checkpoint to {}".format(final_output_dir))
                epoch = j / config.save_interval
                save_checkpoint(
                    {
                        'epoch':
                        epoch + 1,
                        'model':
                        get_model_name(config),
                        'state_dict':
                        actor_critic.state_dict(),
                        'perf':
                        perf_indicator,
                        'optimizer':
                        agent.optimizer.state_dict(),
                        'ob_rms':
                        getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
                    }, best_model, final_output_dir)

    final_model_state_file = os.path.join(final_output_dir, 'final_state.pth')
    logger.info(
        '=> saving final model state to {}'.format(final_model_state_file))
    torch.save(actor_critic.state_dict(), final_model_state_file)

    # export_scalars_to_json needs results from add scalars
    writer.export_scalars_to_json(os.path.join(tb_log_dir, 'all_scalars.json'))
    writer.close()
Пример #3
0
def record_trajectories():
    args = get_args()
    print(args)

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    # Append the model name
    log_dir = os.path.expanduser(args.log_dir)
    log_dir = os.path.join(log_dir, args.model_name, str(args.seed))

    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name,
                         args.seed,
                         1,
                         args.gamma,
                         log_dir,
                         device,
                         True,
                         training=False)

    # Take activation for carracing
    print("Loaded env...")
    activation = None
    if args.env_name == 'CarRacing-v0' and args.use_activation:
        activation = torch.tanh
    print(activation)

    actor_critic = Policy(
        envs.observation_space.shape,
        envs.action_space,
        base_kwargs={
            'recurrent': args.recurrent_policy,
            'env': args.env_name
        },
        activation=activation,
    )
    actor_critic.to(device)

    # Load from previous model
    if args.load_model_name:
        loaddata = torch.load(
            os.path.join(args.save_dir, args.load_model_name,
                         args.load_model_name + '_{}.pt'.format(args.seed)))
        state = loaddata[0]
        try:
            obs_rms, ret_rms = loaddata[1:]
            # Feed it into the env
            envs.obs_rms = None
            envs.ret_rms = None
        except:
            print("Couldnt load obsrms")
            obs_rms = ret_rms = None
        try:
            actor_critic.load_state_dict(state)
        except:
            actor_critic = state
    else:
        raise NotImplementedError

    # Record trajectories
    actions = []
    rewards = []
    observations = []
    episode_starts = []

    for eps in range(args.num_episodes):
        obs = envs.reset()
        # Init variables for storing
        episode_starts.append(True)
        reward = 0
        while True:
            # Take action
            act = actor_critic.act(obs, None, None, None)[1]
            next_state, rew, done, info = envs.step(act)
            #print(obs.shape, act.shape, rew.shape, done)
            reward += rew
            # Add the current observation and act
            observations.append(obs.data.cpu().numpy()[0])  # [C, H, W]
            actions.append(act.data.cpu().numpy()[0])  # [A]
            rewards.append(rew[0, 0].data.cpu().numpy())
            if done[0]:
                break
            episode_starts.append(False)
            obs = next_state + 0
        print("Total reward: {}".format(reward[0, 0].data.cpu().numpy()))

    # Save these values
    save_trajectories_images(observations, actions, rewards, episode_starts)
Пример #4
0
def main():
    args = get_args()

    # Record trajectories
    if args.record_trajectories:
        record_trajectories()
        return

    print(args)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    # Append the model name
    log_dir = os.path.expanduser(args.log_dir)
    log_dir = os.path.join(log_dir, args.model_name, str(args.seed))

    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, log_dir, device, False)

    # Take activation for carracing
    print("Loaded env...")
    activation = None
    if args.env_name == 'CarRacing-v0' and args.use_activation:
        activation = torch.tanh
    print(activation)

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={
                              'recurrent': args.recurrent_policy,
                              'env': args.env_name
                          },
                          activation=activation)
    actor_critic.to(device)
    # Load from previous model
    if args.load_model_name:
        state = torch.load(
            os.path.join(args.save_dir, args.load_model_name,
                         args.load_model_name + '_{}.pt'.format(args.seed)))[0]
        try:
            actor_critic.load_state_dict(state)
        except:
            actor_critic = state

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    if args.gail:
        if len(envs.observation_space.shape) == 1:
            discr = gail.Discriminator(
                envs.observation_space.shape[0] + envs.action_space.shape[0],
                100, device)
            file_name = os.path.join(
                args.gail_experts_dir,
                "trajs_{}.pt".format(args.env_name.split('-')[0].lower()))

            expert_dataset = gail.ExpertDataset(file_name,
                                                num_trajectories=3,
                                                subsample_frequency=1)
            expert_dataset_test = gail.ExpertDataset(file_name,
                                                     num_trajectories=1,
                                                     start=3,
                                                     subsample_frequency=1)
            drop_last = len(expert_dataset) > args.gail_batch_size
            gail_train_loader = torch.utils.data.DataLoader(
                dataset=expert_dataset,
                batch_size=args.gail_batch_size,
                shuffle=True,
                drop_last=drop_last)
            gail_test_loader = torch.utils.data.DataLoader(
                dataset=expert_dataset_test,
                batch_size=args.gail_batch_size,
                shuffle=False,
                drop_last=False)
            print(len(expert_dataset), len(expert_dataset_test))
        else:
            # env observation shape is 3 => its an image
            assert len(envs.observation_space.shape) == 3
            discr = gail.CNNDiscriminator(envs.observation_space.shape,
                                          envs.action_space, 100, device)
            file_name = os.path.join(args.gail_experts_dir, 'expert_data.pkl')

            expert_dataset = gail.ExpertImageDataset(file_name, train=True)
            test_dataset = gail.ExpertImageDataset(file_name, train=False)
            gail_train_loader = torch.utils.data.DataLoader(
                dataset=expert_dataset,
                batch_size=args.gail_batch_size,
                shuffle=True,
                drop_last=len(expert_dataset) > args.gail_batch_size,
            )
            gail_test_loader = torch.utils.data.DataLoader(
                dataset=test_dataset,
                batch_size=args.gail_batch_size,
                shuffle=False,
                drop_last=len(test_dataset) > args.gail_batch_size,
            )
            print('Dataloader size', len(gail_train_loader))

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)
    start = time.time()
    #num_updates = int(
    #args.num_env_steps) // args.num_steps // args.num_processes
    num_updates = args.num_steps
    print(num_updates)

    # count the number of times validation loss increases
    val_loss_increase = 0
    prev_val_action = np.inf
    best_val_loss = np.inf

    for j in range(num_updates):
        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr)

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Observe reward and next obs
            obs, reward, done, infos = envs.step(action)
            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        if args.gail:
            if j >= 10:
                try:
                    envs.venv.eval()
                except:
                    pass

            gail_epoch = args.gail_epoch
            #if j < 10:
            #gail_epoch = 100  # Warm up
            for _ in range(gail_epoch):
                #discr.update(gail_train_loader, rollouts,
                #None)
                pass

            for step in range(args.num_steps):
                rollouts.rewards[step] = discr.predict_reward(
                    rollouts.obs[step], rollouts.actions[step], args.gamma,
                    rollouts.masks[step])

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        #value_loss, action_loss, dist_entropy = agent.update(rollouts)
        value_loss = 0
        dist_entropy = 0
        for data in gail_train_loader:
            expert_states, expert_actions = data
            expert_states = Variable(expert_states).to(device)
            expert_actions = Variable(expert_actions).to(device)
            loss = agent.update_bc(expert_states, expert_actions)
            action_loss = loss.data.cpu().numpy()
        print("Epoch: {}, Loss: {}".format(j, action_loss))

        with torch.no_grad():
            cnt = 0
            val_action_loss = 0
            for data in gail_test_loader:
                expert_states, expert_actions = data
                expert_states = Variable(expert_states).to(device)
                expert_actions = Variable(expert_actions).to(device)
                loss = agent.get_action_loss(expert_states, expert_actions)
                val_action_loss += loss.data.cpu().numpy()
                cnt += 1
            val_action_loss /= cnt
            print("Val Loss: {}".format(val_action_loss))

        #rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":

            if val_action_loss < best_val_loss:
                val_loss_increase = 0
                best_val_loss = val_action_loss
                save_path = os.path.join(args.save_dir, args.model_name)
                try:
                    os.makedirs(save_path)
                except OSError:
                    pass

                torch.save([
                    actor_critic.state_dict(),
                    getattr(utils.get_vec_normalize(envs), 'ob_rms', None),
                    getattr(utils.get_vec_normalize(envs), 'ret_rms', None)
                ],
                           os.path.join(
                               save_path,
                               args.model_name + "_{}.pt".format(args.seed)))
            elif val_action_loss > prev_val_action:
                val_loss_increase += 1
                if val_loss_increase == 10:
                    print("Val loss increasing too much, breaking here...")
                    break
            elif val_action_loss < prev_val_action:
                val_loss_increase = 0

            # Update prev val action
            prev_val_action = val_action_loss

        # log interval
        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(actor_critic, ob_rms, args.env_name, args.seed,
                     args.num_processes, eval_log_dir, device)
Пример #5
0
def main():
    args = get_args()
    trace_size = args.trace_size
    toke = tokenizer(args)

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, device, False)

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    tobs = torch.zeros((args.num_processes, trace_size), dtype=torch.long)
    #print (tobs.dtype)
    rollouts.obs[0].copy_(obs)
    rollouts.tobs[0].copy_(tobs)

    rollouts.to(device)

    episode_rewards = deque(maxlen=args.num_processes)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes
    save_path = os.path.join(args.save_dir, args.algo)

    if args.load:
        actor_critic.load_state_dict = (os.path.join(save_path,
                                                     args.env_name + ".pt"))
    for j in range(num_updates):

        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr)

        for step in range(args.num_steps):

            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.tobs[step],
                    rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)
            tobs = []
            envs.render()
            for info in infos:
                if 'episode' in info.keys():
                    #print ("episode ", info['episode'])
                    episode_rewards.append(info['episode']['r'])
                trace = info['trace'][0:trace_size]
                trace = [x[2] for x in trace]
                word_to_ix = toke.tokenize(trace)
                seq = prepare_sequence(trace, word_to_ix)
                if len(seq) < trace_size:
                    seq = torch.zeros((trace_size), dtype=torch.long)
                seq = seq[:trace_size]
                #print (seq.dtype)
                tobs.append(seq)
            tobs = torch.stack(tobs)
            #print (tobs)
            #print (tobs.size())
            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, tobs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.tobs[-1],
                rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        #"""
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], os.path.join(save_path, args.env_name + ".pt"))
            pickle.dump(toke.word_to_ix, open("save.p", "wb"))

        #"""

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))
            writer.add_scalar(
                'mean reward',
                np.mean(episode_rewards),
                total_num_steps,
            )
            writer.add_scalar(
                'median reward',
                np.median(episode_rewards),
                total_num_steps,
            )
            writer.add_scalar(
                'max reward',
                np.max(episode_rewards),
                total_num_steps,
            )

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(actor_critic, ob_rms, args.env_name, args.seed,
                     args.num_processes, eval_log_dir, device)
Пример #6
0
    step_size=1.0,
    n_anamolous=args.n_anamolous,
    n_drones=args.n_drones,
)

state_size = env.state_size
action_size = env.action_size
n_drones = env.n_drones

actor_critics = []
for i in range(n_drones):
    actor_critic = Policy(
        (state_size + 2,), (action_size,), base_kwargs={"recurrent": True}
    )
    actor_critic.load_state_dict(
        torch.load(args.model_path + f"A2C_drone_MARL_{i}.bin")
    )
    actor_critics.append(actor_critic)

recurrent_hidden_states = torch.zeros(1, actor_critic.recurrent_hidden_state_size)
masks = torch.zeros(1, 1)

obs = env.reset()
done = False

step = 0
print(f"Step: {step+1}")
print(f"Drone positions:{env.n_drones_pos}")
temp_st = obs.reshape((args.grid_size, args.grid_size))
print(temp_st, "\n")
Пример #7
0
from a2c_ppo_acktr.model import Policy

a = CNNBase(env.observation_space.shape[0], recurrent=False)

actor_critic = Policy(
    env.observation_space.shape,
    env.action_space,
    # (obs_shape[0], ** base_kwargs)
    base=a,
    # base_kwargs={'recurrent': args.recurrent_policy}
)
actor_critic.to(device)
save_model, ob_rms = torch.load('./trained_models/a2c/PongNoFrameskip-v4.pt')
#save_model2, ob_rms2 = torch.load('./trained_models/a2c/DemonAttackNoFrameskip-v4.pt')

actor_critic.load_state_dict(save_model.state_dict())
actor_critic.to(device)
actor_critic.eval()

from a2c_ppo_acktr.storage import RolloutStorage

rollouts = RolloutStorage(5, 16, env.observation_space.shape, env.action_space,
                          actor_critic.recurrent_hidden_state_size)

vec_norm = get_vec_normalize(env)
if vec_norm is not None:
    vec_norm.eval()
    vec_norm.ob_rms = ob_rms

recurrent_hidden_states = torch.zeros(1,
                                      actor_critic.recurrent_hidden_state_size)
Пример #8
0
def main():
    args = get_args()
    import random
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    logdir = args.env_name + '_' + args.algo + '_num_arms_' + str(
        args.num_processes) + '_' + time.strftime("%d-%m-%Y_%H-%M-%S")
    if args.use_privacy:
        logdir = logdir + '_privacy'
    elif args.use_noisygrad:
        logdir = logdir + '_noisygrad'
    elif args.use_pcgrad:
        logdir = logdir + '_pcgrad'
    elif args.use_testgrad:
        logdir = logdir + '_testgrad'
    elif args.use_median_grad:
        logdir = logdir + '_mediangrad'
    logdir = os.path.join('runs', logdir)
    logdir = os.path.join(os.path.expanduser(args.log_dir), logdir)
    utils.cleanup_log_dir(logdir)

    # Ugly but simple logging
    log_dict = {
        'task_steps': args.task_steps,
        'grad_noise_ratio': args.grad_noise_ratio,
        'max_task_grad_norm': args.max_task_grad_norm,
        'use_noisygrad': args.use_noisygrad,
        'use_pcgrad': args.use_pcgrad,
        'use_testgrad': args.use_testgrad,
        'use_testgrad_median': args.use_testgrad_median,
        'testgrad_quantile': args.testgrad_quantile,
        'median_grad': args.use_median_grad,
        'use_meanvargrad': args.use_meanvargrad,
        'meanvar_beta': args.meanvar_beta,
        'no_special_grad_for_critic': args.no_special_grad_for_critic,
        'use_privacy': args.use_privacy,
        'seed': args.seed,
        'recurrent': args.recurrent_policy,
        'obs_recurrent': args.obs_recurrent,
        'cmd': ' '.join(sys.argv[1:])
    }
    for eval_disp_name, eval_env_name in EVAL_ENVS.items():
        log_dict[eval_disp_name] = []

    summary_writer = SummaryWriter()
    summary_writer.add_hparams(
        {
            'task_steps': args.task_steps,
            'grad_noise_ratio': args.grad_noise_ratio,
            'max_task_grad_norm': args.max_task_grad_norm,
            'use_noisygrad': args.use_noisygrad,
            'use_pcgrad': args.use_pcgrad,
            'use_testgrad': args.use_testgrad,
            'use_testgrad_median': args.use_testgrad_median,
            'testgrad_quantile': args.testgrad_quantile,
            'median_grad': args.use_median_grad,
            'use_meanvargrad': args.use_meanvargrad,
            'meanvar_beta': args.meanvar_beta,
            'no_special_grad_for_critic': args.no_special_grad_for_critic,
            'use_privacy': args.use_privacy,
            'seed': args.seed,
            'recurrent': args.recurrent_policy,
            'obs_recurrent': args.obs_recurrent,
            'cmd': ' '.join(sys.argv[1:])
        }, {})

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    print('making envs...')
    envs = make_vec_envs(args.env_name,
                         args.seed,
                         args.num_processes,
                         args.gamma,
                         args.log_dir,
                         device,
                         False,
                         steps=args.task_steps,
                         free_exploration=args.free_exploration,
                         recurrent=args.recurrent_policy,
                         obs_recurrent=args.obs_recurrent,
                         multi_task=True)

    val_envs = make_vec_envs(args.val_env_name,
                             args.seed,
                             args.num_processes,
                             args.gamma,
                             args.log_dir,
                             device,
                             False,
                             steps=args.task_steps,
                             free_exploration=args.free_exploration,
                             recurrent=args.recurrent_policy,
                             obs_recurrent=args.obs_recurrent,
                             multi_task=True)

    eval_envs_dic = {}
    for eval_disp_name, eval_env_name in EVAL_ENVS.items():
        eval_envs_dic[eval_disp_name] = make_vec_envs(
            eval_env_name[0],
            args.seed,
            args.num_processes,
            None,
            logdir,
            device,
            True,
            steps=args.task_steps,
            recurrent=args.recurrent_policy,
            obs_recurrent=args.obs_recurrent,
            multi_task=True,
            free_exploration=args.free_exploration)
    prev_eval_r = {}
    print('done')
    if args.hard_attn:
        actor_critic = Policy(envs.observation_space.shape,
                              envs.action_space,
                              base=MLPHardAttnBase,
                              base_kwargs={
                                  'recurrent':
                                  args.recurrent_policy or args.obs_recurrent
                              })
    else:
        actor_critic = Policy(envs.observation_space.shape,
                              envs.action_space,
                              base=MLPAttnBase,
                              base_kwargs={
                                  'recurrent':
                                  args.recurrent_policy or args.obs_recurrent
                              })
    actor_critic.to(device)

    if (args.continue_from_epoch > 0) and args.save_dir != "":
        save_path = os.path.join(args.save_dir, args.algo)
        actor_critic_, loaded_obs_rms_ = torch.load(
            os.path.join(
                save_path, args.env_name +
                "-epoch-{}.pt".format(args.continue_from_epoch)))
        actor_critic.load_state_dict(actor_critic_.state_dict())

    if args.algo != 'ppo':
        raise "only PPO is supported"
    agent = algo.PPO(actor_critic,
                     args.clip_param,
                     args.ppo_epoch,
                     args.num_mini_batch,
                     args.value_loss_coef,
                     args.entropy_coef,
                     lr=args.lr,
                     eps=args.eps,
                     num_tasks=args.num_processes,
                     attention_policy=False,
                     max_grad_norm=args.max_grad_norm,
                     weight_decay=args.weight_decay)
    val_agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.val_lr,
                         eps=args.eps,
                         num_tasks=args.num_processes,
                         attention_policy=True,
                         max_grad_norm=args.max_grad_norm,
                         weight_decay=args.weight_decay)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    val_rollouts = RolloutStorage(args.num_steps, args.num_processes,
                                  val_envs.observation_space.shape,
                                  val_envs.action_space,
                                  actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    val_obs = val_envs.reset()
    val_rollouts.obs[0].copy_(val_obs)
    val_rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes

    save_copy = True
    for j in range(args.continue_from_epoch,
                   args.continue_from_epoch + num_updates):

        # policy rollouts
        for step in range(args.num_steps):
            # Sample actions
            actor_critic.eval()
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])
            actor_critic.train()

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])
                    for k, v in info['episode'].items():
                        summary_writer.add_scalar(
                            f'training/{k}', v,
                            j * args.num_processes * args.num_steps +
                            args.num_processes * step)

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        actor_critic.eval()
        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()
        actor_critic.train()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        if save_copy:
            prev_weights = copy.deepcopy(actor_critic.state_dict())
            prev_opt_state = copy.deepcopy(agent.optimizer.state_dict())
            prev_val_opt_state = copy.deepcopy(
                val_agent.optimizer.state_dict())
            save_copy = False

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        # validation rollouts
        for val_iter in range(args.val_agent_steps):
            for step in range(args.num_steps):
                # Sample actions
                actor_critic.eval()
                with torch.no_grad():
                    value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                        val_rollouts.obs[step],
                        val_rollouts.recurrent_hidden_states[step],
                        val_rollouts.masks[step])
                actor_critic.train()

                # Obser reward and next obs
                obs, reward, done, infos = val_envs.step(action)

                # If done then clean the history of observations.
                masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                           for done_ in done])
                bad_masks = torch.FloatTensor(
                    [[0.0] if 'bad_transition' in info.keys() else [1.0]
                     for info in infos])
                val_rollouts.insert(obs, recurrent_hidden_states, action,
                                    action_log_prob, value, reward, masks,
                                    bad_masks)

            actor_critic.eval()
            with torch.no_grad():
                next_value = actor_critic.get_value(
                    val_rollouts.obs[-1],
                    val_rollouts.recurrent_hidden_states[-1],
                    val_rollouts.masks[-1]).detach()
            actor_critic.train()

            val_rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                         args.gae_lambda,
                                         args.use_proper_time_limits)

            val_value_loss, val_action_loss, val_dist_entropy = val_agent.update(
                val_rollouts)
            val_rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'obs_rms', None)
            ], os.path.join(save_path,
                            args.env_name + "-epoch-{}.pt".format(j)))

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))
        revert = False
        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            actor_critic.eval()
            obs_rms = utils.get_vec_normalize(envs).obs_rms
            eval_r = {}
            printout = f'Seed {args.seed} Iter {j} '
            for eval_disp_name, eval_env_name in EVAL_ENVS.items():
                eval_r[eval_disp_name] = evaluate(
                    actor_critic,
                    obs_rms,
                    eval_envs_dic,
                    eval_disp_name,
                    args.seed,
                    args.num_processes,
                    eval_env_name[1],
                    logdir,
                    device,
                    steps=args.task_steps,
                    recurrent=args.recurrent_policy,
                    obs_recurrent=args.obs_recurrent,
                    multi_task=True,
                    free_exploration=args.free_exploration)
                if eval_disp_name in prev_eval_r:
                    diff = np.array(eval_r[eval_disp_name]) - np.array(
                        prev_eval_r[eval_disp_name])
                    if eval_disp_name == 'many_arms':
                        if np.sum(diff > 0) - np.sum(
                                diff < 0) < args.val_improvement_threshold:
                            print('no update')
                            revert = True

                summary_writer.add_scalar(f'eval/{eval_disp_name}',
                                          np.mean(eval_r[eval_disp_name]),
                                          (j + 1) * args.num_processes *
                                          args.num_steps)
                log_dict[eval_disp_name].append([
                    (j + 1) * args.num_processes * args.num_steps,
                    eval_r[eval_disp_name]
                ])
                printout += eval_disp_name + ' ' + str(
                    np.mean(eval_r[eval_disp_name])) + ' '
            # summary_writer.add_scalars('eval_combined', eval_r, (j+1) * args.num_processes * args.num_steps)
            if revert:
                actor_critic.load_state_dict(prev_weights)
                agent.optimizer.load_state_dict(prev_opt_state)
                val_agent.optimizer.load_state_dict(prev_val_opt_state)
            else:
                print(printout)
                prev_eval_r = eval_r.copy()
            save_copy = True
            actor_critic.train()

    save_obj(log_dict, os.path.join(logdir, 'log_dict.pkl'))
    envs.close()
    val_envs.close()
    for eval_disp_name, eval_env_name in EVAL_ENVS.items():
        eval_envs_dic[eval_disp_name].close()
Пример #9
0
def main():
    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    ## Make environments
    envs = make_vec_envs(args, device)

    ## Setup Policy / network architecture
    if args.load_path != '':
        if os.path.isfile(os.path.join(args.load_path, "best_model.pt")):
            import_name = "best_model.pt"
        else:
            import_name = "model.pt"
        online_actor_critic = torch.load(
            os.path.join(args.load_path, import_name))
        target_actor_critic = torch.load(
            os.path.join(args.load_path, import_name))
        if args.cuda:
            target_actor_critic = target_actor_critic.cuda()
            online_actor_critic = online_actor_critic.cuda()
    else:
        online_actor_critic = Policy(occ_obs_shape, sign_obs_shape,
                                     args.state_rep, envs.action_space,
                                     args.recurrent_policy)
        online_actor_critic.to(device)
        target_actor_critic = Policy(occ_obs_shape, sign_obs_shape,
                                     args.state_rep, envs.action_space,
                                     args.recurrent_policy)
        target_actor_critic.to(device)
        target_actor_critic.load_state_dict(online_actor_critic.state_dict())

    if args.penetration_type == "constant":
        target_actor_critic = online_actor_critic

    ## Choose algorithm to use
    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(online_actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(online_actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(online_actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    ## Initiate memory buffer
    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              occ_obs_shape, sign_obs_shape, envs.action_space,
                              target_actor_critic.recurrent_hidden_state_size)

    ## Start env with first observation
    occ_obs, sign_obs = envs.reset()
    if args.state_rep == 'full':
        rollouts.occ_obs[0].copy_(occ_obs)
    rollouts.sign_obs[0].copy_(sign_obs)
    rollouts.to(device)

    # Last 20 rewards - can set different queue length for different averaging
    episode_rewards = deque(maxlen=args.num_steps)
    reward_track = []
    best_eval_rewards = 0
    start = time.time()

    ## Loop over every policy updatetarget network
    for j in range(num_updates):

        ## Setup parameter decays
        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            if args.algo == "acktr":
                # use optimizer's learning rate since it's hard-coded in kfac.py
                update_linear_schedule(agent.optimizer, j, num_updates,
                                       agent.optimizer.lr)
            else:
                update_linear_schedule(agent.optimizer, j, num_updates,
                                       args.lr)

        if args.algo == 'ppo' and args.use_linear_clip_decay:
            agent.clip_param = args.clip_param * (1 - j / float(num_updates))

        ## Loop over num_steps environment updates to form trajectory
        for step in range(args.num_steps):
            # Sample actionspython3 main.py --algo ppo --num-steps 700000 --penetration-rate $i --env-name TrafficLight-simple-dense-v0 --lr 2.5e-4 --num-processes 8 --num-steps 128 --num-mini-batch 4 --use-linear-lr-decay --use-linear-clip-decay
            with torch.no_grad():
                # Pass observation through network and get outputs
                value, action, action_log_prob, recurrent_hidden_states = target_actor_critic.act(
                    rollouts.occ_obs[step], rollouts.sign_obs[step],
                    rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Do action in environment and save reward
            occ_obs, sign_obs, reward, done, _ = envs.step(action)
            episode_rewards.append(reward.numpy())

            # Masks the processes which are done
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])

            # Insert step information in buffer
            rollouts.insert(occ_obs, sign_obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks)

        ## Get state value of current env state
        with torch.no_grad():
            next_value = target_actor_critic.get_value(
                rollouts.occ_obs[-1], rollouts.sign_obs[-1],
                rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        ## Computes the num_step return (next_value approximates reward after num_step) see Supp Material of https://arxiv.org/pdf/1804.02717.pdf
        ## Can use Generalized Advantage Estimation
        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.tau)

        # Update the policy with the rollouts
        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        # Clean the rollout by cylcing last elements to first ones
        rollouts.after_update()

        if (args.penetration_type == "linear") and (j % update_period == 0):
            target_actor_critic.load_state_dict(
                online_actor_critic.state_dict())

        ## Save model}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.3f}/{:.3f}, min/max reward {:.3f}/{:.3f}\n".
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":

            # A really ugly way to save a model to CPU
            save_model = target_actor_critic
            if args.cuda:
                save_model = copy.deepcopy(target_actor_critic).cpu()

            torch.save(save_model, os.path.join(save_path, "model.pt"))

        total_num_steps = (j + 1) * args.num_processes * args.num_steps

        if args.vis:
            # Add the average reward of update to reward tracker
            reward_track.append(np.mean(episode_rewards))

        ## Log progress
        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.3f}/{:.3f}, min/max reward {:.3f}/{:.3f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy))

        ## Evaluate model on new environments for 10 rewards
        percentage = 100 * total_num_steps // args.num_env_steps
        if (args.eval_interval is not None and percentage > 1
                and (j % args.eval_interval == 0 or j == num_updates - 1)):
            print("###### EVALUATING #######")
            args_eval = copy.deepcopy(args)
            args_eval.num_processes = 1
            eval_envs = make_vec_envs(args_eval, device, no_logging=True)

            eval_episode_rewards = []

            occ_obs, sign_obs = eval_envs.reset()
            eval_recurrent_hidden_states = torch.zeros(
                args_eval.num_processes,
                target_actor_critic.recurrent_hidden_state_size,
                device=device)
            eval_masks = torch.zeros(args_eval.num_processes, 1, device=device)

            while len(eval_episode_rewards) < 3000:
                with torch.no_grad():
                    _, action, _, eval_recurrent_hidden_states = target_actor_critic.act(
                        occ_obs,
                        sign_obs,
                        eval_recurrent_hidden_states,
                        eval_masks,
                        deterministic=True)

                # Obser reward and next obs
                occ_obs, sign_obs, reward, done, infos = eval_envs.step(action)

                eval_masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                                for done_ in done])

                eval_episode_rewards.append(reward)

            eval_envs.close()

            if np.mean(eval_episode_rewards) > best_eval_rewards:
                best_eval_rewards = np.mean(eval_episode_rewards)
                save_model = target_actor_critic
                if args.cuda:
                    save_model = copy.deepcopy(target_actor_critic).cpu()
                torch.save(save_model, os.path.join(save_path,
                                                    'best_model.pt'))

    ## Visualize tracked rewards(over num_steps) over time
    if args.vis:
        visualize(reward_track, args.algo, save_path)
Пример #10
0
def job(rank, args, device, shared_model):
    torch.manual_seed(args.seed + rank)
    time.sleep(rank * 4)
    episode_rewards = deque(maxlen=10)
    envs = TorcsEnv(rank, None, throttle=True, gear_change=False)

    actor_critic = Policy(24,
                          3,
                          base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)
    actor_critic.load_state_dict(shared_model.state_dict())
    state_in = np.zeros(24)
    action_out = np.zeros(3)
    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              state_in.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)
    obs = envs.reset()
    obs = torch.from_numpy(obs)
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)
    acc_r = 0
    done = [False]

    for step in range(args.num_steps):
        if done[0]:
            episode_rewards.append(acc_r)
            obs = envs.reset()
            obs = torch.from_numpy(obs)
            rollouts.obs[0].copy_(obs)
            rollouts.to(device)
            acc_r = 0

        # Sample actions
        with torch.no_grad():
            value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                rollouts.masks[step])

        # Obser reward and next obs
        target_action = action.numpy()[0]
        obs, reward, done, infos = envs.step(target_action)

        acc_r += reward
        obs = torch.from_numpy(obs).float().to(device)
        reward = torch.from_numpy(np.array([reward])).unsqueeze(dim=1).float()
        done = [done]

        masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                   for done_ in done])
        bad_masks = torch.FloatTensor([[1.0] for done_ in done])
        rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob,
                        value, reward, masks, bad_masks)
    # print(rank, np.mean(episode_rewards),np.median(episode_rewards),
    # np.min(episode_rewards), np.max(episode_rewards))
    s = "{},{:.2f},{:.2f},{:.2f},{:.2f}\n".format(rank,
                                                  np.mean(episode_rewards),
                                                  np.median(episode_rewards),
                                                  np.min(episode_rewards),
                                                  np.max(episode_rewards))
    print(s)
    with open("logs/{}_mp.csv".format(args.env_name), 'a') as fl:
        fl.write(s)
    return rollouts
Пример #11
0
if record_video:
    env = wrappers.Monitor(env,
                           os.path.join(save_path, 'videos',
                                        record_video_filename),
                           force=True)

policy = Policy(env.observation_space.shape,
                env.action_space,
                base_kwargs={
                    'recurrent': False,
                    'layernorm': args.layernorm
                },
                obj_num=env.obj_dim)

state_dict = torch.load(state_dict_path)
policy.load_state_dict(state_dict)
policy = policy.to(device)
policy.double()
policy.eval()

ob_rms = None
if args.env_params is not None and os.path.exists(args.env_params):
    with open(args.env_params, 'rb') as fp:
        env_params = pickle.load(fp)
    ob_rms = env_params['ob_rms']

while True:
    obs = env.reset()
    obj = np.zeros(env.obj_dim)
    t = time()
    done = False
Пример #12
0
    torch.cuda.manual_seed_all(args.seed)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    # make envs
    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, device, False)

    # load policy
    actor_critic = Policy(
        envs.observation_space.shape,
        envs.action_space.shape, # envs.action_space
        base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)
    actor_critic.load_state_dict(model_dict['policy'])

    # ob_rms = model_dict['ob_rms']
    ob_rms = None

    # run evaluation
    evaluate(actor_critic=actor_critic, 
        ob_rms=ob_rms, 
        env_name=args.env_name, 
        seed=args.seed, 
        num_processes=1, 
        eval_log_dir=args.log_dir,
        device=device, 
        eval_envs=envs)

Пример #13
0
def record_trajectories():
    args = get_args()
    print(args)

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    # Append the model name
    log_dir = os.path.expanduser(args.log_dir)
    log_dir = os.path.join(log_dir, args.model_name, str(args.seed))

    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name,
                         args.seed,
                         1,
                         args.gamma,
                         log_dir,
                         device,
                         True,
                         training=False,
                         red=True)

    # Take activation for carracing
    print("Loaded env...")
    activation = None
    if args.env_name == 'CarRacing-v0' and args.use_activation:
        activation = torch.tanh
    print(activation)

    actor_critic = Policy(
        envs.observation_space.shape,
        envs.action_space,
        base_kwargs={
            'recurrent': args.recurrent_policy,
            'env': args.env_name
        },
        activation=activation,
    )
    actor_critic.to(device)

    # Load from previous model
    if args.load_model_name and args.model_name != 'random':
        loaddata = torch.load(
            os.path.join(args.save_dir, args.load_model_name,
                         args.load_model_name + '_{}.pt'.format(args.seed)))
        state = loaddata[0]
        try:
            #print(envs.venv, loaddata[2])
            if not args.load_model_name.endswith('BC'):
                obs_rms, ret_rms = loaddata[1:]
                envs.venv.ob_rms = obs_rms
                envs.venv.ret_rms = None
                #envs.venv.ret_rms = ret_rms
                print(obs_rms)
                print("Loaded obs rms")
            else:
                envs.venv.ob_rms = None
                envs.venv.ret_rms = None
                print("Its BC, no normalization from env")
        except:
            print("Couldnt load obsrms")
            obs_rms = ret_rms = None
        try:
            actor_critic.load_state_dict(state)
        except:
            actor_critic = state
    elif args.load_model_name and args.model_name == 'random':
        print("Using random policy...")
        envs.venv.ret_rms = None
        envs.venv.ob_rms = None
    else:
        raise NotImplementedError

    # Record trajectories
    actions = []
    rewards = []
    observations = []
    episode_starts = []
    total_rewards = []

    last_length = 0

    for eps in range(args.num_episodes):
        obs = envs.reset()
        # Init variables for storing
        episode_starts.append(True)
        reward = 0
        while True:
            # Take action
            act = actor_critic.act(obs, None, None, None)[1]
            next_state, rew, done, info = envs.step(act)
            #print(obs.shape, act.shape, rew.shape, done)
            reward += rew
            # Add the current observation and act
            observations.append(obs.data.cpu().numpy()[0])  # [C, H, W]
            actions.append(act.data.cpu().numpy()[0])  # [A]
            rewards.append(rew[0, 0].data.cpu().numpy())
            if done[0]:
                break
            episode_starts.append(False)
            obs = next_state + 0
        print("Total reward: {}".format(reward[0, 0].data.cpu().numpy()))
        print("Total length: {}".format(len(observations) - last_length))
        last_length = len(observations)
        total_rewards.append(reward[0, 0].data.cpu().numpy())

    # Save these values
    '''
    if len(envs.observation_space.shape) == 3:
        save_trajectories_images(observations, actions, rewards, episode_starts)
    else:
        save_trajectories(observations, actions, rewards, episode_starts)
    '''
    pathname = args.load_model_name if args.model_name != 'random' else 'random'
    prefix = 'length' if args.savelength else ''
    with open(
            os.path.join(args.save_dir, args.load_model_name,
                         pathname + '_{}{}.txt'.format(prefix, args.seed)),
            'wt') as f:
        if not args.savelength:
            for rew in total_rewards:
                f.write('{}\n'.format(rew))
        else:
            avg_length = len(observations) / args.num_episodes
            f.write("{}\n".format(avg_length))
Пример #14
0
def main():
    args = get_args()

    # Record trajectories
    if args.record_trajectories:
        record_trajectories()
        return

    print(args)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    # Append the model name
    log_dir = os.path.expanduser(args.log_dir)
    log_dir = os.path.join(log_dir, args.model_name, str(args.seed))

    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, log_dir, device, False)
    obs_shape = len(envs.observation_space.shape)

    # Take activation for carracing
    print("Loaded env...")
    activation = None
    if args.env_name == 'CarRacing-v0' and args.use_activation:
        activation = torch.tanh
    print(activation)

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={
                              'recurrent': args.recurrent_policy,
                              'env': args.env_name
                          },
                          activation=activation)
    actor_critic.to(device)
    # Load from previous model
    if args.load_model_name:
        state = torch.load(
            os.path.join(args.save_dir, args.load_model_name,
                         args.load_model_name + '_{}.pt'.format(args.seed)))[0]
        try:
            actor_critic.load_state_dict(state)
        except:
            actor_critic = state

    # If BCGAIL, then decay factor and gamma should be float
    if args.bcgail:
        assert type(args.decay) == float
        assert type(args.gailgamma) == float
        if args.decay < 0:
            args.decay = 1
        elif args.decay > 1:
            args.decay = 0.5**(1. / args.decay)

        print('Gamma: {}, decay: {}'.format(args.gailgamma, args.decay))
        print('BCGAIL used')

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         gamma=args.gailgamma,
                         decay=args.decay,
                         act_space=envs.action_space,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    if args.gail:
        if len(envs.observation_space.shape) == 1:

            # Load RED here
            red = None
            if args.red:
                red = gail.RED(
                    envs.observation_space.shape[0] +
                    envs.action_space.shape[0], 100, device, args.redsigma,
                    args.rediters)

            discr = gail.Discriminator(envs.observation_space.shape[0] +
                                       envs.action_space.shape[0],
                                       100,
                                       device,
                                       red=red,
                                       sail=args.sail,
                                       learn=args.learn)
            file_name = os.path.join(
                args.gail_experts_dir,
                "trajs_{}.pt".format(args.env_name.split('-')[0].lower()))

            expert_dataset = gail.ExpertDataset(file_name,
                                                num_trajectories=args.num_traj,
                                                subsample_frequency=1)
            args.gail_batch_size = min(args.gail_batch_size,
                                       len(expert_dataset))
            drop_last = len(expert_dataset) > args.gail_batch_size
            gail_train_loader = torch.utils.data.DataLoader(
                dataset=expert_dataset,
                batch_size=args.gail_batch_size,
                shuffle=True,
                drop_last=drop_last)
            print("Data loader size", len(expert_dataset))
        else:
            # env observation shape is 3 => its an image
            assert len(envs.observation_space.shape) == 3
            discr = gail.CNNDiscriminator(envs.observation_space.shape,
                                          envs.action_space, 100, device)
            file_name = os.path.join(args.gail_experts_dir, 'expert_data.pkl')

            expert_dataset = gail.ExpertImageDataset(file_name,
                                                     act=envs.action_space)
            gail_train_loader = torch.utils.data.DataLoader(
                dataset=expert_dataset,
                batch_size=args.gail_batch_size,
                shuffle=True,
                drop_last=len(expert_dataset) > args.gail_batch_size,
            )
            print('Dataloader size', len(gail_train_loader))

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)
    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes
    print(num_updates)
    for j in range(num_updates):
        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr)

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Observe reward and next obs
            obs, reward, done, infos = envs.step(action)
            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        if args.gail:
            if j >= 10:
                try:
                    envs.venv.eval()
                except:
                    pass

            gail_epoch = args.gail_epoch
            if j < 10 and obs_shape == 1:
                gail_epoch = 100  # Warm up
            for _ in range(gail_epoch):
                if obs_shape == 1:
                    discr.update(gail_train_loader, rollouts,
                                 utils.get_vec_normalize(envs)._obfilt)
                else:
                    discr.update(gail_train_loader, rollouts, None)
            if obs_shape == 3:
                obfilt = None
            else:
                obfilt = utils.get_vec_normalize(envs)._rev_obfilt

            for step in range(args.num_steps):
                rollouts.rewards[step] = discr.predict_reward(
                    rollouts.obs[step], rollouts.actions[step], args.gamma,
                    rollouts.masks[step], obfilt
                )  # The reverse function is passed down for RED to receive unnormalized obs which it is trained on

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        if args.bcgail:
            if obs_shape == 3:
                obfilt = None
            else:
                obfilt = utils.get_vec_normalize(envs)._obfilt
            value_loss, action_loss, dist_entropy = agent.update(
                rollouts, gail_train_loader, obfilt)
        else:
            value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.model_name)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic.state_dict(),
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None),
                getattr(utils.get_vec_normalize(envs), 'ret_rms', None)
            ],
                       os.path.join(
                           save_path,
                           args.model_name + "_{}.pt".format(args.seed)))

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(actor_critic, ob_rms, args.env_name, args.seed,
                     args.num_processes, eval_log_dir, device)
Пример #15
0
def main():
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)
    # import pdb; pdb.set_trace()
    save_path = os.path.join(args.save_dir, args.algo)
    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, device, False)
    # import pdb; pdb.set_trace()
    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={'recurrent': args.recurrent_policy})

    # transforms = [ hl.transforms.Prune('Constant') ] # Removes Constant nodes from graph.

    # graph = hl.build_graph(actor_critic, torch.zeros([1, 1, 64, 64]), transforms=transforms)
    # graph.theme = hl.graph.THEMES['blue'].copy()
    # graph.save('rnn_hiddenlayer2', format='png')
    # print(args.re)
    # import pdb; pdb.set_trace()
    my_model_state_dict = actor_critic.state_dict()
    count = 0
    pretrained_weights = torch.load('net_main_4rh_v2_64.pth')
    # pretrained_weights = torch.load(os.path.join(save_path, args.env_name + "_ft.pt"))
    # pretrained_weights['']

    old_names = list(pretrained_weights.items())
    pretrained_weights_items = list(pretrained_weights.items())
    for key, value in my_model_state_dict.items():
        layer_name, weights = pretrained_weights_items[count]
        my_model_state_dict[key] = weights
        print(count)
        print(layer_name)
        count += 1
        if layer_name == 'enc_dense.bias':
            break
    # pretrained_weights = torch.load(os.path.join(save_path, args.env_name + "_random.pt"))[1]
    actor_critic.load_state_dict(my_model_state_dict)
    start_epoch = 0
    ka = 0

    # for param in actor_critic.parameters():
    #     ka += 1
    #     # import pdb; pdb.set_trace()
    #     param.requires_grad = False
    #     if ka == 14:
    #         break
    count = 0
    # import pdb; pdb.set_trace()n

    actor_critic.to(device)

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    if args.gail:
        assert len(envs.observation_space.shape) == 1
        discr = gail.Discriminator(
            envs.observation_space.shape[0] + envs.action_space.shape[0], 100,
            device)
        file_name = os.path.join(
            args.gail_experts_dir,
            "trajs_{}.pt".format(args.env_name.split('-')[0].lower()))

        expert_dataset = gail.ExpertDataset(file_name,
                                            num_trajectories=4,
                                            subsample_frequency=20)
        drop_last = len(expert_dataset) > args.gail_batch_size
        gail_train_loader = torch.utils.data.DataLoader(
            dataset=expert_dataset,
            batch_size=args.gail_batch_size,
            shuffle=True,
            drop_last=drop_last)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    rewards_mean = []
    rewards_median = []
    val_loss = []
    act_loss = []
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes
    for j in range(start_epoch, num_updates):

        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr)

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        if args.gail:
            if j >= 10:
                envs.venv.eval()

            gail_epoch = args.gail_epoch
            if j < 10:
                gail_epoch = 100  # Warm up
            for _ in range(gail_epoch):
                discr.update(gail_train_loader, rollouts,
                             utils.get_vec_normalize(envs)._obfilt)

            for step in range(args.num_steps):
                rollouts.rewards[step] = discr.predict_reward(
                    rollouts.obs[step], rollouts.actions[step], args.gamma,
                    rollouts.masks[step])

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                actor_critic.state_dict(),
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], os.path.join(save_path, args.env_name + "_finetune.pt"))

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))
            rewards_mean.append(np.mean(episode_rewards))
            rewards_median.append(np.median(episode_rewards))
            val_loss.append(value_loss)
            act_loss.append(action_loss)
            torch.save(
                rewards_mean,
                "./plot_data/" + args.env_name + "_avg_rewards_finetune.pt")
            torch.save(
                rewards_median,
                "./plot_data/" + args.env_name + "_median_rewards_finetune.pt")
            # torch.save(val_loss, "./plot_data/"+args.env_name+"_val_loss_enc_weights.pt")
            # torch.save(act_loss, "./plot_data/"+args.env_name+"_act_loss_enc_weights.pt")

            plt.plot(rewards_mean)
            # print(plt_points2)
            plt.savefig("./imgs/" + args.env_name + "avg_reward_finetune.png")
            # plt.show(block = False)

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(actor_critic, ob_rms, args.env_name, args.seed,
                     args.num_processes, eval_log_dir, device)