Exemple #1
0
def validate(val_loader: DataLoader, G: nn.Module, F1: ImageClassifierHead,
             F2: ImageClassifierHead, args: argparse.Namespace) -> Tuple[float, float]:
    batch_time = AverageMeter('Time', ':6.3f')
    top1_1 = AverageMeter('Acc_1', ':6.2f')
    top1_2 = AverageMeter('Acc_2', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, top1_1, top1_2],
        prefix='Test: ')

    # switch to evaluate mode
    G.eval()
    F1.eval()
    F2.eval()

    if args.per_class_eval:
        classes = val_loader.dataset.classes
        confmat = ConfusionMatrix(len(classes))
    else:
        confmat = None

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            images = images.to(device)
            target = target.to(device)

            # compute output
            g = G(images)
            y1, y2 = F1(g), F2(g)

            # measure accuracy and record loss
            acc1, = accuracy(y1, target)
            acc2, = accuracy(y2, target)
            if confmat:
                confmat.update(target, y1.argmax(1))
            top1_1.update(acc1.item(), images.size(0))
            top1_2.update(acc2.item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        print(' * Acc1 {top1_1.avg:.3f} Acc2 {top1_2.avg:.3f}'
              .format(top1_1=top1_1, top1_2=top1_2))
        if confmat:
            print(confmat.format(classes))

    return top1_1.avg, top1_2.avg
Exemple #2
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    if args.center_crop:
        train_transform = T.Compose([
            ResizeImage(256),
            T.CenterCrop(224),
            T.RandomHorizontalFlip(),
            T.ToTensor(), normalize
        ])
    else:
        train_transform = T.Compose([
            ResizeImage(256),
            T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            T.ToTensor(), normalize
        ])
    val_transform = T.Compose(
        [ResizeImage(256),
         T.CenterCrop(224),
         T.ToTensor(), normalize])

    dataset = datasets.__dict__[args.data]
    train_source_dataset = dataset(root=args.root,
                                   task=args.source,
                                   download=True,
                                   transform=train_transform)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    train_target_dataset = dataset(root=args.root,
                                   task=args.target,
                                   download=True,
                                   transform=train_transform)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    val_dataset = dataset(root=args.root,
                          task=args.validation,
                          download=True,
                          transform=val_transform)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers)
    if args.data == 'DomainNet':
        test_dataset = dataset(root=args.root,
                               task=args.target,
                               split='test',
                               download=True,
                               transform=val_transform)
        test_loader = DataLoader(test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.workers)
    else:
        test_loader = val_loader

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    print("=> using pre-trained model '{}'".format(args.arch))
    G = models.__dict__[args.arch](pretrained=True).to(
        device)  # feature extractor
    num_classes = train_source_dataset.num_classes
    # two image classifier heads
    F1 = ImageClassifierHead(G.out_features, num_classes,
                             args.bottleneck_dim).to(device)
    F2 = ImageClassifierHead(G.out_features, num_classes,
                             args.bottleneck_dim).to(device)

    # define optimizer
    # the learning rate is fixed according to origin paper
    optimizer_g = SGD(G.parameters(), lr=args.lr, weight_decay=0.0005)
    optimizer_f = SGD([
        {
            "params": F1.parameters()
        },
        {
            "params": F2.parameters()
        },
    ],
                      momentum=0.9,
                      lr=args.lr,
                      weight_decay=0.0005)

    # resume from the best checkpoint
    if args.phase != 'train':
        checkpoint = torch.load(logger.get_checkpoint_path('best'),
                                map_location='cpu')
        G.load_state_dict(checkpoint['G'])
        F1.load_state_dict(checkpoint['F1'])
        F2.load_state_dict(checkpoint['F2'])

    # analysis the model
    if args.phase == 'analysis':
        # extract features from both domains
        feature_extractor = G.to(device)
        source_feature = collect_feature(train_source_loader,
                                         feature_extractor, device)
        target_feature = collect_feature(train_target_loader,
                                         feature_extractor, device)
        # plot t-SNE
        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
        tsne.visualize(source_feature, target_feature, tSNE_filename)
        print("Saving t-SNE to", tSNE_filename)
        # calculate A-distance, which is a measure for distribution discrepancy
        A_distance = a_distance.calculate(source_feature, target_feature,
                                          device)
        print("A-distance =", A_distance)
        return

    if args.phase == 'test':
        acc1 = validate(test_loader, G, F1, F2, args)
        print(acc1)
        return

    # start training
    best_acc1 = 0.
    best_results = None
    for epoch in range(args.epochs):
        # train for one epoch
        train(train_source_iter, train_target_iter, G, F1, F2, optimizer_g,
              optimizer_f, epoch, args)

        # evaluate on validation set
        results = validate(val_loader, G, F1, F2, args)

        # remember best acc@1 and save checkpoint
        torch.save(
            {
                'G': G.state_dict(),
                'F1': F1.state_dict(),
                'F2': F2.state_dict()
            }, logger.get_checkpoint_path('latest'))
        if max(results) > best_acc1:
            shutil.copy(logger.get_checkpoint_path('latest'),
                        logger.get_checkpoint_path('best'))
            best_acc1 = max(results)
            best_results = results

    print("best_acc1 = {:3.1f}, results = {}".format(best_acc1, best_results))

    # evaluate on test set
    checkpoint = torch.load(logger.get_checkpoint_path('best'),
                            map_location='cpu')
    G.load_state_dict(checkpoint['G'])
    F1.load_state_dict(checkpoint['F1'])
    F2.load_state_dict(checkpoint['F2'])
    results = validate(test_loader, G, F1, F2, args)
    print("test_acc1 = {:3.1f}".format(max(results)))

    logger.close()
Exemple #3
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, G: nn.Module,
          F1: ImageClassifierHead, F2: ImageClassifierHead, optimizer_g: SGD,
          optimizer_f: SGD, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':3.1f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    trans_losses = AverageMeter('Trans Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    G.train()
    F1.train()
    F2.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)
        x = torch.cat((x_s, x_t), dim=0)
        assert x.requires_grad is False

        # measure data loading time
        data_time.update(time.time() - end)

        # Step A train all networks to minimize loss on source domain
        optimizer_g.zero_grad()
        optimizer_f.zero_grad()

        g = G(x)
        y_1 = F1(g)
        y_2 = F2(g)
        y1_s, y1_t = y_1.chunk(2, dim=0)
        y2_s, y2_t = y_2.chunk(2, dim=0)

        y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1)
        loss = F.cross_entropy(y1_s, labels_s) + F.cross_entropy(y2_s, labels_s) + \
               0.01 * (entropy(y1_t) + entropy(y2_t))
        loss.backward()
        optimizer_g.step()
        optimizer_f.step()

        # Step B train classifier to maximize discrepancy
        optimizer_g.zero_grad()
        optimizer_f.zero_grad()

        g = G(x)
        y_1 = F1(g)
        y_2 = F2(g)
        y1_s, y1_t = y_1.chunk(2, dim=0)
        y2_s, y2_t = y_2.chunk(2, dim=0)
        y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1)
        loss = F.cross_entropy(y1_s, labels_s) + F.cross_entropy(y2_s, labels_s) + \
               0.01 * (entropy(y1_t) + entropy(y2_t)) - classifier_discrepancy(y1_t, y2_t) * args.trade_off
        loss.backward()
        optimizer_f.step()

        # Step C train genrator to minimize discrepancy
        for k in range(args.num_k):
            optimizer_g.zero_grad()
            g = G(x)
            y_1 = F1(g)
            y_2 = F2(g)
            y1_s, y1_t = y_1.chunk(2, dim=0)
            y2_s, y2_t = y_2.chunk(2, dim=0)
            y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1)
            mcd_loss = classifier_discrepancy(y1_t, y2_t) * args.trade_off
            mcd_loss.backward()
            optimizer_g.step()

        cls_acc = accuracy(y1_s, labels_s)[0]
        tgt_acc = accuracy(y1_t, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_t.size(0))
        trans_losses.update(mcd_loss.item(), x_s.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def main(args: argparse.Namespace):
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    if args.center_crop:
        train_transform = transforms.Compose([
            ResizeImage(256),
            transforms.CenterCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), normalize
        ])
    else:
        train_transform = transforms.Compose([
            ResizeImage(256),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), normalize
        ])
    val_transform = transforms.Compose([
        ResizeImage(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(), normalize
    ])

    dataset = datasets.__dict__[args.data]
    train_source_dataset = dataset(root=args.root,
                                   task=args.source,
                                   download=True,
                                   transform=train_transform)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    train_target_dataset = dataset(root=args.root,
                                   task=args.target,
                                   download=True,
                                   transform=train_transform)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    val_dataset = dataset(root=args.root,
                          task=args.target,
                          download=True,
                          transform=val_transform)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers)
    if args.data == 'DomainNet':
        test_dataset = dataset(root=args.root,
                               task=args.target,
                               evaluate=True,
                               download=True,
                               transform=val_transform)
        test_loader = DataLoader(test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.workers)
    else:
        test_loader = val_loader

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    print("=> using pre-trained model '{}'".format(args.arch))
    G = models.__dict__[args.arch](pretrained=True).to(
        device)  # feature extractor
    num_classes = train_source_dataset.num_classes
    # two image classifier heads
    F1 = ImageClassifierHead(G.out_features, num_classes,
                             args.bottleneck_dim).to(device)
    F2 = ImageClassifierHead(G.out_features, num_classes,
                             args.bottleneck_dim).to(device)

    # define optimizer
    # the learning rate is fixed according to origin paper
    optimizer_g = SGD(G.parameters(), lr=args.lr, weight_decay=0.0005)
    optimizer_f = SGD(F1.get_parameters() + F2.get_parameters(),
                      momentum=0.9,
                      lr=args.lr,
                      weight_decay=0.0005)

    # start training
    best_acc1 = 0.
    best_results = None
    for epoch in range(args.epochs):
        # train for one epoch
        train(train_source_iter, train_target_iter, G, F1, F2, optimizer_g,
              optimizer_f, epoch, args)

        # evaluate on validation set
        results = validate(val_loader, G, F1, F2, args)

        # remember best acc@1 and save checkpoint
        if max(results) > best_acc1:
            best_G, best_F1, best_F2 = copy.deepcopy(
                G.state_dict()), copy.deepcopy(F1.state_dict()), copy.deepcopy(
                    F2.state_dict())
            best_acc1 = max(results)
            best_results = results

    print("best_acc1 = {:3.1f}, results = {}".format(best_acc1, best_results))

    # evaluate on test set
    G.load_state_dict(best_G)
    F1.load_state_dict(best_F1)
    F2.load_state_dict(best_F2)
    results = validate(test_loader, G, F1, F2, args)
    print("test_acc1 = {:3.1f}".format(max(results)))
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    train_transform = utils.get_train_transform(
        args.train_resizing,
        random_horizontal_flip=not args.no_hflip,
        random_color_jitter=False,
        resize_size=args.resize_size,
        norm_mean=args.norm_mean,
        norm_std=args.norm_std)
    val_transform = utils.get_val_transform(args.val_resizing,
                                            resize_size=args.resize_size,
                                            norm_mean=args.norm_mean,
                                            norm_std=args.norm_std)
    print("train_transform: ", train_transform)
    print("val_transform: ", val_transform)

    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.workers)

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    print("=> using model '{}'".format(args.arch))
    G = utils.get_model(args.arch, pretrain=not args.scratch).to(
        device)  # feature extractor
    # two image classifier heads
    pool_layer = nn.Identity() if args.no_pool else None
    F1 = ImageClassifierHead(G.out_features, num_classes, args.bottleneck_dim,
                             pool_layer).to(device)
    F2 = ImageClassifierHead(G.out_features, num_classes, args.bottleneck_dim,
                             pool_layer).to(device)

    # define optimizer
    # the learning rate is fixed according to origin paper
    optimizer_g = SGD(G.parameters(), lr=args.lr, weight_decay=0.0005)
    optimizer_f = SGD([
        {
            "params": F1.parameters()
        },
        {
            "params": F2.parameters()
        },
    ],
                      momentum=0.9,
                      lr=args.lr,
                      weight_decay=0.0005)

    # resume from the best checkpoint
    if args.phase != 'train':
        checkpoint = torch.load(logger.get_checkpoint_path('best'),
                                map_location='cpu')
        G.load_state_dict(checkpoint['G'])
        F1.load_state_dict(checkpoint['F1'])
        F2.load_state_dict(checkpoint['F2'])

    # analysis the model
    if args.phase == 'analysis':
        # extract features from both domains
        feature_extractor = nn.Sequential(G, F1.pool_layer).to(device)
        source_feature = collect_feature(train_source_loader,
                                         feature_extractor, device)
        target_feature = collect_feature(train_target_loader,
                                         feature_extractor, device)
        # plot t-SNE
        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
        tsne.visualize(source_feature, target_feature, tSNE_filename)
        print("Saving t-SNE to", tSNE_filename)
        # calculate A-distance, which is a measure for distribution discrepancy
        A_distance = a_distance.calculate(source_feature, target_feature,
                                          device)
        print("A-distance =", A_distance)
        return

    if args.phase == 'test':
        acc1 = validate(test_loader, G, F1, F2, args)
        print(acc1)
        return

    # start training
    best_acc1 = 0.
    best_results = None
    for epoch in range(args.epochs):
        # train for one epoch
        train(train_source_iter, train_target_iter, G, F1, F2, optimizer_g,
              optimizer_f, epoch, args)

        # evaluate on validation set
        results = validate(val_loader, G, F1, F2, args)

        # remember best acc@1 and save checkpoint
        torch.save(
            {
                'G': G.state_dict(),
                'F1': F1.state_dict(),
                'F2': F2.state_dict()
            }, logger.get_checkpoint_path('latest'))
        if max(results) > best_acc1:
            shutil.copy(logger.get_checkpoint_path('latest'),
                        logger.get_checkpoint_path('best'))
            best_acc1 = max(results)
            best_results = results

    print("best_acc1 = {:3.1f}, results = {}".format(best_acc1, best_results))

    # evaluate on test set
    checkpoint = torch.load(logger.get_checkpoint_path('best'),
                            map_location='cpu')
    G.load_state_dict(checkpoint['G'])
    F1.load_state_dict(checkpoint['F1'])
    F2.load_state_dict(checkpoint['F2'])
    results = validate(test_loader, G, F1, F2, args)
    print("test_acc1 = {:3.1f}".format(max(results)))

    logger.close()