コード例 #1
0
def main(args):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    train_transform = T.Compose([
        T.RandomResizedCrop(size=args.train_size,
                            ratio=args.resize_ratio,
                            scale=(0.5, 1.)),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    source_dataset = datasets.__dict__[args.source]
    train_source_dataset = source_dataset(root=args.source_root,
                                          transforms=train_transform)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)

    target_dataset = datasets.__dict__[args.target]
    train_target_dataset = target_dataset(root=args.target_root,
                                          transforms=train_transform)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # define networks (both generators and discriminators)
    netG_S2T = cyclegan.generator.__dict__[args.netG](
        ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)
    netG_T2S = cyclegan.generator.__dict__[args.netG](
        ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)
    netD_S = cyclegan.discriminator.__dict__[args.netD](
        ndf=args.ndf, norm=args.norm).to(device)
    netD_T = cyclegan.discriminator.__dict__[args.netD](
        ndf=args.ndf, norm=args.norm).to(device)

    # create image buffer to store previously generated images
    fake_S_pool = ImagePool(args.pool_size)
    fake_T_pool = ImagePool(args.pool_size)

    # define optimizer and lr scheduler
    optimizer_G = Adam(itertools.chain(netG_S2T.parameters(),
                                       netG_T2S.parameters()),
                       lr=args.lr,
                       betas=(args.beta1, 0.999))
    optimizer_D = Adam(itertools.chain(netD_S.parameters(),
                                       netD_T.parameters()),
                       lr=args.lr,
                       betas=(args.beta1, 0.999))
    lr_decay_function = lambda epoch: 1.0 - max(0, epoch - args.epochs
                                                ) / float(args.epochs_decay)
    lr_scheduler_G = LambdaLR(optimizer_G, lr_lambda=lr_decay_function)
    lr_scheduler_D = LambdaLR(optimizer_D, lr_lambda=lr_decay_function)

    # optionally resume from a checkpoint
    if args.resume:
        print("Resume from", args.resume)
        checkpoint = torch.load(args.resume, map_location='cpu')
        netG_S2T.load_state_dict(checkpoint['netG_S2T'])
        netG_T2S.load_state_dict(checkpoint['netG_T2S'])
        netD_S.load_state_dict(checkpoint['netD_S'])
        netD_T.load_state_dict(checkpoint['netD_T'])
        optimizer_G.load_state_dict(checkpoint['optimizer_G'])
        optimizer_D.load_state_dict(checkpoint['optimizer_D'])
        lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G'])
        lr_scheduler_D.load_state_dict(checkpoint['lr_scheduler_D'])
        args.start_epoch = checkpoint['epoch'] + 1

    if args.phase == 'test':
        transform = T.Compose([
            T.Resize(image_size=args.test_input_size),
            T.wrapper(cyclegan.transform.Translation)(netG_S2T, device),
        ])
        train_source_dataset.translate(transform, args.translated_root)
        return

    # define loss function
    criterion_gan = cyclegan.LeastSquaresGenerativeAdversarialLoss()
    criterion_cycle = nn.L1Loss()
    criterion_identity = nn.L1Loss()

    # define visualization function
    tensor_to_image = Compose(
        [Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
         ToPILImage()])

    def visualize(image, name):
        """
        Args:
            image (tensor): image in shape 3 x H x W
            name: name of the saving image
        """
        tensor_to_image(image).save(
            logger.get_image_path("{}.png".format(name)))

    # start training
    for epoch in range(args.start_epoch, args.epochs + args.epochs_decay):
        logger.set_epoch(epoch)
        print(lr_scheduler_G.get_lr())

        # train for one epoch
        train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S,
              netD_T, criterion_gan, criterion_cycle, criterion_identity,
              optimizer_G, optimizer_D, fake_S_pool, fake_T_pool, epoch,
              visualize, args)

        # update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D.step()

        # save checkpoint
        torch.save(
            {
                'netG_S2T': netG_S2T.state_dict(),
                'netG_T2S': netG_T2S.state_dict(),
                'netD_S': netD_S.state_dict(),
                'netD_T': netD_T.state_dict(),
                'optimizer_G': optimizer_G.state_dict(),
                'optimizer_D': optimizer_D.state_dict(),
                'lr_scheduler_G': lr_scheduler_G.state_dict(),
                'lr_scheduler_D': lr_scheduler_D.state_dict(),
                'epoch': epoch,
                'args': args
            }, logger.get_checkpoint_path(epoch))

    if args.translated_root is not None:
        transform = T.Compose([
            T.Resize(image_size=args.test_input_size),
            T.wrapper(cyclegan.transform.Translation)(netG_S2T, device),
        ])
        train_source_dataset.translate(transform, args.translated_root)

    logger.close()
コード例 #2
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, netG_S2T, netG_T2S, netD_S,
          netD_T, siamese_net: spgan.SiameseNetwork,
          criterion_gan: cyclegan.LeastSquaresGenerativeAdversarialLoss,
          criterion_cycle: nn.L1Loss, criterion_identity: nn.L1Loss,
          criterion_contrastive: spgan.ContrastiveLoss, optimizer_G: Adam,
          optimizer_D: Adam, optimizer_siamese: Adam, fake_S_pool: ImagePool,
          fake_T_pool: ImagePool, epoch: int, visualize,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_G_S2T = AverageMeter('G_S2T', ':3.2f')
    losses_G_T2S = AverageMeter('G_T2S', ':3.2f')
    losses_D_S = AverageMeter('D_S', ':3.2f')
    losses_D_T = AverageMeter('D_T', ':3.2f')
    losses_cycle_S = AverageMeter('cycle_S', ':3.2f')
    losses_cycle_T = AverageMeter('cycle_T', ':3.2f')
    losses_identity_S = AverageMeter('idt_S', ':3.2f')
    losses_identity_T = AverageMeter('idt_T', ':3.2f')
    losses_contrastive_G = AverageMeter('contrastive_G', ':3.2f')
    losses_contrastive_siamese = AverageMeter('contrastive_siamese', ':3.2f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_G_S2T, losses_G_T2S, losses_D_S,
        losses_D_T, losses_cycle_S, losses_cycle_T, losses_identity_S,
        losses_identity_T, losses_contrastive_G, losses_contrastive_siamese
    ],
                             prefix="Epoch: [{}]".format(epoch))

    end = time.time()

    for i in range(args.iters_per_epoch):
        real_S, _, _, _ = next(train_source_iter)
        real_T, _, _, _ = next(train_target_iter)

        real_S = real_S.to(device)
        real_T = real_T.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # Compute fake images and reconstruction images.
        fake_T = netG_S2T(real_S)
        rec_S = netG_T2S(fake_T)
        fake_S = netG_T2S(real_T)
        rec_T = netG_S2T(fake_S)

        # ===============================================
        # train the generators (every two iterations)
        # ===============================================
        if i % 2 == 0:
            # save memory
            set_requires_grad(netD_S, False)
            set_requires_grad(netD_T, False)
            set_requires_grad(siamese_net, False)
            # GAN loss D_T(G_S2T(S))
            loss_G_S2T = criterion_gan(netD_T(fake_T), real=True)
            # GAN loss D_S(G_T2S(B))
            loss_G_T2S = criterion_gan(netD_S(fake_S), real=True)
            # Cycle loss || G_T2S(G_S2T(S)) - S||
            loss_cycle_S = criterion_cycle(rec_S,
                                           real_S) * args.trade_off_cycle
            # Cycle loss || G_S2T(G_T2S(T)) - T||
            loss_cycle_T = criterion_cycle(rec_T,
                                           real_T) * args.trade_off_cycle
            # Identity loss
            # G_S2T should be identity if real_T is fed: ||G_S2T(real_T) - real_T||
            identity_T = netG_S2T(real_T)
            loss_identity_T = criterion_identity(
                identity_T, real_T) * args.trade_off_identity
            # G_T2S should be identity if real_S is fed: ||G_T2S(real_S) - real_S||
            identity_S = netG_T2S(real_S)
            loss_identity_S = criterion_identity(
                identity_S, real_S) * args.trade_off_identity

            # siamese network output
            f_real_S = siamese_net(real_S)
            f_fake_T = siamese_net(fake_T)
            f_real_T = siamese_net(real_T)
            f_fake_S = siamese_net(fake_S)

            # positive pair
            loss_contrastive_p_G = criterion_contrastive(f_real_S, f_fake_T, 0) + \
                                   criterion_contrastive(f_real_T, f_fake_S, 0)
            # negative pair
            loss_contrastive_n_G = criterion_contrastive(f_fake_T, f_real_T, 1) + \
                                   criterion_contrastive(f_fake_S, f_real_S, 1) + \
                                   criterion_contrastive(f_real_S, f_real_T, 1)
            # contrastive loss
            loss_contrastive_G = (
                loss_contrastive_p_G +
                0.5 * loss_contrastive_n_G) / 4 * args.trade_off_contrastive

            # combined loss and calculate gradients
            loss_G = loss_G_S2T + loss_G_T2S + loss_cycle_S + loss_cycle_T + loss_identity_S + loss_identity_T
            if epoch > 1:
                loss_G += loss_contrastive_G
            netG_S2T.zero_grad()
            netG_T2S.zero_grad()
            loss_G.backward()
            optimizer_G.step()

            # update corresponding statistics
            losses_G_S2T.update(loss_G_S2T.item(), real_S.size(0))
            losses_G_T2S.update(loss_G_T2S.item(), real_S.size(0))
            losses_cycle_S.update(loss_cycle_S.item(), real_S.size(0))
            losses_cycle_T.update(loss_cycle_T.item(), real_S.size(0))
            losses_identity_S.update(loss_identity_S.item(), real_S.size(0))
            losses_identity_T.update(loss_identity_T.item(), real_S.size(0))
            if epoch > 1:
                losses_contrastive_G.update(loss_contrastive_G, real_S.size(0))

        # ===============================================
        # train the siamese network (when epoch > 0)
        # ===============================================
        if epoch > 0:
            set_requires_grad(siamese_net, True)
            # siamese network output
            f_real_S = siamese_net(real_S)
            f_fake_T = siamese_net(fake_T.detach())
            f_real_T = siamese_net(real_T)
            f_fake_S = siamese_net(fake_S.detach())

            # positive pair
            loss_contrastive_p_siamese = criterion_contrastive(f_real_S, f_fake_T, 0) + \
                                         criterion_contrastive(f_real_T, f_fake_S, 0)
            # negative pair
            loss_contrastive_n_siamese = criterion_contrastive(
                f_real_S, f_real_T, 1)
            # contrastive loss
            loss_contrastive_siamese = (loss_contrastive_p_siamese +
                                        2 * loss_contrastive_n_siamese) / 3

            # update siamese network
            siamese_net.zero_grad()
            loss_contrastive_siamese.backward()
            optimizer_siamese.step()

            # update corresponding statistics
            losses_contrastive_siamese.update(loss_contrastive_siamese,
                                              real_S.size(0))

        # ===============================================
        # train the discriminators
        # ===============================================

        set_requires_grad(netD_S, True)
        set_requires_grad(netD_T, True)
        # Calculate GAN loss for discriminator D_S
        fake_S_ = fake_S_pool.query(fake_S.detach())
        loss_D_S = 0.5 * (criterion_gan(netD_S(real_S), True) +
                          criterion_gan(netD_S(fake_S_), False))
        # Calculate GAN loss for discriminator D_T
        fake_T_ = fake_T_pool.query(fake_T.detach())
        loss_D_T = 0.5 * (criterion_gan(netD_T(real_T), True) +
                          criterion_gan(netD_T(fake_T_), False))

        # update discriminators
        netD_S.zero_grad()
        netD_T.zero_grad()
        loss_D_S.backward()
        loss_D_T.backward()
        optimizer_D.step()

        # update corresponding statistics
        losses_D_S.update(loss_D_S.item(), real_S.size(0))
        losses_D_T.update(loss_D_T.item(), real_S.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

            for tensor, name in zip([
                    real_S, real_T, fake_S, fake_T, rec_S, rec_T, identity_S,
                    identity_T
            ], [
                    "real_S", "real_T", "fake_S", "fake_T", "rec_S", "rec_T",
                    "identity_S", "identity_T"
            ]):
                visualize(tensor[0], "{}_{}".format(i, name))