示例#1
0
    def focal_loss2d(self, x, y):
        '''not works yet'''
        alpha = 0.25
        gamma = 2

        y = y.view(y.size(0), -1)

        if self.using_gpu is True:
            t = one_hot_embedding(y.data.cpu(), 2)
        else:
            t = one_hot_embedding(y.data, 2)
        t = t[:, 1:]  # exclude background

        if self.using_gpu is True:
            t = Variable(t).cuda()
        else:
            t = Variable(t)

        p = x.sigmoid()
        pt = p * t + (1 - p) * (1 - t)  # pt = p if t > 0 else 1-p
        w = alpha * t + (1 - alpha) * (1 - t
                                       )  # w = alpha if t > 0 else 1-alpha
        w = w * (1 - pt).pow(gamma)

        return F.binary_cross_entropy_with_logits(x, t, w)
示例#2
0
    def focal_loss2d(self, x, y):
        alpha = 0.25
        gamma = 2

        y = y.view(-1)
        x = x.permute(0, 2, 3, 1).contiguous().view(-1, 2)

        if self.using_gpu is True:
            t = one_hot_embedding(y.data.cpu(), 2)
        else:
            t = one_hot_embedding(y.data, 2)
        # t = t[:,1:]  # exclude background

        if self.using_gpu is True:
            t = Variable(t).cuda()
        else:
            t = Variable(t)

        p = x.sigmoid()
        pt = p * t + (1 - p) * (1 - t)  # pt = p if t > 0 else 1-p
        w = alpha * t + (1 - alpha) * (1 - t
                                       )  # w = alpha if t > 0 else 1-alpha
        w = w * (1 - pt).pow(gamma)

        return F.binary_cross_entropy_with_logits(x, t, w)
示例#3
0
    def focal_loss(self, x, y):
        '''Focal loss.

        Args:
          x: (tensor) sized [N,D].
          y: (tensor) sized [N,].

        Return:
          (tensor) focal loss.
        '''
        alpha = 0.25
        gamma = 2

        if self.using_gpu is True:
            t = one_hot_embedding(y.data.cpu(), 1 + self.num_classes)
        else:
            t = one_hot_embedding(y.data, 1 + self.num_classes)
        t = t[:, 1:]  # exclude background

        if self.using_gpu is True:
            t = Variable(t).cuda()
        else:
            t = Variable(t)

        p = x.sigmoid()
        pt = p * t + (1 - p) * (1 - t)  # pt = p if t > 0 else 1-p
        w = alpha * t + (1 - alpha) * (1 - t
                                       )  # w = alpha if t > 0 else 1-alpha
        w = w * (1 - pt).pow(gamma)

        return F.binary_cross_entropy_with_logits(x, t, w, size_average=False)
示例#4
0
文件: loss.py 项目: hq-liu/retinanet
    def focal_loss_sigmoid(self, x, y):
        """
        Sigmoid version of focal loss.

        This is described in the original paper.
        With BCELoss, the background should not be counted in num_classes.

        Args:
          x: (tensor) predictions, sized [N,D].
          y: (tensor) targets, sized [N,].

        Return:
          (tensor) focal loss.
        """
        alpha = 0.25
        gamma = 2

        t = one_hot_embedding(y.data.cpu(), 1 + self.num_classes)
        t = t[:, 1:]  # exclude background
        t = Variable(t).cuda()

        p = x.sigmoid()
        pt = p * t + (1 - p) * (1 - t)  # pt = p if t > 0 else 1-p
        w = alpha * t + (1 - alpha) * (1 - t
                                       )  # w = alpha if t > 0 else 1-alpha
        w = w * (1 - pt).pow(gamma)
        return F.binary_cross_entropy_with_logits(x, t, w, size_average=False)
示例#5
0
    def focal_loss_c(self, pred, targ):
        '''Focal loss

        Args:
            pred: (tensor) sized [#anchors, #num_classes]
            targ: (tensor) sized [#anchors, ]

        Return:
            loss = -1*alpht*(1-p_t)**gamma*log(p_t)
        '''
        alpha = 0.25
        gamma = 2

        targ = one_hot_embedding(targ.data.cpu(), 1 + self.num_classes)
        targ = targ[:, 1:]  # exclude_background
        targ = Variable(targ).cuda()  # [#anchors, #num_clsasses]

        if pred.is_cuda is False:
            pred = pred.cuda()
        p = pred.sigmoid()
        # p_t = p if pred is targ,        p_t = 1-p if pred is not targ
        p_t = p * targ + (1 - p) * (1 - targ)
        coeff1 = alpha * targ + (1 - alpha) * (1 - targ)
        coeff2 = (1 - p_t).pow(gamma) * coeff1

        sigmoid_p = pred.sigmoid()
        per_cross_ent = -1 * coeff2 * (p_t.log())

        return torch.sum(per_cross_ent)
示例#6
0
    def focal_loss(self, pred, targ):
        '''Focal loss

        Args:
            pred: (tensor) sized [#anchors, #num_classes]
            targ: (tensor) sized [#anchors, ]

        Return:
            loss = -1*alpht*(1-p_t)**gamma*log(p_t)
        '''
        alpha = 0.25
        gamma = 2

        targ = one_hot_embedding(targ.data.cpu(), 1 + self.num_classes)
        targ = targ[:, 1:]  # exclude_background
        targ = Variable(targ).cuda()  # [#anchors, #num_clsasses]

        p = pred.sigmoid()
        # p_t = p if pred is targ,        p_t = 1-p if pred is not targ
        p_t = p * targ + (1 - p) * (1 - targ)
        # coeff = alpha if pred is targ,  coeff = 1-alpha if pred is not targ
        coeff = alpha * targ + (1 - alpha) * (1 - targ)
        coeff *= (1 - p_t).pow(gamma)

        # within BCEwithlogits, pred and target go through sigmoid function
        return F.binary_cross_entropy_with_logits(pred,
                                                  targ,
                                                  coeff,
                                                  size_average=False)
示例#7
0
    def focal_loss_alt(self, x, y):
        '''Focal loss alternative.

        Args:
          x: (tensor) sized [N,D].
          y: (tensor) sized [N,].

        Return:
          (tensor) focal loss.
        '''
        alpha = 0.25

        t = one_hot_embedding(y.data.cpu(), 1 + self.num_classes)
        t = t[:, 1:]
        t = Variable(t).cuda()
        #         print('t',t)

        #         print('x',x)
        xt = x * (2 * t - 1)  # xt = x if t > 0 else -x
        pt = (2 * xt + 1).sigmoid()
        #         pt = torch.clamp(pt, 1.0, 1e-10)
        #         print('pt',pt)

        w = alpha * t + (1 - alpha) * (1 - t)
        loss = -w * pt.log() / 2
        return loss.sum()
示例#8
0
    def focal_loss_ku(self, x, y):
        '''Focal loss.
        Args:
          x: (tensor) sized [N,D].
          y: (tensor) sized [N,].
        Return:
          (tensor) focal loss.
        '''
        alpha = 0.25
        gamma = 2

        t = one_hot_embedding(y.data.cpu(), 1 + self.num_classes)  # [N,21]
        t = t[:, 1:]  # exclude background
        t = Variable(t).cuda()  # [N,20]

        p = x.sigmoid()
        pt = p * t + (1 - p) * (1 - t)  # pt = p if t > 0 else 1-p
        w = alpha * t + (1 - alpha) * (1 - t
                                       )  # w = alpha if t > 0 else 1-alpha
        w = w * (1 - pt).pow(gamma)
        return F.binary_cross_entropy_with_logits(x,
                                                  t,
                                                  w.detach(),
                                                  size_average=False,
                                                  reduction='sum')
示例#9
0
    def forward(self, x, y, classes=20):
        t = one_hot_embedding(y, classes + 1)  # classes include background
        t = t[:, 1:]  # exclude background
        p = x.sigmoid()
        pt = torch.where(t > 0, p, 1 - p)  # pt = p if t > 0 else 1-p

        w = (1 - pt).pow(self.gamma)
        w = torch.where(t > 0, self.alpha * w, (1 - self.alpha) * w).detach()
        loss = F.binary_cross_entropy_with_logits(x, t, w, reduction='sum')

        return loss
示例#10
0
    def forward(self, inputs, targets):
        class_onehot = one_hot_embedding(targets.data.cpu().long(),
                                         1 + self.num_classes)  # [N, 81]
        class_onehot = torch.Tensor(
            class_onehot[:, 1:]).cuda()  # exclude background

        prob = self.sigmoid(inputs)
        pt = prob * class_onehot + (1 - prob) * (1 - class_onehot)
        w = self.alpha * class_onehot + (1 - self.alpha) * (1 - class_onehot)
        w = w * (1 - pt).pow(self.gamma)
        return F.binary_cross_entropy_with_logits(inputs,
                                                  class_onehot,
                                                  w,
                                                  size_average=False)
示例#11
0
    def focal_loss(self, x, y, gamma=2):
        """
        Focal loss.

        Args:
          x: (tensor) sized [N,D].
          y: (tensor) sized [N,].

        Return:
          (tensor) focal loss.
        """
        t = one_hot_embedding(y.data.cpu(), self.num_classes)
        t = Variable(t).type(self.FloatTensor)  # [N,D]
        p = F.softmax(x, dim=1)  # [N,D]
        pt = (p * t).sum(1)  # [N,]
        loss = F.cross_entropy(x, y, reduce=False)  # [N,]
        loss = (1 - pt).pow(gamma) * loss

        return loss.sum()
示例#12
0
文件: loss.py 项目: Xlsean/Mask-RCNN
def focal_loss(self, x, y):
        '''Focal loss.
        Args:
          x: (tensor) sized [N,D].
          y: (tensor) sized [N,].
        Return:
          (tensor) focal loss.
        '''
        alpha = 0.25
        gamma = 2

        t = utils.one_hot_embedding(y.data.cpu(), 1+self.num_classes) 
        t = t[:,1:]  
        t = Variable(t).cuda() 

        p = x.sigmoid()
        pt = p*t + (1-p)*(1-t)         
        w = alpha*t + (1-alpha)*(1-t)  
        w = w * (1-pt).pow(gamma)
        return F.binary_cross_entropy_with_logits(x, t, w, size_average=False)
示例#13
0
    def focal_loss(self, x, y):
        alpha = 0.25
        gamma = 2

        target = one_hot_embedding(y, 1 + self.num_classes)
        target = target[:, 1:]  # exclude background

        prob = x.sigmoid()
        pred = prob * target + (1 - prob) * (1 - target
                                             )  # pt = p if t > 0 else 1-p
        weight = alpha * target + (1 - alpha) * (
            1 - target)  # w = alpha if t > 0 else 1-alpha
        weight = weight * (1 - pred).pow(gamma)
        weight = weight.detach()

        loss = F.binary_cross_entropy_with_logits(input=x,
                                                  target=target,
                                                  weight=weight,
                                                  reduction='sum')
        return loss
示例#14
0
    def focal_loss_alt(self, x, y):
        # https://github.com/kuangliu/pytorch-retinanet/issues/52  @miramind
        """Focal loss
        Args:
            x(tensor): size [N, D]
            y(tensor): size [N, ]
        Returns:
            (tensor): focal loss
        """
        #print(y)
        t = one_hot_embedding(y.data.cpu(), 1+self.num_classes) # [N,21]
        t = t[:, 1:]  # exclude background
        t = Variable(t).cuda()  # [N,20]

        logit = F.softmax(x)
        logit = logit.clamp(1e-7, 1.-1e-7)
        conf_loss_tmp = -1 * t.float() * torch.log(logit)
        conf_loss_tmp = self.balance_alpha * conf_loss_tmp * (1-logit)**self.gamma
        conf_loss = conf_loss_tmp.sum()

        return conf_loss
示例#15
0
    def focal_loss_alt(self, x, y):
        '''Focal loss alternative.

        Args:
          x: (tensor) sized [N,D].
          y: (tensor) sized [N,].

        Return:
          (tensor) focal loss.
        '''
        alpha = 0.25

        t = one_hot_embedding(y.data.cpu(), 1+self.num_classes)
        t = t[:,1:]
        t = Variable(t).cuda()

        xt = x*(2*t-1)  # xt = x if t > 0 else -x
        pt = (2*xt+1).sigmoid()

        w = alpha*t + (1-alpha)*(1-t)
        loss = -w*pt.log() / 2
        return loss.sum()
示例#16
0
    def focal_loss_new(self, x, y):
        alpha = 0.25
        gamma = 2

        t = one_hot_embedding(y.data.cpu(), 1 + self.num_classes)  # [N,9]
        t = t[:, 1:]  # exclude background
        t = Variable(t).cuda()  # [N,8]

        logpt = F.log_softmax(x)
        logpt = logpt.gather(1, t)
        logpt = logpt.view(-1)

        pt = Variable(logpt.data.exp())

        if alpha.type() != x.data.type():
            alpha = alpha.type_as(x.data)
        at = alpha.gather(0, t.data.view(-1))
        logpt = logpt * Variable(at)

        loss = -1 * (1 - pt)**gamma * logpt

        return loss.sum()
示例#17
0
    def focal_loss_alt(self, x, y):
        '''Focal loss alternative.

        Args:
          x: (tensor) sized [N,D].
          y: (tensor) sized [N,].

        Return:
          (tensor) focal loss.
        '''
        alpha = 0.25

        t = one_hot_embedding(y.data.cpu(), 1 + self.num_classes)
        t = t[:, 1:]
        t = t.to(device)

        xt = x * (2 * t - 1)  # xt = x if t > 0 else -x
        pt = (2 * xt + 1).sigmoid()

        w = alpha * t + (1 - alpha) * (1 - t)
        loss = -w * pt.log() / 2
        return loss.sum()
示例#18
0
    def focal_loss(self, x, y):
        '''Focal loss.

        Args:
          x: (tensor) sized [N,D].
          y: (tensor) sized [N,].

        Return:
          (tensor) focal loss.
        '''
        alpha = 0.25
        gamma = 2

        t = one_hot_embedding(y.data.cpu(), 1+self.num_classes)  # [N,21]
        t = t[:,1:]  # exclude background
        t = Variable(t).cuda()  # [N,20]

        p = x.sigmoid()
        pt = p*t + (1-p)*(1-t)         # pt = p if t > 0 else 1-p
        w = alpha*t + (1-alpha)*(1-t)  # w = alpha if t > 0 else 1-alpha
        w = w * (1-pt).pow(gamma)
        return F.binary_cross_entropy_with_logits(x, t, w, size_average=False)
示例#19
0
    def focal_loss(self, x, y):
        '''Focal loss.

        Args:
          x: (tensor) sized [N,D].
          y: (tensor) sized [N,].

        Return:
          (tensor) focal loss.
        '''
        alpha = 0.25
        gamma = 2.0

        t = one_hot_embedding(y.data.cpu(), 1 + self.num_classes)  # [N,21]
        t = t[:, 1:]  # exclude background
        t = t.cuda()  # [N,20]
        p = x.sigmoid()
        pt = p * t + (1 - p) * (1 - t)  # pt = p if t > 0 else 1-p
        w = alpha * t + (1 - alpha) * (1 - t
                                       )  # w = alpha if t > 0 else 1-alpha
        #w = w * (1 - pt).pow(gamma)
        FL = -w * (1 - pt).pow(gamma) * pt.log()
        return FL.sum()
示例#20
0
    def focal_loss(self, x, y):
        '''Focal loss.

        Args:
          x: (tensor) sized [N,D].
          y: (tensor) sized [N,].

        Return:
          (tensor) focal loss.
        '''
        alpha = 0.25
        gamma = 2

        t = one_hot_embedding(y.data.cpu(), self.num_classes)  # [N,D]
        t = Variable(t).cuda()

        p = x.sigmoid()
        pt = p.clone()
        pt[t == 0] = 1 - p[t == 0]  # pt = p if t>0 else 1-p

        w = alpha * t + (1 - alpha) * (1 - t)
        w = w * (1 - pt).pow(gamma)
        return F.binary_cross_entropy_with_logits(x, t, w, size_average=False)
示例#21
0
    def focal_loss_alt(self, x, y):
        """
        Focal loss alternative.

        Args:
        :param x: (tensor) sized [N, D]
        :param y: (tensor) sized [N, ].
        :return:
                (tensor) focal loss.
        """
        alpha = 0.25

        t = one_hot_embedding(y.data.cpu().long(),
                              1 + self.num_classes)  # [N, 81]
        t = t[:, 1:]  # exclude background
        t = Variable(t).cuda()

        xt = x * (2 * t - 1)  # xt = x if t>0 else -x
        pt = (2 * xt + 1).sigmoid()

        w = alpha * t + (1 - alpha) * (1 - t)
        loss = -w * pt.log() / 2
        return loss.sum()
示例#22
0
    def focal_loss(self, x, y):
        """
        Focal loss.

        Args:
        :param x: (tensor) sized [N, D]
        :param y: (tensor) sized [N, ].
        :return:
                (tensor) focal loss.
        """
        alpha = 0.25
        gamma = 2

        t = one_hot_embedding(y.data.cpu().long(),
                              1 + self.num_classes)  # [N, 81]
        t = t[:, 1:]  # exclude background
        t = Variable(t).cuda()

        p = x.sigmoid()
        pt = p * t + (1 - p) * (1 - t)
        w = alpha * t + (1 - alpha) * (1 - t)
        w = w * (1 - pt).pow(gamma)
        return F.binary_cross_entropy_with_logits(x, t, w, size_average=False)
示例#23
0
    def focal_loss_alt(self, x, y):
        '''Focal loss alternative.

        Args:
          x: (tensor) sized [N,D].
          y: (tensor) sized [N,].

        Return:
          (tensor) focal loss.

        This implementation is numerically unstable because of log function.
        '''
        alpha = 0.25

        t = one_hot_embedding(y.data.cpu(), 1 + self.num_classes)
        t = t[:, 1:]
        t = Variable(t).cuda()

        xt = x * (2 * t - 1)  # xt = x if t > 0 else -x
        pt = ((2 * xt + 1).sigmoid() + 0.0001).data.clamp_(min=0, max=1)

        w = alpha * t + (1 - alpha) * (1 - t)
        loss = -w * pt.log() / 2
        return loss.sum()
示例#24
0
    def optimize(self, agent, database, n_actions, n_tasks): 
        if database.__len__() < self.batch_size:
            return None

        # # Estimate visits
        # NAST = None

        # with torch.no_grad():
        #     for batch in range(0, self.n_batches_estimation*8):
        #         # Sample batch        
        #         inner_states, outer_states, actions, rewards, dones, \
        #             next_inner_states, next_outer_states, tasks = \
        #             database.sample(self.batch_size)
                
        #         PS_s, log_PS_s = agent(inner_states, outer_states)
        #         A_one_hot = one_hot_embedding(actions, n_actions)
        #         T_one_hot = one_hot_embedding(tasks, n_tasks)

        #         PAS = PS_s.unsqueeze(2) * A_one_hot.unsqueeze(1)
        #         NAST_batch = torch.einsum('ijk,ih->hjk', PAS, T_one_hot).detach() + 1e-8

        #         if NAST is None:
        #             NAST = NAST_batch
        #         else:
        #             NAST = NAST + NAST_batch

        # Sample batch        
        inner_states, outer_states, actions, rewards, dones, \
            next_inner_states, next_outer_states, tasks = \
            database.sample(self.batch_size)
        
        PS_s, log_PS_s = agent(inner_states, outer_states)
        A_one_hot = one_hot_embedding(actions, n_actions)
        T_one_hot = one_hot_embedding(tasks, n_tasks)

        PAS_batch = PS_s.unsqueeze(2) * A_one_hot.unsqueeze(1)
        NAST = torch.einsum('ijk,ih->hjk', PAS_batch, T_one_hot).detach() + 1e-8
        PAST_batch = NAST / NAST.sum()

        if self.PAST is None:
            self.PAST = PAST_batch
        else:
            self.PAST = self.PAST * (1.-self.update_rate) + PAST_batch * self.update_rate

        # PAST = NAST / NAST.sum()
        PT = self.PAST.sum((1,2))
        PST = self.PAST.sum(2)
        PS_T = PST / PT.view(-1,1)
        PA_ST = self.PAST / PST.unsqueeze(2)
        PAT = self.PAST.sum(1)
        PA_T = PAT / PT.view(-1,1)
        PAS_T = self.PAST / PT.view(-1,1,1)

        log_PS_T = torch.log(PS_T)
        log_PA_T = torch.log(PA_T)
        log_PA_ST = torch.log(PA_ST)         
        
        HS_gT = torch.einsum('ij,hj->ih', PS_s, -log_PS_T).mean(0)
        HS_s = -(PS_s * log_PS_s).sum(1).mean()
        ISs_gT = HS_gT - HS_s
        ISs_T = (PT * ISs_gT).sum()
        
        HA_gT = -(PA_T * log_PA_T).sum(1)
        HA_T = (PT * HA_gT).sum()
        HA_sT = 0.03*np.log(n_actions)
        HA_ST = -(PS_s * log_PA_ST[tasks,:,actions]).sum(1).mean()
        HA_SgT = -(PAS_T * log_PA_ST).sum((1,2))  

        PS_s.unsqueeze(1)

        IAs_gT = HA_gT - HA_sT   
        IAS_gT = HA_gT - HA_SgT
        IAs_SgT = IAs_gT - IAS_gT

        IAs_T = (PT * IAs_gT).sum()
        IAS_T = HA_T - HA_ST
        IAs_ST = IAs_T - IAS_T
         
        
        n_concepts = PS_s.shape[1]
        H_max = np.log(n_concepts)
        classifier_loss = IAs_ST + self.beta * ISs_T            

        agent.classifier.optimizer.zero_grad()
        classifier_loss.backward()
        clip_grad_norm_(agent.classifier.parameters(), self.clip_value)
        agent.classifier.optimizer.step()

        joint_metrics = {
            'HS_T': HS_gT.mean().item(),
            'HS_s': HS_s.item(),
            'HA_T': HA_T.item(),
            'HA_sT': HA_sT,
            'HA_ST': HA_ST.item(),
            'ISs_T': ISs_T.item(),
            'IAs_T': IAs_T.item(),
            'IAS_T': IAS_T.item(),
            'IAs_ST': IAs_ST.item(),
            'loss': classifier_loss.item(),
        }

        metrics_per_task = {}
        for task in range(0, n_tasks):
            metrics_per_task['HS_T'+str(task)] = HS_gT[task].item()
            metrics_per_task['HA_T'+str(task)] = HA_gT[task].item()
            metrics_per_task['HA_ST'+str(task)] = HA_SgT[task].item()
            metrics_per_task['ISs_T'+str(task)] = ISs_gT[task].item()
            metrics_per_task['IAs_T'+str(task)] = IAs_gT[task].item()
            metrics_per_task['IAS_T'+str(task)] = IAS_gT[task].item()
            metrics_per_task['IAs_ST'+str(task)] = IAs_SgT[task].item()

        metrics = {**joint_metrics, **metrics_per_task}
        
        return metrics
def main(batch_size, continue_training, exp_name, learning_rate, num_epochs, print_freq, run_colab):
    # Data
    data_folder = create_data_lists(run_colab)
    train_dataset = PascalVOCDataset(data_folder,
                                     split='test',
                                     keep_difficult=keep_difficult)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                                               collate_fn=train_dataset.collate_fn, num_workers=workers,
                                               pin_memory=True)  # note that we're passing the collate function here

    # Networks
    checkpoint = torch.load(exp_name / "checkpoint_ssd300.pth.tar", map_location=device)
    print(f"Number of training epochs for detection network: {checkpoint['epoch']}")
    detection_network = checkpoint['model']

    if continue_training:
        adversarial_checkpoint = torch.load(exp_name / checkpoint, map_location=device)
        discriminator = adversarial_checkpoint['adversarial_model']
        optimizer = adversarial_checkpoint['optimizer']
        start_epoch = adversarial_checkpoint['epoch']
        print(f"Continue training of adversarial network from epoch {start_epoch}")
    else:
        start_epoch = 0
        image_encoder = VGGBase()
        discriminator = Discriminator(num_classes)
        optimizer = torch.optim.Adam(list(discriminator.parameters()) + list(image_encoder.parameters()),
                                     lr=learning_rate,
                                     weight_decay=1e-5)
    discriminator, image_encoder = discriminator.to(device), image_encoder.to(device)
    loss_function = GANLoss('vanilla').to(device)
    losses = AverageMeter()  # loss


    for epoch in range(start_epoch, num_epochs):
        for j, (images, boxes, labels, _) in enumerate(train_loader):
            images = images.to(device)
            _, image_embedding = image_encoder(images)
            random_box_indices = [np.random.randint(len(box)) for box in boxes]
            random_boxes = torch.stack([box[random_box_indices[i]] for i, box in enumerate(boxes)]).to(device)
            random_labels = torch.stack([one_hot_embedding(label[random_box_indices[i]], num_classes) for i, label in enumerate(labels)]).to(device)
            pred_real = discriminator(random_boxes, random_labels, image_embedding)
            loss_real = loss_function(pred_real, 1)

            with torch.no_grad():
                predicted_locs, predicted_scores = detection_network.forward(images)
                pred_boxes, pred_labels, _ = detection_network.detect_objects(predicted_locs,
                                                                                                       predicted_scores,
                                                                                                       min_score=0.2,
                                                                                                       max_overlap=0.45,
                                                                                                       top_k=200)
            random_box_indices = [np.random.randint(len(box)) for box in pred_boxes]
            random_fake_boxes = torch.stack([box[random_box_indices[i]] for i, box in enumerate(pred_boxes)]).to(device)
            random_fake_labels = torch.stack([one_hot_embedding(label[random_box_indices[i]], num_classes) for i, label in enumerate(pred_labels)]).to(device)
            pred_fake = discriminator(random_fake_boxes, random_fake_labels, image_embedding)
            loss_fake = loss_function(pred_fake, 0)

            total_loss = loss_fake + loss_real
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            losses.update(total_loss.item(), images.size(0))
            if j % print_freq == 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch, j, len(train_loader), loss=losses))
        save_adversarial_checkpoint(epoch, discriminator, image_encoder, optimizer, exp_name)
示例#26
0
            nhid=args.hidden,
            nclass=labels.max().item() + 1,
            adj=adj,
            dropout_rate=args.dropout)
optimizer = optim.Adam(model.parameters(),
                       lr=args.lr, weight_decay=args.weight_decay)

if args.cuda:
    model.to(device)
    features = features.to(device)
    adj = adj.to(device)
    labels = labels.to(device)
    idx_train = idx_train.to(device)
    idx_val = idx_val.to(device)
    idx_test = idx_test.to(device)
    labels_for_lpa = one_hot_embedding(labels, labels.max().item() + 1).type(torch.FloatTensor).to(device)



def train(epoch):
    t = time.time()
    model.train()
    optimizer.zero_grad()
    output, y_hat = model(features, adj, labels_for_lpa)
    loss_gcn = F.nll_loss(output[idx_train], labels[idx_train])
    loss_lpa = F.nll_loss(y_hat, labels)
    acc_train = accuracy(output[idx_train], labels[idx_train])
    loss_train = loss_gcn + args.Lambda * loss_lpa
    loss_train.backward(retain_graph=True)
    optimizer.step()
示例#27
0
    def optimize(self, agents, database, n_step_td=1):
        if database.__len__() < self.batch_size:
            return None

        # Sample batch
        inner_states, outer_states, actions, rewards, \
            dones, next_inner_states, next_outer_states = \
            database.sample(self.batch_size)

        n_agents = len(agents)
        q_target = 0.0
        log_softmax_target = 0.0
        HA_s_mean = 0.0
        for i in range(0, n_agents - 1):
            q_target_i, log_softmax_target_i, HA_s_mean_i = \
                self.calculate_targets(agents[i], n_step_td, inner_states,
                                    outer_states, actions, rewards, dones,
                                    next_inner_states, next_outer_states)
            q_target += (q_target_i - q_target) / (i + 1)
            log_softmax_target += (log_softmax_target_i -
                                   log_softmax_target) / (i + 1)
            HA_s_mean += (HA_s_mean_i - HA_s_mean) / (i + 1)

        agent = copy.deepcopy(agents[-1])

        # Alias for actor-critic module
        actor_critic = agent.second_level_architecture

        # Calculate q-values and action likelihoods
        q, next_q, next_PA_s, next_log_PA_s, log_alpha = \
            actor_critic.evaluate_critic(inner_states, outer_states,
                                next_inner_states, next_outer_states)

        alpha = log_alpha.exp().item()

        # Calculate entropy of the action distributions
        HA_s = -(next_PA_s * next_log_PA_s).sum(1, keepdim=True)
        HA_s_mean_last = HA_s.detach().mean()
        HA_s_mean += (HA_s_mean_last - HA_s_mean) / n_agents

        # Update mean entropy
        if self.H_mean is None:
            self.H_mean = HA_s_mean.item()
        else:
            self.H_mean = HA_s_mean.item(
            ) * self.entropy_update_rate + self.H_mean * (
                1.0 - self.entropy_update_rate)

        # Choose minimum next q-value to avoid overestimation of target
        if not actor_critic._parallel:
            next_q_target = torch.min(next_q[0], next_q[1])
        else:
            next_q_target = next_q.min(1)[0]

        # Calculate next v-value, exactly, with the next action distribution
        next_v_target = (next_PA_s * (next_q_target - alpha *
                                      (next_log_PA_s + self.H_mean))).sum(
                                          1, keepdim=True)

        # Estimate q-value target by sampling Bellman expectation
        q_target_last = rewards + self.discount_factor**n_step_td * next_v_target * (
            1. - dones)
        q_target += (q_target_last - q_target) / n_agents

        if not actor_critic._parallel:
            # Select q-values corresponding to the action taken
            q1_A = q[0][np.arange(self.batch_size), actions].view(-1, 1)
            q2_A = q[1][np.arange(self.batch_size), actions].view(-1, 1)

            # Calculate losses for both critics as the quadratic TD errors
            q1_loss = (q1_A - q_target.detach()).pow(2).mean()
            q2_loss = (q2_A - q_target.detach()).pow(2).mean()
            q_loss = q1_loss + q2_loss
        else:
            # Select q-values corresponding to the action taken
            q_A = q[np.arange(self.batch_size), :,
                    actions].view(q.shape[0], q.shape[1])

            # Calculate losses for both critics as the quadratic TD errors
            q_loss = (q_A - q_target.unsqueeze(1).detach()).pow(2).mean()

        # Create critic optimizer and optimize model
        actor_critic.q.optimizer.zero_grad()
        q_loss.backward()
        clip_grad_norm_(actor_critic.q.parameters(), self.clip_value)
        actor_critic.q.optimizer.step()

        # Calculate q-values and action likelihoods after critic SGD
        q, PA_s, log_PA_s = actor_critic.evaluate_actor(
            inner_states, outer_states)

        # Choose mean q-value to avoid overestimation
        if not actor_critic._parallel:
            q_dist = torch.min(q[0], q[1])
        else:
            q_dist = q.min(1)[0]

        # Calculate normalizing factors for target softmax distributions
        z = torch.logsumexp(q_dist / (alpha + 1e-10), 1, keepdim=True)

        # Calculate the target log-softmax distribution
        log_softmax_target_last = q_dist / (alpha + 1e-10) - z
        log_softmax_target += (log_softmax_target_last -
                               log_softmax_target) / n_agents

        # Calculate actor losses as the KL divergence between action
        # distributions and softmax target distributions
        difference_ratio = alpha * (log_PA_s - log_softmax_target).detach()
        actor_loss = (PA_s * difference_ratio).sum(1, keepdim=True).mean()

        # Alias for concept module
        concept_net = agent.concept_architecture
        PS_s = (concept_net(inner_states, outer_states)[0]).detach()
        PA_S_target, log_PA_S_target = agent.PA_S()

        new_PS = PS_s.mean(0) + 1e-6
        new_PS = new_PS / new_PS.sum()
        if self.PS is None:
            self.PS = new_PS.detach()
        else:
            self.PS = new_PS.detach() * self.marginal_update_rate + self.PS * (
                1.0 - self.marginal_update_rate)

        if self.distributed_contribution:
            PA_S = torch.einsum('ij,ik->jk', PS_s, PA_s) + 1e-6
        else:
            concepts = PS_s.argmax(1).detach().cpu().numpy()
            S_one_hot = one_hot_embedding(concepts, PS_s.shape[1])
            PA_S = torch.einsum('ij,ik->jk', S_one_hot, PA_s) + 1e-6
        PA_S = PA_S / PA_S.sum(1, keepdim=True).detach()
        log_PA_S = torch.log(PA_S)
        if self.prior_loss_type == 'MSE':
            prior_loss = (log_PA_S - log_PA_S_target).pow(2).mean()
        else:
            KL_div = (PA_S * (log_PA_S - log_PA_S_target)).sum(1)
            prior_loss = (KL_div * self.PS).sum()

        actor_loss_with_prior = (actor_loss + self.prior_weight * prior_loss
                                 )  #/ (1.+self.prior_weight)

        # Create optimizer and optimize model
        actor_critic.actor.optimizer.zero_grad()
        actor_loss_with_prior.backward()
        clip_grad_norm_(actor_critic.actor.parameters(), self.clip_value)

        problems = False
        for param in actor_critic.actor.parameters():
            if param.grad is not None:
                problems = problems or not torch.isfinite(param.grad).all()

        assert not problems, 'Explosion!'

        actor_critic.actor.optimizer.step()

        # Calculate loss for temperature parameter alpha
        scaled_min_entropy = self.min_entropy * self.epsilon
        alpha_error = (HA_s_mean - scaled_min_entropy).mean()
        alpha_loss = log_alpha * alpha_error.detach()

        # Optimize temperature (if it is learnable)
        if self.learn_alpha:
            # Create optimizer and optimize model
            actor_critic.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            clip_grad_norm_([actor_critic.log_alpha], self.clip_value)
            actor_critic.alpha_optimizer.step()

        # Update targets of actor-critic and temperature param.
        actor_critic.update()

        # Anneal epsilon
        self.epsilon = np.max(
            [self.epsilon - self.delta_epsilon, self.min_epsilon])

        agents.append(agent)

        metrics = {
            'q_loss': q_loss.item(),
            'actor_loss': actor_loss.item(),
            'alpha_loss': alpha_loss.item(),
            'prior_loss': prior_loss.item(),
            'SAC_epsilon': self.epsilon,
            'alpha': alpha,
            'base_entropy': self.H_mean,
        }

        return metrics