def forward(self, f_s: torch.Tensor, f_t: torch.Tensor, w_s: Optional[torch.Tensor] = None, w_t: Optional[torch.Tensor] = None) -> torch.Tensor: f = self.grl(torch.cat((f_s, f_t), dim=0)) d = self.domain_discriminator(f) d_s, d_t = d.chunk(2, dim=0) d_label_s = torch.ones((f_s.size(0), 1)).to(f_s.device) d_label_t = torch.zeros((f_t.size(0), 1)).to(f_t.device) self.domain_discriminator_accuracy = 0.5 * ( binary_accuracy(d_s, d_label_s) + binary_accuracy(d_t, d_label_t)) if w_s is None: w_s = torch.ones_like(d_label_s) if w_t is None: w_t = torch.ones_like(d_label_t) return 0.5 * (self.bce(d_s, d_label_s, w_s.view_as(d_s)) + self.bce(d_t, d_label_t, w_t.view_as(d_t)))
def forward(self, g_s: torch.Tensor, f_s: torch.Tensor, g_t: torch.Tensor, f_t: torch.Tensor) -> torch.Tensor: f = torch.cat((f_s, f_t), dim=0) g = torch.cat((g_s, g_t), dim=0) g = F.softmax(g, dim=1).detach() h = self.grl(self.map(f, g)) d = self.domain_discriminator(h) d_label = torch.cat(( torch.ones((g_s.size(0), 1)).to(g_s.device), torch.zeros((g_t.size(0), 1)).to(g_t.device), )) weight = 1.0 + torch.exp(-entropy(g)) batch_size = f.size(0) weight = weight / torch.sum(weight) * batch_size self.domain_discriminator_accuracy = binary_accuracy(d, d_label) return self.bce(d, d_label, weight.view_as(d))
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, domain_discri: DomainDiscriminator, domain_adv: DomainAdversarialLoss, gl, optimizer: SGD, lr_scheduler: LambdaLR, optimizer_d: SGD, lr_scheduler_d: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':5.2f') data_time = AverageMeter('Data', ':5.2f') losses_s = AverageMeter('Cls Loss', ':6.2f') losses_transfer = AverageMeter('Transfer Loss', ':6.2f') losses_discriminator = AverageMeter('Discriminator Loss', ':6.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') domain_accs = AverageMeter('Domain Acc', ':3.1f') progress = ProgressMeter(args.iters_per_epoch, [ batch_time, data_time, losses_s, losses_transfer, losses_discriminator, cls_accs, domain_accs ], prefix="Epoch: [{}]".format(epoch)) end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter) x_t, _ = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) # measure data loading time data_time.update(time.time() - end) # Step 1: Train the classifier, freeze the discriminator model.train() domain_discri.eval() set_requires_grad(model, True) set_requires_grad(domain_discri, False) x = torch.cat((x_s, x_t), dim=0) y, f = model(x) y_s, y_t = y.chunk(2, dim=0) loss_s = F.cross_entropy(y_s, labels_s) # adversarial training to fool the discriminator d = domain_discri(gl(f)) d_s, d_t = d.chunk(2, dim=0) loss_transfer = 0.5 * (domain_adv(d_s, 'target') + domain_adv(d_t, 'source')) optimizer.zero_grad() (loss_s + loss_transfer * args.trade_off).backward() optimizer.step() lr_scheduler.step() # Step 2: Train the discriminator model.eval() domain_discri.train() set_requires_grad(model, False) set_requires_grad(domain_discri, True) d = domain_discri(f.detach()) d_s, d_t = d.chunk(2, dim=0) loss_discriminator = 0.5 * (domain_adv(d_s, 'source') + domain_adv(d_t, 'target')) optimizer_d.zero_grad() loss_discriminator.backward() optimizer_d.step() lr_scheduler_d.step() losses_s.update(loss_s.item(), x_s.size(0)) losses_transfer.update(loss_transfer.item(), x_s.size(0)) losses_discriminator.update(loss_discriminator.item(), x_s.size(0)) cls_acc = accuracy(y_s, labels_s)[0] cls_accs.update(cls_acc.item(), x_s.size(0)) domain_acc = 0.5 * (binary_accuracy(d_s, torch.ones_like(d_s)) + binary_accuracy(d_t, torch.zeros_like(d_t))) domain_accs.update(domain_acc.item(), x_s.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)