def focal_loss2d(self, x, y): '''not works yet''' alpha = 0.25 gamma = 2 y = y.view(y.size(0), -1) if self.using_gpu is True: t = one_hot_embedding(y.data.cpu(), 2) else: t = one_hot_embedding(y.data, 2) t = t[:, 1:] # exclude background if self.using_gpu is True: t = Variable(t).cuda() else: t = Variable(t) p = x.sigmoid() pt = p * t + (1 - p) * (1 - t) # pt = p if t > 0 else 1-p w = alpha * t + (1 - alpha) * (1 - t ) # w = alpha if t > 0 else 1-alpha w = w * (1 - pt).pow(gamma) return F.binary_cross_entropy_with_logits(x, t, w)
def focal_loss2d(self, x, y): alpha = 0.25 gamma = 2 y = y.view(-1) x = x.permute(0, 2, 3, 1).contiguous().view(-1, 2) if self.using_gpu is True: t = one_hot_embedding(y.data.cpu(), 2) else: t = one_hot_embedding(y.data, 2) # t = t[:,1:] # exclude background if self.using_gpu is True: t = Variable(t).cuda() else: t = Variable(t) p = x.sigmoid() pt = p * t + (1 - p) * (1 - t) # pt = p if t > 0 else 1-p w = alpha * t + (1 - alpha) * (1 - t ) # w = alpha if t > 0 else 1-alpha w = w * (1 - pt).pow(gamma) return F.binary_cross_entropy_with_logits(x, t, w)
def focal_loss(self, x, y): '''Focal loss. Args: x: (tensor) sized [N,D]. y: (tensor) sized [N,]. Return: (tensor) focal loss. ''' alpha = 0.25 gamma = 2 if self.using_gpu is True: t = one_hot_embedding(y.data.cpu(), 1 + self.num_classes) else: t = one_hot_embedding(y.data, 1 + self.num_classes) t = t[:, 1:] # exclude background if self.using_gpu is True: t = Variable(t).cuda() else: t = Variable(t) p = x.sigmoid() pt = p * t + (1 - p) * (1 - t) # pt = p if t > 0 else 1-p w = alpha * t + (1 - alpha) * (1 - t ) # w = alpha if t > 0 else 1-alpha w = w * (1 - pt).pow(gamma) return F.binary_cross_entropy_with_logits(x, t, w, size_average=False)
def focal_loss_sigmoid(self, x, y): """ Sigmoid version of focal loss. This is described in the original paper. With BCELoss, the background should not be counted in num_classes. Args: x: (tensor) predictions, sized [N,D]. y: (tensor) targets, sized [N,]. Return: (tensor) focal loss. """ alpha = 0.25 gamma = 2 t = one_hot_embedding(y.data.cpu(), 1 + self.num_classes) t = t[:, 1:] # exclude background t = Variable(t).cuda() p = x.sigmoid() pt = p * t + (1 - p) * (1 - t) # pt = p if t > 0 else 1-p w = alpha * t + (1 - alpha) * (1 - t ) # w = alpha if t > 0 else 1-alpha w = w * (1 - pt).pow(gamma) return F.binary_cross_entropy_with_logits(x, t, w, size_average=False)
def focal_loss_c(self, pred, targ): '''Focal loss Args: pred: (tensor) sized [#anchors, #num_classes] targ: (tensor) sized [#anchors, ] Return: loss = -1*alpht*(1-p_t)**gamma*log(p_t) ''' alpha = 0.25 gamma = 2 targ = one_hot_embedding(targ.data.cpu(), 1 + self.num_classes) targ = targ[:, 1:] # exclude_background targ = Variable(targ).cuda() # [#anchors, #num_clsasses] if pred.is_cuda is False: pred = pred.cuda() p = pred.sigmoid() # p_t = p if pred is targ, p_t = 1-p if pred is not targ p_t = p * targ + (1 - p) * (1 - targ) coeff1 = alpha * targ + (1 - alpha) * (1 - targ) coeff2 = (1 - p_t).pow(gamma) * coeff1 sigmoid_p = pred.sigmoid() per_cross_ent = -1 * coeff2 * (p_t.log()) return torch.sum(per_cross_ent)
def focal_loss(self, pred, targ): '''Focal loss Args: pred: (tensor) sized [#anchors, #num_classes] targ: (tensor) sized [#anchors, ] Return: loss = -1*alpht*(1-p_t)**gamma*log(p_t) ''' alpha = 0.25 gamma = 2 targ = one_hot_embedding(targ.data.cpu(), 1 + self.num_classes) targ = targ[:, 1:] # exclude_background targ = Variable(targ).cuda() # [#anchors, #num_clsasses] p = pred.sigmoid() # p_t = p if pred is targ, p_t = 1-p if pred is not targ p_t = p * targ + (1 - p) * (1 - targ) # coeff = alpha if pred is targ, coeff = 1-alpha if pred is not targ coeff = alpha * targ + (1 - alpha) * (1 - targ) coeff *= (1 - p_t).pow(gamma) # within BCEwithlogits, pred and target go through sigmoid function return F.binary_cross_entropy_with_logits(pred, targ, coeff, size_average=False)
def focal_loss_alt(self, x, y): '''Focal loss alternative. Args: x: (tensor) sized [N,D]. y: (tensor) sized [N,]. Return: (tensor) focal loss. ''' alpha = 0.25 t = one_hot_embedding(y.data.cpu(), 1 + self.num_classes) t = t[:, 1:] t = Variable(t).cuda() # print('t',t) # print('x',x) xt = x * (2 * t - 1) # xt = x if t > 0 else -x pt = (2 * xt + 1).sigmoid() # pt = torch.clamp(pt, 1.0, 1e-10) # print('pt',pt) w = alpha * t + (1 - alpha) * (1 - t) loss = -w * pt.log() / 2 return loss.sum()
def focal_loss_ku(self, x, y): '''Focal loss. Args: x: (tensor) sized [N,D]. y: (tensor) sized [N,]. Return: (tensor) focal loss. ''' alpha = 0.25 gamma = 2 t = one_hot_embedding(y.data.cpu(), 1 + self.num_classes) # [N,21] t = t[:, 1:] # exclude background t = Variable(t).cuda() # [N,20] p = x.sigmoid() pt = p * t + (1 - p) * (1 - t) # pt = p if t > 0 else 1-p w = alpha * t + (1 - alpha) * (1 - t ) # w = alpha if t > 0 else 1-alpha w = w * (1 - pt).pow(gamma) return F.binary_cross_entropy_with_logits(x, t, w.detach(), size_average=False, reduction='sum')
def forward(self, x, y, classes=20): t = one_hot_embedding(y, classes + 1) # classes include background t = t[:, 1:] # exclude background p = x.sigmoid() pt = torch.where(t > 0, p, 1 - p) # pt = p if t > 0 else 1-p w = (1 - pt).pow(self.gamma) w = torch.where(t > 0, self.alpha * w, (1 - self.alpha) * w).detach() loss = F.binary_cross_entropy_with_logits(x, t, w, reduction='sum') return loss
def forward(self, inputs, targets): class_onehot = one_hot_embedding(targets.data.cpu().long(), 1 + self.num_classes) # [N, 81] class_onehot = torch.Tensor( class_onehot[:, 1:]).cuda() # exclude background prob = self.sigmoid(inputs) pt = prob * class_onehot + (1 - prob) * (1 - class_onehot) w = self.alpha * class_onehot + (1 - self.alpha) * (1 - class_onehot) w = w * (1 - pt).pow(self.gamma) return F.binary_cross_entropy_with_logits(inputs, class_onehot, w, size_average=False)
def focal_loss(self, x, y, gamma=2): """ Focal loss. Args: x: (tensor) sized [N,D]. y: (tensor) sized [N,]. Return: (tensor) focal loss. """ t = one_hot_embedding(y.data.cpu(), self.num_classes) t = Variable(t).type(self.FloatTensor) # [N,D] p = F.softmax(x, dim=1) # [N,D] pt = (p * t).sum(1) # [N,] loss = F.cross_entropy(x, y, reduce=False) # [N,] loss = (1 - pt).pow(gamma) * loss return loss.sum()
def focal_loss(self, x, y): '''Focal loss. Args: x: (tensor) sized [N,D]. y: (tensor) sized [N,]. Return: (tensor) focal loss. ''' alpha = 0.25 gamma = 2 t = utils.one_hot_embedding(y.data.cpu(), 1+self.num_classes) t = t[:,1:] t = Variable(t).cuda() p = x.sigmoid() pt = p*t + (1-p)*(1-t) w = alpha*t + (1-alpha)*(1-t) w = w * (1-pt).pow(gamma) return F.binary_cross_entropy_with_logits(x, t, w, size_average=False)
def focal_loss(self, x, y): alpha = 0.25 gamma = 2 target = one_hot_embedding(y, 1 + self.num_classes) target = target[:, 1:] # exclude background prob = x.sigmoid() pred = prob * target + (1 - prob) * (1 - target ) # pt = p if t > 0 else 1-p weight = alpha * target + (1 - alpha) * ( 1 - target) # w = alpha if t > 0 else 1-alpha weight = weight * (1 - pred).pow(gamma) weight = weight.detach() loss = F.binary_cross_entropy_with_logits(input=x, target=target, weight=weight, reduction='sum') return loss
def focal_loss_alt(self, x, y): # https://github.com/kuangliu/pytorch-retinanet/issues/52 @miramind """Focal loss Args: x(tensor): size [N, D] y(tensor): size [N, ] Returns: (tensor): focal loss """ #print(y) t = one_hot_embedding(y.data.cpu(), 1+self.num_classes) # [N,21] t = t[:, 1:] # exclude background t = Variable(t).cuda() # [N,20] logit = F.softmax(x) logit = logit.clamp(1e-7, 1.-1e-7) conf_loss_tmp = -1 * t.float() * torch.log(logit) conf_loss_tmp = self.balance_alpha * conf_loss_tmp * (1-logit)**self.gamma conf_loss = conf_loss_tmp.sum() return conf_loss
def focal_loss_alt(self, x, y): '''Focal loss alternative. Args: x: (tensor) sized [N,D]. y: (tensor) sized [N,]. Return: (tensor) focal loss. ''' alpha = 0.25 t = one_hot_embedding(y.data.cpu(), 1+self.num_classes) t = t[:,1:] t = Variable(t).cuda() xt = x*(2*t-1) # xt = x if t > 0 else -x pt = (2*xt+1).sigmoid() w = alpha*t + (1-alpha)*(1-t) loss = -w*pt.log() / 2 return loss.sum()
def focal_loss_new(self, x, y): alpha = 0.25 gamma = 2 t = one_hot_embedding(y.data.cpu(), 1 + self.num_classes) # [N,9] t = t[:, 1:] # exclude background t = Variable(t).cuda() # [N,8] logpt = F.log_softmax(x) logpt = logpt.gather(1, t) logpt = logpt.view(-1) pt = Variable(logpt.data.exp()) if alpha.type() != x.data.type(): alpha = alpha.type_as(x.data) at = alpha.gather(0, t.data.view(-1)) logpt = logpt * Variable(at) loss = -1 * (1 - pt)**gamma * logpt return loss.sum()
def focal_loss_alt(self, x, y): '''Focal loss alternative. Args: x: (tensor) sized [N,D]. y: (tensor) sized [N,]. Return: (tensor) focal loss. ''' alpha = 0.25 t = one_hot_embedding(y.data.cpu(), 1 + self.num_classes) t = t[:, 1:] t = t.to(device) xt = x * (2 * t - 1) # xt = x if t > 0 else -x pt = (2 * xt + 1).sigmoid() w = alpha * t + (1 - alpha) * (1 - t) loss = -w * pt.log() / 2 return loss.sum()
def focal_loss(self, x, y): '''Focal loss. Args: x: (tensor) sized [N,D]. y: (tensor) sized [N,]. Return: (tensor) focal loss. ''' alpha = 0.25 gamma = 2 t = one_hot_embedding(y.data.cpu(), 1+self.num_classes) # [N,21] t = t[:,1:] # exclude background t = Variable(t).cuda() # [N,20] p = x.sigmoid() pt = p*t + (1-p)*(1-t) # pt = p if t > 0 else 1-p w = alpha*t + (1-alpha)*(1-t) # w = alpha if t > 0 else 1-alpha w = w * (1-pt).pow(gamma) return F.binary_cross_entropy_with_logits(x, t, w, size_average=False)
def focal_loss(self, x, y): '''Focal loss. Args: x: (tensor) sized [N,D]. y: (tensor) sized [N,]. Return: (tensor) focal loss. ''' alpha = 0.25 gamma = 2.0 t = one_hot_embedding(y.data.cpu(), 1 + self.num_classes) # [N,21] t = t[:, 1:] # exclude background t = t.cuda() # [N,20] p = x.sigmoid() pt = p * t + (1 - p) * (1 - t) # pt = p if t > 0 else 1-p w = alpha * t + (1 - alpha) * (1 - t ) # w = alpha if t > 0 else 1-alpha #w = w * (1 - pt).pow(gamma) FL = -w * (1 - pt).pow(gamma) * pt.log() return FL.sum()
def focal_loss(self, x, y): '''Focal loss. Args: x: (tensor) sized [N,D]. y: (tensor) sized [N,]. Return: (tensor) focal loss. ''' alpha = 0.25 gamma = 2 t = one_hot_embedding(y.data.cpu(), self.num_classes) # [N,D] t = Variable(t).cuda() p = x.sigmoid() pt = p.clone() pt[t == 0] = 1 - p[t == 0] # pt = p if t>0 else 1-p w = alpha * t + (1 - alpha) * (1 - t) w = w * (1 - pt).pow(gamma) return F.binary_cross_entropy_with_logits(x, t, w, size_average=False)
def focal_loss_alt(self, x, y): """ Focal loss alternative. Args: :param x: (tensor) sized [N, D] :param y: (tensor) sized [N, ]. :return: (tensor) focal loss. """ alpha = 0.25 t = one_hot_embedding(y.data.cpu().long(), 1 + self.num_classes) # [N, 81] t = t[:, 1:] # exclude background t = Variable(t).cuda() xt = x * (2 * t - 1) # xt = x if t>0 else -x pt = (2 * xt + 1).sigmoid() w = alpha * t + (1 - alpha) * (1 - t) loss = -w * pt.log() / 2 return loss.sum()
def focal_loss(self, x, y): """ Focal loss. Args: :param x: (tensor) sized [N, D] :param y: (tensor) sized [N, ]. :return: (tensor) focal loss. """ alpha = 0.25 gamma = 2 t = one_hot_embedding(y.data.cpu().long(), 1 + self.num_classes) # [N, 81] t = t[:, 1:] # exclude background t = Variable(t).cuda() p = x.sigmoid() pt = p * t + (1 - p) * (1 - t) w = alpha * t + (1 - alpha) * (1 - t) w = w * (1 - pt).pow(gamma) return F.binary_cross_entropy_with_logits(x, t, w, size_average=False)
def focal_loss_alt(self, x, y): '''Focal loss alternative. Args: x: (tensor) sized [N,D]. y: (tensor) sized [N,]. Return: (tensor) focal loss. This implementation is numerically unstable because of log function. ''' alpha = 0.25 t = one_hot_embedding(y.data.cpu(), 1 + self.num_classes) t = t[:, 1:] t = Variable(t).cuda() xt = x * (2 * t - 1) # xt = x if t > 0 else -x pt = ((2 * xt + 1).sigmoid() + 0.0001).data.clamp_(min=0, max=1) w = alpha * t + (1 - alpha) * (1 - t) loss = -w * pt.log() / 2 return loss.sum()
def optimize(self, agent, database, n_actions, n_tasks): if database.__len__() < self.batch_size: return None # # Estimate visits # NAST = None # with torch.no_grad(): # for batch in range(0, self.n_batches_estimation*8): # # Sample batch # inner_states, outer_states, actions, rewards, dones, \ # next_inner_states, next_outer_states, tasks = \ # database.sample(self.batch_size) # PS_s, log_PS_s = agent(inner_states, outer_states) # A_one_hot = one_hot_embedding(actions, n_actions) # T_one_hot = one_hot_embedding(tasks, n_tasks) # PAS = PS_s.unsqueeze(2) * A_one_hot.unsqueeze(1) # NAST_batch = torch.einsum('ijk,ih->hjk', PAS, T_one_hot).detach() + 1e-8 # if NAST is None: # NAST = NAST_batch # else: # NAST = NAST + NAST_batch # Sample batch inner_states, outer_states, actions, rewards, dones, \ next_inner_states, next_outer_states, tasks = \ database.sample(self.batch_size) PS_s, log_PS_s = agent(inner_states, outer_states) A_one_hot = one_hot_embedding(actions, n_actions) T_one_hot = one_hot_embedding(tasks, n_tasks) PAS_batch = PS_s.unsqueeze(2) * A_one_hot.unsqueeze(1) NAST = torch.einsum('ijk,ih->hjk', PAS_batch, T_one_hot).detach() + 1e-8 PAST_batch = NAST / NAST.sum() if self.PAST is None: self.PAST = PAST_batch else: self.PAST = self.PAST * (1.-self.update_rate) + PAST_batch * self.update_rate # PAST = NAST / NAST.sum() PT = self.PAST.sum((1,2)) PST = self.PAST.sum(2) PS_T = PST / PT.view(-1,1) PA_ST = self.PAST / PST.unsqueeze(2) PAT = self.PAST.sum(1) PA_T = PAT / PT.view(-1,1) PAS_T = self.PAST / PT.view(-1,1,1) log_PS_T = torch.log(PS_T) log_PA_T = torch.log(PA_T) log_PA_ST = torch.log(PA_ST) HS_gT = torch.einsum('ij,hj->ih', PS_s, -log_PS_T).mean(0) HS_s = -(PS_s * log_PS_s).sum(1).mean() ISs_gT = HS_gT - HS_s ISs_T = (PT * ISs_gT).sum() HA_gT = -(PA_T * log_PA_T).sum(1) HA_T = (PT * HA_gT).sum() HA_sT = 0.03*np.log(n_actions) HA_ST = -(PS_s * log_PA_ST[tasks,:,actions]).sum(1).mean() HA_SgT = -(PAS_T * log_PA_ST).sum((1,2)) PS_s.unsqueeze(1) IAs_gT = HA_gT - HA_sT IAS_gT = HA_gT - HA_SgT IAs_SgT = IAs_gT - IAS_gT IAs_T = (PT * IAs_gT).sum() IAS_T = HA_T - HA_ST IAs_ST = IAs_T - IAS_T n_concepts = PS_s.shape[1] H_max = np.log(n_concepts) classifier_loss = IAs_ST + self.beta * ISs_T agent.classifier.optimizer.zero_grad() classifier_loss.backward() clip_grad_norm_(agent.classifier.parameters(), self.clip_value) agent.classifier.optimizer.step() joint_metrics = { 'HS_T': HS_gT.mean().item(), 'HS_s': HS_s.item(), 'HA_T': HA_T.item(), 'HA_sT': HA_sT, 'HA_ST': HA_ST.item(), 'ISs_T': ISs_T.item(), 'IAs_T': IAs_T.item(), 'IAS_T': IAS_T.item(), 'IAs_ST': IAs_ST.item(), 'loss': classifier_loss.item(), } metrics_per_task = {} for task in range(0, n_tasks): metrics_per_task['HS_T'+str(task)] = HS_gT[task].item() metrics_per_task['HA_T'+str(task)] = HA_gT[task].item() metrics_per_task['HA_ST'+str(task)] = HA_SgT[task].item() metrics_per_task['ISs_T'+str(task)] = ISs_gT[task].item() metrics_per_task['IAs_T'+str(task)] = IAs_gT[task].item() metrics_per_task['IAS_T'+str(task)] = IAS_gT[task].item() metrics_per_task['IAs_ST'+str(task)] = IAs_SgT[task].item() metrics = {**joint_metrics, **metrics_per_task} return metrics
def main(batch_size, continue_training, exp_name, learning_rate, num_epochs, print_freq, run_colab): # Data data_folder = create_data_lists(run_colab) train_dataset = PascalVOCDataset(data_folder, split='test', keep_difficult=keep_difficult) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=train_dataset.collate_fn, num_workers=workers, pin_memory=True) # note that we're passing the collate function here # Networks checkpoint = torch.load(exp_name / "checkpoint_ssd300.pth.tar", map_location=device) print(f"Number of training epochs for detection network: {checkpoint['epoch']}") detection_network = checkpoint['model'] if continue_training: adversarial_checkpoint = torch.load(exp_name / checkpoint, map_location=device) discriminator = adversarial_checkpoint['adversarial_model'] optimizer = adversarial_checkpoint['optimizer'] start_epoch = adversarial_checkpoint['epoch'] print(f"Continue training of adversarial network from epoch {start_epoch}") else: start_epoch = 0 image_encoder = VGGBase() discriminator = Discriminator(num_classes) optimizer = torch.optim.Adam(list(discriminator.parameters()) + list(image_encoder.parameters()), lr=learning_rate, weight_decay=1e-5) discriminator, image_encoder = discriminator.to(device), image_encoder.to(device) loss_function = GANLoss('vanilla').to(device) losses = AverageMeter() # loss for epoch in range(start_epoch, num_epochs): for j, (images, boxes, labels, _) in enumerate(train_loader): images = images.to(device) _, image_embedding = image_encoder(images) random_box_indices = [np.random.randint(len(box)) for box in boxes] random_boxes = torch.stack([box[random_box_indices[i]] for i, box in enumerate(boxes)]).to(device) random_labels = torch.stack([one_hot_embedding(label[random_box_indices[i]], num_classes) for i, label in enumerate(labels)]).to(device) pred_real = discriminator(random_boxes, random_labels, image_embedding) loss_real = loss_function(pred_real, 1) with torch.no_grad(): predicted_locs, predicted_scores = detection_network.forward(images) pred_boxes, pred_labels, _ = detection_network.detect_objects(predicted_locs, predicted_scores, min_score=0.2, max_overlap=0.45, top_k=200) random_box_indices = [np.random.randint(len(box)) for box in pred_boxes] random_fake_boxes = torch.stack([box[random_box_indices[i]] for i, box in enumerate(pred_boxes)]).to(device) random_fake_labels = torch.stack([one_hot_embedding(label[random_box_indices[i]], num_classes) for i, label in enumerate(pred_labels)]).to(device) pred_fake = discriminator(random_fake_boxes, random_fake_labels, image_embedding) loss_fake = loss_function(pred_fake, 0) total_loss = loss_fake + loss_real optimizer.zero_grad() total_loss.backward() optimizer.step() losses.update(total_loss.item(), images.size(0)) if j % print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch, j, len(train_loader), loss=losses)) save_adversarial_checkpoint(epoch, discriminator, image_encoder, optimizer, exp_name)
nhid=args.hidden, nclass=labels.max().item() + 1, adj=adj, dropout_rate=args.dropout) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) if args.cuda: model.to(device) features = features.to(device) adj = adj.to(device) labels = labels.to(device) idx_train = idx_train.to(device) idx_val = idx_val.to(device) idx_test = idx_test.to(device) labels_for_lpa = one_hot_embedding(labels, labels.max().item() + 1).type(torch.FloatTensor).to(device) def train(epoch): t = time.time() model.train() optimizer.zero_grad() output, y_hat = model(features, adj, labels_for_lpa) loss_gcn = F.nll_loss(output[idx_train], labels[idx_train]) loss_lpa = F.nll_loss(y_hat, labels) acc_train = accuracy(output[idx_train], labels[idx_train]) loss_train = loss_gcn + args.Lambda * loss_lpa loss_train.backward(retain_graph=True) optimizer.step()
def optimize(self, agents, database, n_step_td=1): if database.__len__() < self.batch_size: return None # Sample batch inner_states, outer_states, actions, rewards, \ dones, next_inner_states, next_outer_states = \ database.sample(self.batch_size) n_agents = len(agents) q_target = 0.0 log_softmax_target = 0.0 HA_s_mean = 0.0 for i in range(0, n_agents - 1): q_target_i, log_softmax_target_i, HA_s_mean_i = \ self.calculate_targets(agents[i], n_step_td, inner_states, outer_states, actions, rewards, dones, next_inner_states, next_outer_states) q_target += (q_target_i - q_target) / (i + 1) log_softmax_target += (log_softmax_target_i - log_softmax_target) / (i + 1) HA_s_mean += (HA_s_mean_i - HA_s_mean) / (i + 1) agent = copy.deepcopy(agents[-1]) # Alias for actor-critic module actor_critic = agent.second_level_architecture # Calculate q-values and action likelihoods q, next_q, next_PA_s, next_log_PA_s, log_alpha = \ actor_critic.evaluate_critic(inner_states, outer_states, next_inner_states, next_outer_states) alpha = log_alpha.exp().item() # Calculate entropy of the action distributions HA_s = -(next_PA_s * next_log_PA_s).sum(1, keepdim=True) HA_s_mean_last = HA_s.detach().mean() HA_s_mean += (HA_s_mean_last - HA_s_mean) / n_agents # Update mean entropy if self.H_mean is None: self.H_mean = HA_s_mean.item() else: self.H_mean = HA_s_mean.item( ) * self.entropy_update_rate + self.H_mean * ( 1.0 - self.entropy_update_rate) # Choose minimum next q-value to avoid overestimation of target if not actor_critic._parallel: next_q_target = torch.min(next_q[0], next_q[1]) else: next_q_target = next_q.min(1)[0] # Calculate next v-value, exactly, with the next action distribution next_v_target = (next_PA_s * (next_q_target - alpha * (next_log_PA_s + self.H_mean))).sum( 1, keepdim=True) # Estimate q-value target by sampling Bellman expectation q_target_last = rewards + self.discount_factor**n_step_td * next_v_target * ( 1. - dones) q_target += (q_target_last - q_target) / n_agents if not actor_critic._parallel: # Select q-values corresponding to the action taken q1_A = q[0][np.arange(self.batch_size), actions].view(-1, 1) q2_A = q[1][np.arange(self.batch_size), actions].view(-1, 1) # Calculate losses for both critics as the quadratic TD errors q1_loss = (q1_A - q_target.detach()).pow(2).mean() q2_loss = (q2_A - q_target.detach()).pow(2).mean() q_loss = q1_loss + q2_loss else: # Select q-values corresponding to the action taken q_A = q[np.arange(self.batch_size), :, actions].view(q.shape[0], q.shape[1]) # Calculate losses for both critics as the quadratic TD errors q_loss = (q_A - q_target.unsqueeze(1).detach()).pow(2).mean() # Create critic optimizer and optimize model actor_critic.q.optimizer.zero_grad() q_loss.backward() clip_grad_norm_(actor_critic.q.parameters(), self.clip_value) actor_critic.q.optimizer.step() # Calculate q-values and action likelihoods after critic SGD q, PA_s, log_PA_s = actor_critic.evaluate_actor( inner_states, outer_states) # Choose mean q-value to avoid overestimation if not actor_critic._parallel: q_dist = torch.min(q[0], q[1]) else: q_dist = q.min(1)[0] # Calculate normalizing factors for target softmax distributions z = torch.logsumexp(q_dist / (alpha + 1e-10), 1, keepdim=True) # Calculate the target log-softmax distribution log_softmax_target_last = q_dist / (alpha + 1e-10) - z log_softmax_target += (log_softmax_target_last - log_softmax_target) / n_agents # Calculate actor losses as the KL divergence between action # distributions and softmax target distributions difference_ratio = alpha * (log_PA_s - log_softmax_target).detach() actor_loss = (PA_s * difference_ratio).sum(1, keepdim=True).mean() # Alias for concept module concept_net = agent.concept_architecture PS_s = (concept_net(inner_states, outer_states)[0]).detach() PA_S_target, log_PA_S_target = agent.PA_S() new_PS = PS_s.mean(0) + 1e-6 new_PS = new_PS / new_PS.sum() if self.PS is None: self.PS = new_PS.detach() else: self.PS = new_PS.detach() * self.marginal_update_rate + self.PS * ( 1.0 - self.marginal_update_rate) if self.distributed_contribution: PA_S = torch.einsum('ij,ik->jk', PS_s, PA_s) + 1e-6 else: concepts = PS_s.argmax(1).detach().cpu().numpy() S_one_hot = one_hot_embedding(concepts, PS_s.shape[1]) PA_S = torch.einsum('ij,ik->jk', S_one_hot, PA_s) + 1e-6 PA_S = PA_S / PA_S.sum(1, keepdim=True).detach() log_PA_S = torch.log(PA_S) if self.prior_loss_type == 'MSE': prior_loss = (log_PA_S - log_PA_S_target).pow(2).mean() else: KL_div = (PA_S * (log_PA_S - log_PA_S_target)).sum(1) prior_loss = (KL_div * self.PS).sum() actor_loss_with_prior = (actor_loss + self.prior_weight * prior_loss ) #/ (1.+self.prior_weight) # Create optimizer and optimize model actor_critic.actor.optimizer.zero_grad() actor_loss_with_prior.backward() clip_grad_norm_(actor_critic.actor.parameters(), self.clip_value) problems = False for param in actor_critic.actor.parameters(): if param.grad is not None: problems = problems or not torch.isfinite(param.grad).all() assert not problems, 'Explosion!' actor_critic.actor.optimizer.step() # Calculate loss for temperature parameter alpha scaled_min_entropy = self.min_entropy * self.epsilon alpha_error = (HA_s_mean - scaled_min_entropy).mean() alpha_loss = log_alpha * alpha_error.detach() # Optimize temperature (if it is learnable) if self.learn_alpha: # Create optimizer and optimize model actor_critic.alpha_optimizer.zero_grad() alpha_loss.backward() clip_grad_norm_([actor_critic.log_alpha], self.clip_value) actor_critic.alpha_optimizer.step() # Update targets of actor-critic and temperature param. actor_critic.update() # Anneal epsilon self.epsilon = np.max( [self.epsilon - self.delta_epsilon, self.min_epsilon]) agents.append(agent) metrics = { 'q_loss': q_loss.item(), 'actor_loss': actor_loss.item(), 'alpha_loss': alpha_loss.item(), 'prior_loss': prior_loss.item(), 'SAC_epsilon': self.epsilon, 'alpha': alpha, 'base_entropy': self.H_mean, } return metrics