示例#1
0
def get_dataset(dataset_name,
                root,
                source,
                target,
                train_source_transform,
                val_transform,
                train_target_transform=None):
    if train_target_transform is None:
        train_target_transform = train_source_transform
    # load datasets from common.vision.datasets
    dataset = datasets.__dict__[dataset_name]
    source_dataset = open_set(dataset, source=True)
    target_dataset = open_set(dataset, source=False)

    train_source_dataset = source_dataset(root=root,
                                          task=source,
                                          download=True,
                                          transform=train_source_transform)
    train_target_dataset = target_dataset(root=root,
                                          task=target,
                                          download=True,
                                          transform=train_target_transform)
    val_dataset = target_dataset(root=root,
                                 task=target,
                                 download=True,
                                 transform=val_transform)
    if dataset_name == 'DomainNet':
        test_dataset = target_dataset(root=root,
                                      task=target,
                                      split='test',
                                      download=True,
                                      transform=val_transform)
    else:
        test_dataset = val_dataset
    class_names = train_source_dataset.classes
    num_classes = len(class_names)
    return train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, class_names
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    if args.center_crop:
        train_transform = T.Compose([
            ResizeImage(256),
            T.CenterCrop(224),
            T.RandomHorizontalFlip(),
            T.ToTensor(), normalize
        ])
    else:
        train_transform = T.Compose([
            ResizeImage(256),
            T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            T.ToTensor(), normalize
        ])
    val_transform = T.Compose(
        [ResizeImage(256),
         T.CenterCrop(224),
         T.ToTensor(), normalize])

    dataset = datasets.__dict__[args.data]
    source_dataset = open_set(dataset, source=True)
    target_dataset = open_set(dataset, source=False)
    train_source_dataset = source_dataset(root=args.root,
                                          task=args.source,
                                          download=True,
                                          transform=train_transform)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    val_dataset = target_dataset(root=args.root,
                                 task=args.target,
                                 download=True,
                                 transform=val_transform)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers)
    if args.data == 'DomainNet':
        test_dataset = target_dataset(root=args.root,
                                      task=args.target,
                                      split='test',
                                      download=True,
                                      transform=val_transform)
        test_loader = DataLoader(test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.workers)
    else:
        test_loader = val_loader

    train_source_iter = ForeverDataIterator(train_source_loader)

    # create model
    print("=> using pre-trained model '{}'".format(args.arch))
    backbone = models.__dict__[args.arch](pretrained=True)
    num_classes = train_source_dataset.num_classes
    classifier = Classifier(backbone, num_classes).to(device)
    # define optimizer and lr scheduler
    optimizer = SGD(classifier.get_parameters(),
                    args.lr,
                    momentum=args.momentum,
                    weight_decay=args.wd,
                    nesterov=True)
    lr_scheduler = LambdaLR(
        optimizer, lambda x: args.lr *
        (1. + args.lr_gamma * float(x))**(-args.lr_decay))

    # analysis the model
    if args.phase == 'analysis':
        # using shuffled val loader
        val_loader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=args.workers)
        # extract features from both domains
        feature_extractor = nn.Sequential(classifier.backbone,
                                          classifier.bottleneck).to(device)
        source_feature = collect_feature(train_source_loader,
                                         feature_extractor, device)
        target_feature = collect_feature(val_loader, feature_extractor, device)
        # plot t-SNE
        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
        tsne.visualize(source_feature, target_feature, tSNE_filename)
        print("Saving t-SNE to", tSNE_filename)
        # calculate A-distance, which is a measure for distribution discrepancy
        A_distance = a_distance.calculate(source_feature, target_feature,
                                          device)
        print("A-distance =", A_distance)
        return

    if args.phase == 'test':
        acc1 = validate(test_loader, classifier, args)
        print(acc1)
        return

    # start training
    best_h_score = 0.
    for epoch in range(args.epochs):
        # train for one epoch
        train(train_source_iter, classifier, optimizer, lr_scheduler, epoch,
              args)

        # evaluate on validation set
        h_score = validate(val_loader, classifier, args)

        # remember best acc@1 and save checkpoint
        torch.save(classifier.state_dict(),
                   logger.get_checkpoint_path('latest'))
        if h_score > best_h_score:
            shutil.copy(logger.get_checkpoint_path('latest'),
                        logger.get_checkpoint_path('best'))
        best_h_score = max(h_score, best_h_score)

    print("best_h_score = {:3.1f}".format(best_h_score))

    # evaluate on test set
    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
    h_score = validate(test_loader, classifier, args)
    print("test_h_score = {:3.1f}".format(h_score))

    logger.close()
示例#3
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    if args.center_crop:
        train_transform = T.Compose([
            ResizeImage(256),
            T.CenterCrop(224),
            T.RandomHorizontalFlip(),
            T.ToTensor(), normalize
        ])
    else:
        train_transform = T.Compose([
            ResizeImage(256),
            T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            T.ToTensor(), normalize
        ])
    val_transform = T.Compose(
        [ResizeImage(256),
         T.CenterCrop(224),
         T.ToTensor(), normalize])

    dataset = datasets.__dict__[args.data]
    """
    dataset settings for SECC
    """
    from common.vision.datasets.office31 import Office31
    public_classes = Office31.CLASSES[:10]
    source_private = Office31.CLASSES[10:20]
    target_private = Office31.CLASSES[20:]

    source_dataset = open_set(dataset, public_classes, source_private)
    target_dataset = open_set(dataset, public_classes, target_private)
    """"""

    train_source_dataset = source_dataset(root=args.root,
                                          task=args.source,
                                          download=True,
                                          transform=train_transform)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    train_target_dataset = target_dataset(root=args.root,
                                          task=args.target,
                                          download=True,
                                          transform=train_transform)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    val_dataset = target_dataset(root=args.root,
                                 task=args.target,
                                 download=True,
                                 transform=val_transform)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers)

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    test_loader = val_loader

    # create model
    print("=> using pre-trained model '{}'".format(args.arch))
    num_classes = train_source_dataset.num_classes
    backbone = models.__dict__[args.arch](pretrained=True)
    """
    mean teacher model for SECC
    """

    classifier = ImageClassifier(backbone,
                                 num_classes,
                                 bottleneck_dim=args.bottleneck_dim).to(device)
    teacher = EmaTeacher(classifier, 0.9)

    k = 25
    """
    distribution cluster for SECC
    """
    print("=> initiating k-means clusters")
    feature_extractor = FeatureExtractor(backbone)
    cluster_distribution = ClusterDistribution(train_target_loader,
                                               feature_extractor,
                                               k=k)
    """
    cluster assignment for SECC
    """
    cluster_assignment = ASoftmax(
        feature_extractor,
        num_clusters=k,
        num_features=cluster_distribution.num_features).to(device)
    """
    loss functions for SECC
    """

    kl_loss = nn.KLDivLoss().to(device)
    conditional_loss = ConditionalEntropyLoss().to(device)
    consistent_loss = L2ConsistencyLoss().to(device)
    class_balance_loss = ClassBalanceLoss(num_classes).to(device)

    # define optimizer and lr scheduler
    optimizer = SGD(classifier.get_parameters() +
                    cluster_assignment.get_parameters(),
                    args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay,
                    nesterov=True)
    lr_scheduler = LambdaLR(
        optimizer, lambda x: args.lr *
        (1. + args.lr_gamma * float(x))**(-args.lr_decay))

    # analysis the model
    if args.phase == 'analysis':
        # extract features from both domains
        feature_extractor = nn.Sequential(classifier.backbone,
                                          classifier.bottleneck).to(device)
        source_feature = collect_feature(train_source_loader,
                                         feature_extractor, device)
        target_feature = collect_feature(train_target_loader,
                                         feature_extractor, device)
        # plot t-SNE
        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
        tsne.visualize(source_feature, target_feature, tSNE_filename)
        print("Saving t-SNE to", tSNE_filename)
        # calculate A-distance, which is a measure for distribution discrepancy
        A_distance = a_distance.calculate(source_feature, target_feature,
                                          device)
        print("A-distance =", A_distance)
        return

    if args.phase == 'test':
        acc1 = validate(test_loader, classifier, args)
        print(acc1)
        return

    # start training
    best_h_score = 0.
    for epoch in range(args.epochs):
        # train for one epoch
        train(train_source_iter, train_target_iter, classifier, teacher,
              cluster_assignment, cluster_distribution, consistent_loss,
              class_balance_loss, kl_loss, conditional_loss, optimizer,
              lr_scheduler, epoch, args)

        # evaluate on validation set
        h_score = validate(val_loader, classifier, args)

        # remember best acc@1 and save checkpoint
        torch.save(classifier.state_dict(),
                   logger.get_checkpoint_path('latest'))
        if h_score > best_h_score:
            shutil.copy(logger.get_checkpoint_path('latest'),
                        logger.get_checkpoint_path('best'))
        best_h_score = max(h_score, best_h_score)

    print("best_h_score = {:3.1f}".format(best_h_score))

    # evaluate on test set
    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
    h_score = validate(test_loader, classifier, args)
    print("test_h_score = {:3.1f}".format(h_score))

    logger.close()