Exemplo n.º 1
0
    def step(policy_net, optimizer_policy, states, actions, advantages, traj_num):
        """update policy, objective function is average trajectory return"""
        log_probs = policy_net.get_log_prob(Variable(states), Variable(actions))
        print("old log probs", log_probs)
        policy_loss = -(log_probs * Variable(advantages)).sum() / traj_num

        # flat_grad = torch_utils.compute_flat_grad(policy_loss, policy_net.parameters(), create_graph=True).detach().numpy()
        # logger.log("gradient:" + str(flat_grad))
        #
        # # check what would be the outcome if we just add gradient to the current parameters
        # prev_params = torch_utils.get_flat_params_from(policy_net).detach().numpy()
        # logger.log("old_parameters" + str(prev_params))
        # logger.log("new_parameters handcoded plus" + str(prev_params + flat_grad * 0.01))
        # logger.log("new_parameters handcoded minus" + str(prev_params - flat_grad * 0.01))

        prev_params = torch_utils.get_flat_params_from(policy_net).detach().numpy()
        logger.log("old_parameters" + str(prev_params))

        optimizer_policy.zero_grad()
        policy_loss.backward()
        logger.record_tabular("policy_loss before", policy_loss.item())
        # for param in policy_net.parameters():
        #     logger.log("parameter_grad:" + str(param.grad))

        optimizer_policy.step()

        new_params = torch_utils.get_flat_params_from(policy_net).detach().numpy()
        logger.log("old_parameters" + str(new_params))

        # calculate new loss
        log_probs = policy_net.get_log_prob(Variable(states), Variable(actions))
        print("new log probs",log_probs)
        policy_loss = -(log_probs * Variable(advantages)).sum() / traj_num
        logger.record_tabular("policy_loss after", policy_loss.item())
 def get_kl_diff(old_param, new_param):
     prev_params = torch_utils.get_flat_params_from(policy_net)
     with torch.no_grad():
         torch_utils.set_flat_params_to(policy_net, old_param)
         log_old_prob = torch.clamp(policy_net.get_log_prob(
             Variable(states, volatile=True), Variable(actions)), min=np.log(1e-6))
         torch_utils.set_flat_params_to(policy_net, new_param)
         log_new_prob = torch.clamp(policy_net.get_log_prob(
             Variable(states, volatile=True), Variable(actions)), min=np.log(1e-6))
     torch_utils.set_flat_params_to(policy_net, prev_params)
     return torch.mean(torch.exp(log_old_prob) * (log_old_prob-log_new_prob)).numpy()
Exemplo n.º 3
0
    def _train_SGD(self):

        # TODO: we need to get here the right observations, actions and next_observations for the model
        # expert_observations, expert_actions, expert_next_observations = create_torch_var_from_paths(self.expert_data)
        # now train imitation policy using collect batch of expert_data with MLE on log prob since we have a Gaussian
        # TODO: do we train mean and variance? or only mean

        torch_input_batch, torch_output_batch = self.create_torch_var_from_paths(self.expert_data)

        # split data randomly into training and validation set, let's go with 70 - 30 split
        numTotalSamples = torch_input_batch.size(0)
        trainingSize = int(numTotalSamples*0.7)
        randomIndices = np.random.permutation(np.arange(numTotalSamples))
        trainingIndices = randomIndices[:trainingSize]
        validationIndices = randomIndices[trainingSize:]
        validation_input_batch = torch_input_batch[validationIndices]
        validation_output_batch = torch_output_batch[validationIndices]
        torch_input_batch = torch_input_batch[trainingIndices]
        torch_output_batch = torch_output_batch[trainingIndices]

        best_loss = np.inf
        losses = np.array([best_loss] * 25)
        with tqdm(total=self.n_itr, file=sys.stdout) as pbar:
            for epoch in range(self.n_itr+1):
                with logger.prefix('epoch #%d | ' % epoch):
                    # split into mini batches for training
                    total_batchsize = torch_input_batch.size(0)

                    logger.record_tabular('Iteration', epoch)
                    indices = np.random.permutation(np.arange(total_batchsize))
                    if isinstance(self.imitationModel, CartPoleModel):
                        logger.record_tabular("theta", str(self.imitationModel.theta.detach().numpy()))
                        logger.record_tabular("std", str(self.imitationModel.std.detach().numpy()))
                    # go through the whole batch
                    for k in range(int(total_batchsize/self.mini_batchsize)):
                        idx = indices[self.mini_batchsize*k:self.mini_batchsize*(k+1)]
                        # TODO: how about numerical stability?

                        log_prob = self.imitationModel.get_log_prob(torch_input_batch[idx, :], torch_output_batch[idx, :])

                        # note that L2 regularization is in weight decay of optimizer
                        loss = -torch.mean(log_prob) # negative since we want to minimize and not maximize
                        self.optimizer.zero_grad()
                        loss.backward()
                        self.optimizer.step()

                    # calculate the loss on the whole batch
                    log_prob = self.imitationModel.get_log_prob(validation_input_batch, validation_output_batch)
                    loss = -torch.mean(log_prob)
                    # Note: here we add L2 regularization to the loss to log the proper loss
                    # weight decay
                    for param in self.imitationModel.parameters():
                        loss += param.pow(2).sum() * self.l2_reg
                    logger.record_tabular("loss", loss.item())

                    # check if loss has decreased in the last 25 itr on the validation set, if not stop training
                    # and return the best found parameters
                    losses[1:] = losses[0:-1]
                    losses[0] = loss

                    if epoch == 0:
                        best_loss = np.min(losses)
                        best_flat_parameters = torch_utils.get_flat_params_from(self.imitationModel).detach().numpy()
                        logger.record_tabular("current_best_loss", best_loss)
                    elif np.min(losses) <= best_loss and not (np.mean(losses) == best_loss): #second condition prevents same error in whole losses
                        # set best loss to new one if smaller or keep it
                        best_loss = np.min(losses)
                        best_flat_parameters = torch_utils.get_flat_params_from(self.imitationModel).detach().numpy()
                        logger.record_tabular("current_best_loss", best_loss)
                    else:
                        pbar.close()
                        print("best loss did not decrease in last 25 steps")
                        print("saving best result...")
                        logger.log("best loss did not decrease in last 25 steps")
                        torch_utils.set_flat_params_to(self.imitationModel, torch_utils.torch.from_numpy(best_flat_parameters))
                        logger.log("SGD converged")
                        logger.log("saving best result...")
                        params, torch_params = self.get_itr_snapshot(epoch)
                        if not params is None:
                            params["algo"] = self
                        logger.save_itr_params(self.n_itr, params, torch_params)
                        logger.log("saved")
                        break

                    pbar.set_description('epoch: %d' % (1 + epoch))
                    pbar.update(1)

                # save result
                logger.log("saving snapshot...")
                params, torch_params = self.get_itr_snapshot(epoch)
                if not params is None:
                    params["algo"] = self
                logger.save_itr_params(epoch, params, torch_params)
                logger.log("saved")
                logger.dump_tabular(with_prefix=False)
Exemplo n.º 4
0
    def _train_BGFS(self):
        if not isinstance(self.imitationModel, CartPoleModel):
            raise NotImplementedError("train BGFS can be only called with CartPoleModel")
        expert_observations = torch.from_numpy(self.expert_data["observations"]).float()
        expert_actions = torch.from_numpy(self.expert_data["actions"]).float()
        expert_obs_diff = torch.from_numpy(self.expert_data["env_infos"]["obs_diff"]).float()
        # now train imitation policy using collect batch of expert_data with MLE on log prob since we have a Gaussian
        # TODO: do we train mean and variance? or only mean

        if self.mode == "imitate_env":
            input = torch.cat([expert_observations, expert_actions], dim=1)
            output = expert_obs_diff
        else:
            return ValueError("invalid mode")

        imitation_model = self.imitationModel
        total_batchsize = input.size(0)

        def get_negative_likelihood_loss(flat_params):
            torch_utils.set_flat_params_to(imitation_model, torch_utils.torch.from_numpy(flat_params))
            for param in imitation_model.parameters():
                if param.grad is not None:
                    param.grad.data.fill_(0)

            indices = np.random.permutation(np.arange(total_batchsize))

            loss = - torch.mean(imitation_model.get_log_prob(input[indices[:self.mini_batchsize]], output[indices[:self.mini_batchsize]]))

            # weight decay
            for param in imitation_model.parameters():
                loss += param.pow(2).sum() * self.l2_reg
            loss.backward()

            # FIX: removed [0] since, mean reduces already it to an int (new functionality of new torch version?
            return loss.detach().numpy(), \
                   torch_utils.get_flat_grad_from(
                       imitation_model.parameters()).detach().numpy(). \
                       astype(np.float64)

        curr_itr = 0

        def callback_fun(flat_params):
            nonlocal curr_itr
            torch_utils.set_flat_params_to(imitation_model, torch_utils.torch.from_numpy(flat_params))
            # calculate the loss of the whole batch
            loss = - torch.mean(imitation_model.get_log_prob(input, output))
            # weight decay
            for param in imitation_model.parameters():
                loss += param.pow(2).sum() * self.l2_reg
            loss.backward()
            if isinstance(self.imitationModel, CartPoleModel):
                logger.record_tabular("theta", str(self.imitationModel.theta.detach().numpy()))
                logger.record_tabular("std", str(self.imitationModel.std.detach().numpy()))
            logger.record_tabular('Iteration', curr_itr)
            logger.record_tabular("loss", loss.item())
            logger.dump_tabular(with_prefix=False)
            curr_itr += 1

        x0 = torch_utils.get_flat_params_from(self.imitationModel).detach().numpy()
        # only allow positive variables since we know the masses and variance cannot be negative
        bounds = [(0, np.inf) for _ in x0]

        flat_params, _, opt_info = scipy.optimize.fmin_l_bfgs_b(
            get_negative_likelihood_loss,
            x0, maxiter=self.n_itr, bounds=bounds, callback=callback_fun)
        logger.log(str(opt_info))
        torch_utils.set_flat_params_to(self.imitationModel, torch.from_numpy(flat_params))

        # save result
        logger.log("saving snapshot...")
        params, torch_params = self.get_itr_snapshot(0)
        params["algo"] = self
        logger.save_itr_params(self.n_itr, params, torch_params)
        logger.log("saved")
    def step(self, policy_net, value_net, states, actions, returns, advantages):

        """update critic"""
        values_target = Variable(returns)

        """calculates the mean kl difference between 2 parameter settings"""
        def get_kl_diff(old_param, new_param):
            prev_params = torch_utils.get_flat_params_from(policy_net)
            with torch.no_grad():
                torch_utils.set_flat_params_to(policy_net, old_param)
                log_old_prob = torch.clamp(policy_net.get_log_prob(
                    Variable(states, volatile=True), Variable(actions)), min=np.log(1e-6))
                torch_utils.set_flat_params_to(policy_net, new_param)
                log_new_prob = torch.clamp(policy_net.get_log_prob(
                    Variable(states, volatile=True), Variable(actions)), min=np.log(1e-6))
            torch_utils.set_flat_params_to(policy_net, prev_params)
            return torch.mean(torch.exp(log_old_prob) * (log_old_prob-log_new_prob)).numpy()

        def get_value_loss(flat_params):
            torch_utils.set_flat_params_to(value_net,
                                           torch_utils.torch.from_numpy(flat_params))
            for param in value_net.parameters():
                if param.grad is not None:
                    param.grad.data.fill_(0)
            values_pred = value_net(Variable(states))
            value_loss = (values_pred - values_target).pow(2).mean()

            # weight decay
            for param in value_net.parameters():
                value_loss += param.pow(2).sum() * self.l2_reg
            value_loss.backward()

            # FIX: removed [0] since, mean reduces already it to an int (new functionality of new torch version?
            return value_loss.data.cpu().numpy(), \
                   torch_utils.get_flat_grad_from(
                       value_net.parameters()).data.cpu().numpy(). \
                       astype(np.float64)

        flat_params, _, opt_info = scipy.optimize.fmin_l_bfgs_b(
            get_value_loss,
            torch_utils.get_flat_params_from(value_net).cpu().numpy(), maxiter=25)
        torch_utils.set_flat_params_to(value_net, torch.from_numpy(flat_params))

        """update policy"""
        fixed_log_probs = torch.clamp(policy_net.get_log_prob(
            Variable(states, volatile=True), Variable(actions)), min=np.log(1e-6)).data
        """define the loss function for TRPO"""

        def get_loss(volatile=False):
            # more numberical stable: have a minimum value, s.t. we don't get -inf
            log_probs = torch.clamp(policy_net.get_log_prob(
                Variable(states, volatile=volatile), Variable(actions)), min=np.log(1e-6))
            ent = policy_net.get_entropy(Variable(states, volatile=volatile), Variable(actions)).mean()
            action_loss = -Variable(advantages) * torch.exp(
                log_probs - Variable(fixed_log_probs))
            # logger.log("advantage"+str(advantages))
            # logger.log("log_probs"+str(log_probs))
            # logger.log("mean"+str(torch.mean(torch.exp(
            #     log_probs - Variable(fixed_log_probs)))))
            # logger.log("action_loss_no_mean"+str(-action_loss))
            return action_loss.mean() - self.entropy_coeff * ent

        """use fisher information matrix for Hessian*vector"""

        def Fvp_fim(v):
            M, mu, info = policy_net.get_fim(Variable(states))
            mu = mu.view(-1)
            filter_input_ids = set() if policy_net.is_disc_action else \
                {info['std_id']}

            t = M.new(mu.size())
            t[:] = 1
            t = Variable(t, requires_grad=True)
            mu_t = (mu * t).sum()
            Jt = torch_utils.compute_flat_grad(mu_t, policy_net.parameters(),
                                               filter_input_ids=filter_input_ids,
                                               create_graph=True)
            Jtv = (Jt * Variable(v)).sum()
            Jv = torch.autograd.grad(Jtv, t, retain_graph=True)[0]
            MJv = Variable(M * Jv.data)
            mu_MJv = (MJv * mu).sum()
            JTMJv = torch_utils.compute_flat_grad(mu_MJv, policy_net.parameters(),
                                                  filter_input_ids=filter_input_ids,
                                                  retain_graph=True).data
            JTMJv /= states.shape[0]
            if not policy_net.is_disc_action:
                std_index = info['std_index']
                JTMJv[std_index: std_index + M.shape[0]] += \
                    2 * v[std_index: std_index + M.shape[0]]
            return JTMJv + v * self.damping

        """directly compute Hessian*vector from KL"""

        def Fvp_direct(v):
            kl = policy_net.get_kl(Variable(states))
            kl = kl.mean()

            grads = torch.autograd.grad(kl, policy_net.parameters(),
                                        create_graph=True)
            flat_grad_kl = torch.cat([grad.view(-1) for grad in grads])

            kl_v = (flat_grad_kl * Variable(v)).sum()
            grads = torch.autograd.grad(kl_v, policy_net.parameters())
            flat_grad_grad_kl = torch.cat(
                [grad.contiguous().view(-1) for grad in grads]).data

            return flat_grad_grad_kl + v * self.damping

        Fvp = Fvp_fim if self.use_fim else Fvp_direct

        loss = get_loss()
        grads = torch.autograd.grad(loss, policy_net.parameters())
        loss_grad = torch.cat([grad.view(-1) for grad in grads]).data
        stepdir = conjugate_gradients(Fvp, -loss_grad, 10)

        shs = (stepdir.dot(Fvp(stepdir)))
        lm = np.sqrt(2 * self.max_kl / (shs + 1e-8))
        if np.isnan(lm):
            lm = 1.
        fullstep = stepdir * lm
        expected_improve = -loss_grad.dot(fullstep)

        prev_params = torch_utils.get_flat_params_from(policy_net)
        success, new_params = \
            line_search(policy_net, get_loss, prev_params, fullstep, expected_improve, get_kl_diff, self.max_kl)
        logger.record_tabular('TRPO_linesearch_success', int(success))
        logger.record_tabular("KL_diff", get_kl_diff(prev_params,new_params))
        torch_utils.set_flat_params_to(policy_net, new_params)
        logger.log("old_parameters" + str(prev_params.detach().numpy()))
        logger.log("new_parameters" + str(new_params.detach().numpy()))
        return success