Example #1
0
    def forward(self, source, target):
        loss = 0
        source = self.sharedNet(source)

        if self.training == True:
            target = self.sharedNet(target)
            loss += CORAL(source, target)
        source = self.cls_fc(source)

        return source, loss
Example #2
0
    def train(model, config, epoch):
        model.class_classifier.train()
        model.feature.train()

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        num_iter = len_source_loader
        for i in range(1, num_iter):
            data_source, label_source = iter_source.next()
            data_target, _ = iter_target.next()
            if i % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(
                ), label_source.cuda()
                data_target = data_target.cuda()

            optimizer.zero_grad()

            preds = model.class_classify(data_source)
            loss_cls = criterion(preds, label_source)

            source = model.feature(data_source)
            source = source.view(source.size(0), -1)
            target = model.feature(data_target)
            target = target.view(target.size(0), -1)
            loss_coral = CORAL(source, target)

            loss = loss_cls + config['mmd_gamma'] * loss_coral
            if i % 50 == 0:
                print('loss_cls {}, loss_coral {}, gamma {}, total loss {}'.
                      format(loss_cls.item(), loss_coral.item(),
                             config['mmd_gamma'], loss.item()))
            loss.backward()
            optimizer.step()
Example #3
0
    def forward(self, source, target):
        coral_loss = 0
        source = self.sharedNet(source)
        source = source.view(source.size(0), fc_layer[self.backbone])
        if self.backbone == 'alexnet':
            source = self.fc(source)
        if self.isTrain:
            target = self.sharedNet(target)
            target = target.view(target.size(0), fc_layer[self.backbone])
            if self.backbone == 'alexnet':
                target = self.fc(target)

            coral_loss = CORAL(source, target)

        clf = self.cls_fc(source)
        return clf, coral_loss
Example #4
0
 def adapt_loss(self, X, Y, adapt_loss):
     """Compute adaptation loss, currently we support mmd and coral
     Arguments:
         X {tensor} -- source matrix
         Y {tensor} -- target matrix
         adapt_loss {string} -- loss type, 'mmd' or 'coral'. You can add your own loss
     Returns:
         [tensor] -- adaptation loss tensor
     """
     if adapt_loss == 'mmd':
         mmd_loss = mmd.MMD_loss()
         loss = mmd_loss(X, Y)
     elif adapt_loss == 'coral':
         loss = CORAL(X, Y)
     else:
         loss = 0
     return loss