def optimize_epoch(self, num_epochs): if self.optimizer is None: raise ValueError('Learning rate is not set!') if self.data_loader is None: # convert action into indices self.data_loader = DataLoader(self.memory, self.batch_size, shuffle=True) average_value_loss = 0 average_policy_loss = 0 for epoch in range(num_epochs): value_loss = 0 policy_loss = 0 logging.debug('{}-th epoch starts'.format(epoch)) for data in self.data_loader: inputs, values, _, actions = data self.optimizer.zero_grad() # # outputs_val, outputs_mu, outputs_cov = self.model(inputs) # action_log_probs = MultivariateNormal(outputs_mu, outputs_cov).log_prob(actions) outputs_val, alpha_beta_1, alpha_beta_2 = self.model(inputs) vx_dist = Beta(alpha_beta_1[:, 0], alpha_beta_1[:, 1]) vy_dist = Beta(alpha_beta_2[:, 0], alpha_beta_2[:, 1]) p = torch.Tensor([1 + 1e-6]).to(self.device) q = torch.Tensor([1e-8]).to(self.device) action_log_probs = (vx_dist.log_prob(actions[:, 0] / p + q)).unsqueeze(1) +\ (vy_dist.log_prob(actions[:, 1] / p + q)).unsqueeze(1) values = values.to(self.device) dist_entropy = vx_dist.entropy().mean() + vy_dist.entropy( ).mean() loss1 = self.criterion_val(outputs_val, values) loss2 = -action_log_probs.mean() loss = loss1 + loss2 - dist_entropy * self.entropy_coef # loss = loss1 + loss2 loss.backward() self.optimizer.step() value_loss += loss1.data.item() policy_loss += loss2.data.item() logging.debug('{}-th epoch ends'.format(epoch)) average_value_loss = value_loss / len(self.memory) average_policy_loss = policy_loss / len(self.memory) self.writer.add_scalar('IL/average_value_loss', average_value_loss, epoch) self.writer.add_scalar('IL/average_policy_loss', average_policy_loss, epoch) logging.info('Average value, policy loss in epoch %d: %.2E, %.2E', epoch, average_value_loss, average_policy_loss) return average_value_loss
def evaluate_actions(pi, actions, dist_type, env_type): if env_type == 'atari': cate_dist = Categorical(pi) log_prob = cate_dist.log_prob(actions).unsqueeze(-1) entropy = cate_dist.entropy().mean() else: if dist_type == 'gauss': mean, std = pi normal_dist = Normal(mean, std) log_prob = normal_dist.log_prob(actions).sum(dim=1, keepdim=True) entropy = normal_dist.entropy().mean() elif dist_type == 'beta': alpha, beta = pi beta_dist = Beta(alpha, beta) log_prob = beta_dist.log_prob(actions).sum(dim=1, keepdim=True) entropy = beta_dist.entropy().mean() return log_prob, entropy
def train_on_batch(self, batch): """perform optimization step. Args: batch (tuple): tuple of batches of environment observations, calling programs, lstm's hidden and cell states Returns: policy loss, value loss, total loss combining policy and value losses """ e_t = torch.FloatTensor(np.stack(batch[0])) i_t = batch[1] lstm_states = batch[2] h_t, c_t = zip(*lstm_states) h_t, c_t = torch.squeeze(torch.stack(list(h_t))), torch.squeeze( torch.stack(list(c_t))) policy_labels = torch.squeeze(torch.stack(batch[3])) value_labels = torch.stack(batch[4]).view(-1, 1) self.optimizer.zero_grad() policy_predictions, value_predictions, _, _ = self.predict_on_batch( e_t, i_t, h_t, c_t) # policy_loss = -torch.mean(policy_labels * torch.log(policy_predictions), dim=-1).mean() beta = Beta(policy_predictions[0], policy_predictions[1]) policy_action = beta.sample() prob_action = beta.log_prob(policy_action) log_mcts = self.temperature * torch.log(policy_labels) with torch.no_grad(): modified_kl = prob_action - log_mcts policy_loss = -modified_kl * (torch.log(modified_kl) + prob_action) entropy_loss = self.entropy_lambda * beta.entropy() policy_network_loss = policy_loss + entropy_loss value_network_loss = torch.pow(value_predictions - value_labels, 2).mean() total_loss = (policy_network_loss + value_network_loss) / 2 total_loss.backward() self.optimizer.step() return policy_network_loss, value_network_loss, total_loss
def optimize_batch(self, num_batches, episode=None): if self.optimizer is None: raise ValueError('Learning rate is not set!') if self.data_loader is None: self.data_loader = DataLoader(self.memory, self.batch_size, shuffle=True) value_losses = 0 policy_losses = 0 entropy = 0 l2_losses = 0 batch_count = 0 for data in self.data_loader: inputs, values, rewards, actions, returns, old_action_log_probs, adv_targ = data self.optimizer.zero_grad() # outputs_vals, outputs_mu, outputs_cov = self.model(inputs) # dist = MultivariateNormal(outputs_mu, outputs_cov) # action_log_probs = dist.log_prob(actions) outputs_vals, alpha_beta_1, alpha_beta_2 = self.model(inputs) vx_dist = Beta(alpha_beta_1[:, 0], alpha_beta_1[:, 1]) vy_dist = Beta(alpha_beta_2[:, 0], alpha_beta_2[:, 1]) action_log_probs = vx_dist.log_prob( actions[:, 0]).unsqueeze(1) + vy_dist.log_prob( actions[:, 1]).unsqueeze(1) # TODO: check why entropy is negative dist_entropy = vx_dist.entropy().mean() + vy_dist.entropy().mean() ratio = torch.exp(action_log_probs - old_action_log_probs) assert ratio.shape[1] == 1 surr1 = ratio * adv_targ surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ loss1 = -torch.min(surr1, surr2).mean() loss2 = self.criterion_val(outputs_vals, values) * 0.5 * self.value_loss_coef loss3 = -dist_entropy * self.entropy_coef # speed_square_diff = torch.sum(torch.pow(outputs_mu, 2), dim=1) - torch.Tensor([1]).to(self.device).double() # loss4 = torch.pow(torch.max(speed_square_diff, torch.Tensor([0]).to(self.device).double()), 2).mean() * 1 loss = loss1 + loss2 + loss3 loss.backward() self.optimizer.step() policy_losses += loss1.data.item() value_losses += loss2.data.item() entropy += float(dist_entropy.cpu()) # l2_losses += loss4.data.item() batch_count += 1 if batch_count > num_batches: break average_value_loss = value_losses / num_batches average_policy_loss = policy_losses / num_batches average_entropy = entropy / num_batches average_l2_loss = l2_losses / num_batches logging.info('Average value, policy loss : %.2E, %.2E', average_value_loss, average_policy_loss) self.writer.add_scalar('train/average_value_loss', average_value_loss, episode) self.writer.add_scalar('train/average_policy_loss', average_policy_loss, episode) self.writer.add_scalar('train/average_entropy', average_entropy, episode) # self.writer.add_scalar('train/average_l2_loss', average_l2_loss, episode) return average_value_loss
def get_entropy(self, state): bsize = state.size(0) alpha, beta = self.forward(state) dist = Beta(concentration1=alpha, concentration0=beta) entropy = dist.entropy().view(bsize, 1) return entropy