예제 #1
0
    def forward_backward(self, batch_x, batch_u):
        parsed_data = self.parse_batch_train(batch_x, batch_u)
        input_x, input_x2, label_x, input_u, input_u2 = parsed_data
        input_u = torch.cat([input_x, input_u], 0)
        input_u2 = torch.cat([input_x2, input_u2], 0)

        # Generate artificial label
        with torch.no_grad():
            output_u = F.softmax(self.model(input_u), 1)
            max_prob, label_u = output_u.max(1)
            mask_u = (max_prob >= self.conf_thre).float()

        # Supervised loss
        output_x = self.model(input_x)
        loss_x = F.cross_entropy(output_x, label_x)

        # Unsupervised loss
        output_u = self.model(input_u2)
        loss_u = F.cross_entropy(output_u, label_u, reduction='none')
        loss_u = (loss_u * mask_u).mean()

        loss = loss_x + loss_u * self.weight_u
        self.model_backward_and_update(loss)

        loss_summary = {
            'loss_x': loss_x.item(),
            'acc_x': compute_accuracy(output_x, label_x)[0].item(),
            'loss_u': loss_u.item(),
            'acc_u': compute_accuracy(output_u, label_u)[0].item()
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary
예제 #2
0
    def forward_backward(self, batch_x, batch_u):
        input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)

        feat_x = self.F(input_x)
        logit_x = self.C(feat_x)
        loss_x = F.cross_entropy(logit_x, label_x)
        self.model_backward_and_update(loss_x)

        feat_u = self.F(input_u)
        feat_u = self.revgrad(feat_u)
        logit_u = self.C(feat_u)
        prob_u = F.softmax(logit_u, 1)
        loss_u = -(-prob_u * torch.log(prob_u + 1e-5)).sum(1).mean()
        self.model_backward_and_update(loss_u * self.lmda)

        loss_summary = {
            'loss_x': loss_x.item(),
            'acc_x': compute_accuracy(logit_x, label_x)[0].item(),
            'loss_u': loss_u.item()
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary
예제 #3
0
    def forward_backward(self, batch_x, batch_u):
        input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
        domain_x = torch.ones(input_x.shape[0], 1).to(self.device)
        domain_u = torch.zeros(input_u.shape[0], 1).to(self.device)

        global_step = self.batch_idx + self.epoch * self.num_batches
        progress = global_step / (self.max_epoch * self.num_batches)
        lmda = 2 / (1 + np.exp(-10 * progress)) - 1

        logit_x, feat_x = self.model(input_x, return_feature=True)
        _, feat_u = self.model(input_u, return_feature=True)

        loss_x = self.ce(logit_x, label_x)

        feat_x = self.revgrad(feat_x, grad_scaling=lmda)
        feat_u = self.revgrad(feat_u, grad_scaling=lmda)
        output_xd = self.critic(feat_x)
        output_ud = self.critic(feat_u)
        loss_d = self.bce(output_xd, domain_x) + self.bce(output_ud, domain_u)

        loss = loss_x + loss_d
        self.model_backward_and_update(loss)

        loss_summary = {
            'loss_x': loss_x.item(),
            'acc_x': compute_accuracy(logit_x, label_x)[0].item(),
            'loss_d': loss_d.item()
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary
예제 #4
0
    def forward_backward(self, batch_x, batch_u):
        global_step = self.batch_idx + self.epoch * self.num_batches
        input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)

        logit_x = self.model(input_x)
        loss_x = F.cross_entropy(logit_x, label_x)

        target_u = F.softmax(self.teacher(input_u), 1)
        prob_u = F.softmax(self.model(input_u), 1)
        loss_u = ((prob_u - target_u)**2).sum(1).mean()

        loss = loss_x + loss_u * self.weight_u
        self.model_backward_and_update(loss)

        ema_alpha = min(1 - 1 / (global_step + 1), self.ema_alpha)
        ema_model_update(self.model, self.teacher, ema_alpha)

        output_dict = {
            'loss_x': loss_x.item(),
            'acc_x': compute_accuracy(logit_x.detach(), label_x)[0].item(),
            'loss_u': loss_u.item(),
            'lr': self.optim.param_groups[0]['lr']
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return output_dict
예제 #5
0
    def forward_backward(self, batch_u):
        input_u, label = self.parse_batch_train(batch_u)
        with torch.no_grad():
            # _, features = self.model(input_u, return_feature=True)
            if self.model.head is not None:
                features = self.netH(self.netB(input_u))
            else:
                features = self.netB(input_u)
            pred = self.obtain_label(features, self.center)

        outputs, features = self.model(input_u, return_feature=True)
        classifier_loss = CrossEntropyLabelSmooth(self.num_classes, 0)(outputs,
                                                                       pred)
        softmax_out = nn.Softmax(dim=1)(outputs)
        im_loss = torch.mean(Entropy(softmax_out))
        msoftmax = softmax_out.mean(dim=0)
        im_loss -= torch.sum(-msoftmax * torch.log(msoftmax + 1e-5))
        loss = im_loss + self.cfg.TRAINER.SHOT.PAR * classifier_loss

        self.model_backward_and_update(loss)
        loss_summary = {
            'loss': loss.item(),
            'acc': compute_accuracy(outputs, label)[0].item()
        }

        return loss_summary
예제 #6
0
    def forward_backward(self, batch_x, batch_u):
        input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)

        logit_x = self.model(input_x)
        loss_x = F.cross_entropy(logit_x, label_x)

        target_u = F.softmax(self.teacher(input_u), 1)
        prob_u = F.softmax(self.model(input_u), 1)
        loss_u = ((prob_u - target_u)**2).sum(1).mean()

        weight_u = self.weight_u * sigmoid_rampup(self.epoch, self.rampup)
        loss = loss_x + loss_u * weight_u
        self.model_backward_and_update(loss)

        global_step = self.batch_idx + self.epoch * self.num_batches
        ema_alpha = min(1 - 1 / (global_step + 1), self.ema_alpha)
        ema_model_update(self.model, self.teacher, ema_alpha)

        loss_summary = {
            'loss_x': loss_x.item(),
            'acc_x': compute_accuracy(logit_x, label_x)[0].item(),
            'loss_u': loss_u.item()
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary
예제 #7
0
    def forward_backward(self, batch):
        parsed_data = self.parse_batch_train(batch)
        input, input2, label, domain = parsed_data

        input = torch.split(input, self.split_batch, 0)
        input2 = torch.split(input2, self.split_batch, 0)
        label = torch.split(label, self.split_batch, 0)
        domain = torch.split(domain, self.split_batch, 0)
        domain = [d[0].item() for d in domain]

        loss = 0
        loss_cr = 0
        acc = 0

        feat = [self.F(x) for x in input]
        feat2 = [self.F(x) for x in input2]

        for feat_i, feat2_i, label_i, i in zip(feat, feat2, label, domain):
            cr_s = [j for j in domain if j != i]

            # Learning expert
            pred_i = self.E(i, feat_i)
            loss += (-label_i * torch.log(pred_i + 1e-5)).sum(1).mean()
            expert_label_i = pred_i.detach()
            acc += compute_accuracy(pred_i.detach(),
                                    label_i.max(1)[1])[0].item()

            # Consistency regularization
            cr_pred = []
            for j in cr_s:
                pred_j = self.E(j, feat2_i)
                pred_j = pred_j.unsqueeze(1)
                cr_pred.append(pred_j)
            cr_pred = torch.cat(cr_pred, 1)
            cr_pred = cr_pred.mean(1)
            loss_cr += ((cr_pred - expert_label_i)**2).sum(1).mean()

        loss /= self.n_domain
        loss_cr /= self.n_domain
        acc /= self.n_domain

        loss = 0
        loss += loss
        loss += loss_cr
        self.model_backward_and_update(loss)

        output_dict = {
            'loss': loss.item(),
            'acc': acc,
            'loss_cr': loss_cr.item(),
            'lr': self.optim_F.param_groups[0]['lr']
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return output_dict
예제 #8
0
    def forward_backward(self, batch_x):
        input, label = self.parse_batch_train(batch_x)
        output = self.model(input)
        loss = F.cross_entropy(output, label)
        self.model_backward_and_update(loss)

        loss_summary = {
            'loss': loss.item(),
            'acc': compute_accuracy(output, label)[0].item()
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary
예제 #9
0
    def forward_backward(self, batch_x, batch_u):
        input, label = self.parse_batch_train(batch_x, batch_u)
        output = self.model(input)
        loss = F.cross_entropy(output, label)
        self.model_backward_and_update(loss)

        output_dict = {
            'loss': loss.item(),
            'acc': compute_accuracy(output.detach(), label)[0].item(),
            'lr': self.optim.param_groups[0]['lr']
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return output_dict
예제 #10
0
    def forward_backward(self, batch_x):
        input, label = self.parse_batch_train(batch_x)
        output = self.model(input)
        loss = CrossEntropyLabelSmooth(self.num_classes,
                                       self.cfg.TRAINER.SHOT.SMOOTH)(output,
                                                                     label)
        self.model_backward_and_update(loss)

        loss_summary = {
            'loss': loss.item(),
            'acc': compute_accuracy(output, label)[0].item()
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary
예제 #11
0
    def forward_backward(self, batch_x, batch_u):
        parsed_data = self.parse_batch_train(batch_x, batch_u)
        input_x, input_x2, label_x, input_u, input_u2, label_u = parsed_data
        input_u = torch.cat([input_x, input_u], 0)
        input_u2 = torch.cat([input_x2, input_u2], 0)
        n_x = input_x.size(0)

        # Generate pseudo labels
        with torch.no_grad():
            output_u = F.softmax(self.model(input_u), 1)
            max_prob, label_u_pred = output_u.max(1)
            mask_u = (max_prob >= self.conf_thre).float()

            # Evaluate pseudo labels' accuracy
            y_u_pred_stats = self.assess_y_pred_quality(
                label_u_pred[n_x:], label_u, mask_u[n_x:]
            )

        # Supervised loss
        output_x = self.model(input_x)
        loss_x = F.cross_entropy(output_x, label_x)

        # Unsupervised loss
        output_u = self.model(input_u2)
        loss_u = F.cross_entropy(output_u, label_u_pred, reduction='none')
        loss_u = (loss_u * mask_u).mean()

        loss = loss_x + loss_u * self.weight_u
        self.model_backward_and_update(loss)

        loss_summary = {
            'loss_x': loss_x.item(),
            'acc_x': compute_accuracy(output_x, label_x)[0].item(),
            'loss_u': loss_u.item(),
            'y_u_pred_acc_raw': y_u_pred_stats['acc_raw'],
            'y_u_pred_acc_thre': y_u_pred_stats['acc_thre'],
            'y_u_pred_keep': y_u_pred_stats['keep_rate']
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary
예제 #12
0
    def forward_backward(self, batch_x, batch_u):
        input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)

        output_x = self.model(input_x)
        loss_x = F.cross_entropy(output_x, label_x)

        output_u = F.softmax(self.model(input_u), 1)
        loss_u = (-output_u * torch.log(output_u + 1e-5)).sum(1).mean()

        loss = loss_x + loss_u * self.lmda

        self.model_backward_and_update(loss)

        loss_summary = {
            'loss_x': loss_x.item(),
            'acc_x': compute_accuracy(output_x, label_x)[0].item(),
            'loss_u': loss_u.item()
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary
예제 #13
0
    def forward_backward(self, batch_x, batch_u):
        global_step = self.batch_idx + self.epoch * self.num_batches
        parsed = self.parse_batch_train(batch_x, batch_u)
        input_x, label_x, input_u1, input_u2 = parsed

        logit_x = self.model(input_x)
        loss_x = F.cross_entropy(logit_x, label_x)

        prob_u = F.softmax(self.model(input_u1), 1)
        t_prob_u = F.softmax(self.teacher(input_u2), 1)
        loss_u = ((prob_u - t_prob_u)**2).sum(1)

        if self.conf_thre:
            max_prob = t_prob_u.max(1)[0]
            mask = (max_prob > self.conf_thre).float()
            loss_u = (loss_u * mask).mean()
        else:
            weight_u = sigmoid_rampup(global_step, self.rampup)
            loss_u = loss_u.mean() * weight_u

        loss = loss_x + loss_u
        self.model_backward_and_update(loss)

        ema_alpha = min(1 - 1 / (global_step + 1), self.ema_alpha)
        ema_model_update(self.model, self.teacher, ema_alpha)

        output_dict = {
            'loss_x': loss_x.item(),
            'acc_x': compute_accuracy(logit_x.detach(), label_x)[0].item(),
            'loss_u': loss_u.item(),
            'lr': self.optim.param_groups[0]['lr']
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return output_dict
예제 #14
0
    def forward_backward(self, batch_x, batch_u):
        # Load data
        parsed_data = self.parse_batch_train(batch_x, batch_u)
        input_x, input_x2, label_x, domain_x, input_u, input_u2 = parsed_data
        input_x = torch.split(input_x, self.split_batch, 0)
        input_x2 = torch.split(input_x2, self.split_batch, 0)
        label_x = torch.split(label_x, self.split_batch, 0)
        domain_x = torch.split(domain_x, self.split_batch, 0)
        domain_x = [d[0].item() for d in domain_x]

        # x = data with small augmentations. x2 = data with large augmentations
        # They both correspond to the same datapoints. Same scheme for u and u2.

        # Generate pseudo label
        with torch.no_grad():
            # Unsupervised predictions
            feat_u = self.F(input_u)
            pred_u = []
            for k in range(self.dm.num_source_domains):
                pred_uk = self.E(k, feat_u)
                pred_uk = pred_uk.unsqueeze(1)
                pred_u.append(pred_uk)
            pred_u = torch.cat(pred_u, 1)  # (B, K, C)
            # Pseudolabel = weighted predictions
            u_filter = self.G(feat_u)  # (B, K)
            label_u_mask = (u_filter.max(1)[0] >= self.conf_thre
                            )  # (B). 1 if >=1 expert > thre, 0 otherwise
            new_u_filter = torch.zeros(*u_filter.shape).to(self.device)
            for i, row in enumerate(u_filter):
                j_max = row.max(0)[1]
                new_u_filter[i, j_max] = 1
            u_filter = new_u_filter
            d_closest = self.d_closest(u_filter).max(0)[1]
            u_filter = u_filter.unsqueeze(2).expand(*pred_u.shape)
            pred_fu = (pred_u * u_filter).sum(
                1)  # Zero out all non chosen experts
            pseudo_label_u = pred_fu.max(1)[1]  # (B)
            pseudo_label_u = create_onehot(pseudo_label_u,
                                           self.num_classes).to(self.device)
        # Init losses
        loss_x = 0
        loss_cr = 0
        acc_x = 0
        loss_filter = 0
        acc_filter = 0

        # Supervised and unsupervised features
        feat_x = [self.F(x) for x in input_x]
        feat_x2 = [self.F(x) for x in input_x2]
        feat_u2 = self.F(input_u2)

        for feat_xi, feat_x2i, label_xi, i in zip(feat_x, feat_x2, label_x,
                                                  domain_x):
            cr_s = [j for j in domain_x if j != i]

            # Learning expert
            pred_xi = self.E(i, feat_xi)
            expert_label_xi = pred_xi.detach()
            if self.is_regressive:
                loss_x += ((pred_xi - label_xi)**2).sum(1).mean()
            else:
                loss_x += (-label_xi * torch.log(pred_xi + 1e-5)).sum(1).mean()
                acc_x += compute_accuracy(pred_xi.detach(),
                                          label_xi.max(1)[1])[0].item()

            x_filter = self.G(feat_xi)
            # Filter must be 1 for expert, 0 otherwise
            filter_label = torch.Tensor([0 for _ in range(len(domain_x))
                                         ]).to(self.device)
            filter_label[i] = 1
            filter_label = filter_label.unsqueeze(0).expand(*x_filter.shape)
            loss_filter += (-filter_label *
                            torch.log(x_filter + 1e-5)).sum(1).mean()
            acc_filter += compute_accuracy(x_filter.detach(),
                                           filter_label.max(1)[1])[0].item()

            # Consistency regularization - Mean must follow the leading expert
            cr_pred = []
            for j in cr_s:
                pred_j = self.E(j, feat_x2i)
                pred_j = pred_j.unsqueeze(1)
                cr_pred.append(pred_j)
            cr_pred = torch.cat(cr_pred, 1).mean(1)
            loss_cr += ((cr_pred - expert_label_xi)**2).sum(1).mean()

        loss_x /= self.n_domain
        loss_cr /= self.n_domain
        if not self.is_regressive:
            acc_x /= self.n_domain
        loss_filter /= self.n_domain
        acc_filter /= self.n_domain

        # Unsupervised loss
        pred_u = []
        for k in range(self.dm.num_source_domains):
            pred_uk = self.E(k, feat_u2)
            pred_uk = pred_uk.unsqueeze(1)
            pred_u.append(pred_uk)
        pred_u = torch.cat(pred_u, 1).to(self.device)
        pred_u = pred_u.mean(1)
        if self.is_regressive:
            l_u = (-pseudo_label_u * torch.log(pred_u + 1e-5)).sum(1)
        else:
            l_u = ((pseudo_label_u - pred_u)**2).sum(1).mean()
        loss_u = (l_u * label_u_mask).mean()

        loss = 0
        loss += loss_x
        loss += loss_cr
        loss += loss_filter
        loss += loss_u * self.weight_u
        self.model_backward_and_update(loss)

        loss_summary = {
            'loss_x': loss_x.item(),
            'loss_filter': loss_filter.item(),
            'acc_filter': acc_filter,
            'loss_cr': loss_cr.item(),
            'loss_u': loss_u.item(),
            #'d_closest': d_closest.max(0)[1]
            'd_closest': d_closest.item()
        }
        if not self.is_regressive:
            loss_summary['acc_x'] = acc_x

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary
예제 #15
0
    def forward_backward(self, batch_x, batch_u):
        parsed_data = self.parse_batch_train(batch_x, batch_u)
        input_x, input_x2, label_x, domain_x, input_u, input_u2 = parsed_data

        input_x = torch.split(input_x, self.split_batch, 0)
        input_x2 = torch.split(input_x2, self.split_batch, 0)
        label_x = torch.split(label_x, self.split_batch, 0)
        domain_x = torch.split(domain_x, self.split_batch, 0)
        domain_x = [d[0].item() for d in domain_x]
        # Generate pseudo label
        with torch.no_grad():
            feat_u = self.F(input_u)
            pred_u = []
            for k in range(self.dm.num_source_domains):
                pred_uk = self.E(k, feat_u)
                pred_uk = pred_uk.unsqueeze(1)
                pred_u.append(pred_uk)
            pred_u = torch.cat(pred_u, 1)  # (B, K, C)
            # Get the highest probability and index (label) for each expert
            experts_max_p, experts_max_idx = pred_u.max(2)  # (B, K)
            # Get the most confident expert
            max_expert_p, max_expert_idx = experts_max_p.max(1)  # (B)
            pseudo_label_u = []
            for i, experts_label in zip(max_expert_idx, experts_max_idx):
                pseudo_label_u.append(experts_label[i])
            pseudo_label_u = torch.stack(pseudo_label_u, 0)
            pseudo_label_u = create_onehot(pseudo_label_u, self.num_classes)
            pseudo_label_u = pseudo_label_u.to(self.device)
            label_u_mask = (max_expert_p >= self.conf_thre).float()

        loss_x = 0
        loss_cr = 0
        acc_x = 0

        feat_x = [self.F(x) for x in input_x]
        feat_x2 = [self.F(x) for x in input_x2]
        feat_u2 = self.F(input_u2)

        for feat_xi, feat_x2i, label_xi, i in zip(feat_x, feat_x2, label_x,
                                                  domain_x):
            cr_s = [j for j in domain_x if j != i]

            # Learning expert
            pred_xi = self.E(i, feat_xi)
            loss_x += (-label_xi * torch.log(pred_xi + 1e-5)).sum(1).mean()
            expert_label_xi = pred_xi.detach()
            acc_x += compute_accuracy(pred_xi.detach(),
                                      label_xi.max(1)[1])[0].item()

            # Consistency regularization
            cr_pred = []
            for j in cr_s:
                pred_j = self.E(j, feat_x2i)
                pred_j = pred_j.unsqueeze(1)
                cr_pred.append(pred_j)
            cr_pred = torch.cat(cr_pred, 1)
            cr_pred = cr_pred.mean(1)
            loss_cr += ((cr_pred - expert_label_xi)**2).sum(1).mean()

        loss_x /= self.n_domain
        loss_cr /= self.n_domain
        acc_x /= self.n_domain

        # Unsupervised loss
        pred_u = []
        for k in range(self.dm.num_source_domains):
            pred_uk = self.E(k, feat_u2)
            pred_uk = pred_uk.unsqueeze(1)
            pred_u.append(pred_uk)
        pred_u = torch.cat(pred_u, 1)
        pred_u = pred_u.mean(1)
        l_u = (-pseudo_label_u * torch.log(pred_u + 1e-5)).sum(1)
        loss_u = (l_u * label_u_mask).mean()

        loss = 0
        loss += loss_x
        loss += loss_cr
        loss += loss_u * self.weight_u
        self.model_backward_and_update(loss)

        loss_summary = {
            'loss_x': loss_x.item(),
            'acc_x': acc_x,
            'loss_cr': loss_cr.item(),
            'loss_u': loss_u.item()
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary
예제 #16
0
    def forward_backward(self, batch_x, batch_u):
        parsed_data = self.parse_batch_train(batch_x, batch_u)
        input_x, input_x2, label_x, domain_x, input_u, input_u2 = parsed_data

        input_x = torch.split(input_x, self.split_batch, 0)
        input_x2 = torch.split(input_x2, self.split_batch, 0)
        label_x = torch.split(label_x, self.split_batch, 0)
        domain_x = torch.split(domain_x, self.split_batch, 0)
        domain_x = [d[0].item() for d in domain_x]

        # Generate pseudo label
        with torch.no_grad():
            feat_u = self.F(input_u)
            pred_u = []
            for k in range(self.dm.num_source_domains):
                pred_uk = self.E(k, feat_u)
                pred_uk = pred_uk.unsqueeze(1)
                pred_u.append(pred_uk)
            pred_u = torch.cat(pred_u, 1)  # (B, K, C)
            # Get the median prediction for each action
            pred_u = pred_u.median(1).values
            #Note that there is no leading expert, we just take the median of all predictions

        loss_x = 0
        loss_cr = 0
        if not self.is_regressive:
            acc_x = 0

        feat_x = [self.F(x) for x in input_x]
        feat_x2 = [self.F(x) for x in input_x2]
        #         feat_u2 = self.F(input_u2)

        for feat_xi, feat_x2i, label_xi, i in zip(feat_x, feat_x2, label_x,
                                                  domain_x):
            cr_s = [j for j in domain_x if j != i]

            # Learning expert
            pred_xi = self.E(i, feat_xi)
            if self.is_regressive:
                loss_x += ((pred_xi - label_xi)**2).sum(1).mean()
            else:
                loss_x += (-label_xi * torch.log(pred_xi + 1e-5)).sum(1).mean()
                acc_x += compute_accuracy(pred_xi.detach(),
                                          label_xi.max(1)[1])[0].item()
            expert_label_xi = pred_xi.detach()

            # Consistency regularization
            cr_pred = []
            for j in cr_s:
                pred_j = self.E(j, feat_x2i)
                pred_j = pred_j.unsqueeze(1)
                cr_pred.append(pred_j)
            cr_pred = torch.cat(cr_pred, 1)
            cr_pred = cr_pred.mean(1)
            loss_cr += ((cr_pred - expert_label_xi)**2).sum(1).mean()

        loss_x /= self.n_domain
        loss_cr /= self.n_domain
        if not self.is_regressive:
            acc_x /= self.n_domain

        # Unsupervised loss -> None yet
        # Pending: provide a means of establishing a lead expert so that loss can be calculated

        loss = 0
        loss += loss_x
        loss += loss_cr
        #         loss += loss_u * self.weight_u
        self.model_backward_and_update(loss)

        loss_summary = {
            'loss_x': loss_x.item(),
            'loss_cr': loss_cr.item(),
            #             'loss_u': loss_u.item()
        }
        if not self.is_regressive:
            loss_summary['acc_x'] = acc_x
        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary