Пример #1
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          domain_discri: DomainDiscriminator,
          domain_adv: DomainAdversarialLoss, gl, optimizer: SGD,
          lr_scheduler: LambdaLR, optimizer_d: SGD, lr_scheduler_d: LambdaLR,
          epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    losses_s = AverageMeter('Cls Loss', ':6.2f')
    losses_transfer = AverageMeter('Transfer Loss', ':6.2f')
    losses_discriminator = AverageMeter('Discriminator Loss', ':6.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    domain_accs = AverageMeter('Domain Acc', ':3.1f')
    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_s, losses_transfer, losses_discriminator,
        cls_accs, domain_accs
    ],
                             prefix="Epoch: [{}]".format(epoch))

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, _ = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # Step 1: Train the classifier, freeze the discriminator
        model.train()
        domain_discri.eval()
        set_requires_grad(model, True)
        set_requires_grad(domain_discri, False)
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_s, y_t = y.chunk(2, dim=0)
        loss_s = F.cross_entropy(y_s, labels_s)

        # adversarial training to fool the discriminator
        d = domain_discri(gl(f))
        d_s, d_t = d.chunk(2, dim=0)
        loss_transfer = 0.5 * (domain_adv(d_s, 'target') +
                               domain_adv(d_t, 'source'))

        optimizer.zero_grad()
        (loss_s + loss_transfer * args.trade_off).backward()
        optimizer.step()
        lr_scheduler.step()

        # Step 2: Train the discriminator
        model.eval()
        domain_discri.train()
        set_requires_grad(model, False)
        set_requires_grad(domain_discri, True)
        d = domain_discri(f.detach())
        d_s, d_t = d.chunk(2, dim=0)
        loss_discriminator = 0.5 * (domain_adv(d_s, 'source') +
                                    domain_adv(d_t, 'target'))

        optimizer_d.zero_grad()
        loss_discriminator.backward()
        optimizer_d.step()
        lr_scheduler_d.step()

        losses_s.update(loss_s.item(), x_s.size(0))
        losses_transfer.update(loss_transfer.item(), x_s.size(0))
        losses_discriminator.update(loss_discriminator.item(), x_s.size(0))

        cls_acc = accuracy(y_s, labels_s)[0]
        cls_accs.update(cls_acc.item(), x_s.size(0))
        domain_acc = 0.5 * (binary_accuracy(d_s, torch.ones_like(d_s)) +
                            binary_accuracy(d_t, torch.zeros_like(d_t)))
        domain_accs.update(domain_acc.item(), x_s.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S,
          netD_T, criterion_gan, criterion_cycle, criterion_identity,
          optimizer_G, optimizer_D, fake_S_pool, fake_T_pool, epoch: int,
          visualize, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_G_S2T = AverageMeter('G_S2T', ':3.2f')
    losses_G_T2S = AverageMeter('G_T2S', ':3.2f')
    losses_D_S = AverageMeter('D_S', ':3.2f')
    losses_D_T = AverageMeter('D_T', ':3.2f')
    losses_cycle_S = AverageMeter('cycle_S', ':3.2f')
    losses_cycle_T = AverageMeter('cycle_T', ':3.2f')
    losses_identity_S = AverageMeter('idt_S', ':3.2f')
    losses_identity_T = AverageMeter('idt_T', ':3.2f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_G_S2T, losses_G_T2S, losses_D_S,
        losses_D_T, losses_cycle_S, losses_cycle_T, losses_identity_S,
        losses_identity_T
    ],
                             prefix="Epoch: [{}]".format(epoch))

    end = time.time()

    for i in range(args.iters_per_epoch):
        real_S, _ = next(train_source_iter)
        real_T, _ = next(train_target_iter)

        real_S = real_S.to(device)
        real_T = real_T.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # Compute fake images and reconstruction images.
        fake_T = netG_S2T(real_S)
        rec_S = netG_T2S(fake_T)
        fake_S = netG_T2S(real_T)
        rec_T = netG_S2T(fake_S)

        # Optimizing generators
        # discriminators require no gradients
        set_requires_grad(netD_S, False)
        set_requires_grad(netD_T, False)

        optimizer_G.zero_grad()
        # GAN loss D_T(G_S2T(S))
        loss_G_S2T = criterion_gan(netD_T(fake_T), real=True)
        # GAN loss D_S(G_T2S(B))
        loss_G_T2S = criterion_gan(netD_S(fake_S), real=True)
        # Cycle loss || G_T2S(G_S2T(S)) - S||
        loss_cycle_S = criterion_cycle(rec_S, real_S) * args.trade_off_cycle
        # Cycle loss || G_S2T(G_T2S(T)) - T||
        loss_cycle_T = criterion_cycle(rec_T, real_T) * args.trade_off_cycle
        # Identity loss
        # G_S2T should be identity if real_T is fed: ||G_S2T(real_T) - real_T||
        identity_T = netG_S2T(real_T)
        loss_identity_T = criterion_identity(identity_T,
                                             real_T) * args.trade_off_identity
        # G_T2S should be identity if real_S is fed: ||G_T2S(real_S) - real_S||
        identity_S = netG_T2S(real_S)
        loss_identity_S = criterion_identity(identity_S,
                                             real_S) * args.trade_off_identity
        # combined loss and calculate gradients
        loss_G = loss_G_S2T + loss_G_T2S + loss_cycle_S + loss_cycle_T + loss_identity_S + loss_identity_T
        loss_G.backward()
        optimizer_G.step()

        # Optimize discriminator
        set_requires_grad(netD_S, True)
        set_requires_grad(netD_T, True)
        optimizer_D.zero_grad()
        # Calculate GAN loss for discriminator D_S
        fake_S_ = fake_S_pool.query(fake_S.detach())
        loss_D_S = 0.5 * (criterion_gan(netD_S(real_S), True) +
                          criterion_gan(netD_S(fake_S_), False))
        loss_D_S.backward()
        # Calculate GAN loss for discriminator D_T
        fake_T_ = fake_T_pool.query(fake_T.detach())
        loss_D_T = 0.5 * (criterion_gan(netD_T(real_T), True) +
                          criterion_gan(netD_T(fake_T_), False))
        loss_D_T.backward()
        optimizer_D.step()

        # measure elapsed time
        losses_G_S2T.update(loss_G_S2T.item(), real_S.size(0))
        losses_G_T2S.update(loss_G_T2S.item(), real_S.size(0))
        losses_D_S.update(loss_D_S.item(), real_S.size(0))
        losses_D_T.update(loss_D_T.item(), real_S.size(0))
        losses_cycle_S.update(loss_cycle_S.item(), real_S.size(0))
        losses_cycle_T.update(loss_cycle_T.item(), real_S.size(0))
        losses_identity_S.update(loss_identity_S.item(), real_S.size(0))
        losses_identity_T.update(loss_identity_T.item(), real_S.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

            for tensor, name in zip([
                    real_S, real_T, fake_S, fake_T, rec_S, rec_T, identity_S,
                    identity_T
            ], [
                    "real_S", "real_T", "fake_S", "fake_T", "rec_S", "rec_T",
                    "identity_S", "identity_T"
            ]):
                visualize(tensor[0], "{}_{}".format(i, name))
Пример #3
0
 def __init__(self, model, alpha):
     self.model = model
     self.alpha = alpha
     self.teacher = copy.deepcopy(model)
     set_requires_grad(self.teacher, False)
Пример #4
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, netG_S2T, netG_T2S, netD_S,
          netD_T, siamese_net: spgan.SiameseNetwork,
          criterion_gan: cyclegan.LeastSquaresGenerativeAdversarialLoss,
          criterion_cycle: nn.L1Loss, criterion_identity: nn.L1Loss,
          criterion_contrastive: spgan.ContrastiveLoss, optimizer_G: Adam,
          optimizer_D: Adam, optimizer_siamese: Adam, fake_S_pool: ImagePool,
          fake_T_pool: ImagePool, epoch: int, visualize,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_G_S2T = AverageMeter('G_S2T', ':3.2f')
    losses_G_T2S = AverageMeter('G_T2S', ':3.2f')
    losses_D_S = AverageMeter('D_S', ':3.2f')
    losses_D_T = AverageMeter('D_T', ':3.2f')
    losses_cycle_S = AverageMeter('cycle_S', ':3.2f')
    losses_cycle_T = AverageMeter('cycle_T', ':3.2f')
    losses_identity_S = AverageMeter('idt_S', ':3.2f')
    losses_identity_T = AverageMeter('idt_T', ':3.2f')
    losses_contrastive_G = AverageMeter('contrastive_G', ':3.2f')
    losses_contrastive_siamese = AverageMeter('contrastive_siamese', ':3.2f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_G_S2T, losses_G_T2S, losses_D_S,
        losses_D_T, losses_cycle_S, losses_cycle_T, losses_identity_S,
        losses_identity_T, losses_contrastive_G, losses_contrastive_siamese
    ],
                             prefix="Epoch: [{}]".format(epoch))

    end = time.time()

    for i in range(args.iters_per_epoch):
        real_S, _, _, _ = next(train_source_iter)
        real_T, _, _, _ = next(train_target_iter)

        real_S = real_S.to(device)
        real_T = real_T.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # Compute fake images and reconstruction images.
        fake_T = netG_S2T(real_S)
        rec_S = netG_T2S(fake_T)
        fake_S = netG_T2S(real_T)
        rec_T = netG_S2T(fake_S)

        # ===============================================
        # train the generators (every two iterations)
        # ===============================================
        if i % 2 == 0:
            # save memory
            set_requires_grad(netD_S, False)
            set_requires_grad(netD_T, False)
            set_requires_grad(siamese_net, False)
            # GAN loss D_T(G_S2T(S))
            loss_G_S2T = criterion_gan(netD_T(fake_T), real=True)
            # GAN loss D_S(G_T2S(B))
            loss_G_T2S = criterion_gan(netD_S(fake_S), real=True)
            # Cycle loss || G_T2S(G_S2T(S)) - S||
            loss_cycle_S = criterion_cycle(rec_S,
                                           real_S) * args.trade_off_cycle
            # Cycle loss || G_S2T(G_T2S(T)) - T||
            loss_cycle_T = criterion_cycle(rec_T,
                                           real_T) * args.trade_off_cycle
            # Identity loss
            # G_S2T should be identity if real_T is fed: ||G_S2T(real_T) - real_T||
            identity_T = netG_S2T(real_T)
            loss_identity_T = criterion_identity(
                identity_T, real_T) * args.trade_off_identity
            # G_T2S should be identity if real_S is fed: ||G_T2S(real_S) - real_S||
            identity_S = netG_T2S(real_S)
            loss_identity_S = criterion_identity(
                identity_S, real_S) * args.trade_off_identity

            # siamese network output
            f_real_S = siamese_net(real_S)
            f_fake_T = siamese_net(fake_T)
            f_real_T = siamese_net(real_T)
            f_fake_S = siamese_net(fake_S)

            # positive pair
            loss_contrastive_p_G = criterion_contrastive(f_real_S, f_fake_T, 0) + \
                                   criterion_contrastive(f_real_T, f_fake_S, 0)
            # negative pair
            loss_contrastive_n_G = criterion_contrastive(f_fake_T, f_real_T, 1) + \
                                   criterion_contrastive(f_fake_S, f_real_S, 1) + \
                                   criterion_contrastive(f_real_S, f_real_T, 1)
            # contrastive loss
            loss_contrastive_G = (
                loss_contrastive_p_G +
                0.5 * loss_contrastive_n_G) / 4 * args.trade_off_contrastive

            # combined loss and calculate gradients
            loss_G = loss_G_S2T + loss_G_T2S + loss_cycle_S + loss_cycle_T + loss_identity_S + loss_identity_T
            if epoch > 1:
                loss_G += loss_contrastive_G
            netG_S2T.zero_grad()
            netG_T2S.zero_grad()
            loss_G.backward()
            optimizer_G.step()

            # update corresponding statistics
            losses_G_S2T.update(loss_G_S2T.item(), real_S.size(0))
            losses_G_T2S.update(loss_G_T2S.item(), real_S.size(0))
            losses_cycle_S.update(loss_cycle_S.item(), real_S.size(0))
            losses_cycle_T.update(loss_cycle_T.item(), real_S.size(0))
            losses_identity_S.update(loss_identity_S.item(), real_S.size(0))
            losses_identity_T.update(loss_identity_T.item(), real_S.size(0))
            if epoch > 1:
                losses_contrastive_G.update(loss_contrastive_G, real_S.size(0))

        # ===============================================
        # train the siamese network (when epoch > 0)
        # ===============================================
        if epoch > 0:
            set_requires_grad(siamese_net, True)
            # siamese network output
            f_real_S = siamese_net(real_S)
            f_fake_T = siamese_net(fake_T.detach())
            f_real_T = siamese_net(real_T)
            f_fake_S = siamese_net(fake_S.detach())

            # positive pair
            loss_contrastive_p_siamese = criterion_contrastive(f_real_S, f_fake_T, 0) + \
                                         criterion_contrastive(f_real_T, f_fake_S, 0)
            # negative pair
            loss_contrastive_n_siamese = criterion_contrastive(
                f_real_S, f_real_T, 1)
            # contrastive loss
            loss_contrastive_siamese = (loss_contrastive_p_siamese +
                                        2 * loss_contrastive_n_siamese) / 3

            # update siamese network
            siamese_net.zero_grad()
            loss_contrastive_siamese.backward()
            optimizer_siamese.step()

            # update corresponding statistics
            losses_contrastive_siamese.update(loss_contrastive_siamese,
                                              real_S.size(0))

        # ===============================================
        # train the discriminators
        # ===============================================

        set_requires_grad(netD_S, True)
        set_requires_grad(netD_T, True)
        # Calculate GAN loss for discriminator D_S
        fake_S_ = fake_S_pool.query(fake_S.detach())
        loss_D_S = 0.5 * (criterion_gan(netD_S(real_S), True) +
                          criterion_gan(netD_S(fake_S_), False))
        # Calculate GAN loss for discriminator D_T
        fake_T_ = fake_T_pool.query(fake_T.detach())
        loss_D_T = 0.5 * (criterion_gan(netD_T(real_T), True) +
                          criterion_gan(netD_T(fake_T_), False))

        # update discriminators
        netD_S.zero_grad()
        netD_T.zero_grad()
        loss_D_S.backward()
        loss_D_T.backward()
        optimizer_D.step()

        # update corresponding statistics
        losses_D_S.update(loss_D_S.item(), real_S.size(0))
        losses_D_T.update(loss_D_T.item(), real_S.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

            for tensor, name in zip([
                    real_S, real_T, fake_S, fake_T, rec_S, rec_T, identity_S,
                    identity_T
            ], [
                    "real_S", "real_T", "fake_S", "fake_T", "rec_S", "rec_T",
                    "identity_S", "identity_T"
            ]):
                visualize(tensor[0], "{}_{}".format(i, name))
Пример #5
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip,
                                                random_color_jitter=False, resize_size=args.resize_size,
                                                norm_mean=args.norm_mean, norm_std=args.norm_std)
    val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
                                            norm_mean=args.norm_mean, norm_std=args.norm_std)
    print("train_transform: ", train_transform)
    print("val_transform: ", val_transform)

    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
                                     shuffle=True, num_workers=args.workers, drop_last=True)
    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
                                     shuffle=True, num_workers=args.workers, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    print("=> using model '{}'".format(args.arch))
    backbone = utils.get_model(args.arch, pretrain=not args.scratch)
    pool_layer = nn.Identity() if args.no_pool else None
    source_classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
                                        pool_layer=pool_layer, finetune=not args.scratch).to(device)

    if args.phase == 'train' and args.pretrain is None:
        # first pretrain the classifier wish source data
        print("Pretraining the model on source domain.")
        args.pretrain = logger.get_checkpoint_path('pretrain')
        pretrain_model = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
                                         pool_layer=pool_layer, finetune=not args.scratch).to(device)
        pretrain_optimizer = SGD(pretrain_model.get_parameters(), args.pretrain_lr, momentum=args.momentum,
                                 weight_decay=args.weight_decay, nesterov=True)
        pretrain_lr_scheduler = LambdaLR(pretrain_optimizer,
                                         lambda x: args.pretrain_lr * (1. + args.lr_gamma * float(x)) ** (
                                             -args.lr_decay))
        # start pretraining
        for epoch in range(args.pretrain_epochs):
            print("lr:", pretrain_lr_scheduler.get_lr())
            # pretrain for one epoch
            utils.pretrain(train_source_iter, pretrain_model, pretrain_optimizer, pretrain_lr_scheduler, epoch, args,
                           device)
            # validate to show pretrain process
            utils.validate(val_loader, pretrain_model, args, device)

        torch.save(pretrain_model.state_dict(), args.pretrain)
        print("Pretraining process is done.")

    checkpoint = torch.load(args.pretrain, map_location='cpu')
    source_classifier.load_state_dict(checkpoint)
    target_classifier = copy.deepcopy(source_classifier)

    # freeze source classifier
    set_requires_grad(source_classifier, False)
    source_classifier.freeze_bn()

    domain_discri = DomainDiscriminator(in_feature=source_classifier.features_dim, hidden_size=1024).to(device)

    # define loss function
    grl = WarmStartGradientReverseLayer(alpha=1., lo=0., hi=2., max_iters=1000, auto_step=True)
    domain_adv = DomainAdversarialLoss(domain_discri, grl=grl).to(device)

    # define optimizer and lr scheduler
    # note that we only optimize target feature extractor
    optimizer = SGD(target_classifier.get_parameters(optimize_head=False) + domain_discri.get_parameters(), args.lr,
                    momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))

    # resume from the best checkpoint
    if args.phase != 'train':
        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
        target_classifier.load_state_dict(checkpoint)

    # analysis the model
    if args.phase == 'analysis':
        # extract features from both domains
        feature_extractor = nn.Sequential(target_classifier.backbone, target_classifier.pool_layer,
                                          target_classifier.bottleneck).to(device)
        source_feature = collect_feature(train_source_loader, feature_extractor, device)
        target_feature = collect_feature(train_target_loader, feature_extractor, device)
        # plot t-SNE
        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
        tsne.visualize(source_feature, target_feature, tSNE_filename)
        print("Saving t-SNE to", tSNE_filename)
        # calculate A-distance, which is a measure for distribution discrepancy
        A_distance = a_distance.calculate(source_feature, target_feature, device)
        print("A-distance =", A_distance)
        return

    if args.phase == 'test':
        acc1 = utils.validate(test_loader, target_classifier, args, device)
        print(acc1)
        return

    # start training
    best_acc1 = 0.
    for epoch in range(args.epochs):
        print(lr_scheduler.get_lr())
        # train for one epoch
        train(train_source_iter, train_target_iter, source_classifier, target_classifier, domain_adv,
              optimizer, lr_scheduler, epoch, args)

        # evaluate on validation set
        acc1 = utils.validate(val_loader, target_classifier, args, device)

        # remember best acc@1 and save checkpoint
        torch.save(target_classifier.state_dict(), logger.get_checkpoint_path('latest'))
        if acc1 > best_acc1:
            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
        best_acc1 = max(acc1, best_acc1)

    print("best_acc1 = {:3.1f}".format(best_acc1))

    # evaluate on test set
    target_classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
    acc1 = utils.validate(test_loader, target_classifier, args, device)
    print("test_acc1 = {:3.1f}".format(acc1))

    logger.close()