Beispiel #1
0
class trpo_agent:
    def __init__(self, env, args):
        self.env = env
        self.args = args

        # define the network
        self.net = Network(self.env.observation_space.shape[0],
                           self.env.action_space.shape[0])
        self.old_net = Network(self.env.observation_space.shape[0],
                               self.env.action_space.shape[0])

        # make sure the net and old net have the same parameters
        self.old_net.load_state_dict(self.net.state_dict())

        # define the optimizer
        self.optimizer = torch.optim.Adam(self.net.critic.parameters(),
                                          lr=self.args.lr)

        # define the running mean filter
        self.running_state = ZFilter((self.env.observation_space.shape[0], ),
                                     clip=5)

        if not os.path.exists(self.args.save_dir):
            os.mkdir(self.args.save_dir)
        self.model_path = self.args.save_dir + self.args.env_name
        if not os.path.exists(self.model_path):
            os.mkdir(self.model_path)

        self.start_episode = 0

    def learn(self):

        # configuration
        USER_SAVE_DATE = '3006'
        USER_SAVE_MODEL = 'mymodel.pt'
        CONTINUE_TRAINING = False  # False for new training, True for improving the existing model
        num_of_iteration = 0

        # paths
        date = USER_SAVE_DATE
        plot_path = self.model_path + '/' + date + '/plots/plot_'
        best_model_path = self.model_path + '/' + date + '/best/'
        all_model_path = self.model_path + '/' + date
        reward_path = self.model_path + '/' + date + '/rewards/'

        load_model = CONTINUE_TRAINING
        best_model = all_model_path + '/' + USER_SAVE_MODEL
        all_final_rewards = []

        num_updates = 1000000
        obs = self.running_state(self.env.reset())

        final_reward = 0
        episode_reward = 0
        self.dones = False

        # Load the best model for continuing training
        if load_model:
            print("=> Loading checkpoint...")
            checkpoint = torch.load(best_model)
            self.start_episode = checkpoint['update']
            self.net.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.running_state = checkpoint['running_state']
            final_reward = checkpoint['reward']
            all_final_rewards.append(final_reward)
            #print("=> loaded checkpoint (Episode: {}, reward: {})".format(checkpoint['update'], final_reward))

        for update in range(self.start_episode, num_updates):
            mb_obs, mb_rewards, mb_actions, mb_dones, mb_values = [], [], [], [], []
            for step in range(self.args.nsteps):
                with torch.no_grad():
                    obs_tensor = self._get_tensors(obs)
                    value, pi = self.net(obs_tensor)
                # select actions
                actions = select_actions(pi)
                # store informations
                mb_obs.append(np.copy(obs))
                mb_actions.append(actions)
                mb_dones.append(self.dones)
                mb_values.append(value.detach().numpy().squeeze())
                # start to execute actions in the environment
                obs_, reward, done, _ = self.env.step(actions)
                self.dones = done
                mb_rewards.append(reward)
                if done:
                    obs_ = self.env.reset()
                obs = self.running_state(obs_)
                episode_reward += reward
                mask = 0.0 if done else 1.0
                final_reward *= mask
                final_reward += (1 - mask) * episode_reward
                episode_reward *= mask
            # to process the rollouts
            mb_obs = np.asarray(mb_obs, dtype=np.float32)
            mb_rewards = np.asarray(mb_rewards, dtype=np.float32)
            mb_actions = np.asarray(mb_actions, dtype=np.float32)
            mb_dones = np.asarray(mb_dones, dtype=np.bool)
            mb_values = np.asarray(mb_values, dtype=np.float32)
            # compute the last state value
            with torch.no_grad():
                obs_tensor = self._get_tensors(obs)
                last_value, _ = self.net(obs_tensor)
                last_value = last_value.detach().numpy().squeeze()
            # compute the advantages
            mb_returns = np.zeros_like(mb_rewards)
            mb_advs = np.zeros_like(mb_rewards)
            lastgaelam = 0
            for t in reversed(range(self.args.nsteps)):
                if t == self.args.nsteps - 1:
                    nextnonterminal = 1.0 - self.dones
                    nextvalues = last_value
                else:
                    nextnonterminal = 1.0 - mb_dones[t + 1]
                    nextvalues = mb_values[t + 1]
                delta = mb_rewards[
                    t] + self.args.gamma * nextvalues * nextnonterminal - mb_values[
                        t]
                mb_advs[
                    t] = lastgaelam = delta + self.args.gamma * self.args.tau * nextnonterminal * lastgaelam
            mb_returns = mb_advs + mb_values
            # normalize the advantages
            mb_advs = (mb_advs - mb_advs.mean()) / (mb_advs.std() + 1e-5)
            # before the update, make the old network has the parameter of the current network
            self.old_net.load_state_dict(self.net.state_dict())
            # start to update the network
            policy_loss, value_loss = self._update_network(
                mb_obs, mb_actions, mb_returns, mb_advs)
            #torch.save([self.net.state_dict(), self.running_state], self.model_path + 'model.pt')

            print('Episode: {} / {}, Iteration: {}, Reward: {:.3f}'.format(
                update, num_updates, (update + 1) * self.args.nsteps,
                final_reward))

            all_final_rewards.append(final_reward.item())
            self.save_model_for_training(update,
                                         final_reward.item(),
                                         filepath=best_model_path +
                                         str(round(final_reward.item(), 2)) +
                                         '_' + str(update) + '.pt')

            torch.save([self.net.state_dict(), self.running_state],
                       self.model_path + "/" + date + "/" +
                       str(round(final_reward.item(), 2)) + str(update) +
                       '_testing' + ".pt")

            if update % self.args.display_interval == 0:
                fig = plt.figure()
                ax = fig.add_subplot(111)
                plt.plot(np.arange(len(all_final_rewards)), all_final_rewards)
                plt.ylabel('Reward')
                plt.xlabel('Episode #')
                plt.savefig(plot_path + str(update) + '.png')
                plt.plot()
                reward_df = pd.DataFrame(all_final_rewards)
                with open(reward_path + 'rewards.csv', 'a') as f:
                    reward_df.to_csv(f, header=False)

    def save_model_for_training(self, num_of_iteration, reward, filepath):
        checkpoint = {
            'update': num_of_iteration,
            'state_dict': self.net.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'running_state': self.running_state,
            'reward': reward
        }
        torch.save(checkpoint, filepath)

    # start to update network
    def _update_network(self, mb_obs, mb_actions, mb_returns, mb_advs):
        mb_obs_tensor = torch.tensor(mb_obs, dtype=torch.float32)
        mb_actions_tensor = torch.tensor(mb_actions, dtype=torch.float32)
        mb_returns_tensor = torch.tensor(mb_returns,
                                         dtype=torch.float32).unsqueeze(1)
        mb_advs_tensor = torch.tensor(mb_advs,
                                      dtype=torch.float32).unsqueeze(1)
        # try to get the old policy and current policy
        values, _ = self.net(mb_obs_tensor)
        with torch.no_grad():
            _, pi_old = self.old_net(mb_obs_tensor)
        # get the surr loss
        surr_loss = self._get_surrogate_loss(mb_obs_tensor, mb_advs_tensor,
                                             mb_actions_tensor, pi_old)
        # comupte the surrogate gardient -> g, Ax = g, where A is the fisher information matrix
        surr_grad = torch.autograd.grad(surr_loss, self.net.actor.parameters())
        flat_surr_grad = torch.cat([grad.view(-1) for grad in surr_grad]).data
        # use the conjugated gradient to calculate the scaled direction vector (natural gradient)
        nature_grad = conjugated_gradient(self._fisher_vector_product,
                                          -flat_surr_grad, 10, mb_obs_tensor,
                                          pi_old)
        # calculate the scaleing ratio
        non_scale_kl = 0.5 * (nature_grad * self._fisher_vector_product(
            nature_grad, mb_obs_tensor, pi_old)).sum(0, keepdim=True)
        scale_ratio = torch.sqrt(non_scale_kl / self.args.max_kl)
        final_nature_grad = nature_grad / scale_ratio[0]
        # calculate the expected improvement rate...
        expected_improve = (-flat_surr_grad * nature_grad).sum(
            0, keepdim=True) / scale_ratio[0]
        # get the flat param ...
        prev_params = torch.cat(
            [param.data.view(-1) for param in self.net.actor.parameters()])
        # start to do the line search
        success, new_params = line_search(self.net.actor, self._get_surrogate_loss, prev_params, final_nature_grad, \
                                expected_improve, mb_obs_tensor, mb_advs_tensor, mb_actions_tensor, pi_old)
        set_flat_params_to(self.net.actor, new_params)
        # then trying to update the critic network
        inds = np.arange(mb_obs.shape[0])
        for _ in range(self.args.vf_itrs):
            np.random.shuffle(inds)
            for start in range(0, mb_obs.shape[0], self.args.batch_size):
                end = start + self.args.batch_size
                mbinds = inds[start:end]
                mini_obs = mb_obs[mbinds]
                mini_returns = mb_returns[mbinds]
                # put things in the tensor
                mini_obs = torch.tensor(mini_obs, dtype=torch.float32)
                mini_returns = torch.tensor(mini_returns,
                                            dtype=torch.float32).unsqueeze(1)
                values, _ = self.net(mini_obs)
                v_loss = (mini_returns - values).pow(2).mean()
                self.optimizer.zero_grad()
                v_loss.backward()
                self.optimizer.step()
        return surr_loss.item(), v_loss.item()

    # get the surrogate loss
    def _get_surrogate_loss(self, obs, adv, actions, pi_old):
        _, pi = self.net(obs)
        log_prob = eval_actions(pi, actions)
        old_log_prob = eval_actions(pi_old, actions).detach()
        surr_loss = -torch.exp(log_prob - old_log_prob) * adv
        return surr_loss.mean()

    # the product of the fisher informaiton matrix and the nature gradient -> Ax
    def _fisher_vector_product(self, v, obs, pi_old):
        kl = self._get_kl(obs, pi_old)
        kl = kl.mean()
        # start to calculate the second order gradient of the KL
        kl_grads = torch.autograd.grad(kl,
                                       self.net.actor.parameters(),
                                       create_graph=True)
        flat_kl_grads = torch.cat([grad.view(-1) for grad in kl_grads])
        kl_v = (flat_kl_grads * torch.autograd.Variable(v)).sum()
        kl_second_grads = torch.autograd.grad(kl_v,
                                              self.net.actor.parameters())
        flat_kl_second_grads = torch.cat(
            [grad.contiguous().view(-1) for grad in kl_second_grads]).data
        flat_kl_second_grads = flat_kl_second_grads + self.args.damping * v
        return flat_kl_second_grads

    # get the kl divergence between two distributions
    def _get_kl(self, obs, pi_old):
        mean_old, std_old = pi_old
        _, pi = self.net(obs)
        mean, std = pi
        # start to calculate the kl-divergence
        kl = -torch.log(std / std_old) + (
            std.pow(2) + (mean - mean_old).pow(2)) / (2 * std_old.pow(2)) - 0.5
        return kl.sum(1, keepdim=True)

    # get the tensors
    def _get_tensors(self, obs):
        return torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
Beispiel #2
0
                    args)
            losses.append(loss)

            logging.info(
                'train_epoch: %d, loss: %.4f, loss_valid: %.4f, time spent: %.4f'
                % (train_epoch, loss, loss_valid, time() - search_start))

            logging.info('genotype: %s' % g)
            print('genotype_p: %s' % gp)
            if train_epoch > 1500:
                if abs(losses[-2] - losses[-1]) / losses[
                        -1] < 1e-4 / train_queue[0].shape[0] or np.isnan(
                            losses[-1]):
                    break
        if args.mode == 'binarydarts_mlp' or args.mode == 'traindarts':
            torch.save(model.state_dict(),
                       '/data3/chenxiangning/models/' + save_name)

    elif args.mode == 'hyperopt':
        start = time()
        from hyperopt import fmin, tpe, hp

        # def get_cfg_performance(cfg):
        # 	if dim == 2:
        # 		arch = {'unary': {'p': cfg['p_unary'], 'q': cfg['q_unary']},
        # 				'assist': {'p': cfg['p_assist'], 'q': cfg['q_assist']},
        # 				'binary': cfg['binary']}
        # 		rmse = get_arch_performance(arch, num_users, num_items, train_queue, test_queue, args)
        # 	elif dim == 3:
        # 		arch = {'unary': {'p': cfg['p_unary'], 'q': cfg['q_unary'], 'r': cfg['r_unary']},
        # 				'assist': {'p': cfg['p_assist'], 'q': cfg['q_assist'], 'r': cfg['r_assist']},
Beispiel #3
0
def train(train_feats,
          train_caps,
          val_feats,
          val_caps,
          train_prefix="",
          val_prefix="",
          epochs=EPOCHS,
          batch_size=BATCH_SIZE,
          max_seq_len=MAX_LEN,
          hidden_dim=HIDDEN_DIM,
          emb_dim=EMB_DIM,
          enc_seq_len=ENC_SEQ_LEN,
          enc_dim=ENC_DIM,
          clip_val=CLIP_VAL,
          teacher_force=TEACHER_FORCE_RAT,
          dropout_p=0.1,
          attn_activation="relu",
          epsilon=0.0005,
          weight_decay=WEIGHT_DECAY,
          lr=LEARNING_RATE,
          early_stopping=True,
          scheduler="step",
          attention=None,
          deep_out=False,
          checkpoint="",
          out_dir="Pytorch_Exp_Out",
          decoder=None):

    print("EXPERIMENT START ", time.asctime())

    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    # 1. Load the data

    train_captions = open(train_caps, mode='r', encoding='utf-8') \
        .read().strip().split('\n')
    train_features = open(train_feats, mode='r').read().strip().split('\n')
    train_features = [os.path.join(train_prefix, z) for z in train_features]

    assert len(train_captions) == len(train_features)

    if val_caps:
        val_captions = open(val_caps, mode='r', encoding='utf-8') \
            .read().strip().split('\n')

        val_features = open(val_feats, mode='r').read().strip().split('\n')
        val_features = [os.path.join(val_prefix, z) for z in val_features]

        assert len(val_captions) == len(val_features)

    # 2. Preprocess the data

    train_captions = normalize_strings(train_captions)
    train_data = list(zip(train_captions, train_features))
    train_data = filter_inputs(train_data)
    print("Total training instances: ", len(train_data))

    if val_caps:
        val_captions = normalize_strings(val_captions)
        val_data = list(zip(val_captions, val_features))
        val_data = filter_inputs(val_data)
        print("Total validation instances: ", len(val_data))

    vocab = Vocab()
    vocab.build_vocab(map(lambda x: x[0], train_data), max_size=10000)
    vocab.save(path=os.path.join(out_dir, 'vocab.txt'))
    print("Vocabulary size: ", vocab.n_words)

    # 3. Initialize the network, optimizer & loss function

    net = Network(hid_dim=hidden_dim,
                  out_dim=vocab.n_words,
                  sos_token=0,
                  eos_token=1,
                  pad_token=2,
                  teacher_forcing_rat=teacher_force,
                  emb_dim=emb_dim,
                  enc_seq_len=enc_seq_len,
                  enc_dim=enc_dim,
                  dropout_p=dropout_p,
                  deep_out=deep_out,
                  decoder=decoder,
                  attention=attention)
    net.to(DEVICE)

    if checkpoint:
        net.load_state_dict(torch.load(checkpoint))

    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
    loss_function = nn.NLLLoss()

    scheduler = set_scheduler(scheduler, optimizer)

    # 4. Train

    prev_val_l = sys.maxsize
    total_instances = 0
    total_steps = 0
    train_loss_log = []
    train_loss_log_batches = []
    train_penalty_log = []
    val_loss_log = []
    val_loss_log_batches = []
    val_bleu_log = []
    prev_bleu = sys.maxsize

    train_data = DataLoader(captions=map(lambda x: x[0], train_data),
                            sources=map(lambda x: x[1], train_data),
                            batch_size=batch_size,
                            vocab=vocab,
                            max_seq_len=max_seq_len)

    if val_caps:
        val_data = DataLoader(captions=map(lambda x: x[0], val_data),
                              sources=map(lambda x: x[1], val_data),
                              batch_size=batch_size,
                              vocab=vocab,
                              max_seq_len=max_seq_len,
                              val_multiref=True)

    training_start_time = time.time()

    for e in range(1, epochs + 1):
        print("Epoch ", e)

        tfr = _teacher_force(epochs, e, teacher_force)

        # train one epoch
        train_l, inst, steps, t, l_log, pen = train_epoch(
            model=net,
            loss_function=loss_function,
            optimizer=optimizer,
            data_iter=train_data,
            max_len=max_seq_len,
            clip_val=clip_val,
            epsilon=epsilon,
            teacher_forcing_rat=tfr)

        if scheduler is not None:
            scheduler.step()

        # epoch logs
        print("Training loss:\t", train_l)
        print("Instances:\t", inst)
        print("Steps:\t", steps)
        hours = t // 3600
        mins = (t % 3600) // 60
        secs = (t % 60)
        print("Time:\t{0}:{1}:{2}".format(hours, mins, secs))
        total_instances += inst
        total_steps += steps
        train_loss_log.append(train_l)
        train_loss_log_batches += l_log
        train_penalty_log.append(pen)
        print()

        # evaluate
        if val_caps:
            val_l, l_log, bleu = evaluate(model=net,
                                          loss_function=loss_function,
                                          data_iter=val_data,
                                          max_len=max_seq_len,
                                          epsilon=epsilon)

            # validation logs
            print("Validation loss: ", val_l)
            print("Validation BLEU-4: ", bleu)
            if bleu > prev_bleu:
                torch.save(net.state_dict(), os.path.join(out_dir, 'net.pt'))
            val_loss_log.append(val_l)
            val_bleu_log.append(bleu)
            val_loss_log_batches += l_log

        #sample model
        print("Sampling training data...")
        print()
        samples = sample(net,
                         train_data,
                         vocab,
                         samples=3,
                         max_len=max_seq_len)
        for t, s in samples:
            print("Target:\t", t)
            print("Predicted:\t", s)
            print()

        # if val_caps:
        #     print("Sampling validation data...")
        #     print()
        #     samples = sample(net, val_data, vocab, samples=3, max_len=max_seq_len)
        #     for t, s in samples:
        #         print("Target:\t", t)
        #         print("Predicted:\t", s)
        #         print()

        if val_caps:
            # If the validation loss after this epoch increased from the
            # previous epoch, wrap training.
            if prev_bleu > bleu and early_stopping:
                print("\nWrapping training after {0} epochs.\n".format(e + 1))
                break

            prev_val_l = val_l
            prev_bleu = bleu

    # Experiment summary logs.
    tot_time = time.time() - training_start_time
    hours = tot_time // 3600
    mins = (tot_time % 3600) // 60
    secs = (tot_time % 60)
    print("Total training time:\t{0}:{1}:{2}".format(hours, mins, secs))
    print("Total training instances:\t", total_instances)
    print("Total training steps:\t", total_steps)
    print()

    _write_loss_log("train_loss_log.txt", out_dir, train_loss_log)
    _write_loss_log("train_loss_log_batches.txt", out_dir,
                    train_loss_log_batches)
    _write_loss_log("train_penalty.txt", out_dir, train_penalty_log)

    if val_caps:
        _write_loss_log("val_loss_log.txt", out_dir, val_loss_log)
        _write_loss_log("val_loss_log_batches.txt", out_dir,
                        val_loss_log_batches)
        _write_loss_log("val_bleu4_log.txt", out_dir, val_bleu_log)

    print("EXPERIMENT END ", time.asctime())
Beispiel #4
0
class Agent():
    def __init__(self, gamma, epsilon, lr, n_actions=, input_dims,
                 mem_size, batch_size, eps_min=0.01, eps_dec=5e-7,
                 replace=1000, chkpt_dir='tmp/dueling_ddqn'):
        self.gamma = gamma
        self.epsilon = epsilon
        self.lr = lr
        self.n_actions = n_actions
        self.input_dims = input_dims
        self.batch_size = batch_size
        self.eps_min = eps_min
        self.eps_dec = eps_dec
        self.replace_target_cnt = replace
        self.chkpt_dir = chkpt_dir
        self.action_space = [i for i in range(self.n_actions)]
        self.learn_step_counter = 0

        self.memory = ReplayBuffer(mem_size, input_dims, n_actions)

        self.q_eval = Network(self.lr, self.n_actions,
                                   input_dims=self.input_dims,
                                   name='lunar_lander_dueling_ddqn_q_eval',
                                   chkpt_dir=self.chkpt_dir)

        self.q_next = Network(self.lr, self.n_actions,
                                   input_dims=self.input_dims,
                                   name='lunar_lander_dueling_ddqn_q_next',
                                   chkpt_dir=self.chkpt_dir)

    def choose_action(self, observation):
        if np.random.random() > self.epsilon:
            state = torch.tensor([observation],dtype=torch.float).to(self.q_eval.device)
            _, advantage = self.q_eval.forward(state)
            action = torch.argmax(advantage).item()
        else:
            action = np.random.choice(self.action_space)

        return action

    def store_transition(self, state, action, reward, state_, done):
        self.memory.store_transition(state, action, reward, state_, done)

    def replace_target_network(self):
        if self.learn_step_counter % self.replace_target_cnt == 0:
            self.q_next.load_state_dict(self.q_eval.state_dict())

    def decrement_epsilon(self):
        self.epsilon = self.epsilon - self.eps_dec \
                        if self.epsilon > self.eps_min else self.eps_min

    def save_models(self):
        self.q_eval.save_checkpoint()
        self.q_next.save_checkpoint()

    def load_models(self):
        self.q_eval.load_checkpoint()
        self.q_next.load_checkpoint()

    def learn(self):
        if self.memory.mem_cntr < self.batch_size:
            return

        self.q_eval.optimizer.zero_grad()

        self.replace_target_network()

        state, action, reward, new_state, done = \
                                self.memory.sample_buffer(self.batch_size)

        states = torch.tensor(state).to(self.q_eval.device)
        rewards = torch.tensor(reward).to(self.q_eval.device)
        dones = torch.tensor(done).to(self.q_eval.device)
        actions = torch.tensor(action).to(self.q_eval.device)
        states_ = torch.tensor(new_state).to(self.q_eval.device)

        indices = np.arange(self.batch_size)

        V_s, A_s = self.q_eval.forward(states)
        V_s_, A_s_ = self.q_next.forward(states_)

        V_s_eval, A_s_eval = self.q_eval.forward(states_)

        q_pred = torch.add(V_s,
                        (A_s - A_s.mean(dim=1, keepdim=True)))[indices, actions]
        q_next = torch.add(V_s_,
                        (A_s_ - A_s_.mean(dim=1, keepdim=True)))

        q_eval = torch.add(V_s_eval, (A_s_eval - A_s_eval.mean(dim=1,keepdim=True)))

        max_actions = torch.argmax(q_eval, dim=1)

        q_next[dones] = 0.0
        q_target = rewards + self.gamma*q_next[indices, max_actions]

        loss = self.q_eval.loss(q_target, q_pred).to(self.q_eval.device)
        loss.backward()
        self.q_eval.optimizer.step()
        self.learn_step_counter += 1

        self.decrement_epsilon()
Beispiel #5
0
class trpo_agent:
    def __init__(self, env, args):
        self.env = env
        self.args = args
        # define the network
        self.net = Network(self.env.observation_space.shape[0],
                           self.env.action_space.shape[0])
        self.old_net = Network(self.env.observation_space.shape[0],
                               self.env.action_space.shape[0])
        # make sure the net and old net have the same parameters
        self.old_net.load_state_dict(self.net.state_dict())
        # define the optimizer
        self.optimizer = torch.optim.Adam(self.net.critic.parameters(),
                                          lr=self.args.lr)
        # define the running mean filter
        self.running_state = ZFilter((self.env.observation_space.shape[0], ),
                                     clip=5)
        if not os.path.exists(self.args.save_dir):
            os.mkdir(self.args.save_dir)
        self.model_path = self.args.save_dir + self.args.env_name + '/'
        if not os.path.exists(self.model_path):
            os.mkdir(self.model_path)

    def learn(self):
        num_updates = self.args.total_timesteps // self.args.nsteps
        obs = self.running_state(self.env.reset())
        final_reward = 0
        episode_reward = 0
        self.dones = False
        for update in range(num_updates):
            mb_obs, mb_rewards, mb_actions, mb_dones, mb_values = [], [], [], [], []
            for step in range(self.args.nsteps):
                with torch.no_grad():
                    obs_tensor = self._get_tensors(obs)
                    value, pi = self.net(obs_tensor)
                # select actions
                actions = select_actions(pi)
                # store informations
                mb_obs.append(np.copy(obs))
                mb_actions.append(actions)
                mb_dones.append(self.dones)
                mb_values.append(value.detach().numpy().squeeze())
                # start to execute actions in the environment
                obs_, reward, done, _ = self.env.step(actions)
                self.dones = done
                mb_rewards.append(reward)
                if done:
                    obs_ = self.env.reset()
                obs = self.running_state(obs_)
                episode_reward += reward
                mask = 0.0 if done else 1.0
                final_reward *= mask
                final_reward += (1 - mask) * episode_reward
                episode_reward *= mask
            # to process the rollouts
            mb_obs = np.asarray(mb_obs, dtype=np.float32)
            mb_rewards = np.asarray(mb_rewards, dtype=np.float32)
            mb_actions = np.asarray(mb_actions, dtype=np.float32)
            mb_dones = np.asarray(mb_dones, dtype=np.bool)
            mb_values = np.asarray(mb_values, dtype=np.float32)
            # compute the last state value
            with torch.no_grad():
                obs_tensor = self._get_tensors(obs)
                last_value, _ = self.net(obs_tensor)
                last_value = last_value.detach().numpy().squeeze()
            # compute the advantages
            mb_returns = np.zeros_like(mb_rewards)
            mb_advs = np.zeros_like(mb_rewards)
            lastgaelam = 0
            for t in reversed(range(self.args.nsteps)):
                if t == self.args.nsteps - 1:
                    nextnonterminal = 1.0 - self.dones
                    nextvalues = last_value
                else:
                    nextnonterminal = 1.0 - mb_dones[t + 1]
                    nextvalues = mb_values[t + 1]
                delta = mb_rewards[
                    t] + self.args.gamma * nextvalues * nextnonterminal - mb_values[
                        t]
                mb_advs[
                    t] = lastgaelam = delta + self.args.gamma * self.args.tau * nextnonterminal * lastgaelam
            mb_returns = mb_advs + mb_values
            # normalize the advantages
            mb_advs = (mb_advs - mb_advs.mean()) / (mb_advs.std() + 1e-5)
            # before the update, make the old network has the parameter of the current network
            self.old_net.load_state_dict(self.net.state_dict())
            # start to update the network
            policy_loss, value_loss = self._update_network(
                mb_obs, mb_actions, mb_returns, mb_advs)
            torch.save([self.net.state_dict(), self.running_state],
                       self.model_path + 'model.pt')
            print('[{}] Update: {} / {}, Frames: {}, Reward: {:.3f}, VL: {:.3f}, PL: {:.3f}'.format(datetime.now(), update, \
                    num_updates, (update + 1)*self.args.nsteps, final_reward, value_loss, policy_loss))

    # start to update network
    def _update_network(self, mb_obs, mb_actions, mb_returns, mb_advs):
        mb_obs_tensor = torch.tensor(mb_obs, dtype=torch.float32)
        mb_actions_tensor = torch.tensor(mb_actions, dtype=torch.float32)
        mb_returns_tensor = torch.tensor(mb_returns,
                                         dtype=torch.float32).unsqueeze(1)
        mb_advs_tensor = torch.tensor(mb_advs,
                                      dtype=torch.float32).unsqueeze(1)
        # try to get the old policy and current policy
        values, _ = self.net(mb_obs_tensor)
        with torch.no_grad():
            _, pi_old = self.old_net(mb_obs_tensor)
        # get the surr loss
        surr_loss = self._get_surrogate_loss(mb_obs_tensor, mb_advs_tensor,
                                             mb_actions_tensor, pi_old)
        # comupte the surrogate gardient -> g, Ax = g, where A is the fisher information matrix
        surr_grad = torch.autograd.grad(surr_loss, self.net.actor.parameters())
        flat_surr_grad = torch.cat([grad.view(-1) for grad in surr_grad]).data
        # use the conjugated gradient to calculate the scaled direction vector (natural gradient)
        nature_grad = conjugated_gradient(self._fisher_vector_product,
                                          -flat_surr_grad, 10, mb_obs_tensor,
                                          pi_old)
        # calculate the scaleing ratio
        non_scale_kl = 0.5 * (nature_grad * self._fisher_vector_product(
            nature_grad, mb_obs_tensor, pi_old)).sum(0, keepdim=True)
        scale_ratio = torch.sqrt(non_scale_kl / self.args.max_kl)
        final_nature_grad = nature_grad / scale_ratio[0]
        # calculate the expected improvement rate...
        expected_improve = (-flat_surr_grad * nature_grad).sum(
            0, keepdim=True) / scale_ratio[0]
        # get the flat param ...
        prev_params = torch.cat(
            [param.data.view(-1) for param in self.net.actor.parameters()])
        # start to do the line search
        success, new_params = line_search(self.net.actor, self._get_surrogate_loss, prev_params, final_nature_grad, \
                                expected_improve, mb_obs_tensor, mb_advs_tensor, mb_actions_tensor, pi_old)
        set_flat_params_to(self.net.actor, new_params)
        # then trying to update the critic network
        inds = np.arange(mb_obs.shape[0])
        for _ in range(self.args.vf_itrs):
            np.random.shuffle(inds)
            for start in range(0, mb_obs.shape[0], self.args.batch_size):
                end = start + self.args.batch_size
                mbinds = inds[start:end]
                mini_obs = mb_obs[mbinds]
                mini_returns = mb_returns[mbinds]
                # put things in the tensor
                mini_obs = torch.tensor(mini_obs, dtype=torch.float32)
                mini_returns = torch.tensor(mini_returns,
                                            dtype=torch.float32).unsqueeze(1)
                values, _ = self.net(mini_obs)
                v_loss = (mini_returns - values).pow(2).mean()
                self.optimizer.zero_grad()
                v_loss.backward()
                self.optimizer.step()
        return surr_loss.item(), v_loss.item()

    # get the surrogate loss
    def _get_surrogate_loss(self, obs, adv, actions, pi_old):
        _, pi = self.net(obs)
        log_prob = eval_actions(pi, actions)
        old_log_prob = eval_actions(pi_old, actions).detach()
        surr_loss = -torch.exp(log_prob - old_log_prob) * adv
        return surr_loss.mean()

    # the product of the fisher informaiton matrix and the nature gradient -> Ax
    def _fisher_vector_product(self, v, obs, pi_old):
        kl = self._get_kl(obs, pi_old)
        kl = kl.mean()
        # start to calculate the second order gradient of the KL
        kl_grads = torch.autograd.grad(kl,
                                       self.net.actor.parameters(),
                                       create_graph=True)
        flat_kl_grads = torch.cat([grad.view(-1) for grad in kl_grads])
        kl_v = (flat_kl_grads * torch.autograd.Variable(v)).sum()
        kl_second_grads = torch.autograd.grad(kl_v,
                                              self.net.actor.parameters())
        flat_kl_second_grads = torch.cat(
            [grad.contiguous().view(-1) for grad in kl_second_grads]).data
        flat_kl_second_grads = flat_kl_second_grads + self.args.damping * v
        return flat_kl_second_grads

    # get the kl divergence between two distributions
    def _get_kl(self, obs, pi_old):
        mean_old, std_old = pi_old
        _, pi = self.net(obs)
        mean, std = pi
        # start to calculate the kl-divergence
        kl = -torch.log(std / std_old) + (
            std.pow(2) + (mean - mean_old).pow(2)) / (2 * std_old.pow(2)) - 0.5
        return kl.sum(1, keepdim=True)

    # get the tensors
    def _get_tensors(self, obs):
        return torch.tensor(obs, dtype=torch.float32).unsqueeze(0)