Beispiel #1
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        # warnings.warn('You have chosen to seed training. '
        #               'This will turn on the CUDNN deterministic setting, '
        #               'which can slow down your training considerably! '
        #               'You may see unexpected behavior when restarting '
        #               'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    if args.center_crop:
        train_transform = T.Compose([
            ResizeImage(256),
            T.CenterCrop(224),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            normalize
        ])
    else:
        train_transform = T.Compose([
            ResizeImage(256),
            T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            normalize
        ])
    val_transform = T.Compose([
        ResizeImage(256),
        T.CenterCrop(224),
        T.ToTensor(),
        normalize
    ])

    dataset = datasets.__dict__[args.data]
    source_dataset = open_set(dataset, source=True)
    target_dataset = open_set(dataset, source=False)
    train_source_dataset = source_dataset(root=args.root, task=args.source, download=True, transform=train_transform)
    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
                                     shuffle=True, num_workers=args.workers, drop_last=True)
    train_target_dataset = target_dataset(root=args.root, task=args.target, download=True, transform=train_transform)
    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
                                     shuffle=True, num_workers=args.workers, drop_last=True)
    val_dataset = target_dataset(root=args.root, task=args.target, download=True, transform=val_transform)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
    if args.data == 'DomainNet':
        test_dataset = target_dataset(root=args.root, task=args.target, split='test', download=True,
                                      transform=val_transform)
        test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
    else:
        test_loader = val_loader

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    print("=> using pre-trained model '{}'".format(args.arch))
    num_classes = train_source_dataset.num_classes
    backbone = models.__dict__[args.arch](pretrained=True)

    classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim).to(device)
    domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device)

    # define optimizer and lr scheduler
    optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(),
                    args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
    lr_scheduler = LambdaLR(optimizer, lambda x:  args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))

    # define loss function
    domain_adv = DomainAdversarialLoss(domain_discri).to(device)

    # analysis the model
    if args.phase == 'analysis':
        # extract features from both domains
        feature_extractor = nn.Sequential(classifier.backbone, classifier.bottleneck).to(device)
        source_feature = collect_feature(train_source_loader, feature_extractor, device)
        target_feature = collect_feature(train_target_loader, feature_extractor, device)
        # plot t-SNE
        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
        tsne.visualize(source_feature, target_feature, tSNE_filename)
        print("Saving t-SNE to", tSNE_filename)
        # calculate A-distance, which is a measure for distribution discrepancy
        A_distance = a_distance.calculate(source_feature, target_feature, device)
        print("A-distance =", A_distance)
        return

    if args.phase == 'test':
        acc1 = validate(test_loader, classifier, args)
        print(acc1)
        return

    # start training
    best_h_score = 0.
    for epoch in range(args.epochs):
        # train for one epoch
        train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer,
              lr_scheduler, epoch, args)

        # evaluate on validation set
        h_score = validate(val_loader, classifier, args)

        # remember best acc@1 and save checkpoint
        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
        if h_score > best_h_score:
            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
        best_h_score = max(h_score, best_h_score)

    print("best_h_score = {:3.1f}".format(best_h_score))

    # evaluate on test set
    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
    h_score = validate(test_loader, classifier, args)
    print("test_h_score = {:3.1f}".format(h_score))

    logger.close()
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          domain_discri: DomainDiscriminator,
          domain_adv: DomainAdversarialLoss, gl, optimizer: SGD,
          lr_scheduler: LambdaLR, optimizer_d: SGD, lr_scheduler_d: LambdaLR,
          epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    losses_s = AverageMeter('Cls Loss', ':6.2f')
    losses_transfer = AverageMeter('Transfer Loss', ':6.2f')
    losses_discriminator = AverageMeter('Discriminator Loss', ':6.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    domain_accs = AverageMeter('Domain Acc', ':3.1f')
    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_s, losses_transfer, losses_discriminator,
        cls_accs, domain_accs
    ],
                             prefix="Epoch: [{}]".format(epoch))

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, _ = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # Step 1: Train the classifier, freeze the discriminator
        model.train()
        domain_discri.eval()
        set_requires_grad(model, True)
        set_requires_grad(domain_discri, False)
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_s, y_t = y.chunk(2, dim=0)
        loss_s = F.cross_entropy(y_s, labels_s)

        # adversarial training to fool the discriminator
        d = domain_discri(gl(f))
        d_s, d_t = d.chunk(2, dim=0)
        loss_transfer = 0.5 * (domain_adv(d_s, 'target') +
                               domain_adv(d_t, 'source'))

        optimizer.zero_grad()
        (loss_s + loss_transfer * args.trade_off).backward()
        optimizer.step()
        lr_scheduler.step()

        # Step 2: Train the discriminator
        model.eval()
        domain_discri.train()
        set_requires_grad(model, False)
        set_requires_grad(domain_discri, True)
        d = domain_discri(f.detach())
        d_s, d_t = d.chunk(2, dim=0)
        loss_discriminator = 0.5 * (domain_adv(d_s, 'source') +
                                    domain_adv(d_t, 'target'))

        optimizer_d.zero_grad()
        loss_discriminator.backward()
        optimizer_d.step()
        lr_scheduler_d.step()

        losses_s.update(loss_s.item(), x_s.size(0))
        losses_transfer.update(loss_transfer.item(), x_s.size(0))
        losses_discriminator.update(loss_discriminator.item(), x_s.size(0))

        cls_acc = accuracy(y_s, labels_s)[0]
        cls_accs.update(cls_acc.item(), x_s.size(0))
        domain_acc = 0.5 * (binary_accuracy(d_s, torch.ones_like(d_s)) +
                            binary_accuracy(d_t, torch.zeros_like(d_t)))
        domain_accs.update(domain_acc.item(), x_s.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    train_transform = utils.get_train_transform(
        args.train_resizing,
        random_horizontal_flip=not args.no_hflip,
        random_color_jitter=False,
        resize_size=args.resize_size,
        norm_mean=args.norm_mean,
        norm_std=args.norm_std)
    val_transform = utils.get_val_transform(args.val_resizing,
                                            resize_size=args.resize_size,
                                            norm_mean=args.norm_mean,
                                            norm_std=args.norm_std)
    print("train_transform: ", train_transform)
    print("val_transform: ", val_transform)

    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.workers)

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    print("=> using model '{}'".format(args.arch))
    backbone = utils.get_model(args.arch, pretrain=not args.scratch)
    pool_layer = nn.Identity() if args.no_pool else None
    classifier = ImageClassifier(backbone,
                                 num_classes,
                                 bottleneck_dim=args.bottleneck_dim,
                                 pool_layer=pool_layer,
                                 finetune=not args.scratch).to(device)
    domain_discri = DomainDiscriminator(in_feature=classifier.features_dim,
                                        hidden_size=1024).to(device)

    # define loss function
    domain_adv = DomainAdversarialLoss().to(device)
    gl = WarmStartGradientLayer(alpha=1.,
                                lo=0.,
                                hi=1.,
                                max_iters=1000,
                                auto_step=True)

    # define optimizer and lr scheduler
    optimizer = SGD(classifier.get_parameters(),
                    args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay,
                    nesterov=True)
    optimizer_d = SGD(domain_discri.get_parameters(),
                      args.lr_d,
                      momentum=args.momentum,
                      weight_decay=args.weight_decay,
                      nesterov=True)
    lr_scheduler = LambdaLR(
        optimizer, lambda x: args.lr *
        (1. + args.lr_gamma * float(x))**(-args.lr_decay))
    lr_scheduler_d = LambdaLR(
        optimizer_d, lambda x: args.lr_d *
        (1. + args.lr_gamma * float(x))**(-args.lr_decay))

    # resume from the best checkpoint
    if args.phase != 'train':
        checkpoint = torch.load(logger.get_checkpoint_path('best'),
                                map_location='cpu')
        classifier.load_state_dict(checkpoint)

    # analysis the model
    if args.phase == 'analysis':
        # extract features from both domains
        feature_extractor = nn.Sequential(classifier.backbone,
                                          classifier.pool_layer,
                                          classifier.bottleneck).to(device)
        source_feature = collect_feature(train_source_loader,
                                         feature_extractor, device)
        target_feature = collect_feature(train_target_loader,
                                         feature_extractor, device)
        # plot t-SNE
        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
        tsne.visualize(source_feature, target_feature, tSNE_filename)
        print("Saving t-SNE to", tSNE_filename)
        # calculate A-distance, which is a measure for distribution discrepancy
        A_distance = a_distance.calculate(source_feature, target_feature,
                                          device)
        print("A-distance =", A_distance)
        return

    if args.phase == 'test':
        acc1 = utils.validate(test_loader, classifier, args, device)
        print(acc1)
        return

    # start training
    best_acc1 = 0.
    for epoch in range(args.epochs):
        print("lr classifier:", lr_scheduler.get_lr())
        print("lr discriminator:", lr_scheduler_d.get_lr())
        # train for one epoch
        train(train_source_iter, train_target_iter, classifier, domain_discri,
              domain_adv, gl, optimizer, lr_scheduler, optimizer_d,
              lr_scheduler_d, epoch, args)

        # evaluate on validation set
        acc1 = utils.validate(val_loader, classifier, args, device)

        # remember best acc@1 and save checkpoint
        torch.save(classifier.state_dict(),
                   logger.get_checkpoint_path('latest'))
        if acc1 > best_acc1:
            shutil.copy(logger.get_checkpoint_path('latest'),
                        logger.get_checkpoint_path('best'))
        best_acc1 = max(acc1, best_acc1)

    print("best_acc1 = {:3.1f}".format(best_acc1))

    # evaluate on test set
    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
    acc1 = utils.validate(test_loader, classifier, args, device)
    print("test_acc1 = {:3.1f}".format(acc1))

    logger.close()
Beispiel #4
0
def main(args: argparse.Namespace):
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_transform = transforms.Compose([
        ResizeImage(256),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), normalize
    ])
    val_tranform = transforms.Compose([
        ResizeImage(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(), normalize
    ])

    dataset = datasets.__dict__[args.data]
    train_source_dataset = dataset(root=args.root,
                                   task=args.source,
                                   download=True,
                                   transform=train_transform)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    train_target_dataset = dataset(root=args.root,
                                   task=args.target,
                                   download=True,
                                   transform=train_transform)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    val_dataset = dataset(root=args.root,
                          task=args.target,
                          download=True,
                          transform=val_tranform)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers)
    if args.data == 'DomainNet':
        test_dataset = dataset(root=args.root,
                               task=args.target,
                               evaluate=True,
                               download=True,
                               transform=val_tranform)
        test_loader = DataLoader(test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.workers)
    else:
        test_loader = val_loader

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    print("=> using pre-trained model '{}'".format(args.arch))
    backbone = models.__dict__[args.arch](pretrained=True)
    classifier = ImageClassifier(backbone,
                                 train_source_dataset.num_classes).to(device)
    domain_discri = DomainDiscriminator(in_feature=classifier.features_dim,
                                        hidden_size=1024).to(device)

    # define optimizer and lr scheduler
    optimizer = SGD(classifier.get_parameters() +
                    domain_discri.get_parameters(),
                    args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay,
                    nesterov=True)
    lr_scheduler = StepwiseLR(optimizer,
                              init_lr=args.lr,
                              gamma=0.001,
                              decay_rate=0.75)

    # define loss function
    domain_adv = DomainAdversarialLoss(domain_discri).to(device)

    # start training
    best_acc1 = 0.
    for epoch in range(args.epochs):
        # train for one epoch
        train(train_source_iter, train_target_iter, classifier, domain_adv,
              optimizer, lr_scheduler, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, classifier, args)

        # remember best acc@1 and save checkpoint
        if acc1 > best_acc1:
            best_model = copy.deepcopy(classifier.state_dict())
        best_acc1 = max(acc1, best_acc1)

    print("best_acc1 = {:3.1f}".format(best_acc1))

    # evaluate on test set
    classifier.load_state_dict(best_model)
    acc1 = validate(test_loader, classifier, args)
    print("test_acc1 = {:3.1f}".format(acc1))
Beispiel #5
0
def da_methods(discr, frame, ag, dim_x, dim_y, device, ckpt, discr_src=None):
    if ag.mode not in MODES_GEN:
        celossobj, celossfn = get_ce_or_bce_loss(discr, dim_y, ag.reduction)
        if ag.mode == "dann":
            domdisc = DomainDiscriminator(
                in_feature=discr.dim_s, hidden_size=ag.domdisc_dimh).to(device)
            dalossobj = DomainAdversarialLoss(
                domdisc, reduction=ag.reduction).to(device)
            auto_load(locals(), ['domdisc', 'dalossobj'], ckpt)

            def lossfn(x, y, xt):
                logit, feat = discr.ys1x(x)
                featt = discr.s1x(xt)
                if feat.shape[0] < featt.shape[0]:
                    featt = featt[:feat.shape[0]]
                elif feat.shape[0] > featt.shape[0]:
                    feat = feat[:featt.shape[0]]
                return celossobj(logit,
                                 y.float() if dim_y == 1 else
                                 y) + ag.wda * dalossobj(feat, featt)
        elif ag.mode == "cdan":
            # In both randomized and not randomized versions, the code has problems.
            # For randomized, the dim_s*num_cls is fed to domdisc.
            # For not rand, `tc.mm` receives the wrong input order in `RandomizedMultiLinearMap.forward()`
            num_classes = (2 if dim_y == 1 else dim_y)
            domdisc = DomainDiscriminator(
                in_feature=discr.dim_s *
                (1 if ag.cdan_rand else num_classes),  # confusing `in_feature`
                hidden_size=ag.domdisc_dimh).to(device)
            dalossobj = ConditionalDomainAdversarialLoss(
                domdisc,
                reduction=ag.reduction,
                randomized=ag.cdan_rand,
                num_classes=num_classes,
                features_dim=discr.dim_s,
                randomized_dim=discr.dim_s).to(device)
            auto_load(locals(), ['domdisc', 'dalossobj'], ckpt)

            def lossfn(x, y, xt):
                logit, feat = discr.ys1x(x)
                logitt, featt = discr.ys1x(xt)
                logit_stack = tc.stack([tc.zeros_like(logit), logit],
                                       dim=-1) if dim_y == 1 else logit
                logitt_stack = tc.stack([tc.zeros_like(logitt), logitt],
                                        dim=-1) if dim_y == 1 else logitt
                return celossobj(
                    logit,
                    y.float() if dim_y == 1 else y) + ag.wda * dalossobj(
                        logit_stack, feat, logitt_stack, featt)
        elif ag.mode == "dan":
            domdisc = None
            dalossobj = MultipleKernelMaximumMeanDiscrepancy([
                GaussianKernel(alpha=alpha) for alpha in ag.ker_alphas
            ]).to(device)

            def lossfn(x, y, xt):
                logit, feat = discr.ys1x(x)
                featt = discr.s1x(xt)
                if feat.shape[0] < featt.shape[0]:
                    featt = featt[:feat.shape[0]]
                elif feat.shape[0] > featt.shape[0]:
                    feat = feat[:featt.shape[0]]
                return celossobj(logit,
                                 y.float() if dim_y == 1 else
                                 y) + ag.wda * dalossobj(feat, featt)
        elif ag.mode == "mdd":
            num_classes = (2 if dim_y == 1 else dim_y)
            domdisc = mlp.MLP([
                dim_x, ag.domdisc_dimh, ag.domdisc_dimh, num_classes
            ]).to(
                device
            )  # actually not domain discriminator but an auxiliary (adversarial) classifier
            dalossobj = MarginDisparityDiscrepancy(
                margin=ag.mdd_margin, reduction=ag.reduction).to(device)
            auto_load(locals(), ['domdisc', 'dalossobj'], ckpt)

            def lossfn(x, y, xt):
                logit, logitt = discr(x), discr(xt)
                logit_adv, logitt_adv = domdisc(x.reshape(-1, dim_x)), domdisc(
                    xt.reshape(-1, dim_x))
                logit_stack = tc.stack([tc.zeros_like(logit), logit],
                                       dim=-1) if dim_y == 1 else logit
                logitt_stack = tc.stack([tc.zeros_like(logitt), logitt],
                                        dim=-1) if dim_y == 1 else logitt
                return celossobj(
                    logit,
                    y.float() if dim_y == 1 else y) + ag.wda * dalossobj(
                        logit_stack, logit_adv, logitt_stack, logitt_adv)
        elif ag.mode == "bnm":
            domdisc = None
            dalossobj = None

            def lossfn(x, y, xt):
                logit, logitt = discr(x), discr(xt)
                logitt_stack = tc.stack([tc.zeros_like(logitt), logitt],
                                        dim=-1) if dim_y == 1 else logitt
                softmax_tgt = logitt_stack.softmax(dim=1)
                _, s_tgt, _ = tc.svd(softmax_tgt)
                # if config["method"]=="BNM":
                transfer_loss = -tc.mean(s_tgt)
                # elif config["method"]=="BFM":
                #     transfer_loss = -tc.sqrt(tc.sum(s_tgt*s_tgt)/s_tgt.shape[0])
                # elif config["method"]=="ENT":
                #     transfer_loss = -tc.mean(tc.sum(softmax_tgt*tc.log(softmax_tgt+1e-8),dim=1))/tc.log(softmax_tgt.shape[1])
                return celossobj(
                    logit,
                    y.float() if dim_y == 1 else y) + ag.wda * transfer_loss
        else:

            pass
        for obj in [dalossobj, domdisc]:
            if obj is not None: obj.train()
    else:
        if ag.mode.endswith("-da2") and discr_src is not None:
            true_discr = discr_src
        elif ag.mode in MODES_TWIST and ag.true_sup:
            true_discr = partial(frame.logit_y1x_src, n_mc_q=ag.n_mc_q)
        else:
            true_discr = discr
        celossfn = get_ce_or_bce_loss(true_discr, dim_y, ag.reduction)[1]
        lossobj = frame.get_lossfn(ag.n_mc_q,
                                   ag.reduction,
                                   "defl",
                                   weight_da=ag.wda / ag.wgen,
                                   wlogpi=ag.wlogpi / ag.wgen)
        lossfn = add_ce_loss(lossobj, celossfn, ag)
        domdisc, dalossobj = None, None
    return lossfn, domdisc, dalossobj
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    train_transform = T.Compose(
        [T.Resize(args.resize_size),
         T.ToTensor(), normalize])
    val_transform = T.Compose(
        [T.Resize(args.resize_size),
         T.ToTensor(), normalize])

    dataset = datasets.__dict__[args.data]
    train_source_dataset = dataset(root=args.root,
                                   task=args.source,
                                   split='train',
                                   download=True,
                                   transform=train_transform)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    train_target_dataset = dataset(root=args.root,
                                   task=args.target,
                                   split='train',
                                   download=True,
                                   transform=train_transform)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    val_dataset = dataset(root=args.root,
                          task=args.target,
                          split='test',
                          download=True,
                          transform=val_transform)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers)

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    print("=> using pre-trained model '{}'".format(args.arch))
    backbone = models.__dict__[args.arch](pretrained=True)
    if args.normalization == 'IN':
        backbone = convert_model(backbone)
    num_factors = train_source_dataset.num_factors
    bottleneck = nn.Sequential(nn.AdaptiveAvgPool2d(output_size=(1, 1)),
                               nn.Flatten(),
                               nn.Linear(backbone.out_features, 256),
                               nn.ReLU())
    regressor = Regressor(backbone=backbone,
                          num_factors=num_factors,
                          bottleneck=bottleneck,
                          bottleneck_dim=256).to(device)
    print(regressor)
    domain_discri = DomainDiscriminator(in_feature=regressor.features_dim,
                                        hidden_size=1024).to(device)

    # define optimizer and lr scheduler
    optimizer = SGD(regressor.get_parameters() +
                    domain_discri.get_parameters(),
                    args.lr,
                    momentum=args.momentum,
                    weight_decay=args.wd,
                    nesterov=True)
    lr_scheduler = LambdaLR(
        optimizer, lambda x: args.lr *
        (1. + args.lr_gamma * float(x))**(-args.lr_decay))

    # define loss function
    dann = DomainAdversarialLoss(domain_discri).to(device)

    # resume from the best checkpoint
    if args.phase != 'train':
        checkpoint = torch.load(logger.get_checkpoint_path('best'),
                                map_location='cpu')
        regressor.load_state_dict(checkpoint)

    # analysis the model
    if args.phase == 'analysis':
        train_source_loader = DataLoader(train_source_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=args.workers,
                                         drop_last=True)
        train_target_loader = DataLoader(train_target_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=args.workers,
                                         drop_last=True)
        # extract features from both domains
        feature_extractor = nn.Sequential(regressor.backbone,
                                          regressor.bottleneck).to(device)
        source_feature = collect_feature(train_source_loader,
                                         feature_extractor, device)
        target_feature = collect_feature(train_target_loader,
                                         feature_extractor, device)
        # plot t-SNE
        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
        tsne.visualize(source_feature, target_feature, tSNE_filename)
        print("Saving t-SNE to", tSNE_filename)
        # calculate A-distance, which is a measure for distribution discrepancy
        A_distance = a_distance.calculate(source_feature, target_feature,
                                          device)
        print("A-distance =", A_distance)
        return

    if args.phase == 'test':
        mae = validate(val_loader, regressor, args,
                       train_source_dataset.factors, device)
        print(mae)
        return

    # start training
    best_mae = 100000.
    for epoch in range(args.epochs):
        # train for one epoch
        print("lr", lr_scheduler.get_lr())
        train(train_source_iter, train_target_iter, regressor, dann, optimizer,
              lr_scheduler, epoch, args)

        # evaluate on validation set
        mae = validate(val_loader, regressor, args,
                       train_source_dataset.factors, device)

        # remember best mae and save checkpoint
        torch.save(regressor.state_dict(),
                   logger.get_checkpoint_path('latest'))
        if mae < best_mae:
            shutil.copy(logger.get_checkpoint_path('latest'),
                        logger.get_checkpoint_path('best'))
        best_mae = min(mae, best_mae)
        print("mean MAE {:6.3f} best MAE {:6.3f}".format(mae, best_mae))

    print("best_mae = {:6.3f}".format(best_mae))

    logger.close()
def main(args: argparse.Namespace):
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    train_transform = transforms.Compose([
        ResizeImage(256),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])
    val_transform = transforms.Compose([
        ResizeImage(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])

    dataset = datasets.__dict__[args.data]
    train_source_dataset = dataset(root=args.root, task=args.source, download=True, transform=train_transform)
    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
                                     shuffle=True, num_workers=args.workers, drop_last=True)
    train_target_dataset = dataset(root=args.root, task=args.target, download=True, transform=train_transform)
    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
                                     shuffle=True, num_workers=args.workers, drop_last=True)
    val_dataset = dataset(root=args.root, task=args.target, download=True, transform=val_transform)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    print("=> using pre-trained model '{}'".format(args.arch))
    backbone = models.__dict__[args.arch](pretrained=True)
    classifier = ImageClassifier(backbone, train_source_dataset.num_classes).to(device)
    domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device)

    # define optimizer and lr scheduler
    optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(),
                    args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
    lr_scheduler = StepwiseLR(optimizer, init_lr=args.lr, gamma=0.001, decay_rate=0.75)

    # define loss function
    domain_adv = DomainAdversarialLoss(domain_discri).to(device)

    # start training
    best_acc1 = 0.
    best_model = classifier.state_dict()
    for epoch in range(args.epochs):
        # train for one epoch
        train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer,
              lr_scheduler, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, classifier, args)

        # remember best acc@1 and save checkpoint
        if acc1 > best_acc1:
            best_model = classifier.state_dict()
            torch.save(best_model, 'best_model.pth.tar')
        best_acc1 = max(acc1, best_acc1)

    print("best_acc1 = {:3.1f}".format(best_acc1))

    # visualize the results using T-SNE
    classifier.load_state_dict(best_model)
    classifier.eval()

    features, labels, domains = [], [], []
    source_val_dataset = dataset(root=args.root, task=args.source, download=True, transform=val_transform)
    source_val_loader = DataLoader(source_val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)

    with torch.no_grad():
        for loader in [source_val_loader, val_loader]:
            for i, (images, target) in enumerate(loader):
                images = images.to(device)
                target = target.to(device)

                # compute output
                _, f = classifier(images)
                features.extend(f.cpu().numpy().tolist())
                labels.extend(target)

    domains = np.concatenate((np.ones(len(source_val_dataset)), np.zeros(len(val_dataset))))
    features, labels = np.array(features), np.array(labels)
    print("source:", len(source_val_dataset), "target:", len(val_dataset))
    X_tsne = TSNE(n_components=2, random_state=33).fit_transform(features)
    plt.figure(figsize=(10, 10))
    plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=domains, cmap=col.ListedColormap(["r", "b"]), s=2)
    plt.savefig(os.path.join('{}_{}2{}.pdf'.format("dann", args.source, args.target)))
Beispiel #8
0
                           tar_val_loader,
                           5,
                           device,
                           type='tar')
            global_logger.info(f'test on test set nmi: {nmi}')
            if nmi > best_nmi:
                global_logger.info(f"save best model to {model_dir}")
                torch.save(backbone.state_dict(),
                           os.path.join(model_dir, 'minst_best_model.pth'))
                best_nmi = nmi


if __name__ == "__main__":
    # setup model
    backbone = FeatureExtractor().to(device)
    domain_discri = DomainDiscriminator(in_feature=128,
                                        hidden_size=256).to(device)
    domain_adv = DomainAdversarialLoss(domain_discri).to(device)

    # TODO feature reconstruction loss
    # TODO feautre transfer module

    src_domain_class = [0, 1, 2, 3, 4]
    tar_domain_class = [5, 6, 7, 8, 9]
    # setup dataloader
    src_train_loader = get_mnist_m_loader(
        dataset_root='./dataset/MNIST-M',
        label_filter=lambda x: x in src_domain_class,
        sample_per_cls=sample_num_per_cls,
        cls_num=cls_num)
    tar_train_loader = get_mnist_loader(
        dataset_root='./dataset/MNIST',