def validate(val_loader: DataLoader, G: nn.Module, F1: ImageClassifierHead, F2: ImageClassifierHead, args: argparse.Namespace) -> Tuple[float, float]: batch_time = AverageMeter('Time', ':6.3f') top1_1 = AverageMeter('Acc_1', ':6.2f') top1_2 = AverageMeter('Acc_2', ':6.2f') progress = ProgressMeter( len(val_loader), [batch_time, top1_1, top1_2], prefix='Test: ') # switch to evaluate mode G.eval() F1.eval() F2.eval() if args.per_class_eval: classes = val_loader.dataset.classes confmat = ConfusionMatrix(len(classes)) else: confmat = None with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(val_loader): images = images.to(device) target = target.to(device) # compute output g = G(images) y1, y2 = F1(g), F2(g) # measure accuracy and record loss acc1, = accuracy(y1, target) acc2, = accuracy(y2, target) if confmat: confmat.update(target, y1.argmax(1)) top1_1.update(acc1.item(), images.size(0)) top1_2.update(acc2.item(), images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) print(' * Acc1 {top1_1.avg:.3f} Acc2 {top1_2.avg:.3f}' .format(top1_1=top1_1, top1_2=top1_2)) if confmat: print(confmat.format(classes)) return top1_1.avg, top1_2.avg
def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if args.center_crop: train_transform = T.Compose([ ResizeImage(256), T.CenterCrop(224), T.RandomHorizontalFlip(), T.ToTensor(), normalize ]) else: train_transform = T.Compose([ ResizeImage(256), T.RandomResizedCrop(224), T.RandomHorizontalFlip(), T.ToTensor(), normalize ]) val_transform = T.Compose( [ResizeImage(256), T.CenterCrop(224), T.ToTensor(), normalize]) dataset = datasets.__dict__[args.data] train_source_dataset = dataset(root=args.root, task=args.source, download=True, transform=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_dataset = dataset(root=args.root, task=args.target, download=True, transform=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_dataset = dataset(root=args.root, task=args.validation, download=True, transform=val_transform) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) if args.data == 'DomainNet': test_dataset = dataset(root=args.root, task=args.target, split='test', download=True, transform=val_transform) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) else: test_loader = val_loader train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) G = models.__dict__[args.arch](pretrained=True).to( device) # feature extractor num_classes = train_source_dataset.num_classes # two image classifier heads F1 = ImageClassifierHead(G.out_features, num_classes, args.bottleneck_dim).to(device) F2 = ImageClassifierHead(G.out_features, num_classes, args.bottleneck_dim).to(device) # define optimizer # the learning rate is fixed according to origin paper optimizer_g = SGD(G.parameters(), lr=args.lr, weight_decay=0.0005) optimizer_f = SGD([ { "params": F1.parameters() }, { "params": F2.parameters() }, ], momentum=0.9, lr=args.lr, weight_decay=0.0005) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') G.load_state_dict(checkpoint['G']) F1.load_state_dict(checkpoint['F1']) F2.load_state_dict(checkpoint['F2']) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = G.to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = validate(test_loader, G, F1, F2, args) print(acc1) return # start training best_acc1 = 0. best_results = None for epoch in range(args.epochs): # train for one epoch train(train_source_iter, train_target_iter, G, F1, F2, optimizer_g, optimizer_f, epoch, args) # evaluate on validation set results = validate(val_loader, G, F1, F2, args) # remember best acc@1 and save checkpoint torch.save( { 'G': G.state_dict(), 'F1': F1.state_dict(), 'F2': F2.state_dict() }, logger.get_checkpoint_path('latest')) if max(results) > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(results) best_results = results print("best_acc1 = {:3.1f}, results = {}".format(best_acc1, best_results)) # evaluate on test set checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') G.load_state_dict(checkpoint['G']) F1.load_state_dict(checkpoint['F1']) F2.load_state_dict(checkpoint['F2']) results = validate(test_loader, G, F1, F2, args) print("test_acc1 = {:3.1f}".format(max(results))) logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, G: nn.Module, F1: ImageClassifierHead, F2: ImageClassifierHead, optimizer_g: SGD, optimizer_f: SGD, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':3.1f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') trans_losses = AverageMeter('Trans Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode G.train() F1.train() F2.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter) x_t, labels_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) x = torch.cat((x_s, x_t), dim=0) assert x.requires_grad is False # measure data loading time data_time.update(time.time() - end) # Step A train all networks to minimize loss on source domain optimizer_g.zero_grad() optimizer_f.zero_grad() g = G(x) y_1 = F1(g) y_2 = F2(g) y1_s, y1_t = y_1.chunk(2, dim=0) y2_s, y2_t = y_2.chunk(2, dim=0) y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1) loss = F.cross_entropy(y1_s, labels_s) + F.cross_entropy(y2_s, labels_s) + \ 0.01 * (entropy(y1_t) + entropy(y2_t)) loss.backward() optimizer_g.step() optimizer_f.step() # Step B train classifier to maximize discrepancy optimizer_g.zero_grad() optimizer_f.zero_grad() g = G(x) y_1 = F1(g) y_2 = F2(g) y1_s, y1_t = y_1.chunk(2, dim=0) y2_s, y2_t = y_2.chunk(2, dim=0) y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1) loss = F.cross_entropy(y1_s, labels_s) + F.cross_entropy(y2_s, labels_s) + \ 0.01 * (entropy(y1_t) + entropy(y2_t)) - classifier_discrepancy(y1_t, y2_t) * args.trade_off loss.backward() optimizer_f.step() # Step C train genrator to minimize discrepancy for k in range(args.num_k): optimizer_g.zero_grad() g = G(x) y_1 = F1(g) y_2 = F2(g) y1_s, y1_t = y_1.chunk(2, dim=0) y2_s, y2_t = y_2.chunk(2, dim=0) y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1) mcd_loss = classifier_discrepancy(y1_t, y2_t) * args.trade_off mcd_loss.backward() optimizer_g.step() cls_acc = accuracy(y1_s, labels_s)[0] tgt_acc = accuracy(y1_t, labels_t)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_t.size(0)) trans_losses.update(mcd_loss.item(), x_s.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def main(args: argparse.Namespace): if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if args.center_crop: train_transform = transforms.Compose([ ResizeImage(256), transforms.CenterCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) else: train_transform = transforms.Compose([ ResizeImage(256), transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) val_transform = transforms.Compose([ ResizeImage(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize ]) dataset = datasets.__dict__[args.data] train_source_dataset = dataset(root=args.root, task=args.source, download=True, transform=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_dataset = dataset(root=args.root, task=args.target, download=True, transform=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_dataset = dataset(root=args.root, task=args.target, download=True, transform=val_transform) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) if args.data == 'DomainNet': test_dataset = dataset(root=args.root, task=args.target, evaluate=True, download=True, transform=val_transform) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) else: test_loader = val_loader train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) G = models.__dict__[args.arch](pretrained=True).to( device) # feature extractor num_classes = train_source_dataset.num_classes # two image classifier heads F1 = ImageClassifierHead(G.out_features, num_classes, args.bottleneck_dim).to(device) F2 = ImageClassifierHead(G.out_features, num_classes, args.bottleneck_dim).to(device) # define optimizer # the learning rate is fixed according to origin paper optimizer_g = SGD(G.parameters(), lr=args.lr, weight_decay=0.0005) optimizer_f = SGD(F1.get_parameters() + F2.get_parameters(), momentum=0.9, lr=args.lr, weight_decay=0.0005) # start training best_acc1 = 0. best_results = None for epoch in range(args.epochs): # train for one epoch train(train_source_iter, train_target_iter, G, F1, F2, optimizer_g, optimizer_f, epoch, args) # evaluate on validation set results = validate(val_loader, G, F1, F2, args) # remember best acc@1 and save checkpoint if max(results) > best_acc1: best_G, best_F1, best_F2 = copy.deepcopy( G.state_dict()), copy.deepcopy(F1.state_dict()), copy.deepcopy( F2.state_dict()) best_acc1 = max(results) best_results = results print("best_acc1 = {:3.1f}, results = {}".format(best_acc1, best_results)) # evaluate on test set G.load_state_dict(best_G) F1.load_state_dict(best_F1) F2.load_state_dict(best_F2) results = validate(test_loader, G, F1, F2, args) print("test_acc1 = {:3.1f}".format(max(results)))
def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform( args.train_resizing, random_horizontal_flip=not args.no_hflip, random_color_jitter=False, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using model '{}'".format(args.arch)) G = utils.get_model(args.arch, pretrain=not args.scratch).to( device) # feature extractor # two image classifier heads pool_layer = nn.Identity() if args.no_pool else None F1 = ImageClassifierHead(G.out_features, num_classes, args.bottleneck_dim, pool_layer).to(device) F2 = ImageClassifierHead(G.out_features, num_classes, args.bottleneck_dim, pool_layer).to(device) # define optimizer # the learning rate is fixed according to origin paper optimizer_g = SGD(G.parameters(), lr=args.lr, weight_decay=0.0005) optimizer_f = SGD([ { "params": F1.parameters() }, { "params": F2.parameters() }, ], momentum=0.9, lr=args.lr, weight_decay=0.0005) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') G.load_state_dict(checkpoint['G']) F1.load_state_dict(checkpoint['F1']) F2.load_state_dict(checkpoint['F2']) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(G, F1.pool_layer).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = validate(test_loader, G, F1, F2, args) print(acc1) return # start training best_acc1 = 0. best_results = None for epoch in range(args.epochs): # train for one epoch train(train_source_iter, train_target_iter, G, F1, F2, optimizer_g, optimizer_f, epoch, args) # evaluate on validation set results = validate(val_loader, G, F1, F2, args) # remember best acc@1 and save checkpoint torch.save( { 'G': G.state_dict(), 'F1': F1.state_dict(), 'F2': F2.state_dict() }, logger.get_checkpoint_path('latest')) if max(results) > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(results) best_results = results print("best_acc1 = {:3.1f}, results = {}".format(best_acc1, best_results)) # evaluate on test set checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') G.load_state_dict(checkpoint['G']) F1.load_state_dict(checkpoint['F1']) F2.load_state_dict(checkpoint['F2']) results = validate(test_loader, G, F1, F2, args) print("test_acc1 = {:3.1f}".format(max(results))) logger.close()