예제 #1
0
    def optimize_epoch(self, num_epochs):
        if self.optimizer is None:
            raise ValueError('Learning rate is not set!')
        if self.data_loader is None:
            # convert action into indices
            self.data_loader = DataLoader(self.memory,
                                          self.batch_size,
                                          shuffle=True)
        average_value_loss = 0
        average_policy_loss = 0
        for epoch in range(num_epochs):
            value_loss = 0
            policy_loss = 0
            logging.debug('{}-th epoch starts'.format(epoch))
            for data in self.data_loader:
                inputs, values, _, actions = data
                self.optimizer.zero_grad()
                # # outputs_val, outputs_mu, outputs_cov = self.model(inputs)
                # action_log_probs = MultivariateNormal(outputs_mu, outputs_cov).log_prob(actions)
                outputs_val, alpha_beta_1, alpha_beta_2 = self.model(inputs)
                vx_dist = Beta(alpha_beta_1[:, 0], alpha_beta_1[:, 1])
                vy_dist = Beta(alpha_beta_2[:, 0], alpha_beta_2[:, 1])
                p = torch.Tensor([1 + 1e-6]).to(self.device)
                q = torch.Tensor([1e-8]).to(self.device)
                action_log_probs = (vx_dist.log_prob(actions[:, 0] / p + q)).unsqueeze(1) +\
                                    (vy_dist.log_prob(actions[:, 1] / p + q)).unsqueeze(1)

                values = values.to(self.device)
                dist_entropy = vx_dist.entropy().mean() + vy_dist.entropy(
                ).mean()

                loss1 = self.criterion_val(outputs_val, values)
                loss2 = -action_log_probs.mean()
                loss = loss1 + loss2 - dist_entropy * self.entropy_coef
                # loss = loss1 + loss2
                loss.backward()
                self.optimizer.step()
                value_loss += loss1.data.item()
                policy_loss += loss2.data.item()
            logging.debug('{}-th epoch ends'.format(epoch))
            average_value_loss = value_loss / len(self.memory)
            average_policy_loss = policy_loss / len(self.memory)
            self.writer.add_scalar('IL/average_value_loss', average_value_loss,
                                   epoch)
            self.writer.add_scalar('IL/average_policy_loss',
                                   average_policy_loss, epoch)
            logging.info('Average value, policy loss in epoch %d: %.2E, %.2E',
                         epoch, average_value_loss, average_policy_loss)

        return average_value_loss
예제 #2
0
파일: utils.py 프로젝트: DAIM-2020/DAIM
def evaluate_actions(pi, actions, dist_type, env_type):
    if env_type == 'atari':
        cate_dist = Categorical(pi)
        log_prob = cate_dist.log_prob(actions).unsqueeze(-1)
        entropy = cate_dist.entropy().mean()
    else:
        if dist_type == 'gauss':
            mean, std = pi
            normal_dist = Normal(mean, std)
            log_prob = normal_dist.log_prob(actions).sum(dim=1, keepdim=True)
            entropy = normal_dist.entropy().mean()
        elif dist_type == 'beta':
            alpha, beta = pi
            beta_dist = Beta(alpha, beta)
            log_prob = beta_dist.log_prob(actions).sum(dim=1, keepdim=True)
            entropy = beta_dist.entropy().mean()
    return log_prob, entropy
예제 #3
0
    def train_on_batch(self, batch):
        """perform optimization step.

        Args:
          batch (tuple): tuple of batches of environment observations, calling programs, lstm's hidden and cell states

        Returns:
          policy loss, value loss, total loss combining policy and value losses
        """
        e_t = torch.FloatTensor(np.stack(batch[0]))
        i_t = batch[1]
        lstm_states = batch[2]
        h_t, c_t = zip(*lstm_states)
        h_t, c_t = torch.squeeze(torch.stack(list(h_t))), torch.squeeze(
            torch.stack(list(c_t)))

        policy_labels = torch.squeeze(torch.stack(batch[3]))
        value_labels = torch.stack(batch[4]).view(-1, 1)

        self.optimizer.zero_grad()
        policy_predictions, value_predictions, _, _ = self.predict_on_batch(
            e_t, i_t, h_t, c_t)

        # policy_loss = -torch.mean(policy_labels * torch.log(policy_predictions), dim=-1).mean()

        beta = Beta(policy_predictions[0], policy_predictions[1])
        policy_action = beta.sample()
        prob_action = beta.log_prob(policy_action)

        log_mcts = self.temperature * torch.log(policy_labels)
        with torch.no_grad():
            modified_kl = prob_action - log_mcts

        policy_loss = -modified_kl * (torch.log(modified_kl) + prob_action)
        entropy_loss = self.entropy_lambda * beta.entropy()

        policy_network_loss = policy_loss + entropy_loss
        value_network_loss = torch.pow(value_predictions - value_labels,
                                       2).mean()

        total_loss = (policy_network_loss + value_network_loss) / 2
        total_loss.backward()
        self.optimizer.step()

        return policy_network_loss, value_network_loss, total_loss
예제 #4
0
    def optimize_batch(self, num_batches, episode=None):
        if self.optimizer is None:
            raise ValueError('Learning rate is not set!')
        if self.data_loader is None:
            self.data_loader = DataLoader(self.memory,
                                          self.batch_size,
                                          shuffle=True)
        value_losses = 0
        policy_losses = 0
        entropy = 0
        l2_losses = 0
        batch_count = 0
        for data in self.data_loader:
            inputs, values, rewards, actions, returns, old_action_log_probs, adv_targ = data
            self.optimizer.zero_grad()
            # outputs_vals, outputs_mu, outputs_cov = self.model(inputs)
            # dist = MultivariateNormal(outputs_mu, outputs_cov)
            # action_log_probs = dist.log_prob(actions)
            outputs_vals, alpha_beta_1, alpha_beta_2 = self.model(inputs)
            vx_dist = Beta(alpha_beta_1[:, 0], alpha_beta_1[:, 1])
            vy_dist = Beta(alpha_beta_2[:, 0], alpha_beta_2[:, 1])
            action_log_probs = vx_dist.log_prob(
                actions[:, 0]).unsqueeze(1) + vy_dist.log_prob(
                    actions[:, 1]).unsqueeze(1)

            # TODO: check why entropy is negative
            dist_entropy = vx_dist.entropy().mean() + vy_dist.entropy().mean()

            ratio = torch.exp(action_log_probs - old_action_log_probs)
            assert ratio.shape[1] == 1
            surr1 = ratio * adv_targ
            surr2 = torch.clamp(ratio, 1.0 - self.clip_param,
                                1.0 + self.clip_param) * adv_targ
            loss1 = -torch.min(surr1, surr2).mean()
            loss2 = self.criterion_val(outputs_vals,
                                       values) * 0.5 * self.value_loss_coef
            loss3 = -dist_entropy * self.entropy_coef

            # speed_square_diff = torch.sum(torch.pow(outputs_mu, 2), dim=1) - torch.Tensor([1]).to(self.device).double()
            # loss4 = torch.pow(torch.max(speed_square_diff, torch.Tensor([0]).to(self.device).double()), 2).mean() * 1

            loss = loss1 + loss2 + loss3
            loss.backward()
            self.optimizer.step()

            policy_losses += loss1.data.item()
            value_losses += loss2.data.item()
            entropy += float(dist_entropy.cpu())
            # l2_losses += loss4.data.item()
            batch_count += 1
            if batch_count > num_batches:
                break

        average_value_loss = value_losses / num_batches
        average_policy_loss = policy_losses / num_batches
        average_entropy = entropy / num_batches
        average_l2_loss = l2_losses / num_batches
        logging.info('Average value, policy loss : %.2E, %.2E',
                     average_value_loss, average_policy_loss)
        self.writer.add_scalar('train/average_value_loss', average_value_loss,
                               episode)
        self.writer.add_scalar('train/average_policy_loss',
                               average_policy_loss, episode)
        self.writer.add_scalar('train/average_entropy', average_entropy,
                               episode)
        # self.writer.add_scalar('train/average_l2_loss', average_l2_loss, episode)

        return average_value_loss
예제 #5
0
 def get_entropy(self, state):
     bsize = state.size(0)
     alpha, beta = self.forward(state)
     dist = Beta(concentration1=alpha, concentration0=beta)
     entropy = dist.entropy().view(bsize, 1)
     return entropy