Ejemplo n.º 1
0
    def forward(self, z_seq, a_seq, term_seq):
        # x: [B,2,84,84]
        # T = x.size()[0]

        h = torch.zeros(1,self.h_size).cuda()
        z_losses = []
        term_losses = []
        for t in range(len(term_seq)-1):

            inter = self.encode_az(a_seq[t], z_seq[t])
            h = self.update_h(h, inter)
            z_pred, term_pred = self.predict_output(h, inter)

            z_loss = torch.mean((z_seq[t+1] - z_pred)**2)
            term_loss = F.binary_cross_entropy_with_logits(input=term_pred, target=term_seq[t+1])

            z_losses.append(z_loss)
            term_losses.append(term_loss)

        z_loss = torch.mean(torch.stack(z_losses))
        term_loss = torch.mean(torch.stack(term_losses)) 

        loss = z_loss + term_loss 

        return loss, z_loss, term_loss
Ejemplo n.º 2
0
def get_loss(latent_obs, action, reward, terminal, latent_next_obs):
    """ Compute losses.

    The loss that is computed is:
    (GMMLoss(latent_next_obs, GMMPredicted) + MSE(reward, predicted_reward) +
         BCE(terminal, logit_terminal)) / (LSIZE + 2)
    The LSIZE + 2 factor is here to counteract the fact that the GMMLoss scales
    approximately linearily with LSIZE. All losses are averaged both on the
    batch and the sequence dimensions (the two first dimensions).

    :args latent_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor
    :args action: (BSIZE, SEQ_LEN, ASIZE) torch tensor
    :args reward: (BSIZE, SEQ_LEN) torch tensor
    :args latent_next_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor

    :returns: dictionary of losses, containing the gmm, the mse, the bce and
        the averaged loss.
    """
    latent_obs, action,\
        reward, terminal,\
        latent_next_obs = [arr.transpose(1, 0)
                           for arr in [latent_obs, action,
                                       reward, terminal,
                                       latent_next_obs]]
    mus, sigmas, logpi, rs, ds = mdrnn(action, latent_obs)
    gmm = gmm_loss(latent_next_obs, mus, sigmas, logpi)
    bce = f.binary_cross_entropy_with_logits(ds, terminal)
    mse = f.mse_loss(rs, reward)
    loss = (gmm + bce + mse) / (LSIZE + 2)
    return dict(gmm=gmm, bce=bce, mse=mse, loss=loss)
Ejemplo n.º 3
0
def deep_supervised_criterion(logit, logit_pixel, logit_image, truth_pixel, truth_image, is_average=True):
    loss_image = F.binary_cross_entropy_with_logits(logit_image, truth_image, reduce=is_average)
    loss_pixel = 0
    for l in logit_pixel:
        loss_pixel += symmetric_lovasz_ignore_empty(l.squeeze(1), truth_pixel, truth_image)
    loss = symmetric_lovasz(logit.squeeze(1), truth_pixel)
    return 0.05 * loss_image + 0.1 * loss_pixel + 1 * loss
Ejemplo n.º 4
0
def single_scale_rpn_losses(
        rpn_cls_logits, rpn_bbox_pred,
        rpn_labels_int32_wide, rpn_bbox_targets_wide,
        rpn_bbox_inside_weights_wide, rpn_bbox_outside_weights_wide):
    """Add losses for a single scale RPN model (i.e., no FPN)."""
    h, w = rpn_cls_logits.shape[2:]
    rpn_labels_int32 = rpn_labels_int32_wide[:, :, :h, :w]   # -1 means ignore
    h, w = rpn_bbox_pred.shape[2:]
    rpn_bbox_targets = rpn_bbox_targets_wide[:, :, :h, :w]
    rpn_bbox_inside_weights = rpn_bbox_inside_weights_wide[:, :, :h, :w]
    rpn_bbox_outside_weights = rpn_bbox_outside_weights_wide[:, :, :h, :w]

    if cfg.RPN.CLS_ACTIVATION == 'softmax':
        B, C, H, W = rpn_cls_logits.size()
        rpn_cls_logits = rpn_cls_logits.view(
            B, 2, C // 2, H, W).permute(0, 2, 3, 4, 1).contiguous().view(-1, 2)
        rpn_labels_int32 = rpn_labels_int32.contiguous().view(-1).long()
        # the loss is averaged over non-ignored targets
        loss_rpn_cls = F.cross_entropy(
            rpn_cls_logits, rpn_labels_int32, ignore_index=-1)
    else:
        weight = (rpn_labels_int32 >= 0).float()
        loss_rpn_cls = F.binary_cross_entropy_with_logits(
            rpn_cls_logits, rpn_labels_int32.float(), weight, size_average=False)
        loss_rpn_cls /= weight.sum()

    loss_rpn_bbox = net_utils.smooth_l1_loss(
        rpn_bbox_pred, rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights,
        beta=1/9)

    return loss_rpn_cls, loss_rpn_bbox
Ejemplo n.º 5
0
    def __call__(self, anchors, objectness, box_regression, targets):
        """
        Arguments:
            anchors (list[BoxList])
            objectness (list[Tensor])
            box_regression (list[Tensor])
            targets (list[BoxList])

        Returns:
            objectness_loss (Tensor)
            box_loss (Tensor
        """
        anchors = [cat_boxlist(anchors_per_image) for anchors_per_image in anchors]
        labels, regression_targets = self.prepare_targets(anchors, targets)
        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
        sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1)
        sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1)

        sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)

        objectness_flattened = []
        box_regression_flattened = []
        # for each feature level, permute the outputs to make them be in the
        # same format as the labels. Note that the labels are computed for
        # all feature levels concatenated, so we keep the same representation
        # for the objectness and the box_regression
        for objectness_per_level, box_regression_per_level in zip(
            objectness, box_regression
        ):
            N, A, H, W = objectness_per_level.shape
            objectness_per_level = objectness_per_level.permute(0, 2, 3, 1).reshape(
                N, -1
            )
            box_regression_per_level = box_regression_per_level.view(N, -1, 4, H, W)
            box_regression_per_level = box_regression_per_level.permute(0, 3, 4, 1, 2)
            box_regression_per_level = box_regression_per_level.reshape(N, -1, 4)
            objectness_flattened.append(objectness_per_level)
            box_regression_flattened.append(box_regression_per_level)
        # concatenate on the first dimension (representing the feature levels), to
        # take into account the way the labels were generated (with all feature maps
        # being concatenated as well)
        objectness = cat(objectness_flattened, dim=1).reshape(-1)
        box_regression = cat(box_regression_flattened, dim=1).reshape(-1, 4)

        labels = torch.cat(labels, dim=0)
        regression_targets = torch.cat(regression_targets, dim=0)

        box_loss = smooth_l1_loss(
            box_regression[sampled_pos_inds],
            regression_targets[sampled_pos_inds],
            beta=1.0 / 9,
            size_average=False,
        ) / (sampled_inds.numel())

        objectness_loss = F.binary_cross_entropy_with_logits(
            objectness[sampled_inds], labels[sampled_inds]
        )

        return objectness_loss, box_loss
Ejemplo n.º 6
0
    def forward(self, frame): #, DQNs):
        # x: [B,2,84,84]
        self.B = frame.size()[0]

        recon = self.reconstruct(frame)  #[B,3,480,640]

        loss = F.binary_cross_entropy_with_logits(input=recon, target=frame)

        return loss #, dif, mask_sum
Ejemplo n.º 7
0
def mask_rcnn_losses(masks_pred, masks_int32):
    """Mask R-CNN specific losses."""
    n_rois, n_classes, _, _ = masks_pred.size()
    device_id = masks_pred.get_device()
    masks_gt = Variable(torch.from_numpy(masks_int32.astype('float32'))).cuda(device_id)
    weight = (masks_gt > -1).float()  # masks_int32 {1, 0, -1}, -1 means ignore
    loss = F.binary_cross_entropy_with_logits(
        masks_pred.view(n_rois, -1), masks_gt, weight, size_average=False)
    loss /= weight.sum()
    return loss * cfg.MRCNN.WEIGHT_LOSS_MASK
Ejemplo n.º 8
0
    def forward(self, inputs, targets):
        if self.logits:
            BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
        else:
            BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return torch.sum(F_loss)
Ejemplo n.º 9
0
 def _information_loss(model: InfoGAN, fake_hidden, latent):
     cat_logit, cont_mean, cont_logvar, bin_logit = model.rec(fake_hidden)
     info_loss = 0.
     if model.cat_dim > 0:
         cat_code = latent[:, model.cat_idx]
         info_loss += F.cross_entropy(cat_logit, cat_code.argmax(1))
     if model.cont_dim > 0:
         cont_code = latent[:, model.cont_idx]
         info_loss += .1 * _gaussian_loss(cont_code, cont_mean, cont_logvar)
     if model.bin_dim > 0:
         bin_code = latent[:, model.bin_idx]
         info_loss += 2 * F.binary_cross_entropy_with_logits(bin_logit, bin_code)
     return info_loss
Ejemplo n.º 10
0
def mrcnn_mask_loss(target_masks, target_class_ids, pred_masks_logits):
    """Mask binary cross-entropy loss for the masks head.

    target_masks: [batch, num_rois, height, width].
        A float32 tensor of values 0 or 1. Uses zero padding to fill array.
    target_class_ids: [batch, num_rois]. Integer class IDs. Zero padded.
    pred_masks: [batch, proposals, height, width, num_classes] float32 tensor
                with values from 0 to 1.
    """
    # Reshape for simplicity. Merge first two dimensions into one.
    target_class_ids = target_class_ids.view(-1)

    loss = F.binary_cross_entropy_with_logits(pred_masks_logits, target_masks)
    return loss                            
Ejemplo n.º 11
0
    def forward(self, x):
        # x: [B,2,84,84]
        B = x.size()[0]

        z = self.encode(x) 
        # za = torch.cat((z1.detach(),a),1) #[B,z+a]
        # z2_prior = self.transition(za)
        # z2 = self.encode(s2)
        recon = self.decode(z) 

        # z_loss = torch.mean(z2**2) * .001

        recon_loss = F.binary_cross_entropy_with_logits(input=recon, target=x) #, reduce=False)

        loss = recon_loss #+ tran_loss + terminal_loss + z_loss

        return loss, recon_loss#, tran_loss, terminal_loss, z_loss
Ejemplo n.º 12
0
def BCELogit_Loss(score_map, labels):
    """ The Binary Cross-Correlation with Logits Loss.
    Args:
        score_map (torch.Tensor): The score map tensor of shape [B,1,H,W]
        labels (torch.Tensor): The label tensor of shape [B,H,W,2] where the
            fourth dimension is separated in two maps, the first indicates whether
            the pixel is negative (0) or positive (1) and the second one whether
            the pixel is positive/negative (1) or neutral (0) in which case it
            will simply be ignored.
    Return:
        loss (scalar torch.Tensor): The BCE Loss with Logits for the score map and labels.
    """
    labels = labels.unsqueeze(1)
    loss = F.binary_cross_entropy_with_logits(score_map, labels[:, :, :, :, 0],
                                              weight=labels[:, :, :, :, 1],
                                              reduction='elementwise_mean')
    return loss
Ejemplo n.º 13
0
def fpn_rpn_losses(**kwargs):
    """Add RPN on FPN specific losses."""
    losses_cls = []
    losses_bbox = []
    for lvl in range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL + 1):
        slvl = str(lvl)
        # Spatially narrow the full-sized RPN label arrays to match the feature map shape
        b, c, h, w = kwargs['rpn_cls_logits_fpn' + slvl].shape
        rpn_labels_int32_fpn = kwargs['rpn_labels_int32_wide_fpn' + slvl][:, :, :h, :w]
        h, w = kwargs['rpn_bbox_pred_fpn' + slvl].shape[2:]
        rpn_bbox_targets_fpn = kwargs['rpn_bbox_targets_wide_fpn' + slvl][:, :, :h, :w]
        rpn_bbox_inside_weights_fpn = kwargs[
            'rpn_bbox_inside_weights_wide_fpn' + slvl][:, :, :h, :w]
        rpn_bbox_outside_weights_fpn = kwargs[
            'rpn_bbox_outside_weights_wide_fpn' + slvl][:, :, :h, :w]

        if cfg.RPN.CLS_ACTIVATION == 'softmax':
            rpn_cls_logits_fpn = kwargs['rpn_cls_logits_fpn' + slvl].view(
                b, 2, c // 2, h, w).permute(0, 2, 3, 4, 1).contiguous().view(-1, 2)
            rpn_labels_int32_fpn = rpn_labels_int32_fpn.contiguous().view(-1).long()
            # the loss is averaged over non-ignored targets
            loss_rpn_cls_fpn = F.cross_entropy(
                rpn_cls_logits_fpn, rpn_labels_int32_fpn, ignore_index=-1)
        else:  # sigmoid
            weight = (rpn_labels_int32_fpn >= 0).float()
            loss_rpn_cls_fpn = F.binary_cross_entropy_with_logits(
                kwargs['rpn_cls_logits_fpn' + slvl], rpn_labels_int32_fpn.float(), weight,
                size_average=False)
            loss_rpn_cls_fpn /= cfg.TRAIN.RPN_BATCH_SIZE_PER_IM * cfg.TRAIN.IMS_PER_BATCH

        # Normalization by (1) RPN_BATCH_SIZE_PER_IM and (2) IMS_PER_BATCH is
        # handled by (1) setting bbox outside weights and (2) SmoothL1Loss
        # normalizes by IMS_PER_BATCH
        loss_rpn_bbox_fpn = net_utils.smooth_l1_loss(
            kwargs['rpn_bbox_pred_fpn' + slvl], rpn_bbox_targets_fpn,
            rpn_bbox_inside_weights_fpn, rpn_bbox_outside_weights_fpn,
            beta=1/9)

        losses_cls.append(loss_rpn_cls_fpn)
        losses_bbox.append(loss_rpn_bbox_fpn)

    return losses_cls, losses_bbox
Ejemplo n.º 14
0
    def forward(self, s1, s2, a, isterminal):
        # x: [B,2,84,84]
        B = s1.size()[0]

        z1 = self.encode(s1) 
        za = torch.cat((z1.detach(),a),1) #[B,z+a]
        z2_prior = self.transition(za)
        z2 = self.encode(s2)
        s2_recon, isTerm = self.decode(z2) 

        z_loss = torch.mean(z2**2) * .001

        recon_loss = 10. * F.binary_cross_entropy_with_logits(input=s2_recon, target=s2) #, reduce=False)
        # recon_loss = recon_loss.view(B,-1)
        # recon_loss = torch.mean(torch.sum(recon_loss, dim=1))

        # tran_loss = torch.mean(torch.sum((z2.detach()-z2_prior)**2, dim=1))
        tran_loss = torch.mean((z2.detach()-z2_prior)**2)

        # if self.tmp:
        #     print (z2)
        #     print (z2_prior)
        #     print (tran_loss)

        # fsa
        # fads
        # tran_loss = torch.mean((z2-z2_prior)**2)

        # print (isTerm)
        # print (isterminal)
        # print (isTerm.shape)
        # print (isterminal.shape)
        # print (torch.sum((isTerm-isterminal)**2, dim=1).shape)
        # print (torch.sum((isTerm-isterminal)**2, dim=1))
        # fdsa

        # terminal_loss = torch.mean(torch.sum((isTerm-isterminal)**2, dim=1))
        terminal_loss = torch.mean((isTerm-isterminal)**2)

        loss = recon_loss + tran_loss + terminal_loss + z_loss

        return loss, recon_loss, tran_loss, terminal_loss, z_loss
Ejemplo n.º 15
0
    def calculate_loss_single_channel(self, output, target, meter, training, iter_size):
        bce = F.binary_cross_entropy_with_logits(output, target)
        output = F.sigmoid(output)
        d = dice(output, target)
        # jacc = jaccard(output, target)
        dice_r = dice_round(output, target)
        # jacc_r = jaccard_round(output, target)

        loss = (self.config.loss['bce'] * bce + self.config.loss['dice'] * (1 - d)) / iter_size

        if training:
            loss.backward()

        meter['loss'] += loss.data.cpu().numpy()[0]
        meter['dice'] += d.data.cpu().numpy()[0] / iter_size
        # meter['jacc'] += jacc.data.cpu().numpy()[0] / iter_size
        meter['bce'] += bce.data.cpu().numpy()[0] / iter_size
        meter['dr'] += dice_r.data.cpu().numpy()[0] / iter_size
        # meter['jr'] += jacc_r.data.cpu().numpy()[0] / iter_size
        return meter
Ejemplo n.º 16
0
 def forward(self, predict, target, weight=None):
     """
         Args:
             predict:(n, 1, h, w)
             target:(n, 1, h, w)
             weight (Tensor, optional): a manual rescaling weight given to each class.
                                        If given, has to be a Tensor of size "nclasses"
     """
     assert not target.requires_grad
     assert predict.dim() == 4
     assert target.dim() == 4
     assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0))
     assert predict.size(2) == target.size(2), "{0} vs {1} ".format(predict.size(2), target.size(2))
     assert predict.size(3) == target.size(3), "{0} vs {1} ".format(predict.size(3), target.size(3))
     n, c, h, w = predict.size()
     target_mask = (target >= 0) * (target != self.ignore_label)
     target = target[target_mask]
     if not target.data.dim():
         return Variable(torch.zeros(1))
     predict = predict[target_mask]
     loss = F.binary_cross_entropy_with_logits(predict, target, weight=weight, size_average=self.size_average)
     return loss
Ejemplo n.º 17
0
    def focal_loss(self, x, y):
        '''Focal loss.

        Args:
          x: (tensor) sized [N,D].
          y: (tensor) sized [N,].

        Return:
          (tensor) focal loss.
        '''
        alpha = 0.25
        gamma = 2

        t = one_hot_embedding(y.data.cpu(), 1+self.num_classes)  # [N,21]
        t = t[:,1:]  # exclude background
        t = Variable(t).cuda()  # [N,20]

        p = x.sigmoid()
        pt = p*t + (1-p)*(1-t)         # pt = p if t > 0 else 1-p
        w = alpha*t + (1-alpha)*(1-t)  # w = alpha if t > 0 else 1-alpha
        w = w * (1-pt).pow(gamma)
        return F.binary_cross_entropy_with_logits(x, t, w, size_average=False)
Ejemplo n.º 18
0
    def forward(self, x, k=1):
        
        self.B = x.size()[0]
        mu, logvar = self.encode(x)
        z, logpz, logqz = self.sample(mu, logvar, k=k)  #[P,B,Z]
        x_hat = self.decode(z)  #[PB,X]
        # x_hat = x_hat.view(k, self.B, -1)
        # print x_hat.size()
        # print x_hat.size()
        # print x.size()


        logpx = -F.binary_cross_entropy_with_logits(input=x_hat, target=x) * self.x_size


        # print (logpx.shape)
        # print (logpz.shape)
        # print (logqz.shape)
        # fsda
        # logpx = log_bernoulli(x_hat, x)  #[P,B]

        logpz = torch.mean(logpz)
        logqz = torch.mean(logqz)

        elbo = logpx + logpz - logqz  #[P,B]

        # if k>1:
        #     max_ = torch.max(elbo, 0)[0] #[B]
        #     elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B]

        # elbo = torch.mean(elbo) #[1]

        # #for printing
        # logpx = torch.mean(logpx)
        # logpz = torch.mean(logpz)
        # logqz = torch.mean(logqz)
        # self.x_hat_sigmoid = F.sigmoid(x_hat)

        return elbo, logpx, logpz, logqz
Ejemplo n.º 19
0
def configure_training(net_type, opt, lr, clip_grad, lr_decay, batch_size):
    """ supports Adam optimizer only"""
    assert opt in ['adam']
    assert net_type in ['ff', 'rnn']
    opt_kwargs = {}
    opt_kwargs['lr'] = lr

    train_params = {}
    train_params['optimizer']      = (opt, opt_kwargs)
    train_params['clip_grad_norm'] = clip_grad
    train_params['batch_size']     = batch_size
    train_params['lr_decay']       = lr_decay

    if net_type == 'ff':
        criterion = lambda logit, target: F.binary_cross_entropy_with_logits(
            logit, target, reduce=False)
    else:
        ce = lambda logit, target: F.cross_entropy(logit, target, reduce=False)
        def criterion(logits, targets):
            return sequence_loss(logits, targets, ce, pad_idx=-1)

    return criterion, train_params
Ejemplo n.º 20
0
 def entropy(self):
     return binary_cross_entropy_with_logits(self.logits, self.probs, reduce=False)
Ejemplo n.º 21
0
 def log_prob(self, value):
     self._validate_log_prob_arg(value)
     logits, value = broadcast_all(self.logits, value)
     return -binary_cross_entropy_with_logits(logits, value, reduce=False)
 def get_loss(self, pred, target):
     loss = F.binary_cross_entropy_with_logits(pred, target.reshape(-1, 1))
     return loss
Ejemplo n.º 23
0
    def get_scores(self, silent=False):
        eval_features = convert_examples_to_features(self.eval_examples,
                                                     self.args.max_seq_length,
                                                     self.tokenizer)

        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                     dtype=torch.long)

        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=self.args.batch_size)

        self.model.eval()

        total_loss = 0
        nb_eval_steps, nb_eval_examples = 0, 0
        predicted_labels, target_labels = list(), list()

        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                eval_dataloader, desc="Evaluating", disable=silent):
            input_ids = input_ids.to(self.args.device)
            input_mask = input_mask.to(self.args.device)
            segment_ids = segment_ids.to(self.args.device)
            label_ids = label_ids.to(self.args.device)

            with torch.no_grad():
                logits = self.model(input_ids, segment_ids, input_mask)

            if self.args.is_multilabel:
                predicted_labels.extend(
                    F.sigmoid(logits).round().long().cpu().detach().numpy())
                target_labels.extend(label_ids.cpu().detach().numpy())
                loss = F.binary_cross_entropy_with_logits(logits,
                                                          label_ids.float(),
                                                          size_average=False)
            else:
                predicted_labels.extend(
                    torch.argmax(logits, dim=1).cpu().detach().numpy())
                target_labels.extend(
                    torch.argmax(label_ids, dim=1).cpu().detach().numpy())
                loss = F.cross_entropy(logits, torch.argmax(label_ids, dim=1))

            if self.args.n_gpu > 1:
                loss = loss.mean()
            if self.args.gradient_accumulation_steps > 1:
                loss = loss / self.args.gradient_accumulation_steps
            total_loss += loss.item()

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        predicted_labels, target_labels = np.array(predicted_labels), np.array(
            target_labels)
        accuracy = metrics.accuracy_score(target_labels, predicted_labels)
        precision = metrics.precision_score(target_labels,
                                            predicted_labels,
                                            average=None)[0]
        recall = metrics.recall_score(target_labels,
                                      predicted_labels,
                                      average=None)[0]
        f1 = metrics.f1_score(target_labels, predicted_labels, average=None)[0]
        avg_loss = total_loss / nb_eval_steps

        return [accuracy, precision, recall, f1, avg_loss
                ], ['accuracy', 'precision', 'recall', 'f1', 'avg_loss']
Ejemplo n.º 24
0
    def update(self, batch):
        # Train mode
        self.network.train()
        torch.set_grad_enabled(True)

        # Transfer to GPU
        if self.opt['cuda']:
            inputs = [e.cuda(non_blocking=True) for e in batch[:10]]
            # overall_mask:[bsz,max_q_num] =1 if 存在q
            overall_mask = batch[9].cuda(non_blocking=True)

            answer_s = batch[10].cuda(non_blocking=True)
            answer_e = batch[11].cuda(non_blocking=True)
            answer_c = batch[12].cuda(non_blocking=True)
        else:
            inputs = [e for e in batch[:10]]
            overall_mask = batch[9]

            answer_s = batch[10]
            answer_e = batch[11]
            answer_c = batch[12]

        # Run forward
        # output: [batch_size, question_num, context_len], [batch_size, question_num]
        score_s, score_e, score_no_answ = self.network(*inputs)

        # Compute loss and accuracies
        # elmo_lambda=0
        if self.opt['use_elmo']:
            loss = self.opt['elmo_lambda'] * (
                self.network.elmo.scalar_mix_0.scalar_parameters[0]**2 +
                self.network.elmo.scalar_mix_0.scalar_parameters[1]**2 +
                self.network.elmo.scalar_mix_0.scalar_parameters[2]**2)
        else:
            loss = 0
            # ELMo L2 regularization
        all_no_answ = (answer_c == 0)
        answer_s.masked_fill_(all_no_answ,
                              -100)  # ignore_index is -100 in F.cross_entropy
        answer_e.masked_fill_(all_no_answ, -100)

        for i in range(overall_mask.size(0)):
            q_num = sum(overall_mask[i]
                        )  # the true question number for this sampled context

            target_s = answer_s[i, :q_num]  # Size: q_num
            target_e = answer_e[i, :q_num]
            target_c = answer_c[i, :q_num]
            target_no_answ = all_no_answ[i, :q_num]

            # single_loss is averaged across q_num
            # default:true
            if self.opt['question_normalize']:
                single_loss = F.binary_cross_entropy_with_logits(
                    score_no_answ[i, :q_num],
                    target_no_answ.float()) * q_num.item() / 8.0
                single_loss = single_loss + F.cross_entropy(
                    score_s[i, :q_num],
                    target_s) * (q_num - sum(target_no_answ)).item() / 7.0
                single_loss = single_loss + F.cross_entropy(
                    score_e[i, :q_num],
                    target_e) * (q_num - sum(target_no_answ)).item() / 7.0
            else:
                single_loss = F.binary_cross_entropy_with_logits(score_no_answ[i, :q_num], target_no_answ.float()) \
                              + F.cross_entropy(score_s[i, :q_num], target_s) + F.cross_entropy(score_e[i, :q_num],
                                                                                                target_e)

            loss = loss + (single_loss / overall_mask.size(0))
        self.train_loss.update(loss.item(), overall_mask.size(0))

        # Clear gradients and run backward
        self.optimizer.zero_grad()
        loss.backward()

        # Clip gradients
        torch.nn.utils.clip_grad_norm_(self.network.parameters(),
                                       self.opt['grad_clipping'])

        # Update parameters
        self.optimizer.step()
        self.updates += 1

        # Reset any partially fixed parameters (e.g. rare words)
        self.reset_embeddings()
        self.eval_embed_transfer = True
Ejemplo n.º 25
0
    def train_multi(self):
        """Train StarGAN with multiple datasets.
        In the code below, 1 is related to CelebA and 2 is releated to RaFD.
        """
        # Fixed imagse and labels for debugging
        fixed_x = []
        real_c = []

        for i, (images, labels) in enumerate(self.celebA_loader):
            fixed_x.append(images)
            real_c.append(labels)
            if i == 2:
                break

        fixed_x = torch.cat(fixed_x, dim=0)
        fixed_x = self.to_var(fixed_x, volatile=True)
        real_c = torch.cat(real_c, dim=0)
        fixed_c1_list = self.make_celeb_labels(real_c)

        fixed_c2_list = []
        for i in range(self.c2_dim):
            fixed_c = self.one_hot(torch.ones(fixed_x.size(0)) * i, self.c2_dim)
            fixed_c2_list.append(self.to_var(fixed_c, volatile=True))

        fixed_zero1 = self.to_var(torch.zeros(fixed_x.size(0), self.c2_dim))     # zero vector when training with CelebA
        fixed_mask1 = self.to_var(self.one_hot(torch.zeros(fixed_x.size(0)), 2)) # mask vector: [1, 0]
        fixed_zero2 = self.to_var(torch.zeros(fixed_x.size(0), self.c_dim))      # zero vector when training with RaFD
        fixed_mask2 = self.to_var(self.one_hot(torch.ones(fixed_x.size(0)), 2))  # mask vector: [0, 1]

        # lr cache for decaying
        g_lr = self.g_lr
        d_lr = self.d_lr

        # data iterator
        data_iter1 = iter(self.celebA_loader)
        data_iter2 = iter(self.rafd_loader)

        # Start with trained model
        if self.pretrained_model:
            start = int(self.pretrained_model) + 1
        else:
            start = 0

        # # Start training
        start_time = time.time()
        for i in range(start, self.num_iters):

            # Fetch mini-batch images and labels
            try:
                real_x1, real_label1 = next(data_iter1)
            except:
                data_iter1 = iter(self.celebA_loader)
                real_x1, real_label1 = next(data_iter1)

            try:
                real_x2, real_label2 = next(data_iter2)
            except:
                data_iter2 = iter(self.rafd_loader)
                real_x2, real_label2 = next(data_iter2)

            # Generate fake labels randomly (target domain labels)
            rand_idx = torch.randperm(real_label1.size(0))
            fake_label1 = real_label1[rand_idx]
            rand_idx = torch.randperm(real_label2.size(0))
            fake_label2 = real_label2[rand_idx]

            real_c1 = real_label1.clone()
            fake_c1 = fake_label1.clone()
            zero1 = torch.zeros(real_x1.size(0), self.c2_dim)
            mask1 = self.one_hot(torch.zeros(real_x1.size(0)), 2)

            real_c2 = self.one_hot(real_label2, self.c2_dim)
            fake_c2 = self.one_hot(fake_label2, self.c2_dim)
            zero2 = torch.zeros(real_x2.size(0), self.c_dim)
            mask2 = self.one_hot(torch.ones(real_x2.size(0)), 2)

            # Convert tensor to variable
            real_x1 = self.to_var(real_x1)
            real_c1 = self.to_var(real_c1)
            fake_c1 = self.to_var(fake_c1)
            mask1 = self.to_var(mask1)
            zero1 = self.to_var(zero1)

            real_x2 = self.to_var(real_x2)
            real_c2 = self.to_var(real_c2)
            fake_c2 = self.to_var(fake_c2)
            mask2 = self.to_var(mask2)
            zero2 = self.to_var(zero2)

            real_label1 = self.to_var(real_label1)
            fake_label1 = self.to_var(fake_label1)
            real_label2 = self.to_var(real_label2)
            fake_label2 = self.to_var(fake_label2)

            # ================== Train D ================== #

            # Real images (CelebA)
            out_real, out_cls = self.D(real_x1)
            out_cls1 = out_cls[:, :self.c_dim]      # celebA part
            d_loss_real = - torch.mean(out_real)
            d_loss_cls = F.binary_cross_entropy_with_logits(out_cls1, real_label1, size_average=False) / real_x1.size(0)

            # Real images (RaFD)
            out_real, out_cls = self.D(real_x2)
            out_cls2 = out_cls[:, self.c_dim:]      # rafd part
            d_loss_real += - torch.mean(out_real)
            d_loss_cls += F.cross_entropy(out_cls2, real_label2)

            # Compute classification accuracy of the discriminator
            if (i+1) % self.log_step == 0:
                accuracies = self.compute_accuracy(out_cls1, real_label1, 'CelebA')
                log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()]
                print('Classification Acc (Black/Blond/Brown/Gender/Aged): ', end='')
                print(log)
                accuracies = self.compute_accuracy(out_cls2, real_label2, 'RaFD')
                log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()]
                print('Classification Acc (8 emotional expressions): ', end='')
                print(log)

            # Fake images (CelebA)
            fake_c = torch.cat([fake_c1, zero1, mask1], dim=1)
            fake_x1 = self.G(real_x1, fake_c)
            fake_x1 = Variable(fake_x1.data)
            out_fake, _ = self.D(fake_x1)
            d_loss_fake = torch.mean(out_fake)

            # Fake images (RaFD)
            fake_c = torch.cat([zero2, fake_c2, mask2], dim=1)
            fake_x2 = self.G(real_x2, fake_c)
            out_fake, _ = self.D(fake_x2)
            d_loss_fake += torch.mean(out_fake)

            # Backward + Optimize
            d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # Compute gradient penalty
            if (i+1) % 2 == 0:
                real_x = real_x1
                fake_x = fake_x1
            else:
                real_x = real_x2
                fake_x = fake_x2

            alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x)
            interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True)
            out, out_cls = self.D(interpolated)

            if (i+1) % 2 == 0:
                out_cls = out_cls[:, :self.c_dim]  # CelebA
            else:
                out_cls = out_cls[:, self.c_dim:]  # RaFD

            grad = torch.autograd.grad(outputs=out,
                                       inputs=interpolated,
                                       grad_outputs=torch.ones(out.size()).cuda(),
                                       retain_graph=True,
                                       create_graph=True,
                                       only_inputs=True)[0]

            grad = grad.view(grad.size(0), -1)
            grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
            d_loss_gp = torch.mean((grad_l2norm - 1)**2)

            # Backward + Optimize
            d_loss = self.lambda_gp * d_loss_gp
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # Logging
            loss = {}
            loss['D/loss_real'] = d_loss_real.data[0]
            loss['D/loss_fake'] = d_loss_fake.data[0]
            loss['D/loss_cls'] = d_loss_cls.data[0]
            loss['D/loss_gp'] = d_loss_gp.data[0]

            # ================== Train G ================== #
            if (i+1) % self.d_train_repeat == 0:
                # Original-to-target and target-to-original domain (CelebA)
                fake_c = torch.cat([fake_c1, zero1, mask1], dim=1)
                real_c = torch.cat([real_c1, zero1, mask1], dim=1)
                fake_x1 = self.G(real_x1, fake_c)
                rec_x1 = self.G(fake_x1, real_c)

                # Compute losses
                out, out_cls = self.D(fake_x1)
                out_cls1 = out_cls[:, :self.c_dim]
                g_loss_fake = - torch.mean(out)
                g_loss_rec = torch.mean(torch.abs(real_x1 - rec_x1))
                g_loss_cls = F.binary_cross_entropy_with_logits(out_cls1, fake_label1, size_average=False) / fake_x1.size(0)

                # Original-to-target and target-to-original domain (RaFD)
                fake_c = torch.cat([zero2, fake_c2, mask2], dim=1)
                real_c = torch.cat([zero2, real_c2, mask2], dim=1)
                fake_x2 = self.G(real_x2, fake_c)
                rec_x2 = self.G(fake_x2, real_c)

                # Compute losses
                out, out_cls = self.D(fake_x2)
                out_cls2 = out_cls[:, self.c_dim:]
                g_loss_fake += - torch.mean(out)
                g_loss_rec += torch.mean(torch.abs(real_x2 - rec_x2))
                g_loss_cls += F.cross_entropy(out_cls2, fake_label2)

                # Backward + Optimize
                g_loss = g_loss_fake + self.lambda_cls * g_loss_cls + self.lambda_rec * g_loss_rec
                self.reset_grad()
                g_loss.backward()
                self.g_optimizer.step()

                # Logging
                loss['G/loss_fake'] = g_loss_fake.data[0]
                loss['G/loss_cls'] = g_loss_cls.data[0]
                loss['G/loss_rec'] = g_loss_rec.data[0]

            # Print out log info
            if (i+1) % self.log_step == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))

                log = "Elapsed [{}], Iter [{}/{}]".format(
                    elapsed, i+1, self.num_iters)

                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)

                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(tag, value, i+1)

            # Translate the images (debugging)
            if (i+1) % self.sample_step == 0:
                fake_image_list = [fixed_x]

                # Changing hair color, gender, and age
                for j in range(self.c_dim):
                    fake_c = torch.cat([fixed_c1_list[j], fixed_zero1, fixed_mask1], dim=1)
                    fake_image_list.append(self.G(fixed_x, fake_c))
                # Changing emotional expressions
                for j in range(self.c2_dim):
                    fake_c = torch.cat([fixed_zero2, fixed_c2_list[j], fixed_mask2], dim=1)
                    fake_image_list.append(self.G(fixed_x, fake_c))
                fake = torch.cat(fake_image_list, dim=3)

                # Save the translated images
                save_image(self.denorm(fake.data),
                    os.path.join(self.sample_path, '{}_fake.png'.format(i+1)), nrow=1, padding=0)

            # Save model checkpoints
            if (i+1) % self.model_save_step == 0:
                torch.save(self.G.state_dict(),
                    os.path.join(self.model_save_path, '{}_G.pth'.format(i+1)))
                torch.save(self.D.state_dict(),
                    os.path.join(self.model_save_path, '{}_D.pth'.format(i+1)))

            # Decay learning rate
            decay_step = 1000
            if (i+1) > (self.num_iters - self.num_iters_decay) and (i+1) % decay_step==0:
                g_lr -= (self.g_lr / float(self.num_iters_decay) * decay_step)
                d_lr -= (self.d_lr / float(self.num_iters_decay) * decay_step)
                self.update_lr(g_lr, d_lr)
                print ('Decay learning rate to g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
Ejemplo n.º 26
0
    def forward(self, x1, x2, label, size_average=True):

        prod = torch.einsum("ef,ef->e", x1, x2)
        return F.binary_cross_entropy_with_logits(prod, label)
Ejemplo n.º 27
0
def run(split, upto=None):
	torch.set_grad_enabled(split=='train')
	model.train() if split == 'train' else  model.eval()
	nsamples = 1 if split == 'train' else xte
	N, D = x.size()
	B = 128
	n_steps = N // B if upto is None else min(N//B, upto)
	losses = []
	for step in range(n_steps):
		xb = Variable(x[step * B: step * B + B])
		xbhat = torch.zeros_like(xb)
		for s in range(nsamples):
			if step % args.resample_every == 0 or split == 'test':
			model.update_masks()
			xbhat += model(xb)
		xbhat /= nsamples

		loss = F.binary_cross_entropy_with_logits(xbhat, xb, size_average=False) / B
		lossf = loss.data.item()
		losses.append(lossf)

		if split == 'train':
			opt.zero_grad()
			loss.backward()
			opt.step()

	print("%s epoch avg loss: %f" %(split, np.mean(losses)))

if __name__ == '__main__':
	parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--data-path', required=True, type=str, help="Path to binarized_mnist.npz")
    parser.add_argument('-q', '--hiddens', type=str, default='500', help="Comma separated sizes for hidden layers, e.g. 500, or 500,500")
    parser.add_argument('-n', '--num-masks', type=int, default=1, help="Number of orderings for order/connection-agnostic training")
    parser.add_argument('-r', '--resample-every', type=int, default=20, help="For efficiency we can choose to resample orders/masks only once every this many steps")
    parser.add_argument('-s', '--samples', type=int, default=1, help="How many samples of connectivity/masks to average logits over during inference")
    args = parser.parse_args()

    np.random_seed(42)
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)

    print("loading binarized mnist from", args.data_path)
    mnist = np.load(args.data_path)
    xtr, xte = mnist['train_data'], mnist['valid_data']
    xtr = torch.from_numpy(xtr).cuda()
    xte = torch.from_numpy(xte).cuda()

    # construct model and ship to GPU
    hidden_list = list(map(int, args.hiddens.split(',')))
    model = MADE(xtr.size(1), hidden_list, xtr.size(1), num_masks=args.num_masks)
    print("number of model parameters:",sum([np.prod(p.size()) for p in model.parameters()]))
    model.cuda()

    # set up the optimizer
    opt = torch.optim.Adam(model.parameters(), 1e-3, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=45, gamma=0.1)
    
    # start the training
    for epoch in range(100):
        print("epoch %d" % (epoch, ))
        scheduler.step(epoch)
        run_epoch('test', upto=5) # run only a few batches for approximate test accuracy
        run_epoch('train')
    
    print("optimization done. full test set eval:")
    run_epoch('test')
Ejemplo n.º 28
0
 def classification_loss(self, logit, target, dataset='CelebA'):
     """Compute binary or softmax cross entropy loss."""
     if dataset == 'CelebA':
         return F.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0)
     elif dataset == 'RaFD':
         return F.cross_entropy(logit, target)
Ejemplo n.º 29
0
    def train_epoch(self):
        source_domain_label = 1
        target_domain_label = 0
        smooth = 1e-7
        self.model_gen.train()
        self.model_dis.train()
        self.model_dis2.train()
        self.running_seg_loss = 0.0
        self.running_adv_loss = 0.0
        self.running_dis_diff_loss = 0.0
        self.running_dis_same_loss = 0.0
        self.running_total_loss = 0.0
        self.running_cup_dice_tr = 0.0
        self.running_disc_dice_tr = 0.0
        loss_adv_diff_data = 0
        loss_D_same_data = 0
        loss_D_diff_data = 0

        domain_t_loader = enumerate(self.domain_loaderT)
        start_time = timeit.default_timer()
        for batch_idx, sampleS in tqdm.tqdm(
                enumerate(self.domain_loaderS), total=len(self.domain_loaderS),
                desc='Train epoch=%d' % self.epoch, ncols=80, leave=False):

            metrics = []

            iteration = batch_idx + self.epoch * len(self.domain_loaderS)
            self.iteration = iteration

            assert self.model_gen.training
            assert self.model_dis.training
            assert self.model_dis2.training

            self.optim_gen.zero_grad()
            self.optim_dis.zero_grad()
            self.optim_dis2.zero_grad()

            # 1. train generator with random images
            for param in self.model_dis.parameters():
                param.requires_grad = False
            for param in self.model_dis2.parameters():
                param.requires_grad = False
            for param in self.model_gen.parameters():
                param.requires_grad = True

            imageS = sampleS['image'].cuda()
            target_map = sampleS['map'].cuda()
            target_boundary = sampleS['boundary'].cuda()

            oS, boundaryS = self.model_gen(imageS)

            loss_seg1 = bceloss(torch.sigmoid(oS), target_map)
            loss_seg2 = mseloss(torch.sigmoid(boundaryS), target_boundary)
            loss_seg = loss_seg1 + loss_seg2

            self.running_seg_loss += loss_seg.item()
            loss_seg_data = loss_seg.data.item()
            if np.isnan(loss_seg_data):
                raise ValueError('loss is nan while training')

            # cup_dice, disc_dice = dice_coeff_2label(oS, target_map)

            loss_seg.backward(retain_graph=True)
            # self.optim_gen.step()

            # write image log
            if iteration % 30 == 0:
                grid_image = make_grid(
                    imageS[0, ...].clone().cpu().data, 1, normalize=True)
                self.writer.add_image('DomainS/image', grid_image, iteration)
                grid_image = make_grid(
                    target_map[0, 0, ...].clone().cpu().data, 1, normalize=True)
                self.writer.add_image('DomainS/target_cup', grid_image, iteration)
                grid_image = make_grid(
                    target_map[0, 1, ...].clone().cpu().data, 1, normalize=True)
                self.writer.add_image('DomainS/target_disc', grid_image, iteration)
                grid_image = make_grid(
                    target_boundary[0, 0, ...].clone().cpu().data, 1, normalize=True)
                self.writer.add_image('DomainS/target_boundary', grid_image, iteration)
                grid_image = make_grid(torch.sigmoid(oS)[0, 0, ...].clone().cpu().data, 1, normalize=True)
                self.writer.add_image('DomainS/prediction_cup', grid_image, iteration)
                grid_image = make_grid(torch.sigmoid(oS)[0, 1, ...].clone().cpu().data, 1, normalize=True)
                self.writer.add_image('DomainS/prediction_disc', grid_image, iteration)
                grid_image = make_grid(torch.sigmoid(boundaryS)[0, 0, ...].clone().cpu().data, 1, normalize=True)
                self.writer.add_image('DomainS/prediction_boundary', grid_image, iteration)

            if self.epoch > self.warmup_epoch:
                # # 2. train generator with images from different domain
                try:
                    id_, sampleT = next(domain_t_loader)
                except:
                    domain_t_loader = enumerate(self.domain_loaderT)
                    id_, sampleT = next(domain_t_loader)

                imageT = sampleT['image'].cuda()

                oT, boundaryT = self.model_gen(imageT)
                uncertainty_mapT = -1.0 * torch.sigmoid(oT) * torch.log(torch.sigmoid(oT) + smooth)
                D_out2 = self.model_dis(torch.sigmoid(boundaryT))
                D_out1 = self.model_dis2(uncertainty_mapT)

                loss_adv_diff1 = F.binary_cross_entropy_with_logits(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(source_domain_label).cuda())
                loss_adv_diff2 = F.binary_cross_entropy_with_logits(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(source_domain_label).cuda())
                loss_adv_diff = 0.01 * (loss_adv_diff1 + loss_adv_diff2)
                self.running_adv_diff_loss += loss_adv_diff.item()
                loss_adv_diff_data = loss_adv_diff.data.item()
                if np.isnan(loss_adv_diff_data):
                    raise ValueError('loss_adv_diff_data is nan while training')

                loss_adv_diff.backward()
                self.optim_gen.step()

                # 3. train discriminator with images from same domain
                for param in self.model_dis.parameters():
                    param.requires_grad = True
                for param in self.model_dis2.parameters():
                    param.requires_grad = True
                for param in self.model_gen.parameters():
                    param.requires_grad = False

                boundaryS = boundaryS.detach()
                oS = oS.detach()
                uncertainty_mapS = -1.0 * torch.sigmoid(oS) * torch.log(torch.sigmoid(oS) + smooth)
                D_out2 = self.model_dis(torch.sigmoid(boundaryS))
                D_out1 = self.model_dis2(uncertainty_mapS)

                loss_D_same1 = F.binary_cross_entropy_with_logits(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(
                    source_domain_label).cuda())
                loss_D_same2 = F.binary_cross_entropy_with_logits(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(
                    source_domain_label).cuda())
                loss_D_same = loss_D_same1+loss_D_same2

                self.running_dis_same_loss += loss_D_same.item()
                loss_D_same_data = loss_D_same.data.item()
                if np.isnan(loss_D_same_data):
                    raise ValueError('loss is nan while training')
                loss_D_same.backward()

                # 4. train discriminator with images from different domain

                boundaryT = boundaryT.detach()
                oT = oT.detach()
                uncertainty_mapT = -1.0 * torch.sigmoid(oT) * torch.log(torch.sigmoid(oT) + smooth)
                D_out2 = self.model_dis(torch.sigmoid(boundaryT))
                D_out1 = self.model_dis2(uncertainty_mapT)

                loss_D_diff1 = F.binary_cross_entropy_with_logits(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(
                    target_domain_label).cuda())
                loss_D_diff2 = F.binary_cross_entropy_with_logits(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(
                    target_domain_label).cuda())
                loss_D_diff = loss_D_diff1 + loss_D_diff2
                self.running_dis_diff_loss += loss_D_diff.item()
                loss_D_diff_data = loss_D_diff.data.item()
                if np.isnan(loss_D_diff_data):
                    raise ValueError('loss is nan while training')
                loss_D_diff.backward()

                # 5. update parameters
                self.optim_dis.step()
                self.optim_dis2.step()

                if iteration % 30 == 0:
                    grid_image = make_grid(
                        imageT[0, ...].clone().cpu().data, 1, normalize=True)
                    self.writer.add_image('DomainT/image', grid_image, iteration)
                    grid_image = make_grid(
                        sampleT['map'][0, 0, ...].clone().cpu().data, 1, normalize=True)
                    self.writer.add_image('DomainT/target_cup', grid_image, iteration)
                    grid_image = make_grid(
                        sampleT['map'][0, 1, ...].clone().cpu().data, 1, normalize=True)
                    self.writer.add_image('DomainT/target_disc', grid_image, iteration)
                    grid_image = make_grid(torch.sigmoid(oT)[0, 0, ...].clone().cpu().data, 1, normalize=True)
                    self.writer.add_image('DomainT/prediction_cup', grid_image, iteration)
                    grid_image = make_grid(torch.sigmoid(oT)[0, 1, ...].clone().cpu().data, 1, normalize=True)
                    self.writer.add_image('DomainT/prediction_disc', grid_image, iteration)
                    grid_image = make_grid(boundaryS[0, 0, ...].clone().cpu().data, 1, normalize=True)
                    self.writer.add_image('DomainS/boundaryS', grid_image, iteration)
                    grid_image = make_grid(boundaryT[0, 0, ...].clone().cpu().data, 1,
                                           normalize=True)
                    self.writer.add_image('DomainT/boundaryT', grid_image, iteration)

                self.writer.add_scalar('train_adv/loss_adv_diff', loss_adv_diff_data, iteration)
                self.writer.add_scalar('train_dis/loss_D_same', loss_D_same_data, iteration)
                self.writer.add_scalar('train_dis/loss_D_diff', loss_D_diff_data, iteration)
            self.writer.add_scalar('train_gen/loss_seg', loss_seg_data, iteration)

            metrics.append((loss_seg_data, loss_adv_diff_data, loss_D_same_data, loss_D_diff_data))
            metrics = np.mean(metrics, axis=0)

            with open(osp.join(self.out, 'log.csv'), 'a') as f:
                elapsed_time = (
                    datetime.now(pytz.timezone(self.time_zone)) -
                    self.timestamp_start).total_seconds()
                log = [self.epoch, self.iteration]  + \
                    metrics.tolist() + [''] * 5 + [elapsed_time]
                log = map(str, log)
                f.write(','.join(log) + '\n')

        self.running_seg_loss /= len(self.domain_loaderS)
        self.running_adv_diff_loss /= len(self.domain_loaderS)
        self.running_dis_same_loss /= len(self.domain_loaderS)
        self.running_dis_diff_loss /= len(self.domain_loaderS)

        stop_time = timeit.default_timer()

        print('\n[Epoch: %d] lr:%f,  Average segLoss: %f, '
              ' Average advLoss: %f, Average dis_same_Loss: %f, '
              'Average dis_diff_Lyoss: %f,'
              'Execution time: %.5f' %
              (self.epoch, get_lr(self.optim_gen), self.running_seg_loss,
               self.running_adv_diff_loss,
               self.running_dis_same_loss, self.running_dis_diff_loss, stop_time - start_time))
Ejemplo n.º 30
0
    def _add_losses(self, sigma_rpn=3.0):

        loss = Variable(torch.zeros(1).cuda())
#        for name, var in self.named_parameters():
#            print(name, var.requires_grad)

        if not cfg.FIX_RPN:
            if cfg.NUM_ANCHORS_LEVEL1 != 0:
                #---------------------
                # level 1
                #---------------------
                # RPN, class loss
                rpn_cls_score_level1 = self._predictions['rpn_cls_score_level1'] #torch.Size([1, 2, 10, 5, 10, 9])
                rpn_label_level1 = self._anchor_targets['rpn_labels_level1'] #torch.Size([1, 10, 5, 10, 9])
                rpn_select_level1 = (rpn_label_level1.data != -1).nonzero()

                if rpn_select_level1.numel() != 0:
                    #TODO advanced indexing
                    rpn_cls_score_reshape_level1 = []
                    rpn_label_reshape_level1 = []
                    for i in rpn_select_level1:
                        rpn_cls_score_reshape_level1.append(rpn_cls_score_level1[i[0], :, i[1], i[2], i[3], i[4]])
                        rpn_label_reshape_level1.append(rpn_label_level1[i[0], i[1], i[2], i[3], i[4]])

                    rpn_cls_score_reshape_level1 = torch.stack(rpn_cls_score_reshape_level1, 0)
                    rpn_label_reshape_level1 = torch.stack(rpn_label_reshape_level1, 0)
                    rpn_cross_entropy_level1 = F.cross_entropy(rpn_cls_score_reshape_level1, rpn_label_reshape_level1)
                    self._losses['rpn_cross_entropy_level1'] = rpn_cross_entropy_level1

                    #RPN, bbox loss
                    rpn_bbox_pred_level1 = self._predictions['rpn_bbox_pred_level1']
                    rpn_bbox_targets_level1 = self._anchor_targets['rpn_bbox_targets_level1']
                    rpn_bbox_inside_weights_level1 = self._anchor_targets['rpn_bbox_inside_weights_level1']
                    rpn_bbox_outside_weights_level1 = self._anchor_targets['rpn_bbox_outside_weights_level1']
                    rpn_loss_box_level1 = self._smooth_l1_loss(rpn_bbox_pred_level1, rpn_bbox_targets_level1, 
                                                        rpn_bbox_inside_weights_level1, rpn_bbox_outside_weights_level1, 
                                                        sigma=2.0, dim=[1,2,3,4])
                    self._losses['rpn_loss_box_level1'] = rpn_loss_box_level1
                    loss += rpn_cross_entropy_level1 + rpn_loss_box_level1
                else:
                    self._losses['rpn_cross_entropy_level1'] = Variable(torch.FloatTensor([0.0]))
                    self._losses['rpn_loss_box_level1'] = Variable(torch.FloatTensor([0.0]))

            if cfg.NUM_ANCHORS_LEVEL2 != 0:
                #---------------------
                # level 2
                #---------------------
                # RPN, class loss
                rpn_cls_score_level2 = self._predictions['rpn_cls_score_level2'] #torch.Size([1, 2, 10, 5, 10, 9])
                rpn_label_level2 = self._anchor_targets['rpn_labels_level2'] #torch.Size([1, 10, 5, 10, 9])
                rpn_select_level2 = (rpn_label_level2.data != -1).nonzero()
                if rpn_select_level2.numel() != 0:
                    #TODO advanced indexing
                    rpn_cls_score_reshape_level2 = []
                    rpn_label_reshape_level2 = []
                    for i in rpn_select_level2:
                        rpn_cls_score_reshape_level2.append(rpn_cls_score_level2[i[0], :, i[1], i[2], i[3], i[4]])
                        rpn_label_reshape_level2.append(rpn_label_level2[i[0], i[1], i[2], i[3], i[4]])

                    rpn_cls_score_reshape_level2 = torch.stack(rpn_cls_score_reshape_level2, 0)
                    rpn_label_reshape_level2 = torch.stack(rpn_label_reshape_level2, 0)
                    rpn_cross_entropy_level2 = F.cross_entropy(rpn_cls_score_reshape_level2, rpn_label_reshape_level2)
                    self._losses['rpn_cross_entropy_level2'] = rpn_cross_entropy_level2

                    rpn_bbox_pred_level2 = self._predictions['rpn_bbox_pred_level2']
                    rpn_bbox_targets_level2 = self._anchor_targets['rpn_bbox_targets_level2']
                    rpn_bbox_inside_weights_level2 = self._anchor_targets['rpn_bbox_inside_weights_level2']
                    rpn_bbox_outside_weights_level2 = self._anchor_targets['rpn_bbox_outside_weights_level2']
                    rpn_loss_box_level2 = self._smooth_l1_loss(rpn_bbox_pred_level2, rpn_bbox_targets_level2, 
                                                        rpn_bbox_inside_weights_level2, rpn_bbox_outside_weights_level2, 
                                                        sigma=2.0, dim=[1,2,3,4])
                    self._losses['rpn_loss_box_level2'] = rpn_loss_box_level2
                    loss += rpn_cross_entropy_level2 + rpn_loss_box_level2
                else:
                    self._losses['rpn_cross_entropy_level2'] = Variable(torch.FloatTensor([0.0]))
                    self._losses['rpn_loss_box_level2'] = Variable(torch.FloatTensor([0.0]))

            if cfg.NUM_ANCHORS_LEVEL3 != 0:
                #---------------------
                # level 3
                #---------------------
                # RPN, class loss
                rpn_cls_score_level3 = self._predictions['rpn_cls_score_level3'] #torch.Size([1, 2, 10, 5, 10, 9])
                rpn_label_level3 = self._anchor_targets['rpn_labels_level3'] #torch.Size([1, 10, 5, 10, 9])
                rpn_select_level3 = (rpn_label_level3.data != -1).nonzero()
                if rpn_select_level3.numel() != 0:
                    #TODO advanced indexing
                    rpn_cls_score_reshape_level3 = []
                    rpn_label_reshape_level3 = []
                    for i in rpn_select_level3:
                        rpn_cls_score_reshape_level3.append(rpn_cls_score_level3[i[0], :, i[1], i[2], i[3], i[4]])
                        rpn_label_reshape_level3.append(rpn_label_level3[i[0], i[1], i[2], i[3], i[4]])

                    rpn_cls_score_reshape_level3 = torch.stack(rpn_cls_score_reshape_level3, 0)
                    rpn_label_reshape_level3 = torch.cat(rpn_label_reshape_level3, 0)
                    rpn_cross_entropy_level3 = F.cross_entropy(rpn_cls_score_reshape_level3, rpn_label_reshape_level3)
                    self._losses['rpn_cross_entropy_level3'] = rpn_cross_entropy_level3

                    rpn_bbox_pred_level3 = self._predictions['rpn_bbox_pred_level3']
                    rpn_bbox_targets_level3 = self._anchor_targets['rpn_bbox_targets_level3']
                    rpn_bbox_inside_weights_level3 = self._anchor_targets['rpn_bbox_inside_weights_level3']
                    rpn_bbox_outside_weights_level3 = self._anchor_targets['rpn_bbox_outside_weights_level3']
                    rpn_loss_box_level3 = self._smooth_l1_loss(rpn_bbox_pred_level3, rpn_bbox_targets_level3, 
                                                        rpn_bbox_inside_weights_level3, rpn_bbox_outside_weights_level3, 
                                                        sigma=2.0, dim=[1,2,3,4])
                    self._losses['rpn_loss_box_level3'] = rpn_loss_box_level3
                    loss += rpn_cross_entropy_level3 + rpn_loss_box_level3
                else:
                    self._losses['rpn_cross_entropy_level3'] = Variable(torch.FloatTensor([0.0]))
                    self._losses['rpn_loss_box_level3'] = Variable(torch.FloatTensor([0.0]))

        else:
            self._losses['rpn_loss_box_level1'] = Variable(torch.FloatTensor([0.0]))
            self._losses['rpn_cross_entropy_level1'] = Variable(torch.FloatTensor([0.0]))
            self._losses['rpn_loss_box_level2'] = Variable(torch.FloatTensor([0.0]))
            self._losses['rpn_cross_entropy_level2'] = Variable(torch.FloatTensor([0.0]))
            self._losses['rpn_loss_box_level3'] = Variable(torch.FloatTensor([0.0]))
            self._losses['rpn_cross_entropy_level3'] = Variable(torch.FloatTensor([0.0]))

        if not cfg.FIX_CLASS or cfg.NYUV2_FINETUNE:
            #RCNN, class loss
            cls_score = self._predictions['cls_score']
            label = self._proposal_targets['labels'].view(-1)
            normalize_weights = torch.FloatTensor(cfg.NORMALIZE_WEIGHTS).cuda() 
            cross_entropy = F.cross_entropy(cls_score, label, weight=normalize_weights, size_average=True, reduce=True)
            self._losses['cross_entropy'] = cross_entropy
            loss += cross_entropy

            # RCNN, bbox loss
            bbox_pred = self._predictions['bbox_pred']
            bbox_targets = self._proposal_targets['bbox_targets']
            bbox_inside_weights = self._proposal_targets['bbox_inside_weights']
            bbox_outside_weights = self._proposal_targets['bbox_outside_weights']
            loss_box = self._smooth_l1_loss(bbox_pred, bbox_targets, bbox_inside_weights, 
                                            bbox_outside_weights, sigma=1.0, dim=[1])
            self._losses['loss_box'] = loss_box
            loss += loss_box

        else:
            self._losses['loss_box'] = Variable(torch.FloatTensor([0.0]))
            self._losses['cross_entropy'] = Variable(torch.FloatTensor([0.0]))

        if cfg.USE_MASK:
            normalize_weights = torch.FloatTensor(cfg.NORMALIZE_WEIGHTS).cuda() 
            normalize_weights[0] = 0.0

            loss_mask = Variable(torch.zeros(1).cuda())
            counter = 0

            mask_preds = self._predictions['mask_pred']
            mask_targets = self._mask_targets['masks']
            mask_labels = self._mask_targets['labels']

            for i in range(self.batch_size):
                for mask_pred, mask_target, mask_label in zip(mask_preds[i], mask_targets[i], mask_labels[i]):
                    loss_mask += F.binary_cross_entropy_with_logits(mask_pred[0, mask_label], Variable(mask_target.float().cuda())) * normalize_weights[mask_label]
                    counter += normalize_weights[mask_label] != 0.0

            if counter != 0:
                self._losses['loss_mask'] = loss_mask / counter.item()
                loss += loss_mask / counter.item()
            else:
                self._losses['loss_mask'] = loss_mask

        self._losses['total_loss'] = loss
Ejemplo n.º 31
0
    def update(self, batch):
        total_batch_size = len(batch)
        divide_factor = self.update_divide_factor
        batch_size = int(total_batch_size / divide_factor)
        total_loss = 0
        total_td_errors = []
        self.fcn_optimizer.zero_grad()
        for i in range(divide_factor):
            small_batch = batch[batch_size*i:batch_size*(i+1)]
            states, obs, action_idx, rewards, next_states, next_obs, non_final_masks, step_lefts, is_experts = self._loadBatchToDevice(
                small_batch)
            heightmap_size = obs[0].size(2)
            if self.sl:
                q = self.gamma ** step_lefts

            else:
                with torch.no_grad():
                    q_map_prime = self.forwardFCN(next_states, next_obs, target_net=True)
                    q_prime = q_map_prime.reshape((batch_size, -1)).max(1)[0]
                    q = rewards + self.gamma * q_prime * non_final_masks
                    q = q.detach()
                if self.expert_sl:
                    q_target_sl = self.gamma ** step_lefts
                    q[is_experts] = q_target_sl[is_experts]

            q_map = self.forwardFCN(states, obs)
            q_output = q_map[torch.arange(0, batch_size), action_idx[:, 2], action_idx[:, 0], action_idx[:, 1]]
            td_loss = F.smooth_l1_loss(q_output, q)

            # cross entropy
            if self.margin == 'ce':
                expert_q_map = q_map[is_experts]
                if expert_q_map.size(0) == 0:
                    margin_loss = 0
                else:
                    target = action_idx[is_experts, 2] * heightmap_size * heightmap_size + action_idx[
                        is_experts, 0] * heightmap_size + action_idx[is_experts, 1]
                    margin_loss = F.cross_entropy(self.softmax_beta*expert_q_map.reshape(expert_q_map.size(0), -1), target)

            # binary cross entropy
            elif self.margin == 'bce':
                expert_q_map = q_map[is_experts]
                if expert_q_map.size(0) == 0:
                    margin_loss = 0
                else:
                    margin_map = torch.zeros_like(q_map)
                    margin_map[torch.arange(0, batch_size), action_idx[:, 2], action_idx[:, 0], action_idx[:, 1]] = 1
                    margin_map = margin_map[is_experts]
                    softmax = F.softmax(self.softmax_beta*expert_q_map.reshape(is_experts.sum(), -1), dim=1).reshape(expert_q_map.shape)
                    margin_loss = F.binary_cross_entropy(softmax, margin_map)

            # binary cross entropy with logits
            elif self.margin == 'bcel':
                margin_map = torch.zeros_like(q_map)
                margin_map[torch.arange(0, batch_size), action_idx[:, 2], action_idx[:, 0], action_idx[:, 1]] = 1
                margin_loss = F.binary_cross_entropy_with_logits(self.softmax_beta*q_map[is_experts], margin_map[is_experts])
                if torch.isnan(margin_loss):
                    margin_loss = 0

            elif self.margin == 'oril':
                margin_map = torch.ones_like(q_map) * self.margin_l
                margin_map[torch.arange(0, batch_size), action_idx[:, 2], action_idx[:, 0], action_idx[:, 1]] = 0
                margin_q_map = q_map + margin_map
                margin_q_max = margin_q_map.reshape(batch_size, -1).max(1)[0]
                margin_loss = (margin_q_max - q_output)[is_experts]
                margin_loss = margin_loss.mean()
                if torch.isnan(margin_loss):
                    margin_loss = 0

            # l margin
            else:
                # margin_map = torch.ones_like(q_map) * self.margin_l
                # margin_map[torch.arange(0, batch_size), action_idx[:, 2], action_idx[:, 0], action_idx[:, 1]] = 0
                # margin_q_map = q_map + margin_map
                # margin_q_max = margin_q_map.reshape(batch_size, -1).max(1)[0]
                # margin_loss = (margin_q_max - q_output)[is_experts]
                # margin_loss = margin_loss.mean()
                # if torch.isnan(margin_loss):
                #     margin_loss = 0

                margin_losses = []
                for j in range(batch_size):
                    if not is_experts[j]:
                        margin_losses.append(torch.tensor(0).float().to(self.device))
                        continue
                    qm = q_map[j]
                    qe = q_output[j]
                    over_q = qm[qm > qe - self.margin_l]
                    if over_q.shape[0] == 0:
                        margin_losses.append(torch.tensor(0).float().to(self.device))
                        continue
                    over_q_target = torch.ones_like(over_q) * qe - self.margin_l

                    margin_losses.append((over_q - over_q_target).mean())
                margin_loss = torch.stack(margin_losses).mean()

            loss = td_loss + self.margin_weight * margin_loss

            loss.backward()
            total_loss += (loss.item()/divide_factor)
            total_td_errors.append(torch.abs(q_output - q).detach().cpu())

        for param in self.fcn.parameters():
            param.grad.data.clamp_(-1, 1)
        self.fcn_optimizer.step()

        return total_loss, torch.cat(total_td_errors)
Ejemplo n.º 32
0
def mask_rcnn_loss(pred_mask_logits, instances, vis_period=0):
    """
    Compute the mask prediction loss defined in the Mask R-CNN paper.

    Args:
        pred_mask_logits (Tensor): A tensor of shape (B, C, Hmask, Wmask) or (B, 1, Hmask, Wmask)
            for class-specific or class-agnostic, where B is the total number of predicted masks
            in all images, C is the number of foreground classes, and Hmask, Wmask are the height
            and width of the mask predictions. The values are logits.
        instances (list[Instances]): A list of N Instances, where N is the number of images
            in the batch. These instances are in 1:1
            correspondence with the pred_mask_logits. The ground-truth labels (class, box, mask,
            ...) associated with each instance are stored in fields.
        vis_period (int): the period (in steps) to dump visualization.

    Returns:
        mask_loss (Tensor): A scalar tensor containing the loss.
    """
    cls_agnostic_mask = pred_mask_logits.size(1) == 1
    total_num_masks = pred_mask_logits.size(0)
    mask_side_len = pred_mask_logits.size(2)
    assert pred_mask_logits.size(2) == pred_mask_logits.size(
        3), "Mask prediction must be square!"

    gt_classes = []
    gt_masks = []
    for instances_per_image in instances:
        if len(instances_per_image) == 0:
            continue
        if not cls_agnostic_mask:
            gt_classes_per_image = instances_per_image.gt_classes.to(
                dtype=torch.int64)
            gt_classes.append(gt_classes_per_image)

        gt_masks_per_image = instances_per_image.gt_masks.crop_and_resize(
            instances_per_image.proposal_boxes.tensor,
            mask_side_len).to(device=pred_mask_logits.device)
        # A tensor of shape (N, M, M), N=#instances in the image; M=mask_side_len
        gt_masks.append(gt_masks_per_image)

    if len(gt_masks) == 0:
        return pred_mask_logits.sum() * 0

    gt_masks = cat(gt_masks, dim=0)

    if cls_agnostic_mask:
        pred_mask_logits = pred_mask_logits[:, 0]
    else:
        indices = torch.arange(total_num_masks)
        gt_classes = cat(gt_classes, dim=0)
        pred_mask_logits = pred_mask_logits[indices, gt_classes]

    if gt_masks.dtype == torch.bool:
        gt_masks_bool = gt_masks
    else:
        # Here we allow gt_masks to be float as well (depend on the implementation of rasterize())
        gt_masks_bool = gt_masks > 0.5
    gt_masks = gt_masks.to(dtype=torch.float32)

    # Log the training accuracy (using gt classes and 0.5 threshold)
    mask_incorrect = (pred_mask_logits > 0.0) != gt_masks_bool
    mask_accuracy = 1 - (mask_incorrect.sum().item() /
                         max(mask_incorrect.numel(), 1.0))
    num_positive = gt_masks_bool.sum().item()
    false_positive = (mask_incorrect & ~gt_masks_bool).sum().item() / max(
        gt_masks_bool.numel() - num_positive, 1.0)
    false_negative = (mask_incorrect & gt_masks_bool).sum().item() / max(
        num_positive, 1.0)

    storage = get_event_storage()
    storage.put_scalar("mask_rcnn/accuracy", mask_accuracy)
    storage.put_scalar("mask_rcnn/false_positive", false_positive)
    storage.put_scalar("mask_rcnn/false_negative", false_negative)
    if vis_period > 0 and storage.iter % vis_period == 0:
        pred_masks = pred_mask_logits.sigmoid()
        vis_masks = torch.cat([pred_masks, gt_masks], axis=2)
        name = "Left: mask prediction;   Right: mask GT"
        for idx, vis_mask in enumerate(vis_masks):
            vis_mask = torch.stack([vis_mask] * 3, axis=0)
            storage.put_image(name + f" ({idx})", vis_mask)

    mask_loss = F.binary_cross_entropy_with_logits(pred_mask_logits,
                                                   gt_masks,
                                                   reduction="mean")
    return mask_loss
Ejemplo n.º 33
0
    def forward(self,
                features,
                gt_segmap=None,
                gt_contour=None,
                gt_objmask=None,
                gt_classes=None):
        #         for i, f in enumerate(self.in_features):
        #             if i == 0:
        #                 x = self.scale_heads[i](features[f])
        #             else:
        #                 x = x + self.scale_heads[i](features[f])
        x = self.aspp(features[self.in_features[-1]])
        x = self.decoder(features[self.in_features[-1]],
                         features[self.in_features[0]])

        x = F.upsample(x, scale_factor=2, mode="bilinear")
        segmap = self.predictor_segmap(x)
        segmap = F.interpolate(segmap,
                               scale_factor=self.common_stride,
                               mode="bilinear",
                               align_corners=False)
        #         # TODO: add a pointwise softmax here for inter-class competition
        #         segmap = F.softmax(segmap, dim=-3)

        contour = self.predictor_contour(x)
        contour = F.interpolate(contour,
                                scale_factor=self.common_stride,
                                mode="bilinear",
                                align_corners=False)
        #contour = F.softmax(contour, dim=-3)  # the contours are not mutual exclusive between object classes

        # Embedding
        emb = self.predictor_emb(x)
        emb = F.interpolate(emb,
                            scale_factor=self.common_stride,
                            mode="bilinear",
                            align_corners=False)

        if self.training:
            losses = {}

            #pos_weight_s = torch.ones([self.num_classes+1])  # *10
            pos_weight_s = torch.tensor([
                1, 1, 1, 1, 1, 1, 1, 1, 3, 3, 1, 1, 10, 1, 1, 1, 1, 1, 1, 1, 1,
                1, 1, 1, 3, 1, 10, 10, 3, 10, 10, 10, 10, 3, 10, 10, 10, 1, 3,
                3, 3, 1, 10, 10, 10, 1, 3, 3, 1, 3, 1, 3, 3, 1, 3, 1, 3, 1, 1,
                1, 1, 1, 1, 1, 10, 10, 1, 1, 3, 1, 10, 1, 1, 10, 1, 1, 10, 1,
                10, 10, 0.3
            ])
            #             pos_weight_s = torch.reshape(pos_weight_s,(1,self.num_classes,1,1))
            #             assert segmap.size(1)==self.num_classes, segmap.size()
            #             losses["loss_seg_map"] = (
            #                 F.binary_cross_entropy_with_logits(segmap, gt_segmap, reduction="mean", pos_weight=pos_weight_s.to(self.device))
            #                 * self.loss_weight
            #             )
            # after softmax the segmap are already probabilities
            # TODO: add weights to bias towards small objects such as spoons
            #             losses["loss_seg_map"] = (
            #                 F.binary_cross_entropy(segmap, gt_segmap, reduction="mean")
            #                 * self.loss_weight
            #             )
            losses["loss_sem_seg"] = (
                F.cross_entropy(segmap,
                                gt_segmap,
                                weight=pos_weight_s.to(self.device),
                                reduction="mean",
                                ignore_index=self.ignore_value) *
                self.loss_weight)
            #print("loss_seg_map: ", losses["loss_seg_map"])

            pos_weight_c = torch.ones([self.num_classes]) * 20  # 20
            pos_weight_c = torch.reshape(pos_weight_c,
                                         (1, self.num_classes, 1, 1))
            losses["loss_contour"] = (F.binary_cross_entropy_with_logits(
                contour,
                gt_contour,
                reduction="mean",
                pos_weight=pos_weight_c.to(self.device)) * self.loss_weight *
                                      10)

            # Embedding loss, including intra-class variance, inter-class distance, and regularization loss
            loss_emb1d = torch.zeros(emb.size(0)).to(self.device)
            for i in range(emb.size(0)):
                loss_emb1d[i] = self.emb_loss(emb[i].squeeze(dim=0),
                                              gt_objmask[i],
                                              gt_classes[i])  # /emb.size(0)

            if len(loss_emb1d):
                losses["embedding"] = loss_emb1d.mean(
                ) * self.loss_weight * 0.1
            else:
                losses["embedding"] = torch.tensor(0.).to(self.device)

            return [], losses
        else:
            #TODO: combine segmap and contour into detections
            # turning segmap logits into probabilities
            segmap = F.softmax(
                segmap,
                dim=-3)  # F.sigmoid(segmap)  # use sigmoid on logits ONLY
            contour = F.sigmoid(contour)
            return [segmap, contour, emb], {}
Ejemplo n.º 34
0
    def train(self):
        """Train StarGAN within a single dataset."""

        # Set dataloader
        if self.dataset == 'CelebA':
            self.data_loader = self.celebA_loader
        else:
            self.data_loader = self.rafd_loader

        # The number of iterations per epoch
        iters_per_epoch = len(self.data_loader)

        fixed_x = []
        real_c = []
        for i, (images, labels) in enumerate(self.data_loader):
            fixed_x.append(images)
            real_c.append(labels)
            if i == 3:
                break

        # Fixed inputs and target domain labels for debugging
        fixed_x = torch.cat(fixed_x, dim=0)
        fixed_x = self.to_var(fixed_x, volatile=True)
        real_c = torch.cat(real_c, dim=0)

        if self.dataset == 'CelebA':
            fixed_c_list = self.make_celeb_labels(real_c)
        elif self.dataset == 'RaFD':
            fixed_c_list = []
            for i in range(self.c_dim):
                fixed_c = self.one_hot(torch.ones(fixed_x.size(0)) * i, self.c_dim)
                fixed_c_list.append(self.to_var(fixed_c, volatile=True))

        # lr cache for decaying
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start with trained model if exists
        if self.pretrained_model:
            start = int(self.pretrained_model.split('_')[0])
        else:
            start = 0

        # Start training
        start_time = time.time()
        for e in range(start, self.num_epochs):
            for i, (real_x, real_label) in enumerate(self.data_loader):
                
                # Generat fake labels randomly (target domain labels)
                rand_idx = torch.randperm(real_label.size(0))
                fake_label = real_label[rand_idx]

                if self.dataset == 'CelebA':
                    real_c = real_label.clone()
                    fake_c = fake_label.clone()
                else:
                    real_c = self.one_hot(real_label, self.c_dim)
                    fake_c = self.one_hot(fake_label, self.c_dim)

                # Convert tensor to variable
                real_x = self.to_var(real_x)
                real_c = self.to_var(real_c)           # input for the generator
                fake_c = self.to_var(fake_c)
                real_label = self.to_var(real_label)   # this is same as real_c if dataset == 'CelebA'
                fake_label = self.to_var(fake_label)
                
                # ================== Train D ================== #

                # Compute loss with real images
                out_src, out_cls = self.D(real_x)
                d_loss_real = - torch.mean(out_src)

                if self.dataset == 'CelebA':
                    d_loss_cls = F.binary_cross_entropy_with_logits(
                        out_cls, real_label, size_average=False) / real_x.size(0)
                else:
                    d_loss_cls = F.cross_entropy(out_cls, real_label)

                # Compute classification accuracy of the discriminator
                if (i+1) % self.log_step == 0:
                    accuracies = self.compute_accuracy(out_cls, real_label, self.dataset)
                    log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()]
                    if self.dataset == 'CelebA':
                        print('Classification Acc (Black/Blond/Brown/Gender/Aged): ', end='')
                    else:
                        print('Classification Acc (8 emotional expressions): ', end='')
                    print(log)

                # Compute loss with fake images
                fake_x = self.G(real_x, fake_c)
                fake_x = Variable(fake_x.data)
                out_src, out_cls = self.D(fake_x)
                d_loss_fake = torch.mean(out_src)

                # Backward + Optimize
                d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Compute gradient penalty
                alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x)
                interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True)
                out, out_cls = self.D(interpolated)

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)

                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Logging
                loss = {}
                loss['D/loss_real'] = d_loss_real.data[0]
                loss['D/loss_fake'] = d_loss_fake.data[0]
                loss['D/loss_cls'] = d_loss_cls.data[0]
                loss['D/loss_gp'] = d_loss_gp.data[0]

                # ================== Train G ================== #
                if (i+1) % self.d_train_repeat == 0:

                    # Original-to-target and target-to-original domain
                    fake_x = self.G(real_x, fake_c)
                    rec_x = self.G(fake_x, real_c)

                    # Compute losses
                    out_src, out_cls = self.D(fake_x)
                    g_loss_fake = - torch.mean(out_src)
                    g_loss_rec = torch.mean(torch.abs(real_x - rec_x))

                    if self.dataset == 'CelebA':
                        g_loss_cls = F.binary_cross_entropy_with_logits(
                            out_cls, fake_label, size_average=False) / fake_x.size(0)
                    else:
                        g_loss_cls = F.cross_entropy(out_cls, fake_label)

                    # Backward + Optimize
                    g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
                    self.reset_grad()
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Logging
                    loss['G/loss_fake'] = g_loss_fake.data[0]
                    loss['G/loss_rec'] = g_loss_rec.data[0]
                    loss['G/loss_cls'] = g_loss_cls.data[0]

                # Print out log info
                if (i+1) % self.log_step == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))

                    log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format(
                        elapsed, e+1, self.num_epochs, i+1, iters_per_epoch)

                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)

                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(tag, value, e * iters_per_epoch + i + 1)

                # Translate fixed images for debugging
                if (i+1) % self.sample_step == 0:
                    fake_image_list = [fixed_x]
                    for fixed_c in fixed_c_list:
                        fake_image_list.append(self.G(fixed_x, fixed_c))
                    fake_images = torch.cat(fake_image_list, dim=3)
                    save_image(self.denorm(fake_images.data),
                        os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0)
                    print('Translated images and saved into {}..!'.format(self.sample_path))

                # Save model checkpoints
                if (i+1) % self.model_save_step == 0:
                    torch.save(self.G.state_dict(),
                        os.path.join(self.model_save_path, '{}_{}_G.pth'.format(e+1, i+1)))
                    torch.save(self.D.state_dict(),
                        os.path.join(self.model_save_path, '{}_{}_D.pth'.format(e+1, i+1)))

            # Decay learning rate
            if (e+1) > (self.num_epochs - self.num_epochs_decay):
                g_lr -= (self.g_lr / float(self.num_epochs_decay))
                d_lr -= (self.d_lr / float(self.num_epochs_decay))
                self.update_lr(g_lr, d_lr)
                print ('Decay learning rate to g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
Ejemplo n.º 35
0
    def validate(self):
        training = self.model_gen.training
        self.model_gen.eval()

        val_loss = 0
        val_cup_dice = 0
        val_disc_dice = 0
        metrics = []
        with torch.no_grad():

            for batch_idx, sample in tqdm.tqdm(
                    enumerate(self.val_loader), total=len(self.val_loader),
                    desc='Valid iteration=%d' % self.iteration, ncols=80,
                    leave=False):
                data = sample['image']
                target_map = sample['map']
                target_boundary = sample['boundary']
                if self.cuda:
                    data, target_map, target_boundary = data.cuda(), target_map.cuda(), target_boundary.cuda()
                with torch.no_grad():
                    predictions, boundary = self.model_gen(data)

                loss = F.binary_cross_entropy_with_logits(predictions, target_map)
                loss_data = loss.data.item()
                if np.isnan(loss_data):
                    raise ValueError('loss is nan while validating')
                val_loss += loss_data

                dice_cup, dice_disc = dice_coeff_2label(predictions, target_map)
                val_cup_dice += dice_cup
                val_disc_dice += dice_disc
            val_loss /= len(self.val_loader)
            val_cup_dice /= len(self.val_loader)
            val_disc_dice /= len(self.val_loader)
            metrics.append((val_loss, val_cup_dice, val_disc_dice))
            self.writer.add_scalar('val_data/loss_CE', val_loss, self.epoch * (len(self.domain_loaderS)))
            self.writer.add_scalar('val_data/val_CUP_dice', val_cup_dice, self.epoch * (len(self.domain_loaderS)))
            self.writer.add_scalar('val_data/val_DISC_dice', val_disc_dice, self.epoch * (len(self.domain_loaderS)))

            mean_dice = val_cup_dice + val_disc_dice
            is_best = mean_dice > self.best_mean_dice
            if is_best:
                self.best_epoch = self.epoch + 1
                self.best_mean_dice = mean_dice

                torch.save({
                    'epoch': self.epoch,
                    'iteration': self.iteration,
                    'arch': self.model_gen.__class__.__name__,
                    'optim_state_dict': self.optim_gen.state_dict(),
                    'optim_dis_state_dict': self.optim_dis.state_dict(),
                    'optim_dis2_state_dict': self.optim_dis2.state_dict(),
                    'model_state_dict': self.model_gen.state_dict(),
                    'model_dis_state_dict': self.model_dis.state_dict(),
                    'model_dis2_state_dict': self.model_dis2.state_dict(),
                    'learning_rate_gen': get_lr(self.optim_gen),
                    'learning_rate_dis': get_lr(self.optim_dis),
                    'learning_rate_dis2': get_lr(self.optim_dis2),
                    'best_mean_dice': self.best_mean_dice,
                }, osp.join(self.out, 'checkpoint_%d.pth.tar' % self.best_epoch))
            else:
                if (self.epoch + 1) % 50 == 0:
                    torch.save({
                        'epoch': self.epoch,
                    'iteration': self.iteration,
                    'arch': self.model_gen.__class__.__name__,
                    'optim_state_dict': self.optim_gen.state_dict(),
                    'optim_dis_state_dict': self.optim_dis.state_dict(),
                    'optim_dis2_state_dict': self.optim_dis2.state_dict(),
                    'model_state_dict': self.model_gen.state_dict(),
                    'model_dis_state_dict': self.model_dis.state_dict(),
                    'model_dis2_state_dict': self.model_dis2.state_dict(),
                    'learning_rate_gen': get_lr(self.optim_gen),
                    'learning_rate_dis': get_lr(self.optim_dis),
                    'learning_rate_dis2': get_lr(self.optim_dis2),
                    'best_mean_dice': self.best_mean_dice,
                    }, osp.join(self.out, 'checkpoint_%d.pth.tar' % (self.epoch + 1)))


            with open(osp.join(self.out, 'log.csv'), 'a') as f:
                elapsed_time = (
                    datetime.now(pytz.timezone(self.time_zone)) -
                    self.timestamp_start).total_seconds()
                log = [self.epoch, self.iteration] + [''] * 5 + \
                       list(metrics) + [elapsed_time] + ['best model epoch: %d' % self.best_epoch]
                log = map(str, log)
                f.write(','.join(log) + '\n')
            self.writer.add_scalar('best_model_epoch', self.best_epoch, self.epoch * (len(self.domain_loaderS)))
            if training:
                self.model_gen.train()
                self.model_dis.train()
                self.model_dis2.train()
Ejemplo n.º 36
0
def run(init_lr=0.1, max_steps=64e3, mode='rgb', root='/ssd/Charades_v1_rgb', train_split='charades/charades.json', batch_size=8*5, save_model=''):
    # setup dataset
    train_transforms = transforms.Compose([videotransforms.RandomCrop(224),
                                           videotransforms.RandomHorizontalFlip(),
    ])
    test_transforms = transforms.Compose([videotransforms.CenterCrop(224)])

    dataset = Dataset(train_split, 'training', root, mode, train_transforms)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=36, pin_memory=True)

    val_dataset = Dataset(train_split, 'testing', root, mode, test_transforms)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=36, pin_memory=True)    

    dataloaders = {'train': dataloader, 'val': val_dataloader}
    datasets = {'train': dataset, 'val': val_dataset}

    
    # setup the model
    if mode == 'flow':
        i3d = InceptionI3d(400, in_channels=2)
        i3d.load_state_dict(torch.load('models/flow_imagenet.pt'))
    else:
        i3d = InceptionI3d(400, in_channels=3)
        i3d.load_state_dict(torch.load('models/rgb_imagenet.pt'))
    i3d.replace_logits(157)
    #i3d.load_state_dict(torch.load('/ssd/models/000920.pt'))
    i3d.cuda()
    i3d = nn.DataParallel(i3d)

    lr = init_lr
    optimizer = optim.SGD(i3d.parameters(), lr=lr, momentum=0.9, weight_decay=0.0000001)
    lr_sched = optim.lr_scheduler.MultiStepLR(optimizer, [300, 1000])


    num_steps_per_update = 4 # accum gradient
    steps = 0
    # train it
    while steps < max_steps:#for epoch in range(num_epochs):
        print 'Step {}/{}'.format(steps, max_steps)
        print '-' * 10

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                i3d.train(True)
            else:
                i3d.train(False)  # Set model to evaluate mode
                
            tot_loss = 0.0
            tot_loc_loss = 0.0
            tot_cls_loss = 0.0
            num_iter = 0
            optimizer.zero_grad()
            
            # Iterate over data.
            for data in dataloaders[phase]:
                num_iter += 1
                # get the inputs
                inputs, labels = data

                # wrap them in Variable
                inputs = Variable(inputs.cuda())
                t = inputs.size(2)
                labels = Variable(labels.cuda())

                per_frame_logits = i3d(inputs)
                # upsample to input size
                per_frame_logits = F.upsample(per_frame_logits, t, mode='linear')

                # compute localization loss
                loc_loss = F.binary_cross_entropy_with_logits(per_frame_logits, labels)
                tot_loc_loss += loc_loss.data[0]

                # compute classification loss (with max-pooling along time B x C x T)
                cls_loss = F.binary_cross_entropy_with_logits(torch.max(per_frame_logits, dim=2)[0], torch.max(labels, dim=2)[0])
                tot_cls_loss += cls_loss.data[0]

                loss = (0.5*loc_loss + 0.5*cls_loss)/num_steps_per_update
                tot_loss += loss.data[0]
                loss.backward()

                if num_iter == num_steps_per_update and phase == 'train':
                    steps += 1
                    num_iter = 0
                    optimizer.step()
                    optimizer.zero_grad()
                    lr_sched.step()
                    if steps % 10 == 0:
                        print '{} Loc Loss: {:.4f} Cls Loss: {:.4f} Tot Loss: {:.4f}'.format(phase, tot_loc_loss/(10*num_steps_per_update), tot_cls_loss/(10*num_steps_per_update), tot_loss/10)
                        # save model
                        torch.save(i3d.module.state_dict(), save_model+str(steps).zfill(6)+'.pt')
                        tot_loss = tot_loc_loss = tot_cls_loss = 0.
            if phase == 'val':
                print '{} Loc Loss: {:.4f} Cls Loss: {:.4f} Tot Loss: {:.4f}'.format(phase, tot_loc_loss/num_iter, tot_cls_loss/num_iter, (tot_loss*num_steps_per_update)/num_iter) 
Ejemplo n.º 37
0
 def classification_loss(self, logit, target):
     """Compute binary cross entropy loss."""
     return F.binary_cross_entropy_with_logits(logit, target, reduction='sum') / logit.size(0)
Ejemplo n.º 38
0
def train_epoch(train_loader: torch.utils.data.DataLoader,
                base_model: torch.nn.Module,
                classification_layer: torch.nn.Module,
                forg_layer: torch.nn.Module, adv_models: List[torch.nn.Module],
                epoch: int, optimizer: torch.optim.Optimizer,
                lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
                callback: Optional[VisdomLogger], device: torch.device,
                args: Any):
    """ Trains the network for one epoch

        Parameters
        ----------
        train_loader: torch.utils.data.DataLoader
            Iterable that loads the training set (x, y) tuples
        base_model: torch.nn.Module
            The model architecture that "extract features" from signatures
        classification_layer: torch.nn.Module
            The classification layer (from features to predictions of which user
            wrote the signature)
        forg_layer: torch.nn.Module
            The forgery prediction layer (from features to predictions of whether
            the signature is a forgery). Only used in args.forg = True
        epoch: int
            The current epoch (used for reporting)
        optimizer: torch.optim.Optimizer
            The optimizer (already initialized)
        lr_scheduler: torch.optim.lr_scheduler._LRScheduler
            The learning rate scheduler
        callback: VisdomLogger (optional)
            A callback to report the training progress
        device: torch.device
            The device (CPU or GPU) to use for training
        args: Namespace
            Extra arguments used for training:
            args.forg: bool
                Whether forgeries are being used for training
            args.lamb: float
                The weight used for the forgery loss (training with forgeries only)

        Returns
        -------
        None
        """

    step = 0
    n_steps = len(train_loader)

    adv_model_idx = 0
    for batch in train_loader:
        x, y = batch[0], batch[1]
        x = torch.tensor(x, dtype=torch.float).to(device)
        y = torch.tensor(y, dtype=torch.long).to(device)
        yforg = torch.tensor(batch[2], dtype=torch.float).to(device)

        # Create adversarial example
        adv = create_adversarial(adv_models, adv_model_idx, x, y, args.eps)

        # Clean example
        features = base_model(x)

        if args.forg:
            # Eq (4) in https://arxiv.org/abs/1705.05787
            logits = classification_layer(features[yforg == 0])
            class_loss = F.cross_entropy(logits, y[yforg == 0])

            forg_logits = forg_layer(features).squeeze()
            forg_loss = F.binary_cross_entropy_with_logits(forg_logits, yforg)

            loss = (1 - args.lamb) * class_loss
            loss += args.lamb * forg_loss
        else:
            # Eq (1) in https://arxiv.org/abs/1705.05787
            logits = classification_layer(features)
            loss = class_loss = F.cross_entropy(logits, y)

        # Back propagation
        loss = args.alpha * loss
        optimizer.zero_grad()
        loss.backward()

        # adv example
        adv_features = base_model(adv)
        adv_logits = classification_layer(adv_features)
        adv_loss = F.cross_entropy(adv_logits, y)
        loss2 = (1 - args.alpha) * adv_loss
        loss2.backward()

        torch.nn.utils.clip_grad_value_(optimizer.param_groups[0]['params'],
                                        10)

        # Update weights
        optimizer.step()

        # Logging
        if callback and step % 11 == 0:
            with torch.no_grad():
                pred_clean = logits.argmax(1)
                acc_clean = y[yforg == 0].eq(pred_clean).float().mean()

                pred_adv = adv_logits.argmax(1)
                acc_adv = y[yforg == 0].eq(pred_adv).float().mean()

            iteration = epoch + (step / n_steps)
            callback.scalars(
                ['closs_clean', 'closs_adv'], iteration,
                [class_loss.detach(), adv_loss.detach()])
            callback.scalar('closs_adv_{}'.format(adv_model_idx), iteration,
                            adv_loss.detach())
            callback.scalars(['acc_clean', 'acc_addv'],
                             epoch + (step / n_steps),
                             [acc_clean, acc_adv.detach()])
            if args.forg:
                forg_pred = forg_logits > 0
                forg_acc = yforg.long().eq(forg_pred.long()).float().mean()
                callback.scalar('forg_loss', iteration, forg_loss.detach())
                callback.scalar('forg_acc', iteration, forg_acc.detach())

        step += 1
        adv_model_idx = (adv_model_idx + 1) % len(adv_models)
    lr_scheduler.step()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = MyModel().to(device)
optimiser = torch.optim.SGD(
    model.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-5, nesterov=True
)

for epoch in range(150):
	model.train()
	start = time.time()
	for idx, (data, target) in enumerate(train_loader):
		data, target = data.to(device), target.to(device).float()
		optimiser.zero_grad()
		output = model(data).squeeze()
		loss = F.binary_cross_entropy_with_logits(output, target)
		loss.backward()
		optimiser.step()
		
		if idx % 20 == 0:
			acc = accuracy(output, target)
			print("Batch {}/{} Loss: {} Acc: {}".format(idx, len(train_loader), loss.item(), acc.detach().cpu().numpy()))
	end = time.time()
	print('Time for one epoch: {:.1f} secs'.format(end-start))
	
	if epoch % 2 == 0:
		model.eval()
		with torch.no_grad():
			loss = 0
			acc = 0
			num_batches = len(valid_loader)
Ejemplo n.º 40
0
def mask_rcnn_loss(pred_mask_logits, instances):
    """
    Compute the mask prediction loss defined in the Mask R-CNN paper.

    Args:
        pred_mask_logits (Tensor): A tensor of shape (B, C, Hmask, Wmask) or (B, 1, Hmask, Wmask)
            for class-specific or class-agnostic, where B is the total number of predicted masks
            in all images, C is the number of foreground classes, and Hmask, Wmask are the height
            and width of the mask predictions. The values are logits.
        instances (list[Instances]): A list of N Instances, where N is the number of images
            in the batch. These instances are in 1:1
            correspondence with the pred_mask_logits. The ground-truth labels (class, box, mask,
            ...) associated with each instance are stored in fields.

    Returns:
        mask_loss (Tensor): A scalar tensor containing the loss.
        and groundtruth masks for visualization
    """
    cls_agnostic_mask = pred_mask_logits.size(1) == 1
    total_num_masks = pred_mask_logits.size(0)
    mask_side_len = pred_mask_logits.size(2)
    assert pred_mask_logits.size(2) == pred_mask_logits.size(
        3), "Mask prediction must be square!"

    gt_classes = []
    gt_mask_logits = []
    for instances_per_image in instances:
        if len(instances_per_image) == 0:
            continue
        if not cls_agnostic_mask:
            gt_classes_per_image = instances_per_image.gt_classes.to(
                dtype=torch.int64)
            gt_classes.append(gt_classes_per_image)

        gt_masks = instances_per_image.gt_masks
        gt_mask_logits_per_image = batch_crop_masks_within_box(
            gt_masks, instances_per_image.proposal_boxes.tensor,
            mask_side_len).to(device=pred_mask_logits.device)
        gt_mask_logits.append(gt_mask_logits_per_image)

    if len(gt_mask_logits) == 0:
        return pred_mask_logits.sum() * 0, gt_mask_logits

    gt_mask_logits = cat(gt_mask_logits, dim=0)
    assert gt_mask_logits.numel() > 0, gt_mask_logits.shape

    if cls_agnostic_mask:
        pred_mask_logits = pred_mask_logits[:, 0]
    else:
        indices = torch.arange(total_num_masks)
        gt_classes = cat(gt_classes, dim=0)
        pred_mask_logits = pred_mask_logits[indices, gt_classes]

    # Log the training accuracy (using gt classes and 0.5 threshold)
    # Note that here we allow gt_mask_logits to be float as well
    # (depend on the implementation of rasterize())
    mask_accurate = (pred_mask_logits > 0.5) == (gt_mask_logits > 0.5)
    mask_accuracy = mask_accurate.nonzero().size(0) / mask_accurate.numel()
    get_event_storage().put_scalar("mask_rcnn/accuracy", mask_accuracy)

    mask_loss = F.binary_cross_entropy_with_logits(
        pred_mask_logits,
        gt_mask_logits.to(dtype=torch.float32),
        reduction="mean")
    return mask_loss, gt_mask_logits