コード例 #1
0
    def forward(self,
                f_s: torch.Tensor,
                f_t: torch.Tensor,
                w_s: Optional[torch.Tensor] = None,
                w_t: Optional[torch.Tensor] = None) -> torch.Tensor:
        f = self.grl(torch.cat((f_s, f_t), dim=0))
        d = self.domain_discriminator(f)
        d_s, d_t = d.chunk(2, dim=0)
        d_label_s = torch.ones((f_s.size(0), 1)).to(f_s.device)
        d_label_t = torch.zeros((f_t.size(0), 1)).to(f_t.device)
        self.domain_discriminator_accuracy = 0.5 * (
            binary_accuracy(d_s, d_label_s) + binary_accuracy(d_t, d_label_t))

        if w_s is None:
            w_s = torch.ones_like(d_label_s)
        if w_t is None:
            w_t = torch.ones_like(d_label_t)
        return 0.5 * (self.bce(d_s, d_label_s, w_s.view_as(d_s)) +
                      self.bce(d_t, d_label_t, w_t.view_as(d_t)))
コード例 #2
0
 def forward(self, g_s: torch.Tensor, f_s: torch.Tensor, g_t: torch.Tensor, f_t: torch.Tensor) -> torch.Tensor:
     f = torch.cat((f_s, f_t), dim=0)
     g = torch.cat((g_s, g_t), dim=0)
     g = F.softmax(g, dim=1).detach()
     h = self.grl(self.map(f, g))
     d = self.domain_discriminator(h)
     d_label = torch.cat((
         torch.ones((g_s.size(0), 1)).to(g_s.device),
         torch.zeros((g_t.size(0), 1)).to(g_t.device),
     ))
     weight = 1.0 + torch.exp(-entropy(g))
     batch_size = f.size(0)
     weight = weight / torch.sum(weight) * batch_size
     self.domain_discriminator_accuracy = binary_accuracy(d, d_label)
     return self.bce(d, d_label, weight.view_as(d))
コード例 #3
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          domain_discri: DomainDiscriminator,
          domain_adv: DomainAdversarialLoss, gl, optimizer: SGD,
          lr_scheduler: LambdaLR, optimizer_d: SGD, lr_scheduler_d: LambdaLR,
          epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    losses_s = AverageMeter('Cls Loss', ':6.2f')
    losses_transfer = AverageMeter('Transfer Loss', ':6.2f')
    losses_discriminator = AverageMeter('Discriminator Loss', ':6.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    domain_accs = AverageMeter('Domain Acc', ':3.1f')
    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_s, losses_transfer, losses_discriminator,
        cls_accs, domain_accs
    ],
                             prefix="Epoch: [{}]".format(epoch))

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, _ = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # Step 1: Train the classifier, freeze the discriminator
        model.train()
        domain_discri.eval()
        set_requires_grad(model, True)
        set_requires_grad(domain_discri, False)
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_s, y_t = y.chunk(2, dim=0)
        loss_s = F.cross_entropy(y_s, labels_s)

        # adversarial training to fool the discriminator
        d = domain_discri(gl(f))
        d_s, d_t = d.chunk(2, dim=0)
        loss_transfer = 0.5 * (domain_adv(d_s, 'target') +
                               domain_adv(d_t, 'source'))

        optimizer.zero_grad()
        (loss_s + loss_transfer * args.trade_off).backward()
        optimizer.step()
        lr_scheduler.step()

        # Step 2: Train the discriminator
        model.eval()
        domain_discri.train()
        set_requires_grad(model, False)
        set_requires_grad(domain_discri, True)
        d = domain_discri(f.detach())
        d_s, d_t = d.chunk(2, dim=0)
        loss_discriminator = 0.5 * (domain_adv(d_s, 'source') +
                                    domain_adv(d_t, 'target'))

        optimizer_d.zero_grad()
        loss_discriminator.backward()
        optimizer_d.step()
        lr_scheduler_d.step()

        losses_s.update(loss_s.item(), x_s.size(0))
        losses_transfer.update(loss_transfer.item(), x_s.size(0))
        losses_discriminator.update(loss_discriminator.item(), x_s.size(0))

        cls_acc = accuracy(y_s, labels_s)[0]
        cls_accs.update(cls_acc.item(), x_s.size(0))
        domain_acc = 0.5 * (binary_accuracy(d_s, torch.ones_like(d_s)) +
                            binary_accuracy(d_t, torch.zeros_like(d_t)))
        domain_accs.update(domain_acc.item(), x_s.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)