def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, domain_adv: DomainAdversarialLoss, optimizer: SGD, lr_scheduler: StepwiseLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':5.2f') data_time = AverageMeter('Data', ':5.2f') losses = AverageMeter('Loss', ':6.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') domain_accs = AverageMeter('Domain Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs, domain_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() domain_adv.train() end = time.time() for i in range(args.iters_per_epoch): lr_scheduler.step() # measure data loading time data_time.update(time.time() - end) x_s, labels_s = next(train_source_iter) x_t, _ = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) # compute output x = torch.cat((x_s, x_t), dim=0) y, f = model(x) y_s, y_t = y.chunk(2, dim=0) f_s, f_t = f.chunk(2, dim=0) cls_loss = F.cross_entropy(y_s, labels_s) transfer_loss = domain_adv(f_s, f_t) domain_acc = domain_adv.domain_discriminator_accuracy loss = cls_loss + transfer_loss * args.trade_off cls_acc = accuracy(y_s, labels_s)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) domain_accs.update(domain_acc.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train(train_source_iter: ForeverDataIterator, model: Classifier, optimizer: SGD, lr_sheduler: StepwiseLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter(args.iters_per_epoch, [batch_time, data_time, losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): if lr_sheduler is not None: lr_sheduler.step() # measure data loading time data_time.update(time.time() - end) x_s, labels_s = next(train_source_iter) x_s = x_s.to(device) labels_s = labels_s.to(device) # compute output y_s, f_s = model(x_s) cls_loss = F.cross_entropy(y_s, labels_s) loss = cls_loss cls_acc = accuracy(y_s, labels_s)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train_ssl(inferred_dataloader: DataLoader, model: ImageClassifier, optimizer: SGD, lr_scheduler: StepwiseLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') progress = ProgressMeter(len(inferred_dataloader), [batch_time, losses, top1, top5], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i, (x, labels) in enumerate(inferred_dataloader): lr_scheduler.step() x = x.to(device) labels = labels.to(device) # compute output output, _ = model(x) loss = F.cross_entropy(output, labels) # measure accuracy and record loss acc1, acc5 = accuracy(output, labels, topk=(1, 5)) losses.update(loss.item(), x.size(0)) top1.update(acc1[0], x.size(0)) top5.update(acc5[0], x.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def main(args: argparse.Namespace): if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = transforms.Compose([ ResizeImage(256), transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) val_tranform = transforms.Compose([ ResizeImage(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize ]) dataset = datasets.__dict__[args.data] train_source_dataset = dataset(root=args.root, task=args.source, download=True, transform=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_dataset = dataset(root=args.root, task=args.target, download=True, transform=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_dataset = dataset(root=args.root, task=args.target, download=True, transform=val_tranform) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) if args.data == 'DomainNet': test_dataset = dataset(root=args.root, task=args.target, evaluate=True, download=True, transform=val_tranform) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) else: test_loader = val_loader train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = models.__dict__[args.arch](pretrained=True) classifier = ImageClassifier(backbone, train_source_dataset.num_classes).to(device) domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = StepwiseLR(optimizer, init_lr=args.lr, gamma=0.001, decay_rate=0.75) # define loss function domain_adv = DomainAdversarialLoss(domain_discri).to(device) # start training best_acc1 = 0. for epoch in range(args.epochs): # train for one epoch train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = validate(val_loader, classifier, args) # remember best acc@1 and save checkpoint if acc1 > best_acc1: best_model = copy.deepcopy(classifier.state_dict()) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(best_model) acc1 = validate(test_loader, classifier, args) print("test_acc1 = {:3.1f}".format(acc1))
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: nn.Module, jmmd_loss, optimizer: SGD, lr_sheduler: StepwiseLR, pseudo_labels: PseudoLabeling, epoch: int, args: argparse.Namespace): losses = AverageMeter('Loss', ':3.2f') cls_losses = AverageMeter('Cls Loss', ':3.2f') trans_losses = AverageMeter('Trans Loss', ':5.2f') joint_losses = AverageMeter('Joint Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [losses, cls_losses, trans_losses, joint_losses, cls_accs, tgt_accs], prefix="Epoch: [{}]".format(epoch)) model.train() if args.freeze_bn: for m in model.backbone.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() # jmmd_loss.train() for i in range(args.iters_per_epoch): lr_sheduler.step() x_s, labels_s, index_s = next(train_source_iter) x_t, labels_t, index_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # compute output x = torch.cat((x_s, x_t), dim=0) h = model.backbone_forward(x) y, f = model.head_forward(h) y_s, y_t = y.chunk(2, dim=0) f_s, f_t = f.chunk(2, dim=0) loss = 0.0 cls_loss = F.cross_entropy(y_s, labels_s) loss += cls_loss transfer_loss = jmmd_loss((f_s, F.softmax(y_s, dim=1)), (f_t, F.softmax(y_t, dim=1))) loss += transfer_loss * args.lambda1 if epoch >= args.start_epoch: with torch.no_grad(): pseudo_labels_t = pseudo_labels.get_hard_pseudo_label(index_t) weights_t = pseudo_labels.get_weight(index_t) # print(F.softmax(weights_t, dim=0), F.softmax(weights_t/0.1, dim=0)) weights_t = F.softmax(weights_t / args.temperature, dim=0) joint_loss = 0.0 for j in range(args.loss_sample_num): y_1, _ = model.head_forward(h) y1_s, y1_t = y_1.chunk(2, dim=0) y_2, _ = model.head_forward(h) y2_s, y2_t = y_2.chunk(2, dim=0) joint_loss += CE_disagreement(y1_s, y1_t, y2_s, y2_t, labels_s, pseudo_labels_t, weights_t) joint_loss = joint_loss / args.loss_sample_num loss += joint_loss * args.lambda2 cls_acc = accuracy(y_s, labels_s)[0] tgt_acc = accuracy(y_t, labels_t)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_t.size(0)) cls_losses.update(cls_loss.item(), x_s.size(0)) trans_losses.update(transfer_loss.item(), x_s.size(0)) # trans_losses.update(0.0, x_s.size(0)) try: joint_losses.update(joint_loss.item(), x_s.size(0)) except: joint_losses.update(0.0, x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() if args.gradclip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradclip) optimizer.step() if i % args.print_freq == 0: progress.display(i)
def main(args: argparse.Namespace): setup_seed(args.seed) cudnn.benchmark = True # Data loading code normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = transforms.Compose([ ResizeImage(256), transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) val_transform = transforms.Compose([ ResizeImage(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize ]) dataset = datasets.__dict__[args.data] args.root = os.path.join(args.root, args.data) train_source_dataset = dataset(root=args.root, task=args.source, download=False, transform=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_dataset = dataset(root=args.root, task=args.target, download=False, transform=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_dataset = dataset(root=args.root, task=args.target, download=False, transform=val_transform) val_loader = DataLoader(val_dataset, batch_size=args.val_batch_size, shuffle=False, num_workers=args.workers) if args.data == 'DomainNet': test_dataset = dataset(root=args.root, task=args.target, evaluate=True, download=False, transform=val_transform) test_loader = DataLoader(test_dataset, batch_size=args.val_batch_size, shuffle=False, num_workers=args.workers) else: test_loader = val_loader train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) #pseudo labels pseudo_labels = PseudoLabeling(len(train_target_dataset), train_target_dataset.num_classes, args.prob_ema_gamma, args.tc_ema_gamma, args.threshold, device) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = models.__dict__[args.arch](pretrained=True) num_classes = train_source_dataset.num_classes model = MCdropClassifier(backbone=backbone, num_classes=num_classes, bottleneck_dim=args.bottleneck_dim, classifier_width=args.classifier_width, dropout_rate=args.dropout_rate, dropout_type=args.dropout_type).to(device) # define loss function jmmd_loss = JointMultipleKernelMaximumMeanDiscrepancy( kernels=([GaussianKernel(alpha=2**k) for k in range(-3, 2)], (GaussianKernel(sigma=0.92, track_running_stats=False), )), linear=args.linear, thetas=None).to(device) # define optimizer parameters = model.get_parameters(args.freeze_backbone, args.backbone_decay) optimizer = SGD(parameters, args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_sheduler = StepwiseLR(optimizer, init_lr=args.lr, gamma=args.gamma, decay_rate=args.decay_rate) # start training best_acc1 = 0. for epoch in range(args.epochs): psudo_label_update_and_weight_calculate(train_source_loader, val_loader, model, pseudo_labels, args, epoch) pseudo_labels.copy_history() # train for one epoch train(train_source_iter, train_target_iter, model, jmmd_loss, optimizer, lr_sheduler, pseudo_labels, epoch, args) acc1 = validate(val_loader, model, args) # remember best acc@1 and save checkpoint if acc1 > best_acc1: best_model = copy.deepcopy(model.state_dict()) best_acc1 = acc1 print('find best!') print("current best = {:3.3f}".format(best_acc1)) print("best_acc1 = {:3.3f}".format(best_acc1)) # evaluate on test set model.load_state_dict(best_model) acc1 = validate(test_loader, model, args) print("test_acc1 = {:3.3f}".format(acc1))
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, classifier: ImageClassifier, mdd: MarginDisparityDiscrepancy, optimizer: SGD, lr_scheduler: StepwiseLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':3.1f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') trans_losses = AverageMeter('Trans Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode classifier.train() mdd.train() criterion = nn.CrossEntropyLoss().to(device) end = time.time() for i in range(args.iters_per_epoch): lr_scheduler.step() optimizer.zero_grad() # measure data loading time data_time.update(time.time() - end) x_s, labels_s = next(train_source_iter) x_t, labels_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # compute output x = torch.cat((x_s, x_t), dim=0) outputs, outputs_adv = classifier(x) y_s, y_t = outputs.chunk(2, dim=0) y_s_adv, y_t_adv = outputs_adv.chunk(2, dim=0) # compute cross entropy loss on source domain cls_loss = criterion(y_s, labels_s) # compute margin disparity discrepancy between domains transfer_loss = mdd(y_s, y_s_adv, y_t, y_t_adv) loss = cls_loss + transfer_loss * args.trade_off classifier.step() cls_acc = accuracy(y_s, labels_s)[0] tgt_acc = accuracy(y_t, labels_t)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_t.size(0)) trans_losses.update(transfer_loss.item(), x_s.size(0)) # compute gradient and do SGD step loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def main(args: argparse.Namespace): if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = transforms.Compose([ ResizeImage(256), transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) val_transform = transforms.Compose([ ResizeImage(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize ]) dataset = datasets.__dict__[args.data] train_source_dataset = dataset(root=args.root, task=args.source, download=True, transform=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_dataset = dataset(root=args.root, task=args.target, download=True, transform=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_dataset = dataset(root=args.root, task=args.target, download=True, transform=val_transform) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) if args.data == 'DomainNet': test_dataset = dataset(root=args.root, task=args.target, evaluate=True, download=True, transform=val_transform) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) else: test_loader = val_loader train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = models.__dict__[args.arch](pretrained=True) num_classes = train_source_dataset.num_classes classifier = ImageClassifier(backbone, num_classes).to(device) # define loss function if args.adversarial: thetas = [ Theta(dim).to(device) for dim in (classifier.features_dim, num_classes) ] else: thetas = None jmmd_loss = JointMultipleKernelMaximumMeanDiscrepancy( kernels=([GaussianKernel(alpha=2**k) for k in range(-3, 2)], (GaussianKernel(sigma=0.92, track_running_stats=False), )), linear=args.linear, thetas=thetas).to(device) parameters = classifier.get_parameters() if thetas is not None: parameters += [{ "params": theta.parameters(), 'lr_mult': 0.1 } for theta in thetas] # define optimizer optimizer = SGD(parameters, args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_sheduler = StepwiseLR(optimizer, init_lr=args.lr, gamma=0.0003, decay_rate=0.75) # start training best_acc1 = 0. for epoch in range(args.epochs): # train for one epoch train(train_source_iter, train_target_iter, classifier, jmmd_loss, optimizer, lr_sheduler, epoch, args) # evaluate on validation set acc1 = validate(val_loader, classifier, args) # remember best acc@1 and save checkpoint if acc1 > best_acc1: best_model = copy.deepcopy(classifier.state_dict()) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(best_model) acc1 = validate(test_loader, classifier, args) print("test_acc1 = {:3.1f}".format(acc1))
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, jmmd_loss: JointMultipleKernelMaximumMeanDiscrepancy, optimizer: SGD, lr_sheduler: StepwiseLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') trans_losses = AverageMeter('Trans Loss', ':5.4f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() jmmd_loss.train() end = time.time() for i in range(args.iters_per_epoch): lr_sheduler.step() # measure data loading time data_time.update(time.time() - end) x_s, labels_s = next(train_source_iter) x_t, labels_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # compute output x = torch.cat((x_s, x_t), dim=0) y, f = model(x) y_s, y_t = y.chunk(2, dim=0) f_s, f_t = f.chunk(2, dim=0) cls_loss = F.cross_entropy(y_s, labels_s) transfer_loss = jmmd_loss((f_s, F.softmax(y_s, dim=1)), (f_t, F.softmax(y_t, dim=1))) loss = cls_loss + transfer_loss * args.trade_off cls_acc = accuracy(y_s, labels_s)[0] tgt_acc = accuracy(y_t, labels_t)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_t.size(0)) trans_losses.update(transfer_loss.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def main(args: argparse.Namespace): if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = transforms.Compose([ ResizeImage(256), transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) val_transform = transforms.Compose([ ResizeImage(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize ]) dataset = datasets.__dict__[args.data] train_source_dataset = dataset(root=args.root, task=args.source, download=True, transform=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_dataset = dataset(root=args.root, task=args.target, download=True, transform=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_dataset = dataset(root=args.root, task=args.target, download=True, transform=val_transform) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = models.__dict__[args.arch](pretrained=True) classifier = ImageClassifier(backbone, train_source_dataset.num_classes).to(device) domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = StepwiseLR(optimizer, init_lr=args.lr, gamma=0.001, decay_rate=0.75) # define loss function domain_adv = DomainAdversarialLoss(domain_discri).to(device) # start training best_acc1 = 0. best_model = classifier.state_dict() for epoch in range(args.epochs): # train for one epoch train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = validate(val_loader, classifier, args) # remember best acc@1 and save checkpoint if acc1 > best_acc1: best_model = classifier.state_dict() torch.save(best_model, 'best_model.pth.tar') best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # visualize the results using T-SNE classifier.load_state_dict(best_model) classifier.eval() features, labels, domains = [], [], [] source_val_dataset = dataset(root=args.root, task=args.source, download=True, transform=val_transform) source_val_loader = DataLoader(source_val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) with torch.no_grad(): for loader in [source_val_loader, val_loader]: for i, (images, target) in enumerate(loader): images = images.to(device) target = target.to(device) # compute output _, f = classifier(images) features.extend(f.cpu().numpy().tolist()) labels.extend(target) domains = np.concatenate((np.ones(len(source_val_dataset)), np.zeros(len(val_dataset)))) features, labels = np.array(features), np.array(labels) print("source:", len(source_val_dataset), "target:", len(val_dataset)) X_tsne = TSNE(n_components=2, random_state=33).fit_transform(features) plt.figure(figsize=(10, 10)) plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=domains, cmap=col.ListedColormap(["r", "b"]), s=2) plt.savefig(os.path.join('{}_{}2{}.pdf'.format("dann", args.source, args.target)))