Ejemplo n.º 1
0
 def forward(self, input, target):
     output = F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
     l2_loss = sum(param.norm(2)**2 for param in self.model.parameters())
     output += self.l2 / 2 * l2_loss
     l1_loss = sum(param.norm(1) for param in self.model.parameters())
     output += self.l1 * l1_loss
     return output
Ejemplo n.º 2
0
    def mapping_step(self, stats):
        """
        Fooling discriminator training step.
        """
        if self.params.dis_lambda == 0:
            return 0

        self.discriminator.eval()

        # loss
        x, y = self.get_dis_xy(volatile=False)
        preds = self.discriminator(x)
        loss = F.binary_cross_entropy(preds, 1 - y)
        loss = self.params.dis_lambda * loss

        # check NaN
        if (loss != loss).data.any():
            logger.error("NaN detected (fool discriminator)")
            exit()

        # optim
        self.map_optimizer.zero_grad()
        loss.backward()
        self.map_optimizer.step()
        self.orthogonalize()

        return 2 * self.params.batch_size
Ejemplo n.º 3
0
    def forward(self, true_binary, rule_masks, raw_logits):
        if cmd_args.loss_type == 'binary':
            exp_pred = torch.exp(raw_logits) * rule_masks

            norm = F.torch.sum(exp_pred, 2, keepdim=True)
            prob = F.torch.div(exp_pred, norm)

            return F.binary_cross_entropy(prob, true_binary) * cmd_args.max_decode_steps

        if cmd_args.loss_type == 'perplexity':
            return my_perp_loss(true_binary, rule_masks, raw_logits)

        if cmd_args.loss_type == 'vanilla':
            exp_pred = torch.exp(raw_logits) * rule_masks + 1e-30
            norm = torch.sum(exp_pred, 2, keepdim=True)
            prob = torch.div(exp_pred, norm)

            ll = F.torch.abs(F.torch.sum( true_binary * prob, 2))
            mask = 1 - rule_masks[:, :, -1]
            logll = mask * F.torch.log(ll)

            loss = -torch.sum(logll) / true_binary.size()[1]
            
            return loss
        print('unknown loss type %s' % cmd_args.loss_type)
        raise NotImplementedError
Ejemplo n.º 4
0
    def calc_dis_loss(self, input_fake, input_real):
        # calculate the loss to train D
        outs0 = self.forward(input_fake)
        outs1 = self.forward(input_real)
        loss = 0

        for it, (out0, out1) in enumerate(zip(outs0, outs1)):
            if self.gan_type == 'lsgan':
                loss += torch.mean((out0 - 0)**2) + torch.mean((out1 - 1)**2)
            elif self.gan_type == 'nsgan':
                all0 = Variable(torch.zeros_like(out0.data).cuda(), requires_grad=False)
                all1 = Variable(torch.ones_like(out1.data).cuda(), requires_grad=False)
                loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all0) +
                                   F.binary_cross_entropy(F.sigmoid(out1), all1))
            else:
                assert 0, "Unsupported GAN type: {}".format(self.gan_type)
        return loss
Ejemplo n.º 5
0
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD
Ejemplo n.º 6
0
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784))

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    # Normalise by same number of elements as in reconstruction
    KLD /= args.batch_size * 784

    return BCE + KLD
Ejemplo n.º 7
0
    def forward(self, inputs, targets):
        if self.logits:
            BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
        else:
            BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return torch.sum(F_loss)
Ejemplo n.º 8
0
 def calc_gen_loss(self, input_fake):
     # calculate the loss to train G
     outs0 = self.forward(input_fake)
     loss = 0
     for it, (out0) in enumerate(outs0):
         if self.gan_type == 'lsgan':
             loss += torch.mean((out0 - 1)**2) # LSGAN
         elif self.gan_type == 'nsgan':
             all1 = Variable(torch.ones_like(out0.data).cuda(), requires_grad=False)
             loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all1))
         else:
             assert 0, "Unsupported GAN type: {}".format(self.gan_type)
     return loss
Ejemplo n.º 9
0
    def forward(ctx, true_binary, rule_masks, input_logits):
        ctx.save_for_backward(true_binary, rule_masks, input_logits)

        b = F.torch.max(input_logits, 2, keepdim=True)[0]
        raw_logits = input_logits - b
        exp_pred = torch.exp(raw_logits) * rule_masks

        norm = torch.sum(exp_pred, 2, keepdim=True)
        prob = torch.div(exp_pred, norm)
                
        loss = F.binary_cross_entropy(prob, true_binary)
        
        return loss
Ejemplo n.º 10
0
    def step(self, real_data, verbose: bool = False):
        mean, logvar, latent, fake_data = self.model(real_data)

        rec_loss = F.binary_cross_entropy(fake_data, (real_data > .5).float(), size_average=False)
        # rec_loss = F.binary_cross_entropy(fake_data, real_data, size_average=False)
        kl_div = -.5 * (1. + logvar - mean ** 2 - logvar.exp()).sum()

        self.opt.zero_grad()
        (rec_loss + self.beta * kl_div).backward()
        self.opt.step()

        if verbose:
            print(f"rec_loss = {rec_loss.item():6g}, KL_div = {kl_div.item():6g}, ")
Ejemplo n.º 11
0
    def forward(self, true_binary, rule_masks, raw_logits):
        if cmd_args.loss_type == 'binary':
            exp_pred = torch.exp(raw_logits) * rule_masks

            norm = F.torch.sum(exp_pred, 2, keepdim=True)
            prob = F.torch.div(exp_pred, norm)

            return F.binary_cross_entropy(prob, true_binary) * cmd_args.max_decode_steps

        if cmd_args.loss_type == 'perplexity':
            return my_perp_loss(true_binary, rule_masks, raw_logits)

        print('unknown loss type %s' % cmd_args.loss_type)
        raise NotImplementedError
Ejemplo n.º 12
0
def val_casenet(epoch,model,data_loader,args):
    model.eval()
    starttime = time.time()
    loss1Hist = []
    loss2Hist = []
    lossHist = []
    missHist = []
    accHist = []
    lenHist = []
    tpn = 0
    fpn = 0
    fnn = 0

    for i,(x,coord,isnod,y) in enumerate(data_loader):

        coord = Variable(coord,volatile=True).cuda()
        x = Variable(x,volatile=True).cuda()
        xsize = x.size()
        ydata = y.numpy()[:,0]
        y = Variable(y).float().cuda()
        isnod = Variable(isnod).float().cuda()

        nodulePred,casePred,casePred_each = model(x,coord)
        
        loss2 = binary_cross_entropy(casePred,y[:,0])
        missMask = (casePred_each<args.miss_thresh).float()
        missLoss = -torch.sum(missMask*isnod*torch.log(casePred_each+0.001))/xsize[0]/xsize[1]

        #loss2 = binary_cross_entropy(sigmoid(casePred),y[:,0])
        loss2Hist.append(loss2.data[0])
        missHist.append(missLoss.data[0])
        lenHist.append(len(x))
        outdata = casePred.data.cpu().numpy()
        #print([i,data_loader.dataset.split[i,1],sigmoid(casePred).data.cpu().numpy()])
        pred = outdata>0.5
        tpn += np.sum(1==pred[ydata==1])
        fpn += np.sum(1==pred[ydata==0])
        fnn += np.sum(0==pred[ydata==1])
        acc = np.mean(ydata==pred)
        accHist.append(acc)
    endtime = time.time()
    lenHist = np.array(lenHist)
    loss2Hist = np.array(loss2Hist)
    accHist = np.array(accHist)
    mean_loss2 = np.sum(loss2Hist*lenHist)/np.sum(lenHist)
    mean_missloss = np.sum(missHist*lenHist)/np.sum(lenHist)
    mean_acc = np.sum(accHist*lenHist)/np.sum(lenHist)
    print('Valid, epoch %d, loss2 %.4f, miss loss %.4f, acc %.4f, tpn %d, fpn %d, fnn %d,  time %3.2f'
          %(epoch,mean_loss2,mean_missloss,mean_acc,tpn,fpn, fnn, endtime-starttime))
Ejemplo n.º 13
0
 def compute_loss_and_gradient(self, x):
     self.optimizer.zero_grad()
     recon_x, z_mean, z_var = self.model_eval(x)
     binary_cross_entropy = functional.binary_cross_entropy(recon_x, x.reshape(-1, 784))
     # Uses analytical KL divergence expression for D_kl(q(z|x) || p(z))
     # Refer to Appendix B from VAE paper:
     # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
     # (https://arxiv.org/abs/1312.6114)
     kl_div = -0.5 * torch.sum(1 + z_var.log() - z_mean.pow(2) - z_var)
     kl_div /= self.args.batch_size * 784
     loss = binary_cross_entropy + kl_div
     if self.mode == TRAIN:
         loss.backward()
         self.optimizer.step()
     return loss.item()
Ejemplo n.º 14
0
def optimize_cnt(worm_img, eig_prev, skel_width, segment_length,  n_epochs = 1000):
    
    
    #this is the variable that is going t obe modified
    
    eig_r = [torch.nn.Parameter(x.data) for x in eigen_prev]#+ torch.zeros(*skel_prev.size()).normal_()
    
    
    optimizer = optim.SGD(eig_r, lr=0.01)
    for ii in range(n_epochs):
        skel_r = _h_eigenworms_inv_T(*eig_r)
        
        skel_map = get_skel_map(skel_r, skel_width)
        skel_map += 1e-3
        #%%
        p_w = (skel_map*worm_img)
        
        skel_map_inv = (-skel_map).add_(1)
        worm_img_inv = (-worm_img).add_(1)
        p_bng = (skel_map_inv*worm_img_inv)
        
        c_loss = F.binary_cross_entropy(p_w, p_bng)
        
        
        ds = skel_r[1:] - skel_r[:-1]
        dds = ds[1:] - ds[:-1]
        #seg_mean = seg_sizes.mean()
        
        cont_loss = ds.norm(p=2)
        curv_loss = dds.norm(p=2)
        
        seg_sizes = ((ds).pow(2)).sum(1).sqrt()
        seg_loss = (seg_sizes-segment_length).cosh().mean()
        
        loss = 50*c_loss #+ seg_loss #+ cont_loss/10 +  curv_loss/10
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if ii % 250 == 0:
            print(ii,
                  loss.data[0], 
                  c_loss.data[0],
                  seg_loss.data[0]
                  )
        
        #print(eig_r.data.numpy())
    return eig_r, skel_r, skel_map
Ejemplo n.º 15
0
    def dis_step(self, stats):
        """
        Train the discriminator.
        """
        self.discriminator.train()

        # loss
        x, y = self.get_dis_xy(volatile=True)
        preds = self.discriminator(Variable(x.data))
        loss = F.binary_cross_entropy(preds, y)
        stats['DIS_COSTS'].append(loss.data[0])

        # check NaN
        if (loss != loss).data.any():
            logger.error("NaN detected (discriminator)")
            exit()

        # optim
        self.dis_optimizer.zero_grad()
        loss.backward()
        self.dis_optimizer.step()
        clip_parameters(self.discriminator, self.params.dis_clip_weights)
Ejemplo n.º 16
0
    def train_epoch(self, epoch = 0):

        TINY = 1e-15 #BCE is NaN with an input of 0
        before_time = time.clock()

        self.encoder.train()
        self.decoder.eval()

        loss_sum = 0
        total = 0
        for img, label in self.dataloader:
            self.optimizer.zero_grad()

            X = Variable(img)
            if self.use_cuda:
                X = X.cuda()

            z_mu, z_var = self.encoder(X)
            z = self.encoder.sample(self.params['batch size'], z_mu, z_var)

            X_reconstructed = self.decoder(z)
            reconstruction_loss = F.binary_cross_entropy(X_reconstructed + TINY, X + TINY, size_average=False)
            KL_loss = z_mu.pow(2).add_(z_var.exp()).mul_(-1).add_(1).add_(z_var)
            KL_loss = torch.sum(KL_loss).mul_(-0.5)
            total_loss = reconstruction_loss + KL_loss
            total_loss.backward()
            self.optimizer.step()

        duration = time.clock() - before_time

        def loss_reporting(loss):
            return "Loss {}".format(loss.data[0])

        report = Result(duration, total_loss, epoch, loss_reporting)
        #TODO: Structured way to return training results
        return report
Ejemplo n.º 17
0
def recon_loss(x_recon, x):
    n = x.size(0)
    loss = F.binary_cross_entropy(x_recon, x, size_average=False).div(n)
    return loss
Ejemplo n.º 18
0
        # audio_data = F.avg_pool1d(audio_data, kernel_size=2, padding=1)
        text_emb = text_embedding(text_data)
        text_pos_emb = pos_embedding_(text_pos)
        enc_out, att_heads_enc = encoder(text_emb, text_mask, text_pos_emb)

        mel_pos_emb = pos_embedding(mel_pos)
        # [B, T, C], [B, T, C], [B, T, 1], [B, T, T_text]
        mels_out, mels_out_post, gates_out, att_heads_dec, att_heads = decoder(
            mel_data, enc_out, mel_mask, text_mask, mel_pos_emb)
        text_len = text_pos.max(1)[0]
        mel_len = mel_pos.max(1)[0]
        loss_mel = torch.sum(
            (mels_out - mel_data)**2) / torch.sum(mel_len * text_len)
        loss_mel_post = torch.sum(
            (mels_out_post - mel_data)**2) / torch.sum(mel_len * text_len)
        loss_gate = F.binary_cross_entropy(gates_out, gate)
        loss = loss_mel + loss_mel_post + loss_gate

        optimizer.zero_grad()
        loss.backward()

        grad_norm_enc = nn.utils.clip_grad_norm_(encoder.parameters(), 1.0)
        grad_norm_dec = nn.utils.clip_grad_norm_(decoder.parameters(), 1.0)

        optimizer.step()

        # -----------------------------------------

        global_idx += 1
        summ_counter += 1
        mean_losses += [
Ejemplo n.º 19
0
        model.cuda()
    print('Model:', model)
    optimizer = optim.Adam(model.parameters(), lr=.0003)

    for epoch in range(200):
        print("Epoch", epoch)
        model.train()
        train_losses = []
        for batch_idx, (image, _) in enumerate(train_loader):
            image = Variable(image, requires_grad=False)
            if use_cuda:
                image = image.cuda()
            optimizer.zero_grad()
            flat_image = image.view(-1, int(np.prod(image.size()[1:])))
            recon, _ = model(flat_image)
            loss = F.binary_cross_entropy(input=recon, target=flat_image)
            loss.backward()
            optimizer.step()
            train_losses.extend(loss.data)
        print("Train Loss", np.average(train_losses))
        example, _ = train_loader.dataset[0]
        example = Variable(example, requires_grad=False)
        if use_cuda:
            example = example.cuda()
        flat_example = example.view(1, 784)
        example_recon, example_noisy = model(flat_example)
        im = transforms.ToPILImage()(flat_example.view(1,28,28).cpu().data)
        im.save('{!s}_image.png'.format(epoch))
        noisy = transforms.ToPILImage()(example_noisy.view(1, 28, 28).cpu().data)
        noisy.save('{!s}_noisy.png'.format(epoch))
        r_im = transforms.ToPILImage()(example_recon.view(1,28,28).cpu().data)
Ejemplo n.º 20
0
    tic = time.time()
    n_words_proc = 0
    stats = {'SUPER_COSTS': [], 'REBUILD_LOSS': [], 'ENC&DEC_LOSS': []}

    src_enc, tgt_enc, dis_out, src_re_emb, tgt_re_emb, A1, B1 = model(
        src_emb.weight.data, src_adj, tgt_emb.weight.data, tgt_adj, dico,
        params)

    src_vocab = src_emb.weight.shape[0]
    tgt_vocab = tgt_emb.weight.shape[0]
    y = torch.FloatTensor(src_vocab + tgt_vocab).zero_()
    y[:src_vocab] = 1 - params.dis_smooth
    y[src_vocab:] = params.dis_smooth
    y = Variable(y.cuda() if params.cuda else y)

    dis_loss = F.binary_cross_entropy(dis_out, 1 - y)
    rebuild_loss = F.mse_loss(src_emb.weight, src_re_emb) + F.mse_loss(
        tgt_emb.weight, tgt_re_emb)
    super_loss = F.mse_loss(A1, B1)
    # super_loss += F.mse_loss(src_emb.weight, src_re_emb)
    # super_loss += F.mse_loss(tgt_emb.weight, tgt_re_emb)
    stats['SUPER_COSTS'].append(super_loss.item())
    stats['REBUILD_LOSS'].append(rebuild_loss.item())
    stats['ENC&DEC_LOSS'].append(dis_loss.item())

    src_encoder_optimizer.zero_grad()
    tgt_encoder_optimizer.zero_grad()
    src_decoder_optimizer.zero_grad()
    tgt_decoder_optimizer.zero_grad()
    dis_optimizer.zero_grad()
    dis_loss.backward(retain_graph=True)
Ejemplo n.º 21
0
def explainability_regularization_loss(masks):
    loss = 0
    for i, mask in enumerate(masks):
        ones = torch.ones_like(mask)
        loss += F.binary_cross_entropy(mask, ones)
    return loss
Ejemplo n.º 22
0
 def __call__(self, x, y):
     x = Metric.convert_to_tensor(x)
     y = Metric.convert_to_tensor(y)
     self.val = F.binary_cross_entropy(x, y)
     return self.val
ones_label = Variable(torch.ones(mb_size, 1))
zeros_label = Variable(torch.zeros(mb_size, 1))


for it in range(100000):
    # Sample data
    z = Variable(torch.randn(mb_size, Z_dim))
    X, _ = mnist.train.next_batch(mb_size)
    X = Variable(torch.from_numpy(X))

    # Dicriminator forward-loss-backward-update
    G_sample = G(z)
    D_real = D(X)
    D_fake = D(G_sample)

    D_loss_real = nn.binary_cross_entropy(D_real, ones_label)
    D_loss_fake = nn.binary_cross_entropy(D_fake, zeros_label)
    D_loss = D_loss_real + D_loss_fake

    D_loss.backward()
    D_solver.step()

    # Housekeeping - reset gradient
    reset_grad()

    # Generator forward-loss-backward-update
    z = Variable(torch.randn(mb_size, Z_dim))
    G_sample = G(z)
    D_fake = D(G_sample)

    G_loss = nn.binary_cross_entropy(D_fake, ones_label)
Ejemplo n.º 24
0
 def _forward(self, pred, target, weight):
     return F.binary_cross_entropy(pred, target, weight=weight)
 def forward(self, y_pred, y_true, beta):
     with torch.no_grad():
         y_true_updated = (beta*y_true+(1-beta)*y_pred) * y_true
     return F.binary_cross_entropy(y_pred, y_true_updated, reduction='none')
Ejemplo n.º 26
0
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decode(z)
        return x_reconst, mu, log_var


model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    for i, (x, _) in enumerate(data_loader):

        x = x.to(device).view(-1, image_size)
        x_reconst, mu, log_var = model(x)

        reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
        kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

        loss = reconst_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 10 == 0:
            print(
                "Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}"
                .format(epoch + 1, num_epochs, i + 1, len(data_loader),
                        reconst_loss.item(), kl_div.item()))

    with torch.no_grad():
        z = torch.randn(batch_size, z_dim).to(device)
Ejemplo n.º 27
0
    def pretrain(self,
                 dataloader,
                 pre_epoch=50,
                 retrain=False,
                 metric='cross_entropy'):
        if (not os.path.exists('.pretrain/vade_pretrain.wght')
            ) or retrain == True:
            if not os.path.exists('.pretrain/'):
                os.mkdir('.pretrain')
            optimizer = torch.optim.Adam(itertools.chain(self.vade.encoder.parameters(),\
                self.vade.fc_mu.parameters(),\
                    self.vade.fc1.parameters(),\
                        self.vade.decoder.parameters()))

            print('Start pretraining ...')
            self.vade.train()
            for epoch in tqdm(range(pre_epoch)):
                total_loss = []
                n_instances = 0
                for data in dataloader:
                    optimizer.zero_grad()
                    txts, bows = data
                    bows = bows.to(self.device)
                    bows_recon, _mus, _log_vars = self.vade(
                        bows,
                        collate_fn=lambda x: F.softmax(x, dim=1),
                        isPretrain=True)
                    #bows_recon,_mus,_log_vars = self.vade(bows,collate_fn=None,isPretrain=True)
                    if metric == 'cross_entropy':
                        logsoftmax = torch.log_softmax(bows_recon, dim=1)
                        rec_loss = -1.0 * torch.sum(bows * logsoftmax)
                        rec_loss /= len(bows)
                    elif metric == 'bce_softmax':
                        rec_loss = F.binary_cross_entropy(torch.softmax(
                            bows_recon, dim=1),
                                                          bows,
                                                          reduction='sum')
                    elif metric == 'bce_sigmoid':
                        rec_loss = F.binary_cross_entropy(
                            torch.sigmoid(bows_recon), bows, reduction='sum')
                    else:
                        rec_loss = nn.MSELoss()(bows_recon, bows)

                    rec_loss.backward()
                    optimizer.step()
                    total_loss.append(rec_loss.item())
                    n_instances += len(bows)
                print(
                    f'Pretrain: epoch:{epoch:03d}\taverage_loss:{sum(total_loss)/n_instances}'
                )
            self.vade.fc_logvar.load_state_dict(self.vade.fc_mu.state_dict())
            print('Initialize GMM parameters ...')
            z_latents = torch.cat([
                self.vade.get_latent(bows.to(self.device))
                for txts, bows in tqdm(dataloader)
            ],
                                  dim=0).detach().cpu().numpy()
            # TBD_corvarance_type
            try:
                self.vade.gmm.fit(z_latents)

                self.vade.pi.data = torch.from_numpy(
                    self.vade.gmm.weights_).to(self.device).float()
                self.vade.mu_c.data = torch.from_numpy(
                    self.vade.gmm.means_).to(self.device).float()
                self.vade.logvar_c.data = torch.log(
                    torch.from_numpy(self.vade.gmm.covariances_)).to(
                        self.device).float()
            except:
                self.vade.mu_c.data = torch.from_numpy(
                    np.random.dirichlet(
                        alpha=1.0 * np.ones(self.vade.n_clusters) /
                        self.vade.n_clusters,
                        size=(self.vade.n_clusters,
                              self.vade.latent_dim))).float().to(self.device)
                self.vade.logvar_c.data = torch.ones(
                    self.vade.n_clusters,
                    self.vade.latent_dim).float().to(self.device)

            torch.save(self.vade.state_dict(), '.pretrain/vade_pretrain.wght')
            print(
                'Store the pretrain weights at dir .pretrain/vade_pretrain.wght'
            )

        else:
            self.vade.load_state_dict(
                torch.load('.pretrain/vade_pretrain.wght'))
Ejemplo n.º 28
0
    def train(self,
              train_data,
              batch_size=256,
              learning_rate=2e-3,
              test_data=None,
              num_epochs=100,
              is_evaluate=False,
              log_every=5,
              beta=1.0,
              gamma=1e7,
              criterion='cross_entropy'):
        self.vade.train()
        self.id2token = {
            v: k
            for k, v in train_data.dictionary.token2id.items()
        }
        data_loader = DataLoader(train_data,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 num_workers=4,
                                 collate_fn=train_data.collate_fn)

        #self.pretrain(data_loader,pre_epoch=30,retrain=True,metric='cross_entropy')
        self.pretrain(data_loader,
                      pre_epoch=30,
                      retrain=True,
                      metric='bce_softmax')

        optimizer = torch.optim.Adam(self.vade.parameters(), lr=learning_rate)
        #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
        trainloss_lst, valloss_lst = [], []
        c_v_lst, c_w2v_lst, c_uci_lst, c_npmi_lst, mimno_tc_lst, td_lst = [], [], [], [], [], []
        for epoch in range(num_epochs):
            epochloss_lst = []
            for iter, data in enumerate(data_loader):
                #optimizer.zero_grad()

                txts, bows = data
                bows = bows.to(self.device)

                bows_recon, mus, log_vars = self.vade(
                    bows,
                    collate_fn=lambda x: F.softmax(x, dim=1),
                    isPretrain=False)
                #bows_recon, mus, log_vars = self.vade(bows,collate_fn=None,isPretrain=False)

                if criterion == 'cross_entropy':
                    logsoftmax = torch.log_softmax(bows_recon, dim=1)
                    rec_loss = -1.0 * torch.sum(bows * logsoftmax)
                    rec_loss /= len(bows)
                elif criterion == 'bce_softmax':
                    rec_loss = F.binary_cross_entropy(torch.softmax(bows_recon,
                                                                    dim=1),
                                                      bows,
                                                      reduction='sum')
                elif criterion == 'bce_sigmoid':
                    rec_loss = F.binary_cross_entropy(
                        torch.sigmoid(bows_recon), bows, reduction='sum')

                kl_div = self.vade.gmm_kl_div(mus, log_vars)
                center_mut_dists = self.vade.mus_mutual_distance()

                loss = rec_loss + kl_div * beta + center_mut_dists * gamma

                optimizer.zero_grad()
                loss.backward()
                #nn.utils.clip_grad_norm_(self.vade.parameters(), max_norm=20, norm_type=2)
                optimizer.step()

                trainloss_lst.append(loss.item() / len(bows))
                epochloss_lst.append(loss.item() / len(bows))

                if (iter + 1) % 10 == 0:
                    print(
                        f'Epoch {(epoch+1):>3d}\tIter {(iter+1):>4d}\tLoss:{loss.item()/len(bows):<.7f}\tRec Loss:{rec_loss.item()/len(bows):<.7f}\tGMM_KL_Div:{kl_div.item()/len(bows):<.7f}\tCenter_Mutual_Distance:{center_mut_dists/(len(bows)*(len(bows)-1))}'
                    )
            #scheduler.step()
            if (epoch + 1) % log_every == 0:
                print(
                    f'Epoch {(epoch+1):>3d}\tLoss:{sum(epochloss_lst)/len(epochloss_lst):<.7f}'
                )
                print('\n'.join([str(lst) for lst in self.show_topic_words()]))
                print('=' * 30)
                smth_pts = smooth_curve(trainloss_lst)
                plt.plot(np.array(range(len(smth_pts))) * log_every, smth_pts)
                plt.xlabel('epochs')
                plt.title('Train Loss')
                plt.savefig('gmntm_trainloss.png')
                if test_data != None:
                    c_v, c_w2v, c_uci, c_npmi, mimno_tc, td = self.evaluate(
                        test_data, calc4each=False)
                    c_v_lst.append(c_v), c_w2v_lst.append(
                        c_w2v), c_uci_lst.append(c_uci), c_npmi_lst.append(
                            c_npmi), mimno_tc_lst.append(
                                mimno_tc), td_lst.append(td)
        scrs = {
            'c_v': c_v_lst,
            'c_w2v': c_w2v_lst,
            'c_uci': c_uci_lst,
            'c_npmi': c_npmi_lst,
            'mimno_tc': mimno_tc_lst,
            'td': td_lst
        }
        '''
        for scr_name,scr_lst in scrs.items():
            plt.cla()
            plt.plot(np.array(range(len(scr_lst)))*log_every,scr_lst)
            plt.savefig(f'wlda_{scr_name}.png')
        '''
        plt.cla()
        for scr_name, scr_lst in scrs.items():
            if scr_name in ['c_v', 'c_w2v', 'td']:
                plt.plot(np.array(range(len(scr_lst))) * log_every,
                         scr_lst,
                         label=scr_name)
        plt.title('Topic Coherence')
        plt.xlabel('epochs')
        plt.legend()
        plt.savefig(f'gmntm_tc_scores.png')
Ejemplo n.º 29
0
def Train():
	"""
	Training netG and netD alternatively.
	Then Show the result by calling Show_Results function.
	"""
	# Lists to keep track of progress
	img_list = []
	G_losses = []
	D_losses = []
	fixed_latent_vectors = torch.randn(64, config.nz, 1, 1, device=device)
	iters = 0
	# alias
	real_label = 1
	fake_label = 0

	print("Starting Training Loop...")
	# For each epoch
	for epoch in range(config.num_epochs):
		# For each batch in the dataloader
		for i, data in enumerate(dataloader, 0):
			##################################
			# Update Discriminator
			# Discriminator Loss: (maximize)
			# log(D(x)) + log(1 - D(G(z)))
			##################################
			netD.zero_grad()
			# Calculate log(D(x))
			real_images = data[0].to(device)  # x_train
			batch_size = real_images.size(0)
			label_real = torch.full((batch_size, ), real_label,
									device=device, dtype=torch.float32)
			pred_real = netD(real_images).view(-1)  # to 1-d tensor
			loss_D_real = F.binary_cross_entropy(pred_real, label_real)
			D_x = pred_real.mean().item()  # for logging stuff
			# Note that BCE Loss is -(y * log(x) + (1-y) * log(1-x))
			# so minimize it is to maximize log(x) when y = 1
			loss_D_real.backward()
			# Calculate log(1 - D(G(z)))
			latent_vectors = torch.randn(
				batch_size, config.nz, 1, 1, device=device)  # 1, 1 for Conv2d
			fake_images = netG(latent_vectors)
			label_fake = torch.full((batch_size, ), fake_label,
									device=device, dtype=torch.float32)
			pred_fake = netD(fake_images).view(-1)
			D_G_z1 = pred_fake.mean().item()
			loss_D_fake = F.binary_cross_entropy(pred_fake, label_fake)
			loss_D_fake.backward()
			# apply gradients
			loss_D = loss_D_real + loss_D_fake
			netD.optimizer.step()

			##################################
			# Update Generator
			# Generator Loss: (maximize)
			# 		log(D(G(z)))
			# not minimizing log(1-D(G(z))) because of
			# providing no sufficient gradients.
			##################################
			netG.zero_grad()
			latent_vectors = torch.randn(
				batch_size, config.nz, 1, 1, device=device)
			fake_images = netG(latent_vectors)
			# we use real label because we want to maximize log(D(G(z)))
			label_real = torch.full((batch_size, ), real_label,
									device=device, dtype=torch.float32)
			pred_fake = netD(fake_images).view(-1)
			D_G_z2 = pred_fake.mean().item()
			loss_G = F.binary_cross_entropy(pred_fake, label_real)
			loss_G.backward()
			# apply gradients
			netG.optimizer.step()

			# Output training stats
			if i % 50 == 0:
				print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\t'
					  'D(x): %.4f\tD(G(z)): %.4f / %.4f'
					  % (epoch, config.num_epochs, i, len(dataloader),
						 loss_D.item(), loss_G.item(), D_x, D_G_z1, D_G_z2))

			# Save losses for plotting later
			G_losses.append(loss_G.item())
			D_losses.append(loss_D.item())

			# Check how the generator is doing by saving G's
			# output on fixed latent vector
			if (iters % 500 == 0) or \
					((epoch == config.num_epochs - 1) and (i == len(dataloader) - 1)):
				with torch.no_grad():
					fake_images = netG(fixed_latent_vectors).detach().cpu()
					img_list.append(
						vutils.make_grid(
							fake_images,
							padding=2,
							normalize=True))

			iters += 1
	
	return G_losses, D_losses, img_list
Ejemplo n.º 30
0
 def adj_recon_loss(self, adj_truth, adj_pred):
     return F.binary_cross_entropy(adj_truth, adj_pred)
Ejemplo n.º 31
0
 def forward(self, pred, label):
     loss = self.weights[0] * F.binary_cross_entropy(pred[0], label)
     for i, x in enumerate(pred[1:]):
         loss += self.weights[i] * F.binary_cross_entropy(x, label)
     return loss
Ejemplo n.º 32
0
zeros_label = Variable(torch.zeros(mb_size, 1))


for it in range(100000):
    # Sample data
    z = Variable(torch.randn(mb_size, Z_dim))
    X, c = mnist.train.next_batch(mb_size)
    X = Variable(torch.from_numpy(X))
    c = Variable(torch.from_numpy(c.astype('float32')))

    # Dicriminator forward-loss-backward-update
    G_sample = G(z, c)
    D_real = D(X, c)
    D_fake = D(G_sample, c)

    D_loss_real = nn.binary_cross_entropy(D_real, ones_label)
    D_loss_fake = nn.binary_cross_entropy(D_fake, zeros_label)
    D_loss = D_loss_real + D_loss_fake

    D_loss.backward()
    D_solver.step()

    # Housekeeping - reset gradient
    reset_grad()

    # Generator forward-loss-backward-update
    z = Variable(torch.randn(mb_size, Z_dim))
    G_sample = G(z, c)
    D_fake = D(G_sample, c)

    G_loss = nn.binary_cross_entropy(D_fake, ones_label)
Ejemplo n.º 33
0
    def _assign(self,
                pred_scores,
                priors,
                decoded_bboxes,
                gt_bboxes,
                gt_labels,
                gt_bboxes_ignore=None,
                eps=1e-7):
        """Assign gt to priors using SimOTA.
        Args:
            pred_scores (Tensor): Classification scores of one image,
                a 2D-Tensor with shape [num_priors, num_classes]
            priors (Tensor): All priors of one image, a 2D-Tensor with shape
                [num_priors, 4] in [cx, xy, stride_w, stride_y] format.
            decoded_bboxes (Tensor): Predicted bboxes, a 2D-Tensor with shape
                [num_priors, 4] in [tl_x, tl_y, br_x, br_y] format.
            gt_bboxes (Tensor): Ground truth bboxes of one image, a 2D-Tensor
                with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format.
            gt_labels (Tensor): Ground truth labels of one image, a Tensor
                with shape [num_gts].
            gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
                labelled as `ignored`, e.g., crowd boxes in COCO.
            eps (float): A value added to the denominator for numerical
                stability. Default 1e-7.
        Returns:
            :obj:`AssignResult`: The assigned result.
        """
        INF = 100000000
        num_gt = gt_bboxes.size(0)
        num_bboxes = decoded_bboxes.size(0)

        # assign 0 by default
        assigned_gt_inds = decoded_bboxes.new_full((num_bboxes, ),
                                                   0,
                                                   dtype=torch.long)
        if num_gt == 0 or num_bboxes == 0:
            # No ground truth or boxes, return empty assignment
            max_overlaps = decoded_bboxes.new_zeros((num_bboxes, ))
            if num_gt == 0:
                # No truth, assign everything to background
                assigned_gt_inds[:] = 0
            if gt_labels is None:
                assigned_labels = None
            else:
                assigned_labels = decoded_bboxes.new_full((num_bboxes, ),
                                                          -1,
                                                          dtype=torch.long)
            return AssignResult(num_gt,
                                assigned_gt_inds,
                                max_overlaps,
                                labels=assigned_labels)

        valid_mask, is_in_boxes_and_center = self.get_in_gt_and_in_center_info(
            priors, gt_bboxes)

        valid_decoded_bbox = decoded_bboxes[valid_mask]
        valid_pred_scores = pred_scores[valid_mask]
        num_valid = valid_decoded_bbox.size(0)

        pairwise_ious = bbox_overlaps(valid_decoded_bbox, gt_bboxes)
        iou_cost = -torch.log(pairwise_ious + eps)

        gt_onehot_label = (F.one_hot(gt_labels.to(
            torch.int64), pred_scores.shape[-1]).float().unsqueeze(0).repeat(
                num_valid, 1, 1))

        valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1)
        cls_cost = F.binary_cross_entropy(valid_pred_scores.sqrt_(),
                                          gt_onehot_label,
                                          reduction='none').sum(-1)

        cost_matrix = (cls_cost * self.cls_weight +
                       iou_cost * self.iou_weight +
                       (~is_in_boxes_and_center) * INF)

        matched_pred_ious, matched_gt_inds = \
            self.dynamic_k_matching(
                cost_matrix, pairwise_ious, num_gt, valid_mask)

        # convert to AssignResult format
        assigned_gt_inds[valid_mask] = matched_gt_inds + 1
        assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
        assigned_labels[valid_mask] = gt_labels[matched_gt_inds].long()
        max_overlaps = assigned_gt_inds.new_full((num_bboxes, ),
                                                 -INF,
                                                 dtype=torch.float32)
        max_overlaps[valid_mask] = matched_pred_ious
        return AssignResult(num_gt,
                            assigned_gt_inds,
                            max_overlaps,
                            labels=assigned_labels)
Ejemplo n.º 34
0
def get_loss(task: Task, logits, label_ids, config, extra_arg=dict({}), input_head = None, **kwargs):
  if task.name.startswith(IR_TASK):
    if label_ids is None:
      label_ids = kwargs.pop("_label_ids")
    if config.regression:
      return F.binary_cross_entropy_with_logits(logits.squeeze(1), label_ids)
    else:
      return F.cross_entropy(logits, label_ids)
  if task.name.startswith(DOCIR_TASK):
    if config.regression:
      return F.binary_cross_entropy_with_logits(logits.squeeze(1), label_ids[0].unsqueeze(0))
    else:
      return F.cross_entropy(logits, label_ids[0].unsqueeze(0))
  if task.name in ["joint_srl"]:
    if isinstance(label_ids, tuple):
      label_ids, pred_span_label, arg_span_label, pos_tag_ids = label_ids
    batch_size = label_ids.size(0)
    key = label_ids[:, :, :4]
    batch_id = torch.arange(0, batch_size).unsqueeze(1).unsqueeze(1).repeat(1, key.size(1), 1).to(key.device)
    expanded_key = torch.cat([batch_id, key], dim=-1)
    v = label_ids[:, :, 4]
    # batch_id = torch.arange(0, v.size(0)).unsqueeze(1).repeat(1, v.size(1)).to(key.device)
    # expanded_v = torch.cat([batch_id.unsqueeze(-1), v.unsqueeze(-1)], dim=-1)
    flatten_key = expanded_key.view(-1, 5)
    flatten_v = v.view(-1)

    (srl_scores, top_pred_spans, top_arg_spans, top_pred_span_mask, top_arg_span_mask,
     pred_span_mention_full_scores, arg_span_mention_full_scores, pos_tag_logits) = logits
    # (batch_size, max_pred_num, 2), (batch_size, max_arg_num, 2)
    max_pred_num = top_pred_spans.size(1)
    max_arg_num = top_arg_spans.size(1)
    expanded_top_pred_spans = top_pred_spans.unsqueeze(2).repeat(1, 1, max_arg_num, 1)
    expanded_top_arg_spans = top_arg_spans.unsqueeze(1).repeat(1, max_pred_num, 1, 1)
    indices = torch.cat([expanded_top_pred_spans, expanded_top_arg_spans], dim=-1)
    batch_id = torch.arange(0, batch_size).unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, *indices.size()[1:3], 1).to(
      indices.device)
    expanded_indices = torch.cat([batch_id, indices], dim=-1)
    flatten_expanded_indices = expanded_indices.view(-1, 5)

    # Generate Loss Mask
    flatten_expanded_top_pred_span_mask = top_pred_span_mask.unsqueeze(2).repeat(1, 1, max_arg_num).view(-1)
    # (batch_size, max_pred_to_keep)
    flatten_expanded_top_arg_span_mask = top_arg_span_mask.unsqueeze(2).repeat(1, max_pred_num, 1).view(-1)
    # (batch_size, max_arg_to_keep)
    merged_mask = flatten_expanded_top_pred_span_mask & flatten_expanded_top_arg_span_mask

    # build dictionary
    d = {}
    for key, value in zip(flatten_key.cpu().numpy(), flatten_v.cpu().numpy()):
      d[tuple(key)] = value

    label_list = []
    for index in flatten_expanded_indices.cpu().numpy():
      label_list.append(d.get(tuple(index), 0))

    # arg_boundary = max(torch.max(top_arg_spans).item(), torch.max(key[:,:,2:]).item()) + 1
    # pred_boundary= max(torch.max(top_pred_spans).item(), torch.max(key[:,:,:2]).item()) + 1
    # size = (batch_size, pred_boundary, pred_boundary, arg_boundary, arg_boundary)
    #
    # dense_label = torch.sparse.LongTensor(flatten_key.t(), flatten_v, size).to(key.device)
    #
    # selected_label = dense_label.masked_select(expanded_indices)

    selected_label = torch.LongTensor(label_list).to(label_ids.device)

    label_loss = F.cross_entropy(srl_scores.view(-1, srl_scores.size(-1))[merged_mask == 1], selected_label[merged_mask == 1])

    # Compute the unary scorer loss
    if hasattr(task.config, "srl_candidate_loss") and task.config.srl_candidate_loss:
      flatten_pred_span_mention_full_scores = F.sigmoid(pred_span_mention_full_scores).view(-1)
      flatten_arg_span_mention_full_scores = F.sigmoid(arg_span_mention_full_scores).view(-1)
      flatten_pred_span_label = pred_span_label.view(-1).float()
      flatten_arg_span_label = arg_span_label.view(-1).float()
      srl_pred_candidate_loss = F.binary_cross_entropy(flatten_pred_span_mention_full_scores, flatten_pred_span_label)
      srl_arg_candidate_loss = F.binary_cross_entropy(flatten_arg_span_mention_full_scores, flatten_arg_span_label)
      candidate_loss = srl_pred_candidate_loss + srl_arg_candidate_loss
      return candidate_loss + label_loss

    if hasattr(task.config, "srl_compute_pos_tag_loss") and task.config.srl_compute_pos_tag_loss:
      active_loss = input_head[:, :pos_tag_logits.size(1)].contiguous().view(-1) == 1
      # I use pos_tag_logits.size(1) to get the label length. It is OK to filter extra things
      active_pos_tag_logits = pos_tag_logits.view(-1, pos_tag_logits.size(-1))[active_loss]
      active_labels = pos_tag_ids[:, :pos_tag_logits.size(1)].contiguous().view(-1)[active_loss]
      loss = F.cross_entropy(active_pos_tag_logits, active_labels)
      label_loss += loss

    return label_loss
  elif task.name in [NER_TASK, POS_TASK, PIPE_SRL_TASK, PREDICATE_DETECTION_TASK]:
    if input_head is not None:
      # Use the BERT based model
      active_loss = input_head[:, :logits.size(1)].contiguous().view(-1) == 1
      # I use logits.size(1) to get the label length. It is OK to filter extra things
    else:
      # create a mask based on the sequence length
      active_loss = kwargs.pop("_input_token_mask").view(-1)
      # We need to assign the value of _label_ids to label_ids
      label_ids = kwargs.pop("_label_ids")
    active_logits = logits.view(-1, logits.size(-1))[active_loss]
    active_labels = label_ids[:, :logits.size(1)].contiguous().view(-1)[active_loss]
    loss = F.cross_entropy(active_logits, active_labels)

    return loss
  elif task.name in [ENTITY_TYPE_CLASSIFICATION]:
    loss = F.binary_cross_entropy_with_logits(logits, label_ids.float())
    return loss

  elif task.name in [PARALLEL_TEACHER_STUDENT_TASK]:
    active_loss = kwargs.pop("mask")
    if active_loss is not None:
      active_loss = kwargs.pop("mask").view(-1)
      target = kwargs.pop("target")
      active_logits = F.softmax(logits.view(-1, logits.size(-1)), -1)[active_loss]
      active_target = F.softmax(target.view(-1, target.size(-1)), -1)[active_loss]
      loss = F.kl_div(active_logits.log(), active_target)
    else:
      target = kwargs.pop("target")
      if task.config.use_cosine_loss:
        loss = (1 - F.cosine_similarity(logits, target)).sum(-1) / logits.size(0)
      else:
        loss = F.l1_loss(logits, target)
    return loss
  elif task.name in [MIXSENT_TASK]:
    target = kwargs.pop("target")
    # loss = F.mse_loss(logits, target)
    active_logits = F.softmax(logits.view(-1, logits.size(-1)), -1)
    active_target = F.softmax(target.view(-1, target.size(-1)), -1)
    loss = F.kl_div(active_logits.log(), active_target, reduction="batchmean")
    return loss
  else:
    span_boundary, logits = logits
    return F.cross_entropy(logits.view(-1, logits.size(-1)), label_ids.view(-1))
Ejemplo n.º 35
0
import torch.nn.functional as F
import torch.optim as optim
torch.manual_seed(1)
x_data = [[1, 2], [2, 3], [3, 1], [4, 3], [5, 3], [6, 2]]
y_data = [[0], [0], [0], [1], [1], [1]]
x_train = torch.FloatTensor(x_data)
y_train = torch.FloatTensor(y_data)

model = nn.Sequential(nn.Linear(2, 1), nn.Sigmoid())

optimizer = optim.SGD(model.parameters(), lr=1)

epochs = 1000
for epoch in range(epochs + 1):
    h = model(x_train)
    cost = F.binary_cross_entropy(h, y_train)

    optimizer.zero_grad()
    cost.backward()
    optimizer.step()

    if epoch % 10 == 0:
        prediction = h >= torch.FloatTensor([0.5])
        correct_prediction = prediction.float() == y_train
        accuracy = correct_prediction.sum().item() / len(correct_prediction)
        print('Epoch: {:4d}/{}, Cost: {:.6f} Accuracy {:2.2f}%'.format(
            epoch, epochs, cost.item(), accuracy * 100))

#로지스틱 회귀는 인공 신경망으로 간주할 수 있다.
#로지스틱 회귀를 식으로 표현하자면 H(x) = sigmoid(x1w1 + x2w2 + b)이다.
def train(model, env, args):
    #################################### PLOT ###################################################
    STEPS = 10
    LAMBDA = 0.99
    vis = visdom.Visdom(env=args.name+'[{}]'.format(args.phrase))
    pre_per_replay = [[] for _ in range(args.n_replays)]
    gt_per_replay = [[] for _ in range(args.n_replays)]
    acc = None
    win = vis.line(X=np.zeros(1), Y=np.zeros(1))
    loss_win = vis.line(X=np.zeros(1), Y=np.zeros(1))

    #################################### TRAIN ######################################################
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    gpu_id = args.gpu_id
    with torch.cuda.device(gpu_id):
        model = model.cuda() if gpu_id >= 0 else model
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    epoch = 0
    save = args.save_intervel
    env_return = env.step()
    if env_return is not None:
        (states_S, states_G, rewards), require_init = env_return
    with torch.cuda.device(gpu_id):
        states_S = torch.from_numpy(states_S).float()
        states_G = torch.from_numpy(states_G).float()
        rewards = torch.from_numpy(rewards).float()
        if gpu_id >= 0:
            states_S = states_S.cuda()
            states_G = states_G.cuda()
            rewards = rewards.cuda()

    while True:
        values = model(Variable(states_S), Variable(states_G), require_init)

        value_loss = 0
        for value, reward in zip(values, rewards):
            value_loss = value_loss + F.binary_cross_entropy(value, Variable(reward))

        model.zero_grad()
        value_loss.backward()
        optimizer.step()
        model.detach()

        if env.epoch > epoch:
            epoch = env.epoch
            for p in optimizer.param_groups:
                p['lr'] *= 0.5

        ############################ PLOT ##########################################
        vis.updateTrace(X=np.asarray([env.step_count()]),
                        Y=np.asarray(value_loss.data.cpu().numpy()),
                        win=loss_win,
                        name='value')

        values_np = np.swapaxes(np.asarray([value.data.cpu().numpy() for value in values]), 0, 1)
        rewards_np = np.swapaxes(rewards.cpu().numpy(), 0, 1)

        for idx, (value, reward, init) in enumerate(zip(values_np, rewards_np, require_init)):
            if init and len(pre_per_replay[idx]) > 0:
                pre_per_replay[idx] = np.asarray(pre_per_replay[idx], dtype=np.uint8)
                gt_per_replay[idx] = np.asarray(gt_per_replay[idx], dtype=np.uint8)

                step = len(pre_per_replay[idx]) // STEPS
                if step > 0:
                    acc_tmp = []
                    for s in range(STEPS):
                        value_pre = pre_per_replay[idx][s*step:(s+1)*step]
                        value_gt = gt_per_replay[idx][s*step:(s+1)*step]
                        acc_tmp.append(np.mean(value_pre == value_gt))

                    acc_tmp = np.asarray(acc_tmp)
                    if acc is None:
                        acc = acc_tmp
                    else:
                        acc = LAMBDA * acc + (1-LAMBDA) * acc_tmp

                    if acc is None:
                        continue
                    for s in range(STEPS):
                        vis.updateTrace(X=np.asarray([env.step_count()]),
                                        Y=np.asarray([acc[s]]),
                                        win=win,
                                        name='{}[{}%~{}%]'.format('value', s*10, (s+1)*10))
                    vis.updateTrace(X=np.asarray([env.step_count()]),
                                    Y=np.asarray([np.mean(acc)]),
                                    win=win,
                                    name='value[TOTAL]')

                pre_per_replay[idx] = []
                gt_per_replay[idx] = []

            pre_per_replay[idx].append(int(value[-1] >= 0.5))
            gt_per_replay[idx].append(int(reward[-1]))

        ####################### NEXT BATCH ###################################
        env_return = env.step()
        if env_return is not None:
            (raw_states_S, raw_states_G, raw_rewards), require_init = env_return
            states_S = states_S.copy_(torch.from_numpy(raw_states_S).float())
            states_G = states_G.copy_(torch.from_numpy(raw_states_G).float())
            rewards = rewards.copy_(torch.from_numpy(raw_rewards).float())

        if env.step_count() > save or env_return is None:
            save = env.step_count()+args.save_intervel
            torch.save(model.state_dict(),
                       os.path.join(args.model_path, 'model_iter_{}.pth'.format(env.step_count())))
            torch.save(model.state_dict(), os.path.join(args.model_path, 'model_latest.pth'))
        if env_return is None:
            env.close()
            break
Ejemplo n.º 37
0
 def cross_entropy(self, x, y):
     return F.binary_cross_entropy(input=x, target=y, reduction='sum')
Ejemplo n.º 38
0
    def loss(self, outputs, labels, phase):
        self.loss_cnt += labels['boxes'].shape[0]
        pred_size = eval(f'self.p{phase}_size')
        # calculate bbox loss
        # of shape (batch, time, #obj, 4)
        loss = (outputs['boxes'] - labels['boxes'])**2
        # take weighted sum over axis 2 (objs dim) since some index are not valid
        valid = labels['valid'][:, None, :, None]
        loss = loss * valid
        loss = loss.sum(2) / valid.sum(2)
        loss *= self.position_loss_weight

        for i in range(pred_size):
            self.box_p_step_losses[i] += loss[:, i, :2].sum().item()
            self.box_s_step_losses[i] += loss[:, i, 2:].sum().item()

        self.losses['p_1'] = float(
            np.mean(self.box_p_step_losses[:self.ptrain_size]))
        self.losses['p_2'] = float(np.mean(self.box_p_step_losses[self.ptrain_size:])) \
            if self.ptrain_size < self.ptest_size else 0
        self.losses['s_1'] = float(
            np.mean(self.box_s_step_losses[:self.ptrain_size]))
        self.losses['s_2'] = float(np.mean(self.box_s_step_losses[self.ptrain_size:])) \
            if self.ptrain_size < self.ptest_size else 0

        mask_loss = 0
        if C.RIN.MASK_LOSS_WEIGHT > 0:
            # of shape (batch, time, #obj, m_sz, m_sz)
            mask_loss_ = F.binary_cross_entropy(outputs['masks'],
                                                labels['masks'],
                                                reduction='none')
            mask_loss = mask_loss_.mean((3, 4))
            valid = labels['valid'][:, None, :]
            mask_loss = mask_loss * valid
            mask_loss = mask_loss.sum(2) / valid.sum(2)

            for i in range(pred_size):
                self.masks_step_losses[i] += mask_loss[:, i].sum().item()

            m1_loss = self.masks_step_losses[:self.ptrain_size]
            m2_loss = self.masks_step_losses[self.ptrain_size:]
            self.losses['m_1'] = np.mean(m1_loss)
            self.losses['m_2'] = np.mean(
                m2_loss) if self.ptrain_size < self.ptest_size else 0

            mask_loss = mask_loss.mean(0)
            init_tau = C.RIN.DISCOUNT_TAU**(1 / self.ptrain_size)
            tau = init_tau + (self.iterations / self.max_iters) * (1 -
                                                                   init_tau)
            tau = torch.pow(
                tau, torch.arange(pred_size,
                                  out=torch.FloatTensor()))[:, None].to('cuda')
            mask_loss = ((mask_loss * tau) /
                         tau.sum(axis=0, keepdims=True)).sum()
            mask_loss = mask_loss * C.RIN.MASK_LOSS_WEIGHT

        seq_loss = 0
        if C.RIN.SEQ_CLS_LOSS_WEIGHT > 0:
            seq_loss = F.binary_cross_entropy(outputs['score'],
                                              labels['seq_l'],
                                              reduction='none')
            self.losses['seq'] += seq_loss.sum().item()
            seq_loss = seq_loss.mean() * C.RIN.SEQ_CLS_LOSS_WEIGHT

        kl_loss = 0
        if C.RIN.VAE and phase == 'train':
            kl_loss = outputs['kl']
            self.losses['kl'] += kl_loss.sum().item()
            kl_loss = C.RIN.VAE_KL_LOSS_WEIGHT * kl_loss.sum()

        # no need to do precise batch statistics, just do mean for backward gradient
        loss = loss.mean(0)
        init_tau = C.RIN.DISCOUNT_TAU**(1 / self.ptrain_size)
        tau = init_tau + (self.iterations / self.max_iters) * (1 - init_tau)
        tau = torch.pow(tau,
                        torch.arange(pred_size,
                                     out=torch.FloatTensor()))[:,
                                                               None].to('cuda')
        loss = ((loss * tau) / tau.sum(axis=0, keepdims=True)).sum()
        loss = loss + mask_loss + kl_loss + seq_loss

        # **************************************************************************************************************************** #
        # INDICATOR LOSS
        # **************************************************************************************************************************** #
        gt, pred = labels['gt_indicators'], outputs['pred_indicators']

        # # bce loss
        # valid1 = labels['valid'][:, None, :, None]    # (b, 1, 6, 1)
        # valid2 = labels['valid'][:, None, None, :]    # (b, 1, 1, 6)
        # ind_loss = self.indicator_criterion_bce(pred, gt)    # (b, 5, 6, 6)
        # ind_loss = ind_loss * valid1 * valid2    # mask out invalid rows & columns
        # ind_loss = ind_loss.sum(2) / valid1.sum(2)        # loss is mean of all valid objects (b, 5)

        # cross entropy loss
        valid = labels['valid'][:, None, :]  # (b, 1, 6)
        ind_loss = self.indicator_criterion_cross(
            pred.reshape(-1, 6),
            torch.argmax(gt, 3).flatten()).reshape(-1, pred_size,
                                                   gt.shape[2])  # (b, 5, 6)
        ind_loss = ind_loss * valid  # mask out invalid rows (objects)
        ind_loss = ind_loss.sum(2) / valid.sum(
            2)  # loss is mean of all valid objects (b, 5)

        self.loss_ind = ((ind_loss.mean(0) * tau) /
                         tau.sum(axis=0, keepdims=True)).sum()
        loss += self.ind * self.loss_ind
        # **************************************************************************************************************************** #

        return loss
Ejemplo n.º 39
0
def bce_loss(y_pred, y_true):
    BCE = F.binary_cross_entropy(y_pred, y_true.view(-1, 784), size_average=False)
    return BCE
Ejemplo n.º 40
0
    def test_focalloss(self):
        """
        Test some predefines focal loss values
        """

        from delira.training.losses import BCEFocalLossLogitPyTorch, \
            BCEFocalLossPyTorch
        import torch.nn as nn
        import torch
        import torch.nn.functional as F

        # examples
        ########################################################################
        # binary values
        p = torch.Tensor([[0, 0.2, 0.5, 1.0], [0, 0.2, 0.5, 1.0]])
        t = torch.Tensor([[0, 0, 0, 0], [1, 1, 1, 1]])
        p_l = torch.Tensor([[-2, -1, 0, 2], [-2, -1, 0, 1]])

        ########################################################################
        # params
        gamma = 2
        alpha = 0.25
        eps = 1e-8

        ########################################################################
        # compute targets
        # target for focal loss
        p_t = p * t + (1 - p) * (1 - t)
        alpha_t = torch.Tensor([alpha]).expand_as(t) * t + \
            (1 - t) * (1 - torch.Tensor([alpha]).expand_as(t))
        w = alpha_t * (1 - p_t).pow(torch.Tensor([gamma]))
        fc_value = F.binary_cross_entropy(p, t, w, reduction='none')

        # target for focal loss with logit
        p_tmp = torch.sigmoid(p_l)
        p_t = p_tmp * t + (1 - p_tmp) * (1 - t)
        alpha_t = torch.Tensor([alpha]).expand_as(t) * t + \
            (1 - t) * (1 - torch.Tensor([alpha]).expand_as(t))
        w = alpha_t * (1 - p_t).pow(torch.Tensor([gamma]))

        fc_value_logit = \
            F.binary_cross_entropy_with_logits(p_l, t, w, reduction='none')

        ########################################################################
        # test against BCE and CE =>focal loss with gamma=0, alpha=None
        # test against binary_cross_entropy
        bce = nn.BCELoss(reduction='none')
        focal = BCEFocalLossPyTorch(alpha=None, gamma=0, reduction='none')
        bce_loss = bce(p, t)
        focal_loss = focal(p, t)

        self.assertTrue((torch.abs(bce_loss - focal_loss) < eps).all())

        # test against binary_cross_entropy with logit
        bce = nn.BCEWithLogitsLoss()
        focal = BCEFocalLossLogitPyTorch(alpha=None, gamma=0)
        bce_loss = bce(p_l, t)
        focal_loss = focal(p_l, t)
        self.assertTrue((torch.abs(bce_loss - focal_loss) < eps).all())

        ########################################################################
        # test focal loss with pre computed values
        # test focal loss binary (values manually pre computed)
        focal = BCEFocalLossPyTorch(gamma=gamma, alpha=alpha, reduction='none')
        focal_loss = focal(p, t)
        self.assertTrue((torch.abs(fc_value - focal_loss) < eps).all())

        # test focal loss binary with logit (values manually pre computed)
        # Note that now p_l is used as prediction
        focal = BCEFocalLossLogitPyTorch(gamma=gamma,
                                         alpha=alpha,
                                         reduction='none')
        focal_loss = focal(p_l, t)
        self.assertTrue((torch.abs(fc_value_logit - focal_loss) < eps).all())

        ########################################################################
        # test if backward function works
        p.requires_grad = True
        focal = BCEFocalLossPyTorch(gamma=gamma, alpha=alpha)
        focal_loss = focal(p, t)
        try:
            focal_loss.backward()
        except:
            self.assertTrue(False, "Backward function failed for focal loss")

        p_l.requires_grad = True
        focal = BCEFocalLossLogitPyTorch(gamma=gamma, alpha=alpha)
        focal_loss = focal(p_l, t)
        try:
            focal_loss.backward()
        except:
            self.assertTrue(
                False, "Backward function failed for focal loss with logits")
Ejemplo n.º 41
0
 def training_step(self, batch, batch_idx):
     x, y = batch
     y_pred = self(x)
     loss = F.binary_cross_entropy(y_pred, y)
     return loss
Ejemplo n.º 42
0
def train(train_loader, networks, optimizers, epoch, args, is_main=False):
    am_loss_g = AverageMeter()
    am_loss_d = AverageMeter()

    am_mean_r = AverageMeter()
    am_mean_f = AverageMeter()

    networks['G'].train()
    networks['D'].train()

    ones = torch.ones(args.batch_size, 1)
    zeros = torch.zeros(args.batch_size, 1)

    if args.gpu is not None:
        ones = ones.cuda(args.gpu, non_blocking=True)
        zeros = zeros.cuda(args.gpu, non_blocking=True)

    else:
        ones = ones.cuda()
        zeros = zeros.cuda()

    print("", end="", flush=True)
    train_it = iter(train_loader)
    t_train = tqdm.trange(0, args.steps, disable=not is_main)
    for t in t_train:
        am_loss_g.reset()
        am_loss_d.reset()
        am_mean_r.reset()
        am_mean_f.reset()
        for i in range(args.ttur_d):
            try:
                x_real = next(train_it)
            except StopIteration:
                train_it = iter(train_loader)
                x_real = next(train_it)

            z_input = torch.randn(args.batch_size, args.latent_size)
            if args.gpu is not None:
                x_real = x_real.cuda(args.gpu, non_blocking=True)
                z_input = z_input.cuda(args.gpu, non_blocking=True)
            else:
                x_real = x_real.cuda()
                z_input = z_input.cuda()

            x_fake = networks['G'](z_input)

            if i == 0:
                # G update
                optimizers['G'].zero_grad()

                # G forward
                logit_d_fake = networks['D'](x_fake)
                loss_g = F.binary_cross_entropy(logit_d_fake, ones)

                # G backward
                loss_g.backward()
                optimizers['G'].step()
                if is_main:
                    accumulate(networks['G_running'], networks['G'].module)

                # AM update
                am_loss_g.update(loss_g.item(), x_real.size(0))

            # D update
            optimizers['D'].zero_grad()
            x_real.requires_grad_()

            # D real forward
            logit_d_real = networks['D'](x_real)
            loss_d_real = F.binary_cross_entropy(logit_d_real, ones)

            # D real regularization - 0-GP
            loss_gp = 5.0 * compute_zero_gp(logit_d_real, x_real).mean()

            # D fake forward
            logit_d_fake = networks['D'](x_fake.detach())
            loss_d_fake = F.binary_cross_entropy(logit_d_fake, zeros)

            # D step
            loss_d = loss_d_real + loss_d_fake + loss_gp
            loss_d.backward()
            optimizers['D'].step()

            # AM update
            am_loss_d.update(loss_d.item(), x_real.size(0))
            am_mean_r.update(logit_d_real.mean().item(), x_real.size(0))
            am_mean_f.update(logit_d_fake.mean().item(), x_real.size(0))

        if t % args.log_step == 0 and is_main:
            t_train.set_description('Epoch: [{}/{}], '
                                    'Loss: '
                                    'D[{loss_d.avg:.3f}] '
                                    'G[{loss_g.avg:.3f}] '
                                    'Fm[{mean_f.avg:.3f}] '
                                    'Rm[{mean_r.avg:.3f}]'
                                    ''.format(epoch,
                                              args.epochs,
                                              loss_d=am_loss_d,
                                              loss_g=am_loss_g,
                                              mean_f=am_mean_f,
                                              mean_r=am_mean_r))
Ejemplo n.º 43
0
def loss_function(recon_x, x):
    BCE = F.binary_cross_entropy(recon_x.view(-1, 328 * 256 * 3),
                                 x.view(-1, 328 * 256 * 3), size_average=False)
    return BCE
Ejemplo n.º 44
0
 def adversarial_loss(self, y_hat, y):
     return F.binary_cross_entropy(y_hat, y)
Ejemplo n.º 45
0
    def train_model(self, X_pos, X_neg=None):
        self.network_module.ae_module.mode = 'train'
        self.network_module.ae_module.train()
        self.network_module.ae_loss_module.train()
        self.network_module.mode = 'train'
        learning_rate = self.LR

        parameters = list(self.network_module.parameters())
        self.optimizer = torch.optim.Adam(parameters, lr=learning_rate)

        log_interval = self.log_interval
        num_neg_samples = X_neg.shape[1]
        losses = []
        bs = self.batch_size

        burn_in_epochs = self.burn_in_epochs
        self.num_epochs = self.burn_in_epochs + self.phase_2_epochs + self.phase_3_epochs
        t_max = self.num_epochs
        t_start = burn_in_epochs
        last_phase = False
        epoch_phase = 0
        epoch_losses_phase_3 = []
        import pandas as pd

        df_phase_3_losses = pd.DataFrame(columns=['epoch', 'loss'])

        for epoch in tqdm(range(1, self.num_epochs + 1)):
            t = epoch
            if epoch < burn_in_epochs:
                epoch_phase = 1

            elif burn_in_epochs < epoch <= self.burn_in_epochs + self.phase_2_epochs:
                epoch_phase = 2
            else:
                epoch_phase = 3

            if epoch_phase == 1:
                lambda_1 = 1
                gamma = 1
                lambda_2 = 1
            elif epoch_phase == 2:
                lambda_1 = np.exp(-t_start * (t - t_start) / t_start)
                lambda_2 = 1
            elif epoch_phase == 3:
                lambda_1 = 0.001
                lambda_2 = 1

            if epoch > burn_in_epochs:
                gamma = min(1 + np.exp((t - t_start) / (t_max - t_start) + 1),
                            self.max_gamma)

            # At start of new phase reset optimizer
            if epoch == self.burn_in_epochs + self.phase_2_epochs:
                parameters = list(self.network_module.score_layer.parameters())
                self.optimizer = torch.optim.Adam(parameters, lr=learning_rate)

            epoch_losses = []
            num_batches = X_pos.shape[0] // bs + 1
            idx = np.arange(X_pos.shape[0])
            np.random.shuffle(idx)
            X_P = X_pos[idx]
            X_N = X_neg[idx]

            X_P = FT(X_P).to(self.device)
            X_N = FT(X_N).to(self.device)
            b_epoch_losses_phase_3 = []
            for b in range(num_batches):

                # self.network_module.zero_grad()
                self.optimizer.zero_grad()

                _x_p = X_P[b * bs:(b + 1) * bs]
                _x_n = X_N[b * bs:(b + 1) * bs]

                # Positive sample
                batch_loss_pos, sample_score_pos = self.network_module(
                    _x_p, sample_type='pos')
                batch_loss_neg = []
                sample_scores_neg = []

                # Split _x_n into num_neg_samples parts along dim 1
                #  ns  * [ batch, 1, _ ]
                x_neg = torch.chunk(_x_n, num_neg_samples, dim=1)
                # for negative samples at index i

                for ns in x_neg:
                    ns = ns.squeeze(1)
                    n_sample_loss, n_sample_score = self.network_module(
                        ns, sample_type='neg')

                    sample_scores_neg.append(n_sample_score)

                # Shape : [ batch, num_neg_samples ]
                sample_scores_neg = torch.cat(sample_scores_neg, dim=1)
                sample_scores_neg = sample_scores_neg.squeeze(1)

                # ========================
                # Loss 2 should be the scoring function
                # sample_score is of value between 0 and 1
                # Since we model last layer as logistic reg
                # ========================

                data_size = _x_p.shape[0]
                num_neg_samples = sample_scores_neg.shape[1]

                _scores = torch.cat([sample_score_pos, sample_scores_neg],
                                    dim=1)

                targets = torch.cat([
                    torch.ones([data_size, 1]),
                    torch.zeros([data_size, num_neg_samples])
                ],
                                    dim=-1).to(self.device)

                loss_2 = F.binary_cross_entropy(_scores,
                                                targets,
                                                reduction='none')
                loss_2_1 = gamma * loss_2[:, 0]  # positive
                loss_2_0 = loss_2[:, 1:]  # negatives
                loss_2 = loss_2_1 + torch.mean(loss_2_0, dim=1, keepdim=False)
                loss_2 = torch.mean(loss_2, dim=0, keepdims=False)

                # Standard AE loss

                loss_1 = torch.mean(batch_loss_pos, dim=0, keepdim=False)

                # batch_loss_neg = torch.clamp(batch_loss_neg, 0.0001,1)
                # loss_3 = torch.sum(batch_loss_neg, dim=1, keepdim=False)
                # loss_3 = torch.mean(loss_3, dim=0, keepdim=False)
                score_loss = lambda_2 * loss_2

                if epoch_phase == 1:
                    batch_loss = lambda_1 * loss_1

                elif epoch_phase == 2:
                    batch_loss = lambda_1 * loss_1
                    if b % 2 == 0:
                        batch_loss = batch_loss + score_loss
                elif epoch_phase == 3:
                    batch_loss = score_loss

                # ------------------------------------------
                # Record the estimator loss in last phase
                # -------------------------------------------
                if epoch_phase == 3:
                    b_epoch_losses_phase_3.append(
                        score_loss.clone().cpu().data.numpy())
                # ====================
                # Clip Gradient
                # ====================

                batch_loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    self.network_module.parameters(), 2)
                self.optimizer.step()

                loss_value = batch_loss.clone().cpu().data.numpy()
                losses.append(loss_value)
                if b % log_interval == 0:
                    print(
                        ' Epoch {} Batch {} Loss {:.4f} || AE {:.4f} {:.4f} '.
                        format(epoch, b, batch_loss, loss_1, loss_2))

                epoch_losses.append(loss_value)

            mean_epoch_loss = np.mean(epoch_losses)
            print('Epoch loss ::', mean_epoch_loss)

            # ------------------
            # Save checkpoint
            # ------------------
            if epoch_phase == 3:
                epoch_losses_phase_3.append(np.mean(b_epoch_losses_phase_3))
                _path = os.path.join(self.chkpt_folder,
                                     'epoch_{}'.format(epoch))
                torch.save(self.network_module.state_dict(), _path)
                df_phase_3_losses = df_phase_3_losses.append(
                    {
                        'epoch': epoch,
                        'loss': np.mean(b_epoch_losses_phase_3)
                    },
                    ignore_index=True)

        # ===========
        # Find epoch with lowest loss
        # ===========
        best_epoch = self.find_lowest_loss_epoch(df_phase_3_losses)
        _path = os.path.join(self.chkpt_folder,
                             'epoch_{}'.format(int(best_epoch)))

        self.network_module.load_state_dict(torch.load(_path))
        self.network_module.mode = 'test'
        return losses, epoch_losses_phase_3
Ejemplo n.º 46
0
# fixed inputs for debugging
fixed_z = to_var(torch.randn(100, 20))
fixed_x, _ = next(data_iter)
torchvision.utils.save_image(fixed_x.data.cpu(), './data/real_images.png')
fixed_x = to_var(fixed_x.view(fixed_x.size(0), -1))

for epoch in range(50):
    for i, (images, _) in enumerate(data_loader):
        
        images = to_var(images.view(images.size(0), -1))
        out, mu, log_var = vae(images)
        
        # Compute reconstruction loss and kl divergence
        # For kl_divergence, see Appendix B in the paper or http://yunjey47.tistory.com/43
        reconst_loss = F.binary_cross_entropy(out, images, size_average=False)
        kl_divergence = torch.sum(0.5 * (mu**2 + torch.exp(log_var) - log_var -1))
        
        # Backprop + Optimize
        total_loss = reconst_loss + kl_divergence
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        if i % 100 == 0:
            print ("Epoch[%d/%d], Step [%d/%d], Total Loss: %.4f, "
                   "Reconst Loss: %.4f, KL Div: %.7f" 
                   %(epoch+1, 50, i+1, iter_per_epoch, total_loss.data[0], 
                     reconst_loss.data[0], kl_divergence.data[0]))
    
    # Save the reconstructed images
Ejemplo n.º 47
0
def loss_vae(input, output, mu, log_var):
    liklihood = F.binary_cross_entropy(output, input.view(-1, 784), reduction='sum')
    kl = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

    return liklihood + kl, liklihood, kl
            sequence_batch = Variable(sequence_batch, requires_grad=False)
            if use_cuda:
                sequence_batch = sequence_batch.cuda()

            batch_size = sequence_batch.size()[0]
            sequence_len = sequence_batch.size()[1]
            rest = int(np.prod(sequence_batch.size()[2:]))
            flat_sequence_batch = sequence_batch.view(batch_size * sequence_len, rest)
            # break up this potentially large batch into nicer small ones for gsn
            for batch_idx in range(int(flat_sequence_batch.size()[0] / 32)):
                x = flat_sequence_batch[batch_idx * 32:(batch_idx + 1) * 32]
                # train the gsn!
                gsn_optimizer.zero_grad()
                regression_optimizer.zero_grad()
                recons, _, _ = model.gsn(x)
                losses = [F.binary_cross_entropy(input=recon, target=x) for recon in recons]
                loss = sum(losses)
                loss.backward()
                torch.nn.utils.clip_grad_norm(model.parameters(), .25)
                gsn_optimizer.step()
                gsn_train_losses.append(losses[-1].data.cpu().numpy()[0])
                accuracies = [F.mse_loss(input=recon, target=x) for recon in recons]
                gsn_train_accuracies.append(np.mean([acc.data.cpu().numpy() for acc in accuracies]))

        print("GSN Train Loss", np.mean(gsn_train_losses))
        print("GSN Train Accuracy", np.mean(gsn_train_accuracies))
        print("GSN Train time", make_time_units_string(time.time() - gsn_start_time))

        ####
        # train the regression step
        ####
Ejemplo n.º 49
0
    times = []
    epochs = 300
    for epoch in range(epochs):
        print("Epoch", epoch)
        model.train()
        train_losses = []
        epoch_start = time.time()
        for batch_idx, (image_batch, _) in enumerate(train_loader):
            image_batch = Variable(image_batch, requires_grad=False)
            if use_cuda:
                image_batch = image_batch.cuda()
            optimizer.zero_grad()
            flat_image_batch = image_batch.view(-1, int(np.prod(image_batch.size()[1:])))
            recons, _, _ = model(flat_image_batch)
            losses = [F.binary_cross_entropy(input=recon, target=flat_image_batch) for recon in recons]
            loss = sum(losses)
            loss.backward()
            optimizer.step()
            train_losses.append(losses[-1].data.numpy())
        print("Train Loss", np.average(train_losses))
        example, _ = train_loader.dataset[0]
        example = Variable(example, requires_grad=False)
        if use_cuda:
            example = example.cuda()
        flat_example = example.view(1, 784)
        example_recons, _, _ = model(flat_example)
        example_recon = example_recons[-1]
        im = transforms.ToPILImage()(flat_example.view(1,28,28).cpu().data)
        im.save('{!s}_image.png'.format(epoch))
        r_im = transforms.ToPILImage()(example_recon.view(1,28,28).cpu().data)
Ejemplo n.º 50
0
        for batch_idx, sequence_batch in enumerate(train_loader):
            sequence_batch = Variable(sequence_batch, requires_grad=False)
            if use_cuda:
                sequence_batch = sequence_batch.cuda()

            batch_size = sequence_batch.size()[0]
            sequence_len = sequence_batch.size()[1]
            rest = int(np.prod(sequence_batch.size()[2:]))
            flat_sequence_batch = sequence_batch.view(batch_size * sequence_len, rest)
            # break up this potentially large batch into nicer small ones for gsn
            for batch_idx in range(int(flat_sequence_batch.size()[0] / 32)):
                x = flat_sequence_batch[batch_idx*32:(batch_idx+1)*32]
                # train the gsn!
                gsn_optimizer.zero_grad()
                recons, _, _ = model.gsn(x)
                losses = [F.binary_cross_entropy(input=recon, target=x) for recon in recons]
                loss = sum(losses)
                loss.backward()
                gsn_optimizer.step()
                gsn_train_losses.append(losses[-1].data.cpu().numpy())

        print("GSN Train Loss", np.average(gsn_train_losses), "took {!s}".format(make_time_units_string(time.time()-gsn_start_time)))
        ####
        # train the regression step
        ####
        model.train()
        regression_train_losses = []
        for batch_idx, sequence_batch in enumerate(train_loader):
            _start = time.time()
            sequence_batch = Variable(sequence_batch, requires_grad=False)
            if use_cuda:
Ejemplo n.º 51
0
 def forward(ctx , y, y_pred, sum_cr, eta, gbest):
     ctx.save_for_backward(y, y_pred)
     ctx.sum_cr = sum_cr
     ctx.eta = eta
     ctx.gbest = gbest
     return F.binary_cross_entropy(y,y_pred)
Ejemplo n.º 52
0
 def dissonance(self, h2_support_output_sigmoid, target_labels):
     cross_entropy_loss = F.binary_cross_entropy(h2_support_output_sigmoid,
                                                 target_labels)
     return cross_entropy_loss
        if is_gpu_mode:
            inputs = Variable(torch.from_numpy(input_img).float().cuda())
            noise_z = Variable(noise_z.cuda())
        else:
            inputs = Variable(torch.from_numpy(input_img).float())
            noise_z = Variable(noise_z)

        # feedforward the inputs. generator
        outputs_gen = gen_model(noise_z)

        # feedforward the inputs. discriminator
        output_disc_real = disc_model(inputs)
        output_disc_fake = disc_model(outputs_gen)

        # loss functions
        loss_real_d = F.binary_cross_entropy(output_disc_real, ones_label)
        loss_fake_d = F.binary_cross_entropy(output_disc_fake, zeros_label)
        loss_disc_total = loss_real_d + loss_fake_d

        loss_gen = F.binary_cross_entropy(output_disc_fake, ones_label)
        
        #loss_disc_total = -torch.mean(torch.log(output_disc_real) + torch.log(1. - output_disc_fake))
        #loss_gen = -torch.mean(torch.log(output_disc_fake))

        # Before the backward pass, use the optimizer object to zero all of the
        # gradients for the variables it will update (which are the learnable weights
        # of the model)
        optimizer_disc.zero_grad()

        # Backward pass: compute gradient of the loss with respect to model parameters
        loss_disc_total.backward(retain_graph = True)
Ejemplo n.º 54
0
def variational_loss(output, data, mean, logvar, beta):
    'sum reconstruction and divergence losses'
    reconstruction = F.binary_cross_entropy(output, data, reduction='sum')
    divergence = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
    return reconstruction + beta * divergence, reconstruction
Ejemplo n.º 55
0
params = [Wxh, bxh, Whz_mu, bhz_mu, Whz_var, bhz_var,
          Wzh, bzh, Whx, bhx]

solver = optim.Adam(params, lr=lr)

for it in range(100000):
    X, _ = mnist.train.next_batch(mb_size)
    X = Variable(torch.from_numpy(X))

    # Forward
    z_mu, z_var = Q(X)
    z = sample_z(z_mu, z_var)
    X_sample = P(z)

    # Loss
    recon_loss = nn.binary_cross_entropy(X_sample, X, size_average=False) / mb_size
    kl_loss = torch.mean(0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1. - z_var, 1))
    loss = recon_loss + kl_loss

    # Backward
    loss.backward()

    # Update
    solver.step()

    # Housekeeping
    for p in params:
        if p.grad is not None:
            data = p.grad.data
            p.grad = Variable(data.new().resize_as_(data).zero_())
Ejemplo n.º 56
0
    def compute_loss(
        predictions,  # a dictionary of results from the Net
        labels,  # a dictionary of labels
        loss_wts=None
    ):  # weights to assign to each head of the network (if it exists)
        """ Compute Net losses (optionally with SMART tags and vendor detection count auxiliary losses).

        Args:
            predictions: A dictionary of results from the Net
            labels: A dictionary of labels
            loss_wts: Weights to assign to each head of the network (if it exists); defaults to
                      {'malware': 1.0, 'count': 0.1, 'tags': 1.0}
        Returns:
            Loss dictionary.
        """

        # if no loss_wts were provided set some default values
        if loss_wts is None:
            loss_wts = {'malware': 1.0, 'count': 0.1, 'tags': 1.0}

        loss_dict = {'total': 0.}  # initialize dictionary of losses

        if 'malware' in labels:  # if the malware head is enabled
            # extract ground truth malware label, convert it to float and allocate it into the selected device
            # (CPU or GPU)
            malware_labels = labels['malware'].float().to(device)

            # get predicted malware label, reshape it to the same shape of malware_labels
            # then calculate binary cross entropy loss with respect to the ground truth malware labels
            malware_loss = F.binary_cross_entropy(
                predictions['malware'].reshape(malware_labels.shape),
                malware_labels)

            # get loss weight (or set to default if not provided)
            weight = loss_wts['malware'] if 'malware' in loss_wts else 1.0

            # copy calculated malware loss into the loss dictionary
            loss_dict['malware'] = deepcopy(malware_loss.item())

            # update total loss
            loss_dict['total'] += malware_loss * weight

        if 'count' in labels:  # if the count head is enabled
            # extract ground truth count, convert it to float and allocate it into the selected device (CPU or GPU)
            count_labels = labels['count'].float().to(device)

            # get predicted count, reshape it to the same shape of count_labels
            # then calculate poisson loss with respect to the ground truth count
            count_loss = torch.nn.PoissonNLLLoss()(
                predictions['count'].reshape(count_labels.shape), count_labels)

            # get loss weight (or set to default if not provided)
            weight = loss_wts['count'] if 'count' in loss_wts else 1.0

            # copy calculated count loss into the loss dictionary
            loss_dict['count'] = deepcopy(count_loss.item())

            # update total loss
            loss_dict['total'] += count_loss * weight

        if 'tags' in labels:  # if the tags (Joint Embedding) head is enabled
            # extract ground truth tags, convert them to float and allocate them into the selected device (CPU or GPU)
            tag_labels = labels['tags'].float().to(device)

            # get similarity score from model prediction
            similarity_score = predictions['similarity']

            # calculate similarity loss
            similarity_loss = F.binary_cross_entropy(
                similarity_score, tag_labels,
                reduction='none').sum(dim=1).mean(dim=0)

            # get loss weight (or set to default if not provided)
            weight = loss_wts['tags'] if 'tags' in loss_wts else 1.0

            # copy calculated tags loss into the loss dictionary
            loss_dict['jointEmbedding'] = deepcopy(similarity_loss.item())

            # update total loss
            loss_dict['total'] += similarity_loss * weight

        return loss_dict  # return the losses
Ejemplo n.º 57
0
def train_casenet(epoch,model,data_loader,optimizer,args):
    model.train()
    if args.freeze_batchnorm:    
        for m in model.modules():
            if isinstance(m, nn.BatchNorm3d):
                m.eval()

    starttime = time.time()
    lr = get_lr(epoch,args)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    loss1Hist = []
    loss2Hist = []
    missHist = []
    lossHist = []
    accHist = []
    lenHist = []
    tpn = 0
    fpn = 0
    fnn = 0
#     weight = torch.from_numpy(np.ones_like(y).float().cuda()
    for i,(x,coord,isnod,y) in enumerate(data_loader):
        if args.debug:
            if i >4:
                break
        coord = Variable(coord).cuda()
        x = Variable(x).cuda()
        xsize = x.size()
        isnod = Variable(isnod).float().cuda()
        ydata = y.numpy()[:,0]
        y = Variable(y).float().cuda()
#         weight = 3*torch.ones(y.size()).float().cuda()
        optimizer.zero_grad()
        nodulePred,casePred,casePred_each = model(x,coord)
        loss2 = binary_cross_entropy(casePred,y[:,0])
        missMask = (casePred_each<args.miss_thresh).float()
        missLoss = -torch.sum(missMask*isnod*torch.log(casePred_each+0.001))/xsize[0]/xsize[1]
        loss = loss2+args.miss_ratio*missLoss
        loss.backward()
        #torch.nn.utils.clip_grad_norm(model.parameters(), 1)

        optimizer.step()
        loss2Hist.append(loss2.data[0])
        missHist.append(missLoss.data[0])
        lenHist.append(len(x))
        outdata = casePred.data.cpu().numpy()

        pred = outdata>0.5
        tpn += np.sum(1==pred[ydata==1])
        fpn += np.sum(1==pred[ydata==0])
        fnn += np.sum(0==pred[ydata==1])
        acc = np.mean(ydata==pred)
        accHist.append(acc)
        
    endtime = time.time()
    lenHist = np.array(lenHist)
    loss2Hist = np.array(loss2Hist)
    lossHist = np.array(lossHist)
    accHist = np.array(accHist)
    
    mean_loss2 = np.sum(loss2Hist*lenHist)/np.sum(lenHist)
    mean_missloss = np.sum(missHist*lenHist)/np.sum(lenHist)
    mean_acc = np.sum(accHist*lenHist)/np.sum(lenHist)
    print('Train, epoch %d, loss2 %.4f, miss loss %.4f, acc %.4f, tpn %d, fpn %d, fnn %d, time %3.2f, lr % .5f '
          %(epoch,mean_loss2,mean_missloss,mean_acc,tpn,fpn, fnn, endtime-starttime,lr))
Ejemplo n.º 58
0
    def forward(self, predictions, targets, masks, num_crowds):
        """Multibox Loss
        Args:
            predictions (tuple): A tuple containing loc preds, conf preds,
            mask preds, and prior boxes from SSD net.
                loc shape: torch.size(batch_size,num_priors,4)
                conf shape: torch.size(batch_size,num_priors,num_classes)
                masks shape: torch.size(batch_size,num_priors,mask_dim)
                priors shape: torch.size(num_priors,4)
                proto* shape: torch.size(batch_size,mask_h,mask_w,mask_dim)

            targets (list<tensor>): Ground truth boxes and labels for a batch,
                shape: [batch_size][num_objs,5] (last idx is the label).

            masks (list<tensor>): Ground truth masks for each object in each image,
                shape: [batch_size][num_objs,im_height,im_width]

            num_crowds (list<int>): Number of crowd annotations per batch. The crowd
                annotations should be the last num_crowds elements of targets and masks.

            * Only if mask_type == lincomb
        """

        loc_data = predictions['loc']
        conf_data = predictions['conf']
        mask_data = predictions['mask']
        priors = predictions['priors']

        if cfg.mask_type == mask_type.lincomb:
            proto_data = predictions['proto']

        score_data = predictions['score'] if cfg.use_mask_scoring else None
        inst_data = predictions['inst'] if cfg.use_instance_coeff else None

        labels = [None] * len(targets)  # Used in sem segm loss

        batch_size = loc_data.size(0)
        num_priors = priors.size(0)
        num_classes = self.num_classes

        # Match priors (default boxes) and ground truth boxes
        # These tensors will be created with the same device as loc_data
        loc_t = loc_data.new(batch_size, num_priors, 4)
        gt_box_t = loc_data.new(batch_size, num_priors, 4)
        conf_t = loc_data.new(batch_size, num_priors).long()
        idx_t = loc_data.new(batch_size, num_priors).long()

        if cfg.use_class_existence_loss:
            class_existence_t = loc_data.new(batch_size, num_classes - 1)

        for idx in range(batch_size):
            truths = targets[idx][:, :-1].data
            labels[idx] = targets[idx][:, -1].data.long()
            if cfg.use_class_existence_loss:
                # Construct a one-hot vector for each object and collapse it into an existence vector with max
                # Also it's fine to include the crowd annotations here
                class_existence_t[idx, :] = torch.eye(
                    num_classes - 1,
                    device=conf_t.get_device())[labels[idx]].max(dim=0)[0]

            # Split the crowd annotations because they come bundled in
            cur_crowds = num_crowds[idx]
            if cur_crowds > 0:
                split = lambda x: (x[-cur_crowds:], x[:-cur_crowds])
                crowd_boxes, truths = split(truths)

                # We don't use the crowd labels or masks
                _, labels[idx] = split(labels[idx])
                _, masks[idx] = split(masks[idx])
            else:
                crowd_boxes = None

            match(self.pos_threshold, self.neg_threshold, truths, priors.data,
                  labels[idx], crowd_boxes, loc_t, conf_t, idx_t, idx,
                  loc_data[idx])

            gt_box_t[idx, :, :] = truths[idx_t[idx]]

        # wrap targets
        loc_t = Variable(loc_t, requires_grad=False)
        conf_t = Variable(conf_t, requires_grad=False)
        idx_t = Variable(idx_t, requires_grad=False)

        pos = conf_t > 0
        num_pos = pos.sum(dim=1, keepdim=True)

        # Shape: [batch,num_priors,4]
        pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)

        losses = {}

        # Localization Loss (Smooth L1)
        if cfg.train_boxes:
            loc_p = loc_data[pos_idx].view(-1, 4)
            loc_t = loc_t[pos_idx].view(-1, 4)
            losses['B'] = F.smooth_l1_loss(loc_p, loc_t,
                                           reduction='sum') * cfg.bbox_alpha

        if cfg.train_masks:
            if cfg.mask_type == mask_type.direct:
                if cfg.use_gt_bboxes:
                    pos_masks = []
                    for idx in range(batch_size):
                        pos_masks.append(masks[idx][idx_t[idx, pos[idx]]])
                    masks_t = torch.cat(pos_masks, 0)
                    masks_p = mask_data[pos, :].view(-1, cfg.mask_dim)
                    losses['M'] = F.binary_cross_entropy(
                        torch.clamp(masks_p, 0, 1), masks_t,
                        reduction='sum') * cfg.mask_alpha
                else:
                    losses['M'] = self.direct_mask_loss(
                        pos_idx, idx_t, loc_data, mask_data, priors, masks)
            elif cfg.mask_type == mask_type.lincomb:
                losses.update(
                    self.lincomb_mask_loss(pos, idx_t, loc_data, mask_data,
                                           priors, proto_data, masks, gt_box_t,
                                           score_data, inst_data))

                if cfg.mask_proto_loss is not None:
                    if cfg.mask_proto_loss == 'l1':
                        losses['P'] = torch.mean(
                            torch.abs(proto_data)
                        ) / self.l1_expected_area * self.l1_alpha
                    elif cfg.mask_proto_loss == 'disj':
                        losses['P'] = -torch.mean(
                            torch.max(F.log_softmax(proto_data, dim=-1),
                                      dim=-1)[0])

        # Confidence loss
        if cfg.use_focal_loss:
            if cfg.use_sigmoid_focal_loss:
                losses['C'] = self.focal_conf_sigmoid_loss(conf_data, conf_t)
            elif cfg.use_objectness_score:
                losses['C'] = self.focal_conf_objectness_loss(
                    conf_data, conf_t)
            else:
                losses['C'] = self.focal_conf_loss(conf_data, conf_t)
        else:
            if cfg.use_objectness_score:
                losses['C'] = self.conf_objectness_loss(
                    conf_data, conf_t, batch_size, loc_p, loc_t, priors)
            else:
                losses['C'] = self.ohem_conf_loss(conf_data, conf_t, pos,
                                                  batch_size)

        # These losses also don't depend on anchors
        if cfg.use_class_existence_loss:
            losses['E'] = self.class_existence_loss(predictions['classes'],
                                                    class_existence_t)
        if cfg.use_semantic_segmentation_loss:
            losses['S'] = self.semantic_segmentation_loss(
                predictions['segm'], masks, labels)

        # Divide all losses by the number of positives.
        # Don't do it for loss[P] because that doesn't depend on the anchors.
        total_num_pos = num_pos.data.sum().float()
        for k in losses:
            if k not in ('P', 'E', 'S'):
                losses[k] /= total_num_pos
            else:
                losses[k] /= batch_size

        # Loss Key:
        #  - B: Box Localization Loss
        #  - C: Class Confidence Loss
        #  - M: Mask Loss
        #  - P: Prototype Loss
        #  - D: Coefficient Diversity Loss
        #  - E: Class Existence Loss
        #  - S: Semantic Segmentation Loss
        return losses
        for batch_idx, sequence_batch in enumerate(train_loader):
            sequence_batch = Variable(sequence_batch, requires_grad=False)
            if use_cuda:
                sequence_batch = sequence_batch.cuda()
            sequence = sequence_batch.squeeze(dim=0)
            subsequences = torch.split(sequence, split_size=100)
            for seq in subsequences:
                batch_size = 1
                seq_len = seq.size()[0]
                seq = seq.view(seq_len, -1).contiguous()
                seq = seq.unsqueeze(dim=1)
                targets = seq[1:]

                optimizer.zero_grad()
                predictions = model(seq)
                losses = [F.binary_cross_entropy(input=pred, target=targets[step]) for step, pred in enumerate(predictions[:-1])]
                loss = sum(losses)
                loss.backward()
                torch.nn.utils.clip_grad_norm(model.parameters(), .25)
                optimizer.step()
                train_losses.append(np.mean([l.data.cpu().numpy() for l in losses]))

                accuracies = [F.mse_loss(input=pred, target=targets[step]) for step, pred in enumerate(predictions[:-1])]
                train_accuracies.append(np.mean([acc.data.cpu().numpy() for acc in accuracies]))

                acc = []
                p = torch.cat(predictions[:-1]).view(batch_size, seq_len - 1, rest).contiguous()
                t = targets.view(batch_size, seq_len - 1, rest).contiguous()
                for i, px in enumerate(p):
                    tx = t[i]
                    acc.append(torch.sum((tx - px) ** 2) / len(px))
Ejemplo n.º 60
0
    def lincomb_mask_loss(self,
                          pos,
                          idx_t,
                          loc_data,
                          mask_data,
                          priors,
                          proto_data,
                          masks,
                          gt_box_t,
                          score_data,
                          inst_data,
                          interpolation_mode='bilinear'):
        mask_h = proto_data.size(1)
        mask_w = proto_data.size(2)

        process_gt_bboxes = cfg.mask_proto_normalize_emulate_roi_pooling or cfg.mask_proto_crop

        if cfg.mask_proto_remove_empty_masks:
            # Make sure to store a copy of this because we edit it to get rid of all-zero masks
            pos = pos.clone()

        loss_m = 0
        loss_d = 0  # Coefficient diversity loss

        for idx in range(mask_data.size(0)):
            with torch.no_grad():
                downsampled_masks = F.interpolate(
                    masks[idx].unsqueeze(0), (mask_h, mask_w),
                    mode=interpolation_mode,
                    align_corners=False).squeeze(0)
                downsampled_masks = downsampled_masks.permute(1, 2,
                                                              0).contiguous()

                if cfg.mask_proto_binarize_downsampled_gt:
                    downsampled_masks = downsampled_masks.gt(0.5).float()

                if cfg.mask_proto_remove_empty_masks:
                    # Get rid of gt masks that are so small they get downsampled away
                    very_small_masks = (downsampled_masks.sum(dim=(0, 1)) <=
                                        0.0001)
                    for i in range(very_small_masks.size(0)):
                        if very_small_masks[i]:
                            pos[idx, idx_t[idx] == i] = 0

                if cfg.mask_proto_reweight_mask_loss:
                    # Ensure that the gt is binary
                    if not cfg.mask_proto_binarize_downsampled_gt:
                        bin_gt = downsampled_masks.gt(0.5).float()
                    else:
                        bin_gt = downsampled_masks

                    gt_foreground_norm = bin_gt / (
                        torch.sum(bin_gt, dim=(0, 1), keepdim=True) + 0.0001)
                    gt_background_norm = (1 - bin_gt) / (torch.sum(
                        1 - bin_gt, dim=(0, 1), keepdim=True) + 0.0001)

                    mask_reweighting = gt_foreground_norm * cfg.mask_proto_reweight_coeff + gt_background_norm
                    mask_reweighting *= mask_h * mask_w

            cur_pos = pos[idx]
            pos_idx_t = idx_t[idx, cur_pos]

            if process_gt_bboxes:
                # Note: this is in point-form
                if cfg.mask_proto_crop_with_pred_box:
                    pos_gt_box_t = decode(loc_data[idx, :, :], priors.data,
                                          cfg.use_yolo_regressors)[cur_pos]
                else:
                    pos_gt_box_t = gt_box_t[idx, cur_pos]

            if pos_idx_t.size(0) == 0:
                continue

            proto_masks = proto_data[idx]
            proto_coef = mask_data[idx, cur_pos, :]
            if cfg.use_mask_scoring:
                mask_scores = score_data[idx, cur_pos, :]

            if cfg.mask_proto_coeff_diversity_loss:
                if inst_data is not None:
                    div_coeffs = inst_data[idx, cur_pos, :]
                else:
                    div_coeffs = proto_coef

                loss_d += self.coeff_diversity_loss(div_coeffs, pos_idx_t)

            # If we have over the allowed number of masks, select a random sample
            old_num_pos = proto_coef.size(0)
            if old_num_pos > cfg.masks_to_train:
                perm = torch.randperm(proto_coef.size(0))
                select = perm[:cfg.masks_to_train]

                proto_coef = proto_coef[select, :]
                pos_idx_t = pos_idx_t[select]

                if process_gt_bboxes:
                    pos_gt_box_t = pos_gt_box_t[select, :]
                if cfg.use_mask_scoring:
                    mask_scores = mask_scores[select, :]

            num_pos = proto_coef.size(0)
            mask_t = downsampled_masks[:, :, pos_idx_t]

            # Size: [mask_h, mask_w, num_pos]
            pred_masks = proto_masks @ proto_coef.t()
            pred_masks = cfg.mask_proto_mask_activation(pred_masks)

            if cfg.mask_proto_double_loss:
                if cfg.mask_proto_mask_activation == activation_func.sigmoid:
                    pre_loss = F.binary_cross_entropy(torch.clamp(
                        pred_masks, 0, 1),
                                                      mask_t,
                                                      reduction='sum')
                else:
                    pre_loss = F.smooth_l1_loss(pred_masks,
                                                mask_t,
                                                reduction='sum')

                loss_m += cfg.mask_proto_double_loss_alpha * pre_loss

            if cfg.mask_proto_crop:
                pred_masks = crop(pred_masks, pos_gt_box_t)

            if cfg.mask_proto_mask_activation == activation_func.sigmoid:
                pre_loss = F.binary_cross_entropy(torch.clamp(
                    pred_masks, 0, 1),
                                                  mask_t,
                                                  reduction='none')
            else:
                pre_loss = F.smooth_l1_loss(pred_masks,
                                            mask_t,
                                            reduction='none')

            if cfg.mask_proto_normalize_mask_loss_by_sqrt_area:
                gt_area = torch.sum(mask_t, dim=(0, 1), keepdim=True)
                pre_loss = pre_loss / (torch.sqrt(gt_area) + 0.0001)

            if cfg.mask_proto_reweight_mask_loss:
                pre_loss = pre_loss * mask_reweighting[:, :, pos_idx_t]

            if cfg.mask_proto_normalize_emulate_roi_pooling:
                weight = mask_h * mask_w if cfg.mask_proto_crop else 1
                pos_get_csize = center_size(pos_gt_box_t)
                gt_box_width = pos_get_csize[:, 2] * mask_w
                gt_box_height = pos_get_csize[:, 3] * mask_h
                pre_loss = pre_loss.sum(
                    dim=(0, 1)) / gt_box_width / gt_box_height * weight

            # If the number of masks were limited scale the loss accordingly
            if old_num_pos > num_pos:
                pre_loss *= old_num_pos / num_pos

            loss_m += torch.sum(pre_loss)

        losses = {'M': loss_m * cfg.mask_alpha / mask_h / mask_w}

        if cfg.mask_proto_coeff_diversity_loss:
            losses['D'] = loss_d

        return losses