def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, domain_adv: DomainAdversarialLoss, optimizer: SGD, lr_scheduler: StepwiseLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':5.2f') data_time = AverageMeter('Data', ':5.2f') losses = AverageMeter('Loss', ':6.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') domain_accs = AverageMeter('Domain Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs, domain_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() domain_adv.train() end = time.time() for i in range(args.iters_per_epoch): lr_scheduler.step() # measure data loading time data_time.update(time.time() - end) x_s, labels_s = next(train_source_iter) x_t, _ = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) # compute output x = torch.cat((x_s, x_t), dim=0) y, f = model(x) y_s, y_t = y.chunk(2, dim=0) f_s, f_t = f.chunk(2, dim=0) cls_loss = F.cross_entropy(y_s, labels_s) transfer_loss = domain_adv(f_s, f_t) domain_acc = domain_adv.domain_discriminator_accuracy loss = cls_loss + transfer_loss * args.trade_off cls_acc = accuracy(y_s, labels_s)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) domain_accs.update(domain_acc.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train(train_source_iter: ForeverDataIterator, model: Classifier, optimizer: SGD, lr_sheduler: StepwiseLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter(args.iters_per_epoch, [batch_time, data_time, losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): if lr_sheduler is not None: lr_sheduler.step() # measure data loading time data_time.update(time.time() - end) x_s, labels_s = next(train_source_iter) x_s = x_s.to(device) labels_s = labels_s.to(device) # compute output y_s, f_s = model(x_s) cls_loss = F.cross_entropy(y_s, labels_s) loss = cls_loss cls_acc = accuracy(y_s, labels_s)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train_ssl(inferred_dataloader: DataLoader, model: ImageClassifier, optimizer: SGD, lr_scheduler: StepwiseLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') progress = ProgressMeter(len(inferred_dataloader), [batch_time, losses, top1, top5], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i, (x, labels) in enumerate(inferred_dataloader): lr_scheduler.step() x = x.to(device) labels = labels.to(device) # compute output output, _ = model(x) loss = F.cross_entropy(output, labels) # measure accuracy and record loss acc1, acc5 = accuracy(output, labels, topk=(1, 5)) losses.update(loss.item(), x.size(0)) top1.update(acc1[0], x.size(0)) top5.update(acc5[0], x.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def validate(val_loader: DataLoader, G: nn.Module, F1: ImageClassifierHead, F2: ImageClassifierHead, args: argparse.Namespace) -> Tuple[float, float]: batch_time = AverageMeter('Time', ':6.3f') top1_1 = AverageMeter('Acc_1', ':6.2f') top1_2 = AverageMeter('Acc_2', ':6.2f') progress = ProgressMeter(len(val_loader), [batch_time, top1_1, top1_2], prefix='Test: ') # switch to evaluate mode G.eval() F1.eval() F2.eval() with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(val_loader): images = images.to(device) target = target.to(device) # compute output g = G(images) y1, y2 = F1(g), F2(g) # measure accuracy and record loss acc1, = accuracy(y1, target) acc2, = accuracy(y2, target) top1_1.update(acc1[0], images.size(0)) top1_2.update(acc2[0], images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) print(' * Acc1 {top1_1.avg:.3f} Acc2 {top1_2.avg:.3f}'.format( top1_1=top1_1, top1_2=top1_2)) return top1_1.avg, top1_2.avg
def validate(val_loader: DataLoader, model: ImageClassifier, args: argparse.Namespace) -> float: batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') progress = ProgressMeter( len(val_loader), [batch_time, losses, top1, top5], prefix='Test: ') # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(val_loader): images = images.to(device) target = target.to(device) # compute output output, _ = model(images) loss = F.cross_entropy(output, target) # measure accuracy and record loss acc1, acc5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), images.size(0)) top1.update(acc1[0], images.size(0)) top5.update(acc5[0], images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' .format(top1=top1, top5=top5)) return top1.avg
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, classifier: ImageClassifier, mdd: MarginDisparityDiscrepancy, optimizer: SGD, lr_scheduler: StepwiseLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':3.1f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') trans_losses = AverageMeter('Trans Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode classifier.train() mdd.train() criterion = nn.CrossEntropyLoss().to(device) end = time.time() for i in range(args.iters_per_epoch): lr_scheduler.step() optimizer.zero_grad() # measure data loading time data_time.update(time.time() - end) x_s, labels_s = next(train_source_iter) x_t, labels_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # compute output x = torch.cat((x_s, x_t), dim=0) outputs, outputs_adv = classifier(x) y_s, y_t = outputs.chunk(2, dim=0) y_s_adv, y_t_adv = outputs_adv.chunk(2, dim=0) # compute cross entropy loss on source domain cls_loss = criterion(y_s, labels_s) # compute margin disparity discrepancy between domains transfer_loss = mdd(y_s, y_s_adv, y_t, y_t_adv) loss = cls_loss + transfer_loss * args.trade_off classifier.step() cls_acc = accuracy(y_s, labels_s)[0] tgt_acc = accuracy(y_t, labels_t)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_t.size(0)) trans_losses.update(transfer_loss.item(), x_s.size(0)) # compute gradient and do SGD step loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, jmmd_loss: JointMultipleKernelMaximumMeanDiscrepancy, optimizer: SGD, lr_sheduler: StepwiseLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') trans_losses = AverageMeter('Trans Loss', ':5.4f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() jmmd_loss.train() end = time.time() for i in range(args.iters_per_epoch): lr_sheduler.step() # measure data loading time data_time.update(time.time() - end) x_s, labels_s = next(train_source_iter) x_t, labels_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # compute output x = torch.cat((x_s, x_t), dim=0) y, f = model(x) y_s, y_t = y.chunk(2, dim=0) f_s, f_t = f.chunk(2, dim=0) cls_loss = F.cross_entropy(y_s, labels_s) transfer_loss = jmmd_loss((f_s, F.softmax(y_s, dim=1)), (f_t, F.softmax(y_t, dim=1))) loss = cls_loss + transfer_loss * args.trade_off cls_acc = accuracy(y_s, labels_s)[0] tgt_acc = accuracy(y_t, labels_t)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_t.size(0)) trans_losses.update(transfer_loss.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, G: nn.Module, F1: ImageClassifierHead, F2: ImageClassifierHead, optimizer_g: SGD, optimizer_f: SGD, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':3.1f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') trans_losses = AverageMeter('Trans Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode G.train() F1.train() F2.train() end = time.time() for i in range(args.iters_per_epoch): # measure data loading time data_time.update(time.time() - end) x_s, labels_s = next(train_source_iter) x_t, labels_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) x = torch.cat((x_s, x_t), dim=0) assert x.requires_grad is False # Step A train all networks to minimize loss on source domain optimizer_g.zero_grad() optimizer_f.zero_grad() g = G(x) y_1 = F1(g) y_2 = F2(g) y1_s, y1_t = y_1.chunk(2, dim=0) y2_s, y2_t = y_2.chunk(2, dim=0) y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1) loss = F.cross_entropy(y1_s, labels_s) + F.cross_entropy(y2_s, labels_s) + \ 0.01 * (entropy(y1_t) + entropy(y2_t)) loss.backward() optimizer_g.step() optimizer_f.step() # Step B train classifier to maximize discrepancy optimizer_g.zero_grad() optimizer_f.zero_grad() g = G(x) y_1 = F1(g) y_2 = F2(g) y1_s, y1_t = y_1.chunk(2, dim=0) y2_s, y2_t = y_2.chunk(2, dim=0) y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1) loss = F.cross_entropy(y1_s, labels_s) + F.cross_entropy(y2_s, labels_s) + \ 0.01 * (entropy(y1_t) + entropy(y2_t)) - classifier_discrepancy(y1_t, y2_t) * args.trade_off loss.backward() optimizer_f.step() # Step C train genrator to minimize discrepancy for k in range(args.num_k): optimizer_g.zero_grad() g = G(x) y_1 = F1(g) y_2 = F2(g) y1_s, y1_t = y_1.chunk(2, dim=0) y2_s, y2_t = y_2.chunk(2, dim=0) y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1) mcd_loss = classifier_discrepancy(y1_t, y2_t) * args.trade_off mcd_loss.backward() optimizer_g.step() cls_acc = accuracy(y1_s, labels_s)[0] tgt_acc = accuracy(y1_t, labels_t)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_t.size(0)) trans_losses.update(mcd_loss.item(), x_s.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)