示例#1
0
def main():
    args = get_args()
    device = torch.device('cuda' if args.cuda else 'cpu')
    seed = np.random.randint(0, 100)

    env = ObstacleTowerEnv('../ObstacleTower/obstacletower', worker_id=seed,
                               retro=True, config={'total-floors': 12}, greyscale=True, timeout_wait=300)
    env._flattener = ActionFlattener([2, 3, 2, 1])
    env._action_space = env._flattener.action_space
    input_size = env.observation_space.shape  # 4
    output_size = env.action_space.n  # 2

    env.close()

    is_render = False
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    model_path = os.path.join(args.save_dir, 'main.model')
    predictor_path = os.path.join(args.save_dir, 'main.pred')
    target_path = os.path.join(args.save_dir, 'main.target')

    writer = SummaryWriter()#log_dir=args.log_dir)



    discounted_reward = RewardForwardFilter(args.ext_gamma)

    model = CnnActorCriticNetwork(input_size, output_size, args.use_noisy_net)
    rnd = RNDModel(input_size, output_size)
    model = model.to(device)
    rnd = rnd.to(device)
    optimizer = optim.Adam(list(model.parameters()) + list(rnd.predictor.parameters()), lr=args.lr)
   
    if args.load_model:
        "Loading model..."
        if args.cuda:
            model.load_state_dict(torch.load(model_path))
        else:
            model.load_state_dict(torch.load(model_path, map_location='cpu'))


    works = []
    parent_conns = []
    child_conns = []
    for idx in range(args.num_worker):
        parent_conn, child_conn = Pipe()
        work = AtariEnvironment(
            args.env_name,
            is_render,
            idx,
            child_conn,
            sticky_action=args.sticky_action,
            p=args.sticky_action_prob,
            max_episode_steps=args.max_episode_steps)
        work.start()
        works.append(work)
        parent_conns.append(parent_conn)
        child_conns.append(child_conn)

    states = np.zeros([args.num_worker, 4, 84, 84])

    sample_env_index = 0   # Sample Environment index to log
    sample_episode = 0
    sample_rall = 0
    sample_step = 0
    sample_i_rall = 0
    global_update = 0
    global_step = 0

    print("Load RMS =", args.load_rms)
    if args.load_rms:
        print("Loading RMS values for observation and reward normalization")
        with open('reward_rms.pkl', 'rb') as f:
            reward_rms = dill.load(f)
        with open('obs_rms.pkl', 'rb') as f:
            obs_rms = dill.load(f)
    else:
        reward_rms = RunningMeanStd()
        obs_rms = RunningMeanStd(shape=(1, 1, 84, 84))

        # normalize observation
        print('Initializing observation normalization...')
        next_obs = []
        for step in range(args.num_step * args.pre_obs_norm_steps):
            actions = np.random.randint(0, output_size, size=(args.num_worker,))

            for parent_conn, action in zip(parent_conns, actions):
                parent_conn.send(action)

            for parent_conn in parent_conns:
                next_state, reward, done, realdone, log_reward = parent_conn.recv()
                next_obs.append(next_state[3, :, :].reshape([1, 84, 84]))

            if len(next_obs) % (args.num_step * args.num_worker) == 0:
                next_obs = np.stack(next_obs)
                obs_rms.update(next_obs)
                next_obs = []
        with open('reward_rms.pkl', 'wb') as f:
            dill.dump(reward_rms, f)
        with open('obs_rms.pkl', 'wb') as f:
            dill.dump(obs_rms, f)

    print('Training...')
    while True:
        total_state, total_reward, total_done, total_next_state, total_action, total_int_reward, total_next_obs, total_ext_values, total_int_values, total_action_probs = [], [], [], [], [], [], [], [], [], []
        global_step += (args.num_worker * args.num_step)
        global_update += 1

        # Step 1. n-step rollout
        for _ in range(args.num_step):
            actions, value_ext, value_int, action_probs = get_action(model, device, np.float32(states) / 255.)

            for parent_conn, action in zip(parent_conns, actions):
                parent_conn.send(action)

            next_states, rewards, dones, real_dones, log_rewards, next_obs = [], [], [], [], [], []
            for parent_conn in parent_conns:
                next_state, reward, done, real_done, log_reward = parent_conn.recv()
                next_states.append(next_state)
                rewards.append(reward)
                dones.append(done)
                real_dones.append(real_done)
                log_rewards.append(log_reward)
                next_obs.append(next_state[3, :, :].reshape([1, 84, 84]))

            next_states = np.stack(next_states)
            rewards = np.hstack(rewards)
            dones = np.hstack(dones)
            real_dones = np.hstack(real_dones)
            next_obs = np.stack(next_obs)

            # total reward = int reward + ext Reward
            intrinsic_reward = compute_intrinsic_reward(rnd, device,
                ((next_obs - obs_rms.mean) / np.sqrt(obs_rms.var)).clip(-5, 5))
            intrinsic_reward = np.hstack(intrinsic_reward)
            sample_i_rall += intrinsic_reward[sample_env_index]

            total_next_obs.append(next_obs)
            total_int_reward.append(intrinsic_reward)
            total_state.append(states)
            total_reward.append(rewards)
            total_done.append(dones)
            total_action.append(actions)
            total_ext_values.append(value_ext)
            total_int_values.append(value_int)
            total_action_probs.append(action_probs)

            states = next_states[:, :, :, :]

            sample_rall += log_rewards[sample_env_index]

            sample_step += 1
            if real_dones[sample_env_index]:
                sample_episode += 1
                writer.add_scalar('data/reward_per_epi', sample_rall, sample_episode)
                writer.add_scalar('data/reward_per_rollout', sample_rall, global_update)
                writer.add_scalar('data/step', sample_step, sample_episode)
                sample_rall = 0
                sample_step = 0
                sample_i_rall = 0

        # calculate last next value
        _, value_ext, value_int, _ = get_action(model, device, np.float32(states) / 255.)
        total_ext_values.append(value_ext)
        total_int_values.append(value_int)
        # --------------------------------------------------

        total_state = np.stack(total_state).transpose([1, 0, 2, 3, 4]).reshape([-1, 4, 84, 84])
        total_reward = np.stack(total_reward).transpose().clip(-1, 1)
        total_action = np.stack(total_action).transpose().reshape([-1])
        total_done = np.stack(total_done).transpose()
        total_next_obs = np.stack(total_next_obs).transpose([1, 0, 2, 3, 4]).reshape([-1, 1, 84, 84])
        total_ext_values = np.stack(total_ext_values).transpose()
        total_int_values = np.stack(total_int_values).transpose()
        total_logging_action_probs = np.vstack(total_action_probs)

        # Step 2. calculate intrinsic reward
        # running mean intrinsic reward
        total_int_reward = np.stack(total_int_reward).transpose()
        total_reward_per_env = np.array([discounted_reward.update(reward_per_step) for reward_per_step in total_int_reward.T])
        mean, std, count = np.mean(total_reward_per_env), np.std(total_reward_per_env), len(total_reward_per_env)
        reward_rms.update_from_moments(mean, std ** 2, count)

        # normalize intrinsic reward
        total_int_reward /= np.sqrt(reward_rms.var)
        writer.add_scalar('data/int_reward_per_epi', np.sum(total_int_reward) / args.num_worker, sample_episode)
        writer.add_scalar('data/int_reward_per_rollout', np.sum(total_int_reward) / args.num_worker, global_update)
        # -------------------------------------------------------------------------------------------

        # logging Max action probability
        writer.add_scalar('data/max_prob', total_logging_action_probs.max(1).mean(), sample_episode)

        # Step 3. make target and advantage
        # extrinsic reward calculate
        ext_target, ext_adv = make_train_data(total_reward,
                                              total_done,
                                              total_ext_values,
                                              args.ext_gamma,
                                              args.gae_lambda,
                                              args.num_step,
                                              args.num_worker,
                                              args.use_gae)

        # intrinsic reward calculate
        # None Episodic
        int_target, int_adv = make_train_data(total_int_reward,
                                              np.zeros_like(total_int_reward),
                                              total_int_values,
                                              args.int_gamma,
                                              args.gae_lambda,
                                              args.num_step,
                                              args.num_worker,
                                              args.use_gae)

        # add ext adv and int adv
        total_adv = int_adv * args.int_coef + ext_adv * args.ext_coef
        # -----------------------------------------------

        # Step 4. update obs normalize param
        obs_rms.update(total_next_obs)
        # -----------------------------------------------

        # Step 5. Training!
        train_model(args, device, output_size, model, rnd, optimizer,
                        np.float32(total_state) / 255., ext_target, int_target, total_action,
                        total_adv, ((total_next_obs - obs_rms.mean) / np.sqrt(obs_rms.var)).clip(-5, 5),
                        total_action_probs)

        if global_step % (args.num_worker * args.num_step * args.save_interval) == 0:
            print('Now Global Step :{}'.format(global_step))
            torch.save(model.state_dict(), model_path)
            torch.save(rnd.predictor.state_dict(), predictor_path)
            torch.save(rnd.target.state_dict(), target_path)

            """
            checkpoint_list = np.array([int(re.search(r"\d+(\.\d+)?", x)[0]) for x in glob.glob(os.path.join('trained_models', args.env_name+'*.model'))])
            if len(checkpoint_list) == 0:
                last_checkpoint = -1
            else:
                last_checkpoint = checkpoint_list.max()
            next_checkpoint = last_checkpoint + 1
            print("Latest Checkpoint is #{}, saving checkpoint is #{}.".format(last_checkpoint, next_checkpoint))

            incre_model_path = os.path.join(args.save_dir, args.env_name + str(next_checkpoint) + '.model')
            incre_predictor_path = os.path.join(args.save_dir, args.env_name + str(next_checkpoint) + '.pred')
            incre_target_path = os.path.join(args.save_dir, args.env_name + str(next_checkpoint) + '.target')
            with open(incre_model_path, 'wb') as f:
                torch.save(model.state_dict(), f)
            with open(incre_predictor_path, 'wb') as f:
                torch.save(rnd.predictor.state_dict(), f)
            with open(incre_target_path, 'wb') as f:
                torch.save(rnd.target.state_dict(), f)
            """
            if args.terminate and (global_step > args.terminate_steps):
                with open('reward_rms.pkl', 'wb') as f:
                    dill.dump(reward_rms, f)
                with open('obs_rms.pkl', 'wb') as f:
                    dill.dump(obs_rms, f)
                break
示例#2
0
class ICMAgent(object):
    def __init__(
            self,
            input_size,
            output_size,
            num_env,
            num_step,
            gamma,
            lam=0.95,
            learning_rate=1e-4,
            ent_coef=0.01,
            clip_grad_norm=0.5,
            epoch=3,
            batch_size=128,
            ppo_eps=0.1,
            eta=0.01,
            use_gae=True,
            use_cuda=False,
            use_noisy_net=False):
        self.model = CnnActorCriticNetwork(input_size, output_size, use_noisy_net)
        self.num_env = num_env
        self.output_size = output_size
        self.input_size = input_size
        self.num_step = num_step
        self.gamma = gamma
        self.lam = lam
        self.epoch = epoch
        self.batch_size = batch_size
        self.use_gae = use_gae
        self.ent_coef = ent_coef
        self.eta = eta
        self.ppo_eps = ppo_eps
        self.clip_grad_norm = clip_grad_norm
        self.device = torch.device('cuda' if use_cuda else 'cpu')

        self.icm = ICMModel(input_size, output_size, use_cuda)
        self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.icm.parameters()),
                                    lr=learning_rate)
        self.icm = self.icm.to(self.device)

        self.model = self.model.to(self.device)

    def get_action(self, state):
        state = torch.Tensor(state).to(self.device)
        state = state.float()
        policy, value = self.model(state)
        action_prob = F.softmax(policy, dim=-1).data.cpu().numpy()

        action = self.random_choice_prob_index(action_prob)

        return action, value.data.cpu().numpy().squeeze(), policy.detach()

    @staticmethod
    def random_choice_prob_index(p, axis=1):
        r = np.expand_dims(np.random.rand(p.shape[1 - axis]), axis=axis)
        return (p.cumsum(axis=axis) > r).argmax(axis=axis)

    def compute_intrinsic_reward(self, state, next_state, action):
        state = torch.FloatTensor(state).to(self.device)
        next_state = torch.FloatTensor(next_state).to(self.device)
        action = torch.LongTensor(action).to(self.device)

        action_onehot = torch.FloatTensor(
            len(action), self.output_size).to(
            self.device)
        action_onehot.zero_()
        action_onehot.scatter_(1, action.view(len(action), -1), 1)

        real_next_state_feature, pred_next_state_feature, pred_action = self.icm(
            [state, next_state, action_onehot])
        intrinsic_reward = self.eta * F.mse_loss(real_next_state_feature, pred_next_state_feature, reduction='none').mean(-1)
        return intrinsic_reward.data.cpu().numpy()

    def train_model(self, s_batch, next_s_batch, target_batch, y_batch, adv_batch, old_policy):
        s_batch = torch.FloatTensor(s_batch).to(self.device)
        next_s_batch = torch.FloatTensor(next_s_batch).to(self.device)
        target_batch = torch.FloatTensor(target_batch).to(self.device)
        y_batch = torch.LongTensor(y_batch).to(self.device)
        adv_batch = torch.FloatTensor(adv_batch).to(self.device)

        sample_range = np.arange(len(s_batch))
        ce = nn.CrossEntropyLoss()
        forward_mse = nn.MSELoss()

        with torch.no_grad():
            policy_old_list = torch.stack(old_policy).permute(1, 0, 2).contiguous().view(-1, self.output_size).to(
                self.device)

            m_old = Categorical(F.softmax(policy_old_list, dim=-1))
            log_prob_old = m_old.log_prob(y_batch)
            # ------------------------------------------------------------

        for i in range(self.epoch):
            np.random.shuffle(sample_range)
            for j in range(int(len(s_batch) / self.batch_size)):
                sample_idx = sample_range[self.batch_size * j:self.batch_size * (j + 1)]

                # --------------------------------------------------------------------------------
                # for Curiosity-driven
                action_onehot = torch.FloatTensor(self.batch_size, self.output_size).to(self.device)
                action_onehot.zero_()
                action_onehot.scatter_(1, y_batch[sample_idx].view(-1, 1), 1)
                real_next_state_feature, pred_next_state_feature, pred_action = self.icm(
                    [s_batch[sample_idx], next_s_batch[sample_idx], action_onehot])

                inverse_loss = ce(
                    pred_action, y_batch[sample_idx])

                forward_loss = forward_mse(
                    pred_next_state_feature, real_next_state_feature.detach())
                # ---------------------------------------------------------------------------------

                policy, value = self.model(s_batch[sample_idx])
                m = Categorical(F.softmax(policy, dim=-1))
                log_prob = m.log_prob(y_batch[sample_idx])

                ratio = torch.exp(log_prob - log_prob_old[sample_idx])

                surr1 = ratio * adv_batch[sample_idx]
                surr2 = torch.clamp(
                    ratio,
                    1.0 - self.ppo_eps,
                    1.0 + self.ppo_eps) * adv_batch[sample_idx]

                actor_loss = -torch.min(surr1, surr2).mean()
                critic_loss = F.mse_loss(
                    value.sum(1), target_batch[sample_idx])

                entropy = m.entropy().mean()

                self.optimizer.zero_grad()
                loss = (actor_loss + 0.5 * critic_loss - 0.001 * entropy) + forward_loss + inverse_loss
                loss.backward()
                # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
                self.optimizer.step()
示例#3
0
class GANAgent(object):
    def __init__(self,
                 input_size,
                 output_size,
                 num_env,
                 num_step,
                 gamma,
                 lam=0.95,
                 learning_rate=1e-4,
                 ent_coef=0.01,
                 clip_grad_norm=0.5,
                 epoch=3,
                 batch_size=128,
                 ppo_eps=0.1,
                 update_proportion=0.25,
                 use_gae=True,
                 use_cuda=False,
                 use_noisy_net=False,
                 hidden_dim=512):
        self.model = CnnActorCriticNetwork(input_size, output_size,
                                           use_noisy_net)
        self.num_env = num_env
        self.output_size = output_size
        self.input_size = input_size
        self.num_step = num_step
        self.gamma = gamma
        self.lam = lam
        self.epoch = epoch
        self.batch_size = batch_size
        self.use_gae = use_gae
        self.ent_coef = ent_coef
        self.ppo_eps = ppo_eps
        self.clip_grad_norm = clip_grad_norm
        self.update_proportion = update_proportion
        self.device = torch.device('cuda' if use_cuda else 'cpu')

        self.netG = NetG(z_dim=hidden_dim)  #(input_size, z_dim=hidden_dim)
        self.netD = NetD(z_dim=1)
        self.netG.apply(weights_init)
        self.netD.apply(weights_init)

        self.optimizer_policy = optim.Adam(list(self.model.parameters()),
                                           lr=learning_rate)
        self.optimizer_G = optim.Adam(list(self.netG.parameters()),
                                      lr=learning_rate,
                                      betas=(0.5, 0.999))
        self.optimizer_D = optim.Adam(list(self.netD.parameters()),
                                      lr=learning_rate,
                                      betas=(0.5, 0.999))

        self.netG = self.netG.to(self.device)
        self.netD = self.netD.to(self.device)

        self.model = self.model.to(self.device)

    def reconstruct(self, state):
        state = torch.Tensor(state).to(self.device)
        state = state.float()
        reconstructed = self.vae(state.unsqueeze(0))[0].squeeze(0)
        return reconstructed.detach().cpu().numpy()

    def get_action(self, state):
        state = torch.Tensor(state).to(self.device)
        state = state.float()
        policy, value_ext, value_int = self.model(state)
        action_prob = F.softmax(policy, dim=-1).data.cpu().numpy()

        action = self.random_choice_prob_index(action_prob)

        return action, value_ext.data.cpu().numpy().squeeze(
        ), value_int.data.cpu().numpy().squeeze(), policy.detach()

    @staticmethod
    def random_choice_prob_index(p, axis=1):
        r = np.expand_dims(np.random.rand(p.shape[1 - axis]), axis=axis)
        return (p.cumsum(axis=axis) > r).argmax(axis=axis)

    def compute_intrinsic_reward(self, obs):
        obs = torch.FloatTensor(obs).to(self.device)
        #embedding = self.vae.representation(obs)
        #reconstructed_embedding = self.vae.representation(self.vae(obs)[0]) # why use index[0]
        reconstructed_img, embedding, reconstructed_embedding = self.netG(obs)

        intrinsic_reward = (embedding - reconstructed_embedding
                            ).pow(2).sum(1) / 2  # Not use reconstructed loss

        return intrinsic_reward.detach().cpu().numpy()

    def train_model(self, s_batch, target_ext_batch, target_int_batch, y_batch,
                    adv_batch, next_obs_batch, old_policy):
        s_batch = torch.FloatTensor(s_batch).to(self.device)
        target_ext_batch = torch.FloatTensor(target_ext_batch).to(self.device)
        target_int_batch = torch.FloatTensor(target_int_batch).to(self.device)
        y_batch = torch.LongTensor(y_batch).to(self.device)
        adv_batch = torch.FloatTensor(adv_batch).to(self.device)
        next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device)

        sample_range = np.arange(len(s_batch))
        #reconstruction_loss = nn.MSELoss(reduction='none')]
        l_adv = nn.MSELoss(reduction='none')
        l_con = nn.L1Loss(reduction='none')
        l_enc = nn.MSELoss(reduction='none')
        l_bce = nn.BCELoss(reduction='none')

        with torch.no_grad():
            policy_old_list = torch.stack(old_policy).permute(
                1, 0, 2).contiguous().view(-1,
                                           self.output_size).to(self.device)

            m_old = Categorical(F.softmax(policy_old_list, dim=-1))
            log_prob_old = m_old.log_prob(y_batch)
            # ------------------------------------------------------------

        #recon_losses = np.array([])
        #kld_losses = np.array([])
        mean_err_g_adv_per_batch = np.array([])
        mean_err_g_con_per_batch = np.array([])
        mean_err_g_enc_per_batch = np.array([])
        mean_err_d_per_batch = np.array([])

        for i in range(self.epoch):
            np.random.shuffle(sample_range)
            for j in range(int(len(s_batch) / self.batch_size)):
                sample_idx = sample_range[self.batch_size * j:self.batch_size *
                                          (j + 1)]

                # --------------------------------------------------------------------------------
                # for generative curiosity (GAN loss)
                #gen_next_state, mu, logvar = self.vae(next_obs_batch[sample_idx])
                ############### netG forward ##############################################
                gen_next_state, latent_i, latent_o = self.netG(
                    next_obs_batch[sample_idx])

                ############### netD forward ##############################################
                pred_real, feature_real = self.netD(next_obs_batch[sample_idx])
                pred_fake, feature_fake = self.netD(gen_next_state)

                #d = len(gen_next_state.shape)
                #recon_loss = reconstruction_loss(gen_next_state, next_obs_batch[sample_idx]).mean(axis=list(range(1, d)))
                ############### netG backward #############################################
                self.optimizer_G.zero_grad()

                err_g_adv_per_img = l_adv(
                    self.netD(next_obs_batch[sample_idx])[1],
                    self.netD(gen_next_state)[1]).mean(
                        axis=list(range(1, len(feature_real.shape))))
                err_g_con_per_img = l_con(
                    next_obs_batch[sample_idx], gen_next_state).mean(
                        axis=list(range(1, len(gen_next_state.shape))))
                err_g_enc_per_img = l_enc(latent_i, latent_o).mean(-1)

                #kld_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(axis=1)

                # TODO: keep this proportion of experience used for VAE update?
                # Proportion of experience used for VAE update
                img_num = len(err_g_con_per_img)
                mask = torch.rand(img_num).to(self.device)
                mask = (mask < self.update_proportion).type(
                    torch.FloatTensor).to(self.device)
                mean_err_g_adv = (err_g_adv_per_img * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))
                mean_err_g_con = (err_g_con_per_img * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))
                mean_err_g_enc = (err_g_enc_per_img * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))

                # hyperparameter weights:
                w_adv = 1
                w_con = 50
                w_enc = 1

                mean_err_g = mean_err_g_adv * w_adv +\
                        mean_err_g_con * w_con +\
                        mean_err_g_enc * w_enc
                mean_err_g.backward(retain_graph=True)

                self.optimizer_G.step()

                mean_err_g_adv_per_batch = np.append(
                    mean_err_g_adv_per_batch,
                    mean_err_g_adv.detach().cpu().numpy())
                mean_err_g_con_per_batch = np.append(
                    mean_err_g_con_per_batch,
                    mean_err_g_con.detach().cpu().numpy())
                mean_err_g_enc_per_batch = np.append(
                    mean_err_g_enc_per_batch,
                    mean_err_g_enc.detach().cpu().numpy())

                ############## netD backward ##############################################
                self.optimizer_D.zero_grad()

                real_label = torch.ones_like(pred_real).to(self.device)
                fake_label = torch.zeros_like(pred_fake).to(self.device)

                err_d_real_per_img = l_bce(pred_real, real_label)
                err_d_fake_per_img = l_bce(pred_fake, fake_label)
                mean_err_d_real = (err_d_real_per_img *
                                   mask).sum() / torch.max(
                                       mask.sum(),
                                       torch.Tensor([1]).to(self.device))
                mean_err_d_fake = (err_d_fake_per_img *
                                   mask).sum() / torch.max(
                                       mask.sum(),
                                       torch.Tensor([1]).to(self.device))

                mean_err_d = (mean_err_d_real + mean_err_d_fake) / 2
                mean_err_d.backward()
                self.optimizer_D.step()

                mean_err_d_per_batch = np.append(
                    mean_err_d_per_batch,
                    mean_err_d.detach().cpu().numpy())

                if mean_err_d.item() < 1e-5:
                    self.netD.apply(weights_init)
                    print('Reloading net d')
                ############# policy update ###############################################

                policy, value_ext, value_int = self.model(s_batch[sample_idx])
                m = Categorical(F.softmax(policy, dim=-1))
                log_prob = m.log_prob(y_batch[sample_idx])

                ratio = torch.exp(log_prob - log_prob_old[sample_idx])

                surr1 = ratio * adv_batch[sample_idx]
                surr2 = torch.clamp(ratio, 1.0 - self.ppo_eps,
                                    1.0 + self.ppo_eps) * adv_batch[sample_idx]

                actor_loss = -torch.min(surr1, surr2).mean()
                critic_ext_loss = F.mse_loss(value_ext.sum(1),
                                             target_ext_batch[sample_idx])
                critic_int_loss = F.mse_loss(value_int.sum(1),
                                             target_int_batch[sample_idx])

                critic_loss = critic_ext_loss + critic_int_loss

                entropy = m.entropy().mean()

                self.optimizer_policy.zero_grad()
                loss = actor_loss + 0.5 * critic_loss - self.ent_coef * entropy
                loss.backward()
                #global_grad_norm_(list(self.model.parameters())+list(self.vae.parameters())) do we need this step
                #global_grad_norm_(list(self.model.parameter())) or just norm policy
                self.optimizer_poilicy.step()

        return mean_err_g_adv_per_batch, mean_err_g_con_per_batch, mean_err_g_enc_per_batch, mean_err_d_per_batch

    def train_just_vae(self, s_batch, next_obs_batch):
        s_batch = torch.FloatTensor(s_batch).to(self.device)
        next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device)

        sample_range = np.arange(len(s_batch))

        l_adv = nn.MSELoss(reduction='none')
        l_con = nn.L1Loss(reduction='none')
        l_enc = nn.MSELoss(reduction='none')
        l_bce = nn.BCELoss(reduction='none')

        mean_err_g_adv_per_batch = np.array([])
        mean_err_g_con_per_batch = np.array([])
        mean_err_g_enc_per_batch = np.array([])
        mean_err_d_per_batch = np.array([])

        for i in range(self.epoch):
            np.random.shuffle(sample_range)
            for j in range(int(len(s_batch) / self.batch_size)):
                sample_idx = sample_range[self.batch_size * j:self.batch_size *
                                          (j + 1)]

                ############### netG forward ##############################################
                gen_next_state, latent_i, latent_o = self.netG(
                    next_obs_batch[sample_idx])

                ############### netD forward ##############################################
                pred_real, feature_real = self.netD(next_obs_batch[sample_idx])
                pred_fake, feature_fake = self.netD(gen_next_state)

                #d = len(gen_next_state.shape)
                #recon_loss = reconstruction_loss(gen_next_state, next_obs_batch[sample_idx]).mean(axis=list(range(1, d)))
                ############### netG backward #############################################
                self.optimizer_G.zero_grad()

                err_g_adv_per_img = l_adv(
                    self.netD(next_obs_batch[sample_idx])[1],
                    self.netD(gen_next_state)[1]).mean(
                        axis=list(range(1, len(feature_real.shape))))
                err_g_con_per_img = l_con(
                    next_obs_batch[sample_idx], gen_next_state).mean(
                        axis=list(range(1, len(gen_next_state.shape))))
                err_g_enc_per_img = l_enc(latent_i, latent_o).mean(-1)

                #kld_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(axis=1)

                # TODO: keep this proportion of experience used for VAE update?
                # Proportion of experience used for VAE update
                img_num = len(err_g_con_per_img)
                mask = torch.rand(img_num).to(self.device)
                mask = (mask < self.update_proportion).type(
                    torch.FloatTensor).to(self.device)
                mean_err_g_adv = (err_g_adv_per_img * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))
                mean_err_g_con = (err_g_con_per_img * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))
                mean_err_g_enc = (err_g_enc_per_img * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))

                # hyperparameter weights:
                w_adv = 1
                w_con = 50
                w_enc = 1

                mean_err_g = mean_err_g_adv * w_adv +\
                        mean_err_g_con * w_con +\
                        mean_err_g_enc * w_enc
                mean_err_g.backward(retain_graph=True)

                self.optimizer_G.step()

                mean_err_g_adv_per_batch = np.append(
                    mean_err_g_adv_per_batch,
                    mean_err_g_adv.detach().cpu().numpy())
                mean_err_g_con_per_batch = np.append(
                    mean_err_g_con_per_batch,
                    mean_err_g_con.detach().cpu().numpy())
                mean_err_g_enc_per_batch = np.append(
                    mean_err_g_enc_per_batch,
                    mean_err_g_enc.detach().cpu().numpy())

                ############## netD backward ##############################################
                self.optimizer_D.zero_grad()

                real_label = torch.ones_like(pred_real).to(self.device)
                fake_label = torch.zeros_like(pred_fake).to(self.device)

                err_d_real_per_img = l_bce(pred_real, real_label)
                err_d_fake_per_img = l_bce(pred_fake, fake_label)
                mean_err_d_real = (err_d_real_per_img *
                                   mask).sum() / torch.max(
                                       mask.sum(),
                                       torch.Tensor([1]).to(self.device))
                mean_err_d_fake = (err_d_fake_per_img *
                                   mask).sum() / torch.max(
                                       mask.sum(),
                                       torch.Tensor([1]).to(self.device))

                mean_err_d = (mean_err_d_real + mean_err_d_fake) / 2
                mean_err_d.backward()
                self.optimizer_D.step()

                mean_err_d_per_batch = np.append(
                    mean_err_d_per_batch,
                    mean_err_d.detach().cpu().numpy())

        return mean_err_g_adv_per_batch, mean_err_g_con_per_batch, mean_err_g_enc_per_batch, mean_err_d_per_batch
class RNDAgent(object):
    def __init__(self,
                 input_size,
                 output_size,
                 num_env,
                 num_step,
                 gamma,
                 lam=0.95,
                 learning_rate=1e-4,
                 ent_coef=0.01,
                 clip_grad_norm=0.5,
                 epoch=3,
                 batch_size=128,
                 ppo_eps=0.1,
                 update_proportion=0.25,
                 use_gae=True,
                 use_cuda=False,
                 use_noisy_net=False):
        self.model = CnnActorCriticNetwork(input_size, output_size,
                                           use_noisy_net)
        self.num_env = num_env
        self.output_size = output_size
        self.input_size = input_size
        self.num_step = num_step
        self.gamma = gamma
        self.lam = lam
        self.epoch = epoch
        self.batch_size = batch_size
        self.use_gae = use_gae
        self.ent_coef = ent_coef
        self.ppo_eps = ppo_eps
        self.clip_grad_norm = clip_grad_norm
        self.update_proportion = update_proportion
        self.device = torch.device('cuda' if use_cuda else 'cpu')

        self.rnd = RNDModel(input_size, output_size)
        self.optimizer = optim.Adam(list(self.model.parameters()) +
                                    list(self.rnd.predictor.parameters()),
                                    lr=learning_rate)
        self.rnd = self.rnd.to(self.device)

        self.model = self.model.to(self.device)

    def get_action(self, state):
        state = torch.Tensor(state).to(self.device)
        state = state.float()
        policy, value_ext, value_int = self.model(state)
        action_prob = F.softmax(policy, dim=-1).data.cpu().numpy()

        action = self.random_choice_prob_index(action_prob)

        return action, value_ext.data.cpu().numpy().squeeze(
        ), value_int.data.cpu().numpy().squeeze(), policy.detach()

    @staticmethod
    def random_choice_prob_index(p, axis=1):
        r = np.expand_dims(np.random.rand(p.shape[1 - axis]), axis=axis)
        return (p.cumsum(axis=axis) > r).argmax(axis=axis)

    def compute_intrinsic_reward(self, next_obs):
        next_obs = torch.FloatTensor(next_obs).to(self.device)

        target_next_feature = self.rnd.target(next_obs)
        predict_next_feature = self.rnd.predictor(next_obs)
        intrinsic_reward = (target_next_feature -
                            predict_next_feature).pow(2).sum(1) / 2

        return intrinsic_reward.data.cpu().numpy()

    def train_model(self, s_batch, target_ext_batch, target_int_batch, y_batch,
                    adv_batch, next_obs_batch, old_policy):
        s_batch = torch.FloatTensor(s_batch).to(self.device)
        target_ext_batch = torch.FloatTensor(target_ext_batch).to(self.device)
        target_int_batch = torch.FloatTensor(target_int_batch).to(self.device)
        y_batch = torch.LongTensor(y_batch).to(self.device)
        adv_batch = torch.FloatTensor(adv_batch).to(self.device)
        next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device)

        sample_range = np.arange(len(s_batch))
        forward_mse = nn.MSELoss(reduction='none')

        with torch.no_grad():
            policy_old_list = torch.stack(old_policy).permute(
                1, 0, 2).contiguous().view(-1,
                                           self.output_size).to(self.device)

            m_old = Categorical(F.softmax(policy_old_list, dim=-1))
            log_prob_old = m_old.log_prob(y_batch)
            # ------------------------------------------------------------

        for i in range(self.epoch):
            np.random.shuffle(sample_range)
            for j in range(int(len(s_batch) / self.batch_size)):
                sample_idx = sample_range[self.batch_size * j:self.batch_size *
                                          (j + 1)]

                # --------------------------------------------------------------------------------
                # for Curiosity-driven(Random Network Distillation)
                predict_next_state_feature, target_next_state_feature = self.rnd(
                    next_obs_batch[sample_idx])

                forward_loss = forward_mse(
                    predict_next_state_feature,
                    target_next_state_feature.detach()).mean(-1)
                # Proportion of exp used for predictor update
                mask = torch.rand(len(forward_loss)).to(self.device)
                mask = (mask < self.update_proportion).type(
                    torch.FloatTensor).to(self.device)
                forward_loss = (forward_loss * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))
                # ---------------------------------------------------------------------------------

                policy, value_ext, value_int = self.model(s_batch[sample_idx])
                m = Categorical(F.softmax(policy, dim=-1))
                log_prob = m.log_prob(y_batch[sample_idx])

                ratio = torch.exp(log_prob - log_prob_old[sample_idx])

                surr1 = ratio * adv_batch[sample_idx]
                surr2 = torch.clamp(ratio, 1.0 - self.ppo_eps,
                                    1.0 + self.ppo_eps) * adv_batch[sample_idx]

                actor_loss = -torch.min(surr1, surr2).mean()
                critic_ext_loss = F.mse_loss(value_ext.sum(1),
                                             target_ext_batch[sample_idx])
                critic_int_loss = F.mse_loss(value_int.sum(1),
                                             target_int_batch[sample_idx])

                critic_loss = critic_ext_loss + critic_int_loss

                entropy = m.entropy().mean()

                self.optimizer.zero_grad()
                loss = actor_loss + 0.5 * critic_loss - self.ent_coef * entropy + forward_loss
                loss.backward()
                global_grad_norm_(
                    list(self.model.parameters()) +
                    list(self.rnd.predictor.parameters()))
                self.optimizer.step()
示例#5
0
def main():
    args = get_args()
    device = torch.device('cuda' if args.cuda else 'cpu')

    env = gym.make(args.env_name)

    input_size = env.observation_space.shape  # 4
    output_size = env.action_space.n  # 2

    if 'Breakout' in args.env_name:
        output_size -= 1

    env.close()

    is_render = False
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    model_path = os.path.join(args.save_dir, args.env_name + '.model')
    predictor_path = os.path.join(args.save_dir, args.env_name + '.pred')
    target_path = os.path.join(args.save_dir, args.env_name + '.target')

    writer = SummaryWriter(log_dir=args.log_dir)

    reward_rms = RunningMeanStd()
    obs_rms = RunningMeanStd(shape=(1, 1, 84, 84))
    discounted_reward = RewardForwardFilter(args.ext_gamma)

    model = CnnActorCriticNetwork(input_size, output_size, args.use_noisy_net)
    rnd = RNDModel(input_size, output_size)
    model = model.to(device)
    rnd = rnd.to(device)
    optimizer = optim.Adam(list(model.parameters()) +
                           list(rnd.predictor.parameters()),
                           lr=args.lr)

    if args.load_model:
        if args.cuda:
            model.load_state_dict(torch.load(model_path))
        else:
            model.load_state_dict(torch.load(model_path, map_location='cpu'))

    works = []
    parent_conns = []
    child_conns = []
    for idx in range(args.num_worker):
        parent_conn, child_conn = Pipe()
        work = AtariEnvironment(args.env_name,
                                is_render,
                                idx,
                                child_conn,
                                sticky_action=args.sticky_action,
                                p=args.sticky_action_prob,
                                max_episode_steps=args.max_episode_steps)
        work.start()
        works.append(work)
        parent_conns.append(parent_conn)
        child_conns.append(child_conn)

    states = np.zeros([args.num_worker, 4, 84, 84])

    sample_env_index = 0  # Sample Environment index to log
    sample_episode = 0
    sample_rall = 0
    sample_step = 0
    sample_i_rall = 0
    global_update = 0
    global_step = 0

    # normalize observation
    print('Initializes observation normalization...')
    next_obs = []
    for step in range(args.num_step * args.pre_obs_norm_steps):
        actions = np.random.randint(0, output_size, size=(args.num_worker, ))

        for parent_conn, action in zip(parent_conns, actions):
            parent_conn.send(action)

        for parent_conn in parent_conns:
            next_state, reward, done, realdone, log_reward = parent_conn.recv()
            next_obs.append(next_state[3, :, :].reshape([1, 84, 84]))

        if len(next_obs) % (args.num_step * args.num_worker) == 0:
            next_obs = np.stack(next_obs)
            obs_rms.update(next_obs)
            next_obs = []

    print('Training...')
    while True:
        total_state, total_reward, total_done, total_next_state, total_action, total_int_reward, total_next_obs, total_ext_values, total_int_values, total_action_probs = [], [], [], [], [], [], [], [], [], []
        global_step += (args.num_worker * args.num_step)
        global_update += 1

        # Step 1. n-step rollout
        for _ in range(args.num_step):
            actions, value_ext, value_int, action_probs = get_action(
                model, device,
                np.float32(states) / 255.)

            for parent_conn, action in zip(parent_conns, actions):
                parent_conn.send(action)

            next_states, rewards, dones, real_dones, log_rewards, next_obs = [], [], [], [], [], []
            for parent_conn in parent_conns:
                next_state, reward, done, real_done, log_reward = parent_conn.recv(
                )
                next_states.append(next_state)
                rewards.append(reward)
                dones.append(done)
                real_dones.append(real_done)
                log_rewards.append(log_reward)
                next_obs.append(next_state[3, :, :].reshape([1, 84, 84]))

            next_states = np.stack(next_states)
            rewards = np.hstack(rewards)
            dones = np.hstack(dones)
            real_dones = np.hstack(real_dones)
            next_obs = np.stack(next_obs)

            # total reward = int reward + ext Reward
            intrinsic_reward = compute_intrinsic_reward(
                rnd, device,
                ((next_obs - obs_rms.mean) / np.sqrt(obs_rms.var)).clip(-5, 5))
            intrinsic_reward = np.hstack(intrinsic_reward)
            sample_i_rall += intrinsic_reward[sample_env_index]

            total_next_obs.append(next_obs)
            total_int_reward.append(intrinsic_reward)
            total_state.append(states)
            total_reward.append(rewards)
            total_done.append(dones)
            total_action.append(actions)
            total_ext_values.append(value_ext)
            total_int_values.append(value_int)
            total_action_probs.append(action_probs)

            states = next_states[:, :, :, :]

            sample_rall += log_rewards[sample_env_index]

            sample_step += 1
            if real_dones[sample_env_index]:
                sample_episode += 1
                writer.add_scalar('data/reward_per_epi', sample_rall,
                                  sample_episode)
                writer.add_scalar('data/reward_per_rollout', sample_rall,
                                  global_update)
                writer.add_scalar('data/step', sample_step, sample_episode)
                sample_rall = 0
                sample_step = 0
                sample_i_rall = 0

        # calculate last next value
        _, value_ext, value_int, _ = get_action(model, device,
                                                np.float32(states) / 255.)
        total_ext_values.append(value_ext)
        total_int_values.append(value_int)
        # --------------------------------------------------

        total_state = np.stack(total_state).transpose([1, 0, 2, 3, 4]).reshape(
            [-1, 4, 84, 84])
        total_reward = np.stack(total_reward).transpose().clip(-1, 1)
        total_action = np.stack(total_action).transpose().reshape([-1])
        total_done = np.stack(total_done).transpose()
        total_next_obs = np.stack(total_next_obs).transpose(
            [1, 0, 2, 3, 4]).reshape([-1, 1, 84, 84])
        total_ext_values = np.stack(total_ext_values).transpose()
        total_int_values = np.stack(total_int_values).transpose()
        total_logging_action_probs = np.vstack(total_action_probs)

        # Step 2. calculate intrinsic reward
        # running mean intrinsic reward
        total_int_reward = np.stack(total_int_reward).transpose()
        total_reward_per_env = np.array([
            discounted_reward.update(reward_per_step)
            for reward_per_step in total_int_reward.T
        ])
        mean, std, count = np.mean(total_reward_per_env), np.std(
            total_reward_per_env), len(total_reward_per_env)
        reward_rms.update_from_moments(mean, std**2, count)

        # normalize intrinsic reward
        total_int_reward /= np.sqrt(reward_rms.var)
        writer.add_scalar('data/int_reward_per_epi',
                          np.sum(total_int_reward) / args.num_worker,
                          sample_episode)
        writer.add_scalar('data/int_reward_per_rollout',
                          np.sum(total_int_reward) / args.num_worker,
                          global_update)
        # -------------------------------------------------------------------------------------------

        # logging Max action probability
        writer.add_scalar('data/max_prob',
                          total_logging_action_probs.max(1).mean(),
                          sample_episode)

        # Step 3. make target and advantage
        # extrinsic reward calculate
        ext_target, ext_adv = make_train_data(total_reward, total_done,
                                              total_ext_values, args.ext_gamma,
                                              args.gae_lambda, args.num_step,
                                              args.num_worker, args.use_gae)

        # intrinsic reward calculate
        # None Episodic
        int_target, int_adv = make_train_data(total_int_reward,
                                              np.zeros_like(total_int_reward),
                                              total_int_values, args.int_gamma,
                                              args.gae_lambda, args.num_step,
                                              args.num_worker, args.use_gae)

        # add ext adv and int adv
        total_adv = int_adv * args.int_coef + ext_adv * args.ext_coef
        # -----------------------------------------------

        # Step 4. update obs normalize param
        obs_rms.update(total_next_obs)
        # -----------------------------------------------

        # Step 5. Training!
        train_model(args, device, output_size, model, rnd, optimizer,
                    np.float32(total_state) / 255., ext_target, int_target,
                    total_action, total_adv,
                    ((total_next_obs - obs_rms.mean) /
                     np.sqrt(obs_rms.var)).clip(-5, 5), total_action_probs)

        if global_step % (args.num_worker * args.num_step *
                          args.save_interval) == 0:
            print('Now Global Step :{}'.format(global_step))
            torch.save(model.state_dict(), model_path)
            torch.save(rnd.predictor.state_dict(), predictor_path)
            torch.save(rnd.target.state_dict(), target_path)
class GenerativeAgent(object):
    def __init__(self,
                 input_size,
                 output_size,
                 num_env,
                 num_step,
                 gamma,
                 history_size=4,
                 lam=0.95,
                 learning_rate=1e-4,
                 ent_coef=0.01,
                 clip_grad_norm=0.5,
                 epoch=3,
                 batch_size=128,
                 ppo_eps=0.1,
                 update_proportion=0.25,
                 use_gae=True,
                 use_cuda=False,
                 use_noisy_net=False,
                 hidden_dim=512):
        self.model = CnnActorCriticNetwork(input_size, output_size,
                                           use_noisy_net, history_size)
        self.num_env = num_env
        self.output_size = output_size
        self.input_size = input_size
        self.num_step = num_step
        self.gamma = gamma
        self.lam = lam
        self.epoch = epoch
        self.batch_size = batch_size
        self.use_gae = use_gae
        self.ent_coef = ent_coef
        self.ppo_eps = ppo_eps
        self.clip_grad_norm = clip_grad_norm
        self.update_proportion = update_proportion
        self.device = torch.device('cuda' if use_cuda else 'cpu')

        self.vae = VAE(input_size, z_dim=hidden_dim)
        self.optimizer = optim.Adam(list(self.model.parameters()) +
                                    list(self.vae.parameters()),
                                    lr=learning_rate)
        self.vae = self.vae.to(self.device)

        self.model = self.model.to(self.device)

    def reconstruct(self, state):
        state = torch.Tensor(state).to(self.device)
        state = state.float()
        reconstructed = self.vae(state.unsqueeze(0))[0].squeeze(0)
        return reconstructed.detach().cpu().numpy()

    def get_action(self, state):
        state = torch.Tensor(state).to(self.device)
        state = state.float()
        policy, value_ext, value_int = self.model(state)
        action_prob = F.softmax(policy, dim=-1).data.cpu().numpy()

        action = self.random_choice_prob_index(action_prob)

        return action, value_ext.data.cpu().numpy().squeeze(
        ), value_int.data.cpu().numpy().squeeze(), policy.detach()

    @staticmethod
    def random_choice_prob_index(p, axis=1):
        r = np.expand_dims(np.random.rand(p.shape[1 - axis]), axis=axis)
        return (p.cumsum(axis=axis) > r).argmax(axis=axis)

    def compute_intrinsic_reward(self, obs):
        obs = torch.FloatTensor(obs).to(self.device)
        embedding = self.vae.representation(obs)
        reconstructed_embedding = self.vae.representation(self.vae(obs)[0])

        intrinsic_reward = (embedding -
                            reconstructed_embedding).pow(2).sum(1) / 2

        return intrinsic_reward.detach().cpu().numpy()

    def train_model(self, s_batch, target_ext_batch, target_int_batch, y_batch,
                    adv_batch, next_obs_batch, old_policy):
        s_batch = torch.FloatTensor(s_batch).to(self.device)
        target_ext_batch = torch.FloatTensor(target_ext_batch).to(self.device)
        target_int_batch = torch.FloatTensor(target_int_batch).to(self.device)
        y_batch = torch.LongTensor(y_batch).to(self.device)
        adv_batch = torch.FloatTensor(adv_batch).to(self.device)
        next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device)

        sample_range = np.arange(len(s_batch))
        reconstruction_loss = nn.MSELoss(reduction='none')

        with torch.no_grad():
            policy_old_list = torch.stack(old_policy).permute(
                1, 0, 2).contiguous().view(-1,
                                           self.output_size).to(self.device)

            m_old = Categorical(F.softmax(policy_old_list, dim=-1))
            log_prob_old = m_old.log_prob(y_batch)
            # ------------------------------------------------------------

        recon_losses = np.array([])
        kld_losses = np.array([])

        for i in range(self.epoch):
            np.random.shuffle(sample_range)
            for j in range(int(len(s_batch) / self.batch_size)):
                sample_idx = sample_range[self.batch_size * j:self.batch_size *
                                          (j + 1)]

                # --------------------------------------------------------------------------------
                # for generative curiosity (VAE loss)
                gen_next_state, mu, logvar = self.vae(
                    next_obs_batch[sample_idx])

                d = len(gen_next_state.shape)
                recon_loss = reconstruction_loss(
                    gen_next_state,
                    next_obs_batch[sample_idx]).mean(axis=list(range(1, d)))

                kld_loss = -0.5 * (1 + logvar - mu.pow(2) -
                                   logvar.exp()).sum(axis=1)

                # TODO: keep this proportion of experience used for VAE update?
                # Proportion of experience used for VAE update
                mask = torch.rand(len(recon_loss)).to(self.device)
                mask = (mask < self.update_proportion).type(
                    torch.FloatTensor).to(self.device)
                recon_loss = (recon_loss * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))
                kld_loss = (kld_loss * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))

                recon_losses = np.append(recon_losses,
                                         recon_loss.detach().cpu().numpy())
                kld_losses = np.append(kld_losses,
                                       kld_loss.detach().cpu().numpy())
                # ---------------------------------------------------------------------------------

                policy, value_ext, value_int = self.model(s_batch[sample_idx])
                m = Categorical(F.softmax(policy, dim=-1))
                log_prob = m.log_prob(y_batch[sample_idx])

                ratio = torch.exp(log_prob - log_prob_old[sample_idx])

                surr1 = ratio * adv_batch[sample_idx]
                surr2 = torch.clamp(ratio, 1.0 - self.ppo_eps,
                                    1.0 + self.ppo_eps) * adv_batch[sample_idx]

                actor_loss = -torch.min(surr1, surr2).mean()
                critic_ext_loss = F.mse_loss(value_ext.sum(1),
                                             target_ext_batch[sample_idx])
                critic_int_loss = F.mse_loss(value_int.sum(1),
                                             target_int_batch[sample_idx])

                critic_loss = critic_ext_loss + critic_int_loss

                entropy = m.entropy().mean()

                self.optimizer.zero_grad()
                loss = actor_loss + 0.5 * critic_loss - self.ent_coef * entropy + recon_loss + kld_loss
                loss.backward()
                global_grad_norm_(
                    list(self.model.parameters()) +
                    list(self.vae.parameters()))
                self.optimizer.step()

        return recon_losses, kld_losses

    def train_just_vae(self, s_batch, next_obs_batch):
        s_batch = torch.FloatTensor(s_batch).to(self.device)
        next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device)

        sample_range = np.arange(len(s_batch))
        reconstruction_loss = nn.MSELoss(reduction='none')

        recon_losses = np.array([])
        kld_losses = np.array([])

        for i in range(self.epoch):
            np.random.shuffle(sample_range)
            for j in range(int(len(s_batch) / self.batch_size)):
                sample_idx = sample_range[self.batch_size * j:self.batch_size *
                                          (j + 1)]

                # --------------------------------------------------------------------------------
                # for generative curiosity (VAE loss)
                gen_next_state, mu, logvar = self.vae(
                    next_obs_batch[sample_idx])

                d = len(gen_next_state.shape)
                recon_loss = -1 * pytorch_ssim.ssim(gen_next_state,
                                                    next_obs_batch[sample_idx],
                                                    size_average=False)
                # recon_loss = reconstruction_loss(gen_next_state, next_obs_batch[sample_idx]).mean(axis=list(range(1, d)))

                kld_loss = -0.5 * (1 + logvar - mu.pow(2) -
                                   logvar.exp()).sum(axis=1)

                # TODO: keep this proportion of experience used for VAE update?
                # Proportion of experience used for VAE update
                mask = torch.rand(len(recon_loss)).to(self.device)
                mask = (mask < self.update_proportion).type(
                    torch.FloatTensor).to(self.device)
                recon_loss = (recon_loss * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))
                kld_loss = (kld_loss * mask).sum() / torch.max(
                    mask.sum(),
                    torch.Tensor([1]).to(self.device))

                recon_losses = np.append(recon_losses,
                                         recon_loss.detach().cpu().numpy())
                kld_losses = np.append(kld_losses,
                                       kld_loss.detach().cpu().numpy())
                # ---------------------------------------------------------------------------------

                self.optimizer.zero_grad()
                loss = recon_loss + kld_loss
                loss.backward()
                global_grad_norm_(list(self.vae.parameters()))
                self.optimizer.step()

        return recon_losses, kld_losses