Exemple #1
0
def main():
    # parse options
    parser = TrainOptions()
    opts = parser.parse()

    # daita loader
    print('\n--- load dataset ---')
    os.makedirs(opts.dataroot, exist_ok=True)

    dataset = torchvision.datasets.CIFAR10(opts.dataroot, train=True, download=True, transform= transforms.Compose([
        transforms.Resize(opts.img_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=opts.batch_size, shuffle=True, num_workers=opts.nThreads)

    # model
    print('\n--- load model ---')
    model = CDCGAN(opts)
    model.setgpu(opts.gpu)
    if opts.resume is None:
        model.initialize()
        ep0 = -1
        total_it = 0
    else:
        ep0, total_it = model.resume(opts.resume)
    ep0 += 1
    print('start the training at epoch %d'%(ep0))

    # saver for display and output
    saver = Saver(opts)

    # train
    print('\n--- train ---')
    max_it = 200000
    for ep in range(ep0, opts.n_ep):
        for it, (images, label) in enumerate(train_loader):
            if images.size(0) != opts.batch_size:
                continue
            # input data
            images = images.cuda(opts.gpu).detach()
            # update model
            model.update_D(images, label)
            model.update_G()

            # save to display file
            if not opts.no_display_img:
                saver.write_display(total_it, model)

            print('total_it: %d (ep %d, it %d), lr %08f' % (total_it, ep, it, model.gen_opt.param_groups[0]['lr']))
            total_it += 1
            if total_it >= max_it:
                saver.write_img(-1, model)
                saver.write_model(-1, max_it, model)
                break

        # save result image
        saver.write_img(ep, model)
        # Save network weights
        saver.write_model(ep, total_it, model)
    return
Exemple #2
0
def main():
    parser = TrainOptions()
    args = parser.parse()
    args.mode = 'train'

    # Print CUDA version.
    print("Running code using CUDA {}".format(torch.version.cuda))
    gpu_id = int(args.device[-1])
    torch.cuda.set_device(gpu_id)
    print('Training on device cuda:{}'.format(gpu_id))

    trainer = Trainer(args)

    if args.mode == 'train':
        trainer.train()
    elif args.mode == 'verify-data':
        trainer.verify_data()
Exemple #3
0
def main():
    log = logging.getLogger('pixsty')

    from options import TrainOptions
    parser = TrainOptions()
    parser.parser.add_argument('--subjects', type=str, nargs='+')
    args = parser.parse()

    log.info('Create dataset')
    train_loader = CustomDataLoader(args, phase='train')
    val_loader = CustomDataLoader(args, phase='val')
    print('training images = %d' % len(train_loader.dataset))
    print('validation images = %d' % len(val_loader.dataset))

    print('===> Build model')
    models = create_model(args)

    core_fn = {
        'pix2pix': core.training_estimator,
        'cyclegan': core.cyclegan_estimator,
    }[args.model]
    estimator_fn = core_fn(models, args)
    estimator_fn(train_loader, val_loader, epochs=args.niter)
Exemple #4
0
def main():
    # parse options
    parser = TrainOptions()
    opts = parser.parse()
    
    if opts.model_task != 'SYN' and opts.model_task != 'EDT':
        print('%s:unsupported task!'%opts.model_task)
        return
    SYN = opts.model_task == 'SYN'
    
    # create model
    print('--- create model ---')
    G_channels = 3 if SYN else 7 # SYN: (S), EDT:(S,I,M)
    D_channels = 6 if SYN else 7 # SYN: (S,I), EDT:(S,I,M)
    # (img_size, max_level) should be (256, 3), (128, 2) or (64, 1)
    netG = PSGAN(G_channels, opts.G_nlayers, opts.G_nf, D_channels, opts.D_nf, 
                 opts.D_nlayers, opts.max_dilate, opts.max_level, opts.img_size, opts.gpu!=0)
    if opts.gpu:
        netG.cuda()
    netG.init_networks(weights_init)
    netG.train()
    device = None if opts.gpu else torch.device('cpu')
    
    # to adapt G to the pretrained F
    if opts.use_F_level in [1,2,3]:
        # you could change the setting of netF based on your own pretraining setting 
        netF_Norm = 'None' if SYN else 'BN' 
        if opts.use_F_level == 1:
            netF = Pix2pix64(in_channels = G_channels, nef=64, useNorm=netF_Norm)
        elif opts.use_F_level == 2:
            netF = Pix2pix128(in_channels = G_channels, nef=64, useNorm=netF_Norm)
        else:
            netF = Pix2pix256(in_channels = G_channels, nef=64, useNorm=netF_Norm)
        netF.load_state_dict(torch.load(opts.load_F_name, map_location=device))
        for param in netF.parameters():
            param.requires_grad = False
        if opts.gpu:
            netF = netF.cuda()
    
    # for perceptual loss
    VGGNet = models.vgg19(pretrained=True).features
    VGGfeatures = VGGFeature(VGGNet, opts.gpu)
    for param in VGGfeatures.parameters():
        param.requires_grad = False
    if opts.gpu:
        VGGfeatures.cuda()

    print('--- training ---')
    dataset = dset.ImageFolder(root=opts.train_path,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
    
    # sampling refinement level l from l_sample_range
    # we found that r=0.5 should be trained to discriminate the results from those under r=0.0 
    l_sample_range = [1.*r/(opts.max_dilate-1) for r in range(opts.max_dilate)+[0.5]]
    
    # We train 64*64 image (max_level = 1) by directly training netG.G64 on 64*64. 
    # We train 256*256 images (max_level = 3) by first training netG.G64 on 64*64,
    # then fixing netG.G64 and training netG.G128 on 128*128,
    # finally fixing netG.G128 and training netG.G256 on 256*256, like pix2pixHD
    
    # NOTE: using large batch size at level1,2 (small image resolution) could improve results
    # However, this could make CUDA out of memory when training goes to level3.
    # Because the memory used in level1,2 is not released.
    # Solution is to call train.py three times, 
    # each time only training one level and saving the model parameters to load for the next level
    # for example, 
    # 'for level in range(1,1+opts.max_level)' should be changed to 'for level in [cur_level]'
    # and when cur_level=2, adding 'netG.G64.load_state_dict(torch.load(saved model at level 1))'
    # when cur_level=3, adding 'netG.G128.load_state_dict(torch.load(saved model at level 2))'
    batchsize_level = [opts.batchsize_level1, opts.batchsize_level2, opts.batchsize_level3]
    # progressively training from level1 to max_level
    for level in range(1,1+opts.max_level):
        print('--- Training at level %d. Image resolution: %d ---' % (level, 2**(5+level)))
        # fix the model parameters
        if level in [2,3]:
            for param in netG.G64.parameters():
                param.requires_grad = False
        if level in [3]:
            for param in netG.G128.parameters():
                param.requires_grad = False
        dataloader = DataLoader(dataset, batch_size=batchsize_level[level-1], shuffle=True, num_workers=4, drop_last=True)
        print_step = int(len(dataloader) / 10.0)
        if not SYN:
            # generate random 2**(13-level) masks for sampling
            masks_num = 2**(13-level)
            all_masks = get_mask(masks_num, opts.img_size)
            all_masks = to_var(all_masks) if opts.gpu else all_masks
        # main iteration
        for epoch in range(opts.epoch_pre+opts.epoch):
            for i, data in enumerate(dataloader):
                # during pretraining epoches, we only train on the max refinement level l = 1.0
                # then we will train on random l in [0,1]
                l = 1.0 if epoch < opts.epoch_pre else random.choice(l_sample_range) 
                data = to_var(data[0]) if opts.gpu else data[0]
                # if input image is arranged as (S,I) use AtoB = True
                # if input image is arranged as (I,S) use AtoB = False
                if opts.AtoB:
                    S = data[:,:,:,0:opts.img_size]
                    I = data[:,:,:,opts.img_size:opts.img_size*2]
                else:
                    S = data[:,:,:,opts.img_size:opts.img_size*2]  
                    I = data[:,:,:,0:opts.img_size]
                # apply netF loss, this will drastically increase the CUDA memory usuage
                netF_level = netF if level == opts.use_F_level else None
                if SYN:
                    losses = netG.synthesis_one_pass(S, I, l, level, VGGfeatures, netF=netF_level)
                else:
                    M = all_masks[torch.randint(masks_num, (batchsize_level[level-1],))]
                    losses = netG.editing_one_pass(S, I, M, l, level, VGGfeatures, netF=netF_level)
                
                if i % print_step == 0:
                    print('Epoch [%03d/%03d][%04d/%04d]' %(epoch+1, opts.epoch_pre+opts.epoch, i+1,
                                                                       len(dataloader)), end=': ')
                    print('l: %+.3f, LD: %+.3f, LGadv: %+.3f, Lperc: %+.3f, Lrec: %+.3f'%
                          (l, losses[0], losses[1], losses[2], losses[3]))
        print('--- Saving model at level %d ---' % level)
        netG.save_model(opts.save_model_path, opts.save_model_name, level=[level])
Exemple #5
0
def main():
    # parse options
    parser = TrainOptions()
    opts = parser.parse()

    # daita loader
    print('\n--- load dataset ---')
    dataset = dataset_unpair(opts)
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               num_workers=opts.nThreads)

    # model
    print('\n--- load model ---')
    model = UID(opts)
    model.setgpu(opts.gpu)
    if opts.resume is None:
        model.initialize()
        ep0 = -1
        total_it = 0
    else:
        ep0, total_it = model.resume(opts.resume)
    model.set_scheduler(opts, last_ep=ep0)
    ep0 += 1
    print('start the training at epoch %d' % (ep0))

    # saver for display and output
    saver = Saver(opts)

    # train
    print('\n--- train ---')
    max_it = 500000
    for ep in range(ep0, opts.n_ep):
        for it, (images_a, images_b) in enumerate(train_loader):
            if images_a.size(0) != opts.batch_size or images_b.size(
                    0) != opts.batch_size:
                continue
            images_a = images_a.cuda(opts.gpu).detach()
            images_b = images_b.cuda(opts.gpu).detach()

            # update model
            model.update_D(images_a, images_b)
            if (it + 1) % 2 != 0 and it != len(train_loader) - 1:
                continue
            model.update_EG()

            # save to display file
            if (it + 1) % 48 == 0:
                print('total_it: %d (ep %d, it %d), lr %08f' %
                      (total_it + 1, ep, it + 1,
                       model.gen_opt.param_groups[0]['lr']))
                print(
                    'Dis_I_loss: %04f, Dis_B_loss %04f, GAN_loss_I %04f, GAN_loss_B %04f'
                    % (model.disA_loss, model.disB_loss, model.gan_loss_i,
                       model.gan_loss_b))
                print('B_percp_loss %04f, Recon_II_loss %04f' %
                      (model.B_percp_loss, model.l1_recon_II_loss))
            if (it + 1) % 200 == 0:
                saver.write_img(ep * len(train_loader) + (it + 1), model)

            total_it += 1
            if total_it >= max_it:
                saver.write_img(-1, model)
                saver.write_model(-1, model)
                break

        # decay learning rate
        if opts.n_ep_decay > -1:
            model.update_lr()

        # Save network weights
        saver.write_model(ep, total_it + 1, model)

    return
Exemple #6
0
def main():
    # parse options
    parser = TrainOptions()
    opts = parser.parse()

    # data loader
    print('\n--- load dataset ---')

    if opts.multi_modal:
        dataset = dataset_unpair_multi(opts)
    else:
        dataset = dataset_unpair(opts)

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               num_workers=opts.nThreads)

    # model
    print('\n--- load model ---')
    model = DRIT(opts)
    model.setgpu(opts.gpu)
    if opts.resume is None:
        model.initialize()
        ep0 = -1
        total_it = 0
    else:
        ep0, total_it = model.resume(opts.resume)
    model.set_scheduler(opts, last_ep=ep0)
    ep0 += 1
    print('start the training at epoch %d' % (ep0))

    # saver for display and output
    saver = Saver(opts)

    # train
    print('\n--- train ---')
    max_it = 500000
    for ep in range(ep0, opts.n_ep):
        for it, (images_a, images_b) in enumerate(train_loader):
            if images_a.size(0) != opts.batch_size or images_b.size(
                    0) != opts.batch_size:
                continue

            # input data
            images_a = images_a.cuda(opts.gpu).detach()
            images_b = images_b.cuda(opts.gpu).detach()

            # update model
            if (it + 1) % opts.d_iter != 0 and it < len(train_loader) - 2:
                model.update_D_content(images_a, images_b)
                continue
            else:
                model.update_D(images_a, images_b)
                model.update_EG()

            # save to display file

            if not opts.no_display_img and not opts.multi_modal:
                saver.write_display(total_it, model)

            print('total_it: %d (ep %d, it %d), lr %08f' %
                  (total_it, ep, it, model.gen_opt.param_groups[0]['lr']))
            total_it += 1
            if total_it >= max_it:
                # saver.write_img(-1, model)
                saver.write_model(-1, model)
                break

        # decay learning rate
        if opts.n_ep_decay > -1:
            model.update_lr()

        # save result image
        if not opts.multi_modal:
            saver.write_img(ep, model)

        # Save network weights
        saver.write_model(ep, total_it, model)

    return
Exemple #7
0
def main():
    # parse options
    parser = TrainOptions()
    opts = parser.parse()

    # data loader
    print('--- load parameter ---')
    outer_iter = opts.outer_iter
    fade_iter = max(1.0, float(outer_iter / 2))
    epochs = opts.epoch
    batchsize = opts.batchsize
    datasize = opts.datasize
    datarange = opts.datarange
    augementratio = opts.augementratio
    centercropratio = opts.centercropratio

    # model
    print('--- create model ---')
    tetGAN = TETGAN(gpu=(opts.gpu != 0))
    if opts.gpu != 0:
        tetGAN.cuda()
    tetGAN.init_networks(weights_init)
    tetGAN.train()

    print('--- training ---')
    stylenames = os.listdir(opts.train_path)
    print('List of %d styles:' % (len(stylenames)), *stylenames, sep=' ')

    if opts.progressive == 1:
        # proressive training. From level1 64*64, to level2 128*128, to level3 256*256
        # level 1
        for i in range(outer_iter):
            jitter = min(1.0, i / fade_iter)
            fnames = load_trainset_batchfnames(opts.train_path, batchsize * 4,
                                               datarange, datasize * 2)
            for epoch in range(epochs):
                for fname in fnames:
                    x, y_real, y = prepare_batch(fname, 1, jitter,
                                                 centercropratio,
                                                 augementratio, opts.gpu)
                    losses = tetGAN.one_pass(x[0], None, y[0], None, y_real[0],
                                             None, 1, None)
                print('Level1, Iter[%d/%d], Epoch [%d/%d]' %
                      (i + 1, outer_iter, epoch + 1, epochs))
                print(
                    'Lrec: %.3f, Ldadv: %.3f, Ldesty: %.3f, Lsadv: %.3f, Lsty: %.3f'
                    % (losses[0], losses[1], losses[2], losses[3], losses[4]))
        # level 2
        for i in range(outer_iter):
            w = max(0.0, 1 - i / fade_iter)
            fnames = load_trainset_batchfnames(opts.train_path, batchsize * 2,
                                               datarange, datasize * 2)
            for epoch in range(epochs):
                for fname in fnames:
                    x, y_real, y = prepare_batch(fname, 2, 1, centercropratio,
                                                 augementratio, opts.gpu)
                    losses = tetGAN.one_pass(x[0], x[1], y[0], y[1], y_real[0],
                                             y_real[1], 2, w)
                print('Level2, Iter[%d/%d], Epoch [%d/%d]' %
                      (i + 1, outer_iter, epoch + 1, epochs))
                print(
                    'Lrec: %.3f, Ldadv: %.3f, Ldesty: %.3f, Lsadv: %.3f, Lsty: %.3f'
                    % (losses[0], losses[1], losses[2], losses[3], losses[4]))
        # level 3
        for i in range(outer_iter):
            w = max(0.0, 1 - i / fade_iter)
            fnames = load_trainset_batchfnames(opts.train_path, batchsize,
                                               datarange, datasize)
            for epoch in range(epochs):
                for fname in fnames:
                    x, y_real, y = prepare_batch(fname, 3, 1, centercropratio,
                                                 augementratio, opts.gpu)
                    losses = tetGAN.one_pass(x[0], x[1], y[0], y[1], y_real[0],
                                             y_real[1], 3, w)
                print('Level3, Iter[%d/%d], Epoch [%d/%d]' %
                      (i + 1, outer_iter, epoch + 1, epochs))
                print(
                    'Lrec: %.3f, Ldadv: %.3f, Ldesty: %.3f, Lsadv: %.3f, Lsty: %.3f'
                    % (losses[0], losses[1], losses[2], losses[3], losses[4]))
    else:
        # directly train on level3 256*256
        for i in range(outer_iter):
            fnames = load_trainset_batchfnames(opts.train_path, batchsize,
                                               datarange, datasize)
            for epoch in range(epochs):
                for fname in fnames:
                    x, y_real, y = prepare_batch(fname, 3, 1, centercropratio,
                                                 augementratio, opts.gpu)
                    losses = tetGAN.one_pass(x[0], None, y[0], None, y_real[0],
                                             None, 3, 0)
                print('Iter[%d/%d], Epoch [%d/%d]' %
                      (i + 1, outer_iter, epoch + 1, epochs))
                print(
                    'Lrec: %.3f, Ldadv: %.3f, Ldesty: %.3f, Lsadv: %.3f, Lsty: %.3f'
                    % (losses[0], losses[1], losses[2], losses[3], losses[4]))

    print('--- save ---')
    torch.save(tetGAN.state_dict(), opts.save_model_name)
Exemple #8
0
def main():
    # parse options
    parser = TrainOptions()
    opts = parser.parse()

    # daita loader
    print('\n--- load dataset ---')
    dataset = dataset_multi(opts)
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               num_workers=opts.nThreads)

    # model
    print('\n--- load model ---')
    model = SAVI2I(opts)
    model.setgpu(opts.gpu)
    if opts.resume is None:
        model.initialize()
        ep0 = -1
        total_it = 0
    else:
        ep0, total_it = model.resume(opts.resume)
    model.set_scheduler(opts, last_ep=ep0)
    ep0 += 1
    print('start the training at epoch %d' % (ep0))

    # saver for display and output
    saver = Saver(opts)

    # train
    print('\n--- train ---')
    max_it = 1000000
    for ep in range(ep0, opts.n_ep):
        for it, (images, c_org, c_org_mask,
                 c_org_id) in enumerate(train_loader):
            # input data
            images = torch.cat(images, dim=0)
            images = images.cuda(opts.gpu).detach()
            c_org = torch.cat(c_org, dim=0)
            c_org = c_org.cuda(opts.gpu).detach()
            c_org_mask = torch.cat(c_org_mask, dim=0)
            c_org_mask = c_org_mask.cuda(opts.gpu).detach()
            c_org_id = torch.cat(c_org_id, dim=0)
            c_org_id = c_org_id.cuda(opts.gpu).detach()

            # update model
            if (it + 1) % opts.d_iter != 0 and it < len(train_loader) - 2:
                model.update_D_content(images, c_org)
                continue
            else:
                model.update_D(images, c_org, c_org_mask, c_org_id)
                model.update_EFG()

            print('total_it: %d (ep %d, it %d), lr %08f' %
                  (total_it, ep, it, model.gen_opt.param_groups[0]['lr']))
            total_it += 1
            if total_it >= max_it:
                saver.write_img(-1, model)
                saver.write_model(-1, max_it, model)
                break

        # decay learning rate
        if opts.n_ep_decay > -1:
            model.update_lr()

        # save result image
        saver.write_img(ep, model)

        # Save network weights
        saver.write_model(ep, total_it, model)

    return
Exemple #9
0
def main():

    debug_mode=False

    # parse options
    parser = TrainOptions()
    opts = parser.parse()

    # daita loader
    print('\n--- load dataset ---')
    dataset = dataset_unpair(opts)
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=opts.batch_size, shuffle=True,
                                               num_workers=opts.nThreads)
    '''
        通过检查dataset_unpair,我们发现:
            图像是先缩放到256,256,然后再随机裁剪出216,216的patch,(测试时是从中心裁剪)
    '''

    # model
    print('\n--- load model ---')
    model = DRIT(opts)
    if not debug_mode:
        model.setgpu(opts.gpu)
    if opts.resume is None:
        model.initialize()
        ep0 = -1
        total_it = 0
    else:
        ep0, total_it = model.resume(opts.resume)
    model.set_scheduler(opts, last_ep=ep0)
    ep0 += 1
    print('start the training at epoch %d' % (ep0))

    # saver for display and output
    saver = Saver(opts)

    # train
    print('\n--- train ---')
    max_it = 500000
    for ep in range(ep0, opts.n_ep):
        '''
            images_a,images_b: 2,3,216,216
        '''
        for it, (images_a, images_b) in enumerate(train_loader):
            #   假如正好拿到了残次的剩余的一两个样本,就跳过,重新取样
            if images_a.size(0) != opts.batch_size or images_b.size(0) != opts.batch_size:
                continue

            # input data
            if not debug_mode:
                images_a = images_a.cuda(opts.gpu).detach() #   这里进行detach,可能是为了避免计算不需要的梯度,节省显存
                images_b = images_b.cuda(opts.gpu).detach()

            # update model 按照默认设置,1/3的iter更新内容判别器,2/3的iter更新D和EG
            if (it + 1) % opts.d_iter != 0 and it < len(train_loader) - 2:
                model.update_D_content(images_a, images_b)
                continue
            else:
                model.update_D(images_a, images_b)
                model.update_EG()

            # save to display file
            if not opts.no_display_img:
                saver.write_display(total_it, model)

            print('total_it: %d (ep %d, it %d), lr %08f' % (total_it, ep, it, model.gen_opt.param_groups[0]['lr']))
            sys.stdout.flush()
            total_it += 1
            if total_it >= max_it:
                saver.write_img(-1, model)
                saver.write_model(-1, model)
                break

        # decay learning rate
        if opts.n_ep_decay > -1:
            model.update_lr()

        # save result image
        saver.write_img(ep, model)

        # Save network weights
        saver.write_model(ep, total_it, model)

    return
# PyTorch includes
import torch
from torch.autograd import Variable
# Custom includes
from options import TrainOptions
from dataset import createDataset
from models import createModel
from visualizer import ProgressVisualizer

assert Variable

parser = TrainOptions()
opt = parser.parse()

# set dataloader
trainDataset = createDataset(opt, split='train', nInput=opt.nInput)
valDataset = createDataset(opt, split='val', nInput=opt.nInput)

trainDataLoader = torch.utils.data.DataLoader(trainDataset,
                                              batch_size=opt.batchSize,
                                              shuffle=True,
                                              num_workers=opt.nThreads)

valDataLoader = torch.utils.data.DataLoader(valDataset,
                                            batch_size=opt.batchSize,
                                            shuffle=True,
                                            num_workers=opt.nThreads)
# set model

model = createModel(opt)
model.setup(opt)
Exemple #11
0
def main():
    parser = TrainOptions()
    opts = parser.parse()
    dataroot = opts.dataroot
    workers = opts.workers
    batch_size = opts.batch_size
    image_size = opts.image_size
    nc = opts.num_channels
    num_epochs = opts.num_epochs
    lr = opts.lr
    beta1 = opts.beta1
    ngpu = opts.gpu
    checkpoint_path = opts.checkpoint_path
    device = torch.device("cuda:0" if (
        torch.cuda.is_available() and ngpu > 0) else "cpu")
    nz = 128
    ngf = 256
    ndf = 256

    def weights_init_normal(model):
        for param in model.parameters():
            if (len(param.size()) == 2):
                torch.nn.init.xavier_normal(param)

    dataset = datasets.ImageFolder(root=dataroot,
                                   transform=transforms.Compose([
                                       transforms.RandomRotation(0.5),
                                       transforms.RandomAffine(0.5),
                                       transforms.ColorJitter(
                                           0, 0.1, 0.1, 0.1),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5)),
                                   ]))
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=True)

    netG = Generator(ngpu).to(device)
    if (device.type == 'cuda') and (ngpu > 1):
        netG = nn.DataParallel(netG, list(range(ngpu)))
    netG.apply(weights_init_normal)

    netD = Discriminator(ngpu).to(device)
    if (device.type == 'cuda') and (ngpu > 1):
        netD = nn.DataParallel(netD, list(range(ngpu)))
    netD.apply(weights_init_normal)

    criterion = nn.BCELoss()
    fixed_noise = torch.randn(32, nz, 1, 1, device=device)
    real_label = 1
    fake_label = 0
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

    # Training Loop
    print("Starting Training Loop...")
    for epoch in range(num_epochs):
        for i, data in enumerate(dataloader, 0):
            netD.zero_grad()
            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)
            label = torch.full((b_size, ), real_label, device=device)
            output = netD(real_cpu).view(-1)
            errD_real = criterion(output, label)
            errD_real.backward()
            D_x = output.mean().item()

            noise = torch.randn(b_size, nz, 1, 1, device=device)
            fake = netG(noise)
            label.fill_(fake_label)
            output = netD(fake.detach()).view(-1)
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            optimizerD.step()

            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            output = netD(fake).view(-1)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            optimizerG.step()

            # Output training stats
            print(
                '[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                % (epoch, num_epochs, i, len(dataloader), errD.item(),
                   errG.item(), D_x, D_G_z1, D_G_z2))

    torch.save(netG.state_dict(), os.path.join(checkpoint_path,
                                               'generator.pkl'))
def main():
    # for multi processing
    if torch.cuda.is_available():
        multiprocessing.set_start_method('spawn')
    # parse options
    parser = TrainOptions()
    opts = parser.parse()

    if (z_delete_mode):
        z_zero_train = torch.zeros(opts.batch_size, 32).cuda(opts.gpu)
        z_zero_test = torch.zeros(opts.batch_size_test, 32).cuda(opts.gpu)

    # daita loader
    print('\n--- load dataset ---')
    sharp_whole_root = None
    if (sharp_whole_mode):
        sharp_whole_root = "../datasets_npz/datasets_sharp_whole_" + data_version
    dataset = dataset_pair_group_simple(
        opts,
        "../datasets_npz/datasets_blur_" + data_version,
        "../datasets_npz/datasets_sharp_" + data_version,
        "../datasets_npz/datasets_sharp_start_" + data_version,
        "../datasets_npz/datasets_gt_vel_" + data_version,
        "../datasets_npz/datasets_gt_pos_" + data_version,
        "../datasets_npz/datasets_shutter_speed_" + data_version,
        sharp_whole_root=sharp_whole_root)
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=opts.batch_size,
                                               shuffle=True,
                                               num_workers=opts.nThreads)

    dataset_test = dataset_pair_group_simple(
        opts,
        "../datasets_npz/datasets_blur_" + data_version,
        "../datasets_npz/datasets_sharp_" + data_version,
        "../datasets_npz/datasets_sharp_start_" + data_version,
        "../datasets_npz/datasets_gt_vel_" + data_version,
        "../datasets_npz/datasets_gt_pos_" + data_version,
        "../datasets_npz/datasets_shutter_speed_" + data_version,
        test_mode=True,
        sharp_whole_root=sharp_whole_root)
    test_loader = torch.utils.data.DataLoader(dataset_test,
                                              batch_size=opts.batch_size_test,
                                              shuffle=False,
                                              num_workers=opts.nThreads)
    test_iter = iter(test_loader)

    # model
    print('\n--- load model ---')
    model = UID(opts)
    model.setgpu(opts.gpu)
    if opts.resume is None:
        model.initialize()
        ep0 = -1
        total_it = 0
    else:
        ep0, total_it = model.resume(opts.resume)
    model.set_scheduler(opts, last_ep=ep0)
    ep0 += 1
    print('start the training at epoch %d' % (ep0))

    # saver for display and output
    saver = Saver(opts)

    # train
    print('\n--- train ---')
    max_it = 500000
    for ep in range(ep0, opts.n_ep):
        for it, (images_b_s, gt_s, gt_pos_s, shutter_speed_s, images_s_whole_s,
                 images_a_s) in enumerate(train_loader):
            if images_b_s.size(0) != opts.batch_size or images_b_s.size(
                    0) != opts.batch_size:
                continue
            images_b_s = images_b_s.cuda(opts.gpu).detach()
            for o in range(n_obj):
                gt_s[o] = gt_s[o].cuda(opts.gpu).detach()
                gt_pos_s[o] = gt_pos_s[o].cuda(opts.gpu).detach()
            shutter_speed_s = shutter_speed_s.cuda(opts.gpu).detach()
            if (sharp_whole_mode):
                images_s_whole_s = images_s_whole_s.cuda(opts.gpu).detach()
            else:
                images_a_s = images_a_s.cuda(opts.gpu).detach()

            # update model
            s = 0
            z = None
            #last_gt = gt_s[:,opts.sequence_num - 1,:]
            last_gt = None
            loss_pred_last_sum = 0
            while (s < (opts.sequence_num - 2)):
                if (z_delete_mode):
                    z = z_zero_train
                # now: s+1, max: sequence_num-1
                left_steps = opts.sequence_num - s - 2  # 현재(s+1)스텝으로부터 몇번까지 더 gt 가 있는지

                images_prev = images_b_s[:, s, :, :, :]
                images_b = images_b_s[:, s + 1, :, :, :]
                images_post = images_b_s[:, s + 2, :, :, :]
                gt = []
                gt_pos = []
                for o in range(n_obj):
                    gt.append(gt_s[o][:, s + 1, :])
                    gt_pos.append(gt_pos_s[o][:, s + 1, :])
                next_gt = []
                next_gt_pos = []
                for o in range(n_obj):
                    next_gt.append(gt_s[o][:, s + 2, :])
                    next_gt_pos.append(gt_pos_s[o][:, s + 2, :])
                shutter_speed = shutter_speed_s
                if (sharp_whole_mode):
                    images_s_whole = images_s_whole_s[:, (s + 1) * 16:(s + 2) *
                                                      16, :, :, :]
                    images_a = None
                else:
                    images_s_whole = None
                    images_a = images_a_s[:, s + 1, :, :, :]

                if throwing_mode and s != 0:
                    given_vel = model.next_vel_pred
                    given_pos = model.next_pos_pred
                else:
                    given_vel = None
                    given_pos = None

                model.update(images_prev,
                             images_post,
                             images_b,
                             gt,
                             gt_pos,
                             next_gt=next_gt,
                             next_gt_pos=next_gt_pos,
                             z=z,
                             last_gt=last_gt,
                             left_steps=left_steps,
                             gt_pos_set=gt_pos_s,
                             shutter_speed=shutter_speed,
                             images_s_whole=images_s_whole,
                             images_a=images_a,
                             given_vel=given_vel,
                             given_pos=given_pos)
                s += 1
                if (model.pred_mode):
                    z = model.z_next
                    loss_pred_last_sum += model.loss_pred_last_vel
                else:
                    z = None
                    loss_pred_last_sum += -1

            # save to display file
            if (it + 1) % 1 == 0:
                print('total_it: %d (ep %d, it %d), lr %08f' %
                      (total_it + 1, ep, it + 1,
                       model.enc_c_opt.param_groups[0]['lr']))
                if pred_mode:
                    loss_pred_vel = model.loss_pred_vel[0]
                else:
                    loss_pred_vel = -1
                print(
                    'gen_loss: %04f, vel_loss %04f, vel_recons_loss %04f, gen_loss_gt_vel %04f, inverse_loss %04f, vel_dir_loss %04f, loss_content %04f, vel_pred_loss %04f, pos_loss %04f, last_vel_pred_loss %04f, whole_pos_pred %04f, loss_sharp %04f'
                    %
                    (model.loss_gen, model.loss_vel[0], model.loss_vel_recons,
                     model.loss_gen_gt_vel, model.loss_inverse,
                     model.loss_vel_dir[0], model.loss_content, loss_pred_vel,
                     model.loss_pos[0], loss_pred_last_sum / s,
                     model.loss_sharp_whole, model.loss_sharp))
                #print(model.next_gt)
                #print(model.next_gt_pred)

            if (it + 1) % test_ratio == 0:
                n = 0
                total_dir_loss = 0
                total_speed_loss = 0
                total_vel_loss = 0
                total_pos_loss = 0
                total_dir_entropy = 0
                for rep in range(1):
                    try:
                        images_b_s, gt_s, gt_pos_s, shutter_speed_s, images_s_whole_s, images_a_s = next(
                            test_iter)
                    except StopIteration:
                        test_iter = iter(test_loader)
                        images_b_s, gt_s, gt_pos_s, shutter_speed_s, _, _ = next(
                            test_iter)
                    while (images_b_s.size(0) != opts.batch_size_test):
                        try:
                            images_b_s, gt_s, gt_pos_s, shutter_speed_s, _, _ = next(
                                test_iter)
                        except StopIteration:
                            test_iter = iter(test_loader)
                            images_b_s, gt_s, gt_pos_s, shutter_speed_s, _, _ = next(
                                test_iter)

                    images_b_s = images_b_s.cuda(opts.gpu).detach()
                    for o in range(n_obj):
                        gt_s[o] = gt_s[o].cuda(opts.gpu).detach()
                        gt_pos_s[o] = gt_pos_s[o].cuda(opts.gpu).detach()
                    shutter_speed_s = shutter_speed_s.cuda(opts.gpu).detach()

                    # test model
                    s = 0
                    loss_pred_last_sum = 0
                    #last_gt = gt_s[:,opts.sequence_num - 1,:]
                    last_gt = None
                    z = None
                    while (s < (opts.sequence_num - 2)):
                        if (z_delete_mode):
                            z = z_zero_test
                        left_steps = opts.sequence_num - s - 2  # 현재(s+1)스텝으로부터 몇번까지 더 gt 가 있는지
                        images_prev = images_b_s[:, s, :, :, :]
                        images_b = images_b_s[:, s + 1, :, :, :]
                        images_post = images_b_s[:, s + 2, :, :, :]

                        gt = []
                        gt_pos = []
                        for o in range(n_obj):
                            gt.append(gt_s[o][:, s + 1, :])
                            gt_pos.append(gt_pos_s[o][:, s + 1, :])
                        next_gt = []
                        next_gt_pos = []
                        for o in range(n_obj):
                            next_gt.append(gt_s[o][:, s + 2, :])
                            next_gt_pos.append(gt_pos_s[o][:, s + 2, :])
                        loss_pred_last_sum += model.loss_pred_last_vel
                        shutter_speed = shutter_speed_s
                        if throwing_mode and s != 0:
                            given_vel = model.next_vel_pred
                            given_pos = model.next_pos_pred
                        else:
                            given_vel = None
                            given_pos = None
                        model.test(images_prev,
                                   images_post,
                                   images_b,
                                   gt,
                                   gt_pos,
                                   next_gt,
                                   next_gt_pos,
                                   z,
                                   last_gt=last_gt,
                                   left_steps=left_steps,
                                   gt_pos_set=gt_pos_s,
                                   shutter_speed=shutter_speed,
                                   given_vel=given_vel,
                                   given_pos=given_pos)
                        s += 1
                        total_dir_loss += model.loss_vel_dir[0]
                        total_speed_loss += model.loss_vel_speed[0]
                        total_vel_loss += model.loss_vel[0]
                        total_pos_loss += model.loss_pos[0]
                        total_dir_entropy += model.loss_dir_entropy[0]
                        n += 1

                        if (model.pred_mode):
                            z = model.z_next
                        else:
                            z = None
                #print("dir")
                #print(model.dir)
                total_dir_loss /= n
                total_speed_loss /= n
                total_vel_loss /= n
                total_pos_loss /= n
                total_dir_entropy /= n
                print('=============================')
                print("n: " + str(n))
                if (vel_secret_mode):
                    print(
                        'gen_loss: %04f, vel_loss %04f, vel_recons_loss %04f, gen_loss_gt_vel %04f, inverse_loss %04f, vel_dir_loss %04f, loss_content %04f, vel_pred_loss %04f, pos_loss %04f, dir_entropy_loss %04f, speed_loss %04f'
                        %
                        (model.loss_gen, total_vel_loss, model.loss_vel_recons,
                         model.loss_gen_gt_vel, model.loss_inverse,
                         total_dir_loss, model.loss_content, loss_pred_vel,
                         total_pos_loss, total_dir_entropy, total_speed_loss))
                else:
                    if pred_mode:
                        loss_pred_vel = model.loss_pred_vel[0]
                    else:
                        loss_pred_vel = -1
                    print(
                        'gen_loss: %04f, vel_loss %04f, vel_recons_loss %04f, gen_loss_gt_vel %04f, inverse_loss %04f, vel_dir_loss %04f, loss_content %04f, vel_pred_loss %04f, pos_loss %04f, dir_entropy_loss %04f, speed_loss %04f'
                        %
                        (model.loss_gen, total_vel_loss, model.loss_vel_recons,
                         model.loss_gen_gt_vel, model.loss_inverse,
                         total_dir_loss, model.loss_content, loss_pred_vel,
                         total_pos_loss, total_dir_entropy, total_speed_loss))
                print('=============================')
                if (model.gen_mode):
                    if ep % 10 == 0 and (it + 1) % 60 == 0:
                        saver.write_img(ep * len(train_loader) + (it + 1),
                                        model)
                if vel_secret_mode or encoder_tuning_mode:
                    if (ep + 1) % 5 == 0:
                        save_pos(model.input_gt_vel[0],
                                 model.vel_pred[0],
                                 model.vel_pred[0],
                                 model.next_vel_pred[0],
                                 model.next_vel_encoded[0],
                                 opts.visualize_root,
                                 "example%06d.png" % (ep * len(train_loader) +
                                                      (it + 1)),
                                 gt_next_=model.next_vel_encoded[0],
                                 pred_next_=model.next_vel_pred[0],
                                 gt_pos_=model.next_pos_encoded[0],
                                 pred_pos_=model.next_pos_pred[0])
                elif (model.gen_mode and model.pred_mode):
                    if ep % 10 == 0 and (it + 1) % 60 == 0:
                        save_pos(model.input_gt_vel[0],
                                 model.vel_pred[0],
                                 model.vel_pred[0],
                                 model.input_gt_vel[0],
                                 model.vel_pred[0],
                                 opts.visualize_root,
                                 "example%06d.png" % (ep * len(train_loader) +
                                                      (it + 1)),
                                 gt_next_=model.next_gt_vel[0],
                                 pred_next_=model.next_vel_pred[0],
                                 gt_pos_=model.input_gt_pos[0],
                                 pred_pos_=model.pos_pred[0])
                elif (model.pred_mode):
                    if (it + 1) % 300 == 0:
                        save_pos(model.input_gt_vel[0],
                                 model.vel_pred,
                                 model.vel_pred,
                                 model.input_gt_vel[0],
                                 model.vel_pred,
                                 opts.visualize_root,
                                 "example%06d.png" % (ep * len(train_loader) +
                                                      (it + 1)),
                                 gt_next_=model.next_gt_vel,
                                 pred_next_=model.next_vel_pred,
                                 gt_pos_=model.input_gt_pos,
                                 pred_pos_=model.pos_pred)
                elif (model.gen_mode):
                    if ep % 10 == 0 and (it + 1) % 60 == 0:
                        save_pos(model.input_gt_vel[0],
                                 model.vel_pred[0],
                                 model.vel_pred[1],
                                 model.input_gt_vel[0],
                                 model.vel_pred[0],
                                 opts.visualize_root,
                                 "example%06d.png" % (ep * len(train_loader) +
                                                      (it + 1)),
                                 gt_next_=model.input_gt_vel[0],
                                 pred_next_=model.vel_pred[0],
                                 gt_pos_=model.input_gt_pos[0],
                                 pred_pos_=model.pos_pred[0],
                                 gt_pos_2=model.input_gt_pos[0])
                else:
                    if ep % 10 == 0 and (it + 1) % 20 == 0:
                        save_pos(model.input_gt_vel[0],
                                 model.vel_pred,
                                 model.vel_pred,
                                 model.input_gt_vel[0],
                                 model.vel_pred,
                                 opts.visualize_root,
                                 "example%06d.png" % (ep * len(train_loader) +
                                                      (it + 1)),
                                 gt_next_=model.input_gt_vel[0],
                                 pred_next_=model.vel_pred,
                                 gt_pos_=model.input_gt_pos[0],
                                 pred_pos_=model.pos_pred,
                                 gt_pos_2=model.input_gt_pos[0])
                #save_pos_for_next_gt(opts.visualize_root,"example%06d.png"%(ep*len(train_loader) + (it+1)), gt_next_ = model.next_gt, pred_next_ = model.next_gt_pred)

            total_it += 1
            if total_it >= max_it:
                saver.write_img(-1, model)
                saver.write_model(-1, model)
                break

        # decay learning rate
        if opts.n_ep_decay > -1:
            model.update_lr()

        # Save network weights
        #if vel_secret_mode or encoder_tuning_mode:
        #saver.write_model(ep, total_it+1, model)
        if (ep + 1) % 5 == 0:
            saver.write_model(ep, total_it + 1, model)

    return
Exemple #13
0
def main():
    global val_accuracy, val_f1, val_precision, val_recall, val_cross_corr_a, val_cross_corr_b, val_mse, val_mae, \
        saver, ep, save_opts, total_it, iter_counter, t0

    # initialise params
    parser = TrainOptions()
    opts = parser.parse()
    opts.random_seed = RANDOM_SEED
    opts.device = opts.device if torch.cuda.is_available(
    ) and opts.gpu else 'cpu'
    opts.name = opts.data_type + '_' + time.strftime("%d%m%Y-%H%M")
    opts.results_path = os.path.join(opts.result_dir, opts.name)
    opts.image_size = IMAGE_SIZE
    opts.age_range_0 = AGE_RANGE_0
    opts.age_range_1 = AGE_RANGE_1
    opts.resize_image = RESIZE_IMAGE
    opts.resize_size = RESIZE_SIZE
    ep0 = 0
    total_it = 0
    val_accuracy = np.zeros(opts.n_ep)
    val_f1 = np.zeros(opts.n_ep)
    val_precision = np.zeros(opts.n_ep)
    val_recall = np.zeros(opts.n_ep)
    val_cross_corr_a = np.zeros(opts.n_ep)
    val_cross_corr_b = np.zeros(opts.n_ep)
    val_mse = np.zeros(opts.n_ep)
    val_mae = np.zeros(opts.n_ep)
    t0 = time.time()

    # saver for display and output
    if opts.data_dim == '3d':
        from saver_3d import Saver
        opts.nz = LATENT_3D
    else:
        from saver import Saver
        opts.nz = LATENT_2D

    print('\n--- load dataset ---')
    # add new dataloader in _load_dataloader(), and in dataloader_utils.py
    healthy_dataloader, healthy_val_dataloader, healthy_test_dataloader, \
    anomaly_dataloader, anomaly_val_dataloader, anomaly_test_dataloader = _load_dataloader(opts)

    print('\n--- load model ---')
    model = ICAM(opts)
    model.setgpu(opts.device)
    model.initialize()
    model.set_scheduler(opts, last_ep=ep0)
    save_opts = vars(opts)
    saver = Saver(opts)

    if not os.path.exists(opts.results_path):
        os.makedirs(opts.results_path)

    with open(opts.results_path + '/parameters.json', 'w') as file:
        json.dump(save_opts, file, indent=4, sort_keys=True)

    print('\n--- train ---')
    for ep in range(ep0, opts.n_ep):
        healthy_data_iter = iter(healthy_dataloader)
        anomaly_data_iter = iter(anomaly_dataloader)
        iter_counter = 0

        while iter_counter < len(anomaly_dataloader) and iter_counter < len(
                healthy_dataloader):
            # output of iter dataloader: [tensor image, tensor label (regression), tensor mask]
            healthy_images, healthy_label_reg, healthy_mask = healthy_data_iter.next(
            )
            anomaly_images, anomaly_label_reg, anomaly_mask = anomaly_data_iter.next(
            )
            healthy_c_org = torch.zeros(
                (healthy_images.size(0), opts.num_domains)).to(opts.device)
            healthy_c_org[:, 0] = 1
            anomaly_c_org = torch.zeros(
                (healthy_images.size(0), opts.num_domains)).to(opts.device)
            anomaly_c_org[:, 1] = 1
            images = torch.cat((healthy_images, anomaly_images),
                               dim=0).type(torch.FloatTensor)
            c_org = torch.cat((healthy_c_org, anomaly_c_org),
                              dim=0).type(torch.FloatTensor)
            label_reg = torch.cat((healthy_label_reg, anomaly_label_reg),
                                  dim=0).type(torch.FloatTensor)

            if len(healthy_mask.size()) > 2:
                mask = torch.cat((healthy_mask, anomaly_mask),
                                 dim=0).type(torch.FloatTensor)
                mask = mask.to(opts.device).detach()
            else:
                mask = None

            iter_counter += 1
            if images.size(0) != opts.batch_size:
                continue

            # input data
            images = images.to(opts.device).detach()
            c_org = c_org.to(opts.device).detach()
            label_reg = label_reg.to(opts.device).detach()

            # update model
            if (iter_counter % opts.d_iter) != 0 and iter_counter < len(
                    anomaly_dataloader) - opts.d_iter:
                model.update_D_content(opts, images, c_org)
                continue

            model.update_D(opts, images, c_org, label_reg, mask=mask)
            model.update_EG(opts)

            if ((total_it + 1) % opts.train_print_it) == 0:
                train_accuracy, train_f1, _, _ = model.classification_scores(
                    images, c_org)
                if opts.regression:
                    train_mse, train_mae, _ = model.regression(
                        images, label_reg)
            if total_it == 0:
                saver.write_img(ep, total_it, model)
            elif total_it % opts.display_freq == 0:
                saver.write_img(ep, total_it, model)
            total_it += 1

            # save to tensorboard
            saver.write_display(total_it, model)

            time_elapsed = time.time() - t0
            hours, rem = divmod(time_elapsed, 3600)
            minutes, seconds = divmod(rem, 60)

            if (total_it % opts.train_print_it) == 0:
                print(
                    'Total it: {:d} (ep {:d}, it {:d}), Accuracy: {:.2f}, F1 score: {:.2f}, '
                    'Elapsed time: {:0>2}:{:0>2}:{:05.2f}'.format(
                        total_it, ep, iter_counter, train_accuracy, train_f1,
                        int(hours), int(minutes), seconds))

        # save model
        if ep % opts.model_save_freq == 0:
            saver.write_model(ep, total_it, 0, model, epoch=True)
            saver.write_img(ep, total_it, model)

        # example validation
        try:
            _validation(opts, model, healthy_val_dataloader,
                        anomaly_val_dataloader)
        except Exception as e:
            print(f'Encountered error during validation - {e}')
            raise e

    # example test
    try:
        _test(opts, model, healthy_test_dataloader, anomaly_test_dataloader)
    except Exception as e:
        print(f'Encountered error during validation - {e}')
        raise e

    # save last model
    saver.write_model(ep,
                      total_it,
                      iter_counter,
                      model,
                      model_name='model_last')
    saver.write_img(ep, total_it, model)

    return
def main(n=3, input_n=10):
    # parse options
    parser = TrainOptions()
    opts = parser.parse()

    # daita loader
    print('\n--- load dataset ---')
    if (origin_version):
        origin_root = "../datasets_npz/datasets_original_" + data_version
    else:
        origin_root = None

    dataset_test = dataset_pair_group(
        opts,
        "../datasets_npz/datasets_blur_" + data_version,
        "../datasets_npz/datasets_sharp_" + data_version,
        "../datasets_npz/datasets_sharp_start_" + data_version,
        "../datasets_npz/datasets_gt_vel_" + data_version,
        "../datasets_npz/datasets_gt_pos_" + data_version,
        "../datasets_npz/datasets_shutter_speed_" + data_version,
        test_split=test_split,
        origin_root=origin_root)
    test_loader = torch.utils.data.DataLoader(dataset_test,
                                              batch_size=opts.batch_size,
                                              shuffle=False,
                                              num_workers=opts.nThreads)
    test_iter = iter(test_loader)

    if z_delete_mode:
        z_zero = torch.zeros(opts.batch_size, 32).cuda(opts.gpu)

    # model
    print('\n--- load model ---')
    model = UID(opts)
    model.setgpu(opts.gpu)
    if opts.resume is None:
        model.initialize()
        ep0 = -1
        total_it = 0
    else:
        ep0, total_it = model.resume(opts.resume)
    model.set_scheduler(opts, last_ep=ep0)
    ep0 += 1
    print('start the test at epoch %d' % (ep0))

    # saver for display and output
    saver = Saver(opts)

    # test
    print('\n--- test ---')

    # for validation
    total_loss_vel = 0
    total_loss_gen_gt_vel = 0
    total_loss_content = 0
    total_loss_pred_vel = 0
    total_loss_pos = 0
    total_loss_ssim = 0
    total_loss_mse = 0
    loss_pos_per_sequence = np.zeros([opts.sequence_num - 2])
    loss_vel_per_sequence = np.zeros([opts.sequence_num - 2])
    loss_ssim_per_sequence = np.zeros([opts.sequence_num - 2])
    loss_converted_pos_per_sequence = np.zeros([n_obj, opts.sequence_num - 1])
    loss_mse_per_sequence = np.zeros([opts.sequence_num - 2])
    loss_mse_cov = np.array([])
    for i in range(n):
        # for validation
        loss_ssim = 0
        loss_vel = 0
        loss_gen_gt_vel = 0
        loss_content = 0
        loss_pred_vel = 0
        loss_pos = 0
        loss_mse = 0

        if not origin_version:
            try:
                images_a_s, images_b_s, gt_s, gt_pos_s, images_a_start_s, shutter_speed_s = next(
                    test_iter)
            except StopIteration:
                test_iter = iter(test_loader)
                images_a_s, images_b_s, gt_s, gt_pos_s, images_a_start_s, shutter_speed_s = next(
                    test_iter)
            while (images_a_s.size(0) != opts.batch_size):
                try:
                    images_a_s, images_b_s, gt_s, gt_pos_s, images_a_start_s, shutter_speed_s = next(
                        test_iter)
                except StopIteration:
                    test_iter = iter(test_loader)
                    images_a_s, images_b_s, gt_s, gt_pos_s, images_a_start_s, shutter_speed_s = next(
                        test_iter)
        else:
            try:
                images_a_s, images_b_s, gt_s, gt_pos_s, images_a_start_s, shutter_speed_s, images_origin_s = next(
                    test_iter)
            except StopIteration:
                test_iter = iter(test_loader)
                images_a_s, images_b_s, gt_s, gt_pos_s, images_a_start_s, shutter_speed_s, images_origin_s = next(
                    test_iter)
            while (images_a_s.size(0) != opts.batch_size):
                try:
                    images_a_s, images_b_s, gt_s, gt_pos_s, images_a_start_s, shutter_speed_s, images_origin_s = next(
                        test_iter)
                except StopIteration:
                    test_iter = iter(test_loader)
                    images_a_s, images_b_s, gt_s, gt_pos_s, images_a_start_s, shutter_speed_s, images_origin_s = next(
                        test_iter)

        images_a_s = images_a_s.cuda(opts.gpu).detach()
        images_b_s = images_b_s.cuda(opts.gpu).detach()
        images_a_start_s = images_a_start_s.cuda(opts.gpu).detach()
        for o in range(n_obj):
            gt_s[o] = gt_s[o].cuda(opts.gpu).detach()
            gt_pos_s[o] = gt_pos_s[o].cuda(opts.gpu).detach()
        shutter_speed_s = shutter_speed_s.cuda(opts.gpu).detach()

        # for position prediction
        gt_pos_pred = np.zeros([n_obj, 2, opts.sequence_num - 1, 2])
        gt_vel_pred = np.zeros([n_obj, 2, opts.sequence_num - 1, 32])
        if (origin_version):
            images_origin_s = images_origin_s.cuda(opts.gpu).detach()
        # test model
        s = 0
        size = images_a_s.size()
        ret = torch.zeros(size)
        ret_gt = torch.zeros(size)
        ret_origin = torch.zeros(size)
        z = None
        ret[:, 0] = images_b_s[:, 0]  # 첫번째 blur
        #ret[:,1] = images_b_s[:,1] # 두번째 blur
        ret_gt[:, 0] = images_b_s[:, 0]  # 첫번째 blur
        ret_gt[:, 1] = images_b_s[:, 1]  # 두번째 blur
        if origin_version:
            ret_origin[:, 0] = images_origin_s[:, 0]  # 첫번째 blur
            ret_origin[:, 1] = images_origin_s[:, 1]  # 두번째 blur
        # for velocity estimation
        vel_pred = 0
        mse_cov = np.zeros([2])

        while (s < (opts.sequence_num - 2)):

            if (s < input_n):
                images_sharp_prev = images_a_start_s[:, s, :, :, :]
                images_prev = images_b_s[:, s, :, :, :]
            else:
                images_sharp_prev = images_a_start_s[:, s, :, :, :]
                images_prev = model.input_B

            if ((s + 1) < input_n):
                images_a = images_a_start_s[:, s + 1, :, :, :]
                images_b = images_b_s[:, s + 1, :, :, :]
                images_a_end = images_a_s[:, s + 1, :, :, :]
                given_vel = None
                given_pos = None

            else:
                images_a = images_a_start_s[:, s + 1, :, :, :]
                images_b = model.post_B_pred_recons
                images_a_end = images_a_s[:, s + 1, :, :, :]
                given_vel = model.next_vel_pred
                # Euler
                given_pos = model.next_pos_pred

            images_post = images_b_s[:, s + 2, :, :, :]

            data_random = images_a_start_s[:,
                                           random.
                                           randint(0, opts.sequence_num -
                                                   1), :, :, :]
            gt = []
            gt_pos = []
            for o in range(n_obj):
                gt.append(gt_s[o][:, s + 1, :])
                gt_pos.append(gt_pos_s[o][:, s + 1, :])
            next_gt = []
            next_gt_pos = []
            for o in range(n_obj):
                next_gt.append(gt_s[o][:, s + 2, :])
                next_gt_pos.append(gt_pos_s[o][:, s + 2, :])
            shutter_speed = shutter_speed_s

            if z_delete_mode:
                z = z_zero

            #model.test(images_sharp_prev, images_prev, images_post, images_a, images_b, gt, gt_pos, images_a_end, data_random, next_gt, next_gt_pos, z, given_vel = given_vel, given_pos = given_pos, shutter_speed=shutter_speed)
            model.test(images_prev,
                       images_post,
                       images_b,
                       gt,
                       gt_pos,
                       next_gt,
                       next_gt_pos,
                       z,
                       given_vel=given_vel,
                       given_pos=given_pos,
                       shutter_speed=shutter_speed)
            s += 1
            #z = None

            if ((s + 1) < input_n):
                ret[:, s + 1] = images_b_s[:, s + 1, :, :, :]
                ret_gt[:, s + 1] = images_b_s[:, s + 1, :, :, :]
                if (origin_version):
                    ret_origin[:, s + 1] = images_origin_s[:, s + 1, :, :, :]
                for o in range(n_obj):
                    gt_pos_pred[o, :, s - 1] = model.pos_pred[o].cpu().detach()
                    gt_pos_pred[o, :,
                                s] = model.next_pos_pred[o].cpu().detach()
                    gt_vel_pred[o, :, s -
                                1] = model.vel_pred[o].cpu().detach()  #[:,30:]
                    gt_vel_pred[
                        o, :,
                        s] = model.next_vel_pred[o].cpu().detach()  #[:,30:]
            else:
                if (s + 1) == input_n:
                    ret[:, s] = model.recons_B
                    #ret[:,s] = model.fake_B_encoded_with_gt[2]
                ret[:, s + 1] = model.post_B_pred_recons
                ret_gt[:, s + 1] = images_b_s[:, s + 1, :, :, :]
                if (origin_version):
                    ret_origin[:, s + 1] = images_origin_s[:, s + 1, :, :, :]
                for o in range(n_obj):
                    gt_pos_pred[o, :, s - 1] = model.pos_pred[o].cpu().detach()
                    gt_pos_pred[o, :,
                                s] = model.next_pos_pred[o].cpu().detach()
                    gt_vel_pred[o, :, s -
                                1] = model.vel_pred[o].cpu().detach()  #[:,30:]
                    gt_vel_pred[
                        o, :,
                        s] = model.next_vel_pred[o].cpu().detach()  #[:,30:]

            #for validation
            #loss_gen += model.loss_gen
            #loss_vel_recons += model.loss_vel_recons
            loss_gen_gt_vel += model.loss_gen_gt_vel
            #loss_inverse += model.loss_inverse
            #loss_vel_dir += model.loss_vel_dir
            #loss_content += model.loss_content
            loss_content = 0
            #loss_pred_vel = 0
            #loss_ssim += model.ssim_err
            for o in range(n_obj):
                loss_vel += model.loss_vel[o]
                loss_pred_vel += model.loss_pred_vel[o]
                loss_pos += model.loss_pos[o]

            #SSIM loss check version2 which is same method with in E3D
            #=============start from here======================
            pred_0 = model.post_B_pred_recons[0].cpu().detach().numpy(
            ).transpose([1, 2, 0]) * 0.5 + 0.5
            gt_0 = images_post[0].cpu().numpy().transpose([1, 2, 0
                                                           ]) * 0.5 + 0.5
            pred_1 = model.post_B_pred_recons[1].cpu().detach().numpy(
            ).transpose([1, 2, 0]) * 0.5 + 0.5
            gt_1 = images_post[1].cpu().numpy().transpose([1, 2, 0
                                                           ]) * 0.5 + 0.5

            pred_0_grey = pred_0[:, :, 0:
                                 1] * 0.2126 + pred_0[:, :, 1:
                                                      2] * 0.7152 + pred_0[:, :,
                                                                           2:] * 0.0722
            gt_0_grey = gt_0[:, :, 0:
                             1] * 0.2126 + gt_0[:, :, 1:
                                                2] * 0.7152 + gt_0[:, :,
                                                                   2:] * 0.0722
            pred_1_grey = pred_1[:, :, 0:
                                 1] * 0.2126 + pred_1[:, :, 1:
                                                      2] * 0.7152 + pred_1[:, :,
                                                                           2:] * 0.0722
            gt_1_grey = gt_1[:, :, 0:
                             1] * 0.2126 + gt_1[:, :, 1:
                                                2] * 0.7152 + gt_1[:, :,
                                                                   2:] * 0.0722

            loss_ssim += compare_ssim(
                pred_0_grey, gt_0_grey, multichannel=True, win_size=7) / 2
            loss_ssim += compare_ssim(
                pred_1_grey, gt_1_grey, multichannel=True, win_size=7) / 2

            loss_mse += np.sum((pred_0_grey - gt_0_grey)**2) / 2
            loss_mse += np.sum((pred_1_grey - gt_1_grey)**2) / 2
            mse_cov[0] += [np.sum((pred_0_grey - gt_0_grey)**2)]
            mse_cov[1] += [np.sum((pred_1_grey - gt_1_grey)**2)]
            #=============to here==============================

            loss_vel_per_sequence[s - 1] += loss_vel / n
            loss_pos_per_sequence[s - 1] += loss_pos / n
            loss_mse_per_sequence[s - 1] += (np.sum(
                (pred_0_grey - gt_0_grey)**2) / 2 + np.sum(
                    (pred_1_grey - gt_1_grey)**2) / 2) / n
            #loss_ssim_per_sequence[s-1] += model.ssim_err/n
            loss_ssim_per_sequence[s - 1] += (
                compare_ssim(pred_0, gt_0, multichannel=True, win_size=7) / 2 +
                compare_ssim(pred_1, gt_1, multichannel=True, win_size=7) /
                2) / n
            '''
            print("loss_vel_"+str(s)+": "+ str(model.loss_vel))
            print("loss_pos_"+str(s)+": "+ str(model.loss_pos))
            '''
        if data_version in convert_datasets:
            visited = [[], []]  # batch size
            for o in range(n_obj):
                min = [1000, 1000]
                argmin = [0, 0]  # batch size
                gt_pos_convert = gt_pos_s[o][:, 1:, 30:].cpu().numpy()
                gt_vel_convert = gt_s[o][:, 1:, :].cpu().numpy()
                #print(gt_pos_convert.shape)
                for b in range(2):  # batch size
                    for oo in range(n_obj):
                        if oo in visited[b]: continue
                        if np.mean(
                            (gt_pos_convert[b][0] - gt_pos_pred[oo][b][0])**
                                2) < min[b]:
                            min[b] = np.mean((gt_pos_convert[b][0] -
                                              gt_pos_pred[oo][b][0])**2)
                            argmin[b] = oo
                    visited[b].append(argmin[b])
                comp_pos_convert = np.zeros_like(gt_pos_pred[0])
                comp_vel_convert = np.zeros_like(gt_vel_pred[0])
                #print(comp_pos_convert.shape)
                for b in range(2):
                    comp_pos_convert[b] = gt_pos_pred[argmin[b], b]
                    comp_vel_convert[b] = gt_vel_pred[argmin[b], b]
                #print(gt_pos_convert)
                gt_pos_convert = conver_to_straight_moving(
                    gt_pos_convert, gt_vel_convert)
                comp_pos_convert = conver_to_straight_moving(
                    comp_pos_convert, comp_vel_convert)
                #print(gt_pos_convert)

                #print(gt_pos_convert[1])
                #print(comp_pos_convert[1])
                converted_pos_loss = np.mean(
                    (gt_pos_convert - comp_pos_convert)**2 * 100, axis=(0, -1))
                #print(converted_pos_loss)
                print("converted_pos_loss: " +
                      str(np.mean(converted_pos_loss)))
                #loss_converted_pos_per_sequence += converted_pos_loss/n
                loss_converted_pos_per_sequence[o] += converted_pos_loss / n
        if throwing_mode:
            trajectories_gt = []
            trajectories_pred = []
            for o in range(n_obj):
                trajectories_gt.append(gt_pos_s[o][:, 1:, 30:].cpu().numpy())
                trajectories_pred.append(gt_pos_pred[o])
            img_dir = '%s/prediction' % (opts.visualize_root)
            img_filename = '/trajectory_%05d_%02d.jpg' % (ep0, i)
            show_trajectory(img_dir, img_filename, trajectories_gt,
                            trajectories_pred)
        #conver_to_straight_moving
        #for validation
        #loss_gen /= (opts.sequence_num - 2)
        loss_vel /= ((opts.sequence_num - 2) * n_obj)
        #loss_vel_recons /= (opts.sequence_num - 2)
        loss_gen_gt_vel /= (opts.sequence_num - 2)
        #loss_inverse /= (opts.sequence_num - 2)
        #loss_vel_dir /= (opts.sequence_num - 2)
        #loss_content /= (opts.sequence_num - 2)
        loss_pred_vel /= ((opts.sequence_num - 2) * n_obj)
        loss_pos /= ((opts.sequence_num - 2) * n_obj)
        loss_ssim /= (opts.sequence_num - 2)
        #loss_mse /= (opts.sequence_num -2)
        print('=============================')
        print(
            'vel_loss %04f, gen_loss_gt_vel %04f, loss_content %04f, vel_pred_loss %04f, pos_loss %04f, ssim_loss %04f, loss_mse %04f'
            % (loss_vel, loss_gen_gt_vel, loss_content, loss_pred_vel,
               loss_pos, loss_ssim, loss_mse))
        print('=============================')
        saver.write_pred_img(ep0,
                             i,
                             input_n,
                             ret,
                             ret_gt,
                             ret_origin=ret_origin,
                             origin_version=origin_version)

        #total_loss_gen += loss_gen
        total_loss_vel += loss_vel
        #total_loss_vel_recons += loss_vel_recons
        total_loss_gen_gt_vel += loss_gen_gt_vel
        #total_loss_inverse += loss_inverse
        #total_loss_vel_dir += loss_vel_dir
        total_loss_content += loss_content
        total_loss_pred_vel += loss_pred_vel
        total_loss_pos += loss_pos
        total_loss_ssim += loss_ssim.item()
        total_loss_mse += loss_mse
        loss_mse_cov = np.append(loss_mse_cov, mse_cov[0])
        loss_mse_cov = np.append(loss_mse_cov, mse_cov[1])

    #total_loss_gen /= n
    total_loss_vel /= n
    #total_loss_vel_recons /= n
    total_loss_gen_gt_vel /= n
    #total_loss_inverse /= n
    #total_loss_vel_dir /= n
    total_loss_content /= n
    total_loss_pred_vel /= n
    total_loss_pos /= n
    total_loss_ssim /= n
    total_loss_mse /= n

    print('==========================================================')
    print('===========================TOTAL==========================')
    print(
        'vel_loss %04f, gen_loss_gt_vel %04f, loss_content %04f, vel_pred_loss %04f, pos_loss %04f, loss_ssim %04f, loss_mse %04f'
        %
        (total_loss_vel, total_loss_gen_gt_vel, total_loss_content,
         total_loss_pred_vel, total_loss_pos, total_loss_ssim, total_loss_mse))
    print("converted_pos_loss %04f" %
          (np.mean(loss_converted_pos_per_sequence)))
    print('==========================================================')
    print('==========================================================')
    print("loss_pos %04f" % total_loss_pos)
    for i in range(opts.sequence_num - 2):
        print(str(loss_pos_per_sequence[i]))
    print("loss_vel")
    for i in range(opts.sequence_num - 2):
        print(str(loss_vel_per_sequence[i]))
    print("loss_ssim %04f" % total_loss_ssim)
    for i in range(opts.sequence_num - 2):
        print(str(loss_ssim_per_sequence[i]))
    print("loss_mse %04f" % total_loss_mse)
    for i in range(opts.sequence_num - 2):
        print(str(loss_mse_per_sequence[i]))
    if data_version in convert_datasets:
        print("loss_converted_pos %04f" %
              (np.mean(loss_converted_pos_per_sequence)))
        loss_converted_pos_per_sequence = np.mean(
            loss_converted_pos_per_sequence, axis=0)
        for i in range(opts.sequence_num - 2):
            print(str(loss_converted_pos_per_sequence[i]))

    print("mse cov: " + str(np.cov(loss_mse_cov)))
    return
Exemple #15
0
def main():
    # parse options
    parser = TrainOptions()
    opts = parser.parse()

    # daita loader
    print('\n--- load dataset ---')
    vocab = pickle.load(
        open(os.path.join(opts.vocab_path, '%s_vocab.pkl' % opts.data_name),
             'rb'))
    vocab_size = len(vocab)
    opts.vocab_size = vocab_size
    torch.backends.cudnn.enabled = False
    # Load data loaders
    train_loader, val_loader = data.get_loaders(opts.data_name, vocab,
                                                opts.crop_size,
                                                opts.batch_size, opts.workers,
                                                opts)
    test_loader = data.get_test_loader('test', opts.data_name, vocab,
                                       opts.crop_size, opts.batch_size,
                                       opts.workers, opts)
    # model
    print('\n--- load subspace ---')
    subspace = model_2.VSE(opts)
    subspace.setgpu()
    print('\n--- load model ---')
    model = DRIT(opts)
    model.setgpu(opts.gpu)
    if opts.resume is None:  #之前没有保存过模型
        model.initialize()
        ep0 = -1
        total_it = 0
    else:
        ep0, total_it = model.resume(opts.resume)
    model.set_scheduler(opts, last_ep=ep0)
    ep0 += 1
    print('start the training at epoch %d' % (ep0))

    # saver for display and output
    saver = Saver(opts)

    # train
    print('\n--- train ---')
    max_it = 500000
    score = 0.0
    subspace.train_start()
    for ep in range(ep0, opts.pre_iter):
        print('-----ep:{} --------'.format(ep))
        for it, (images, captions, lengths, ids) in enumerate(train_loader):
            if it >= opts.train_iter:
                break
            # input data
            images = images.cuda(opts.gpu).detach()
            captions = captions.cuda(opts.gpu).detach()

            img, cap = subspace.train_emb(images,
                                          captions,
                                          lengths,
                                          ids,
                                          pre=True)  #[b,1024]

            subspace.pre_optimizer.zero_grad()
            img = img.view(images.size(0), -1, 32, 32)
            cap = cap.view(images.size(0), -1, 32, 32)

            model.pretrain_ae(img, cap)

            if opts.grad_clip > 0:
                clip_grad_norm(subspace.params, opts.grad_clip)

            subspace.pre_optimizer.step()

    for ep in range(ep0, opts.n_ep):
        subspace.train_start()
        adjust_learning_rate(opts, subspace.optimizer, ep)
        for it, (images, captions, lengths, ids) in enumerate(train_loader):
            if it >= opts.train_iter:
                break
            # input data
            images = images.cuda(opts.gpu).detach()
            captions = captions.cuda(opts.gpu).detach()

            img, cap = subspace.train_emb(images, captions, lengths,
                                          ids)  #[b,1024]

            img = img.view(images.size(0), -1, 32, 32)
            cap = cap.view(images.size(0), -1, 32, 32)

            subspace.optimizer.zero_grad()

            for p in model.disA.parameters():
                p.requires_grad = True
            for p in model.disB.parameters():
                p.requires_grad = True
            for p in model.disA_attr.parameters():
                p.requires_grad = True
            for p in model.disB_attr.parameters():
                p.requires_grad = True

            for i in range(opts.niters_gan_d):  #5
                model.update_D(img, cap)

            for p in model.disA.parameters():
                p.requires_grad = False
            for p in model.disB.parameters():
                p.requires_grad = False
            for p in model.disA_attr.parameters():
                p.requires_grad = False
            for p in model.disB_attr.parameters():
                p.requires_grad = False

            for i in range(opts.niters_gan_enc):
                model.update_E(img, cap)  #利用新的content损失函数

            subspace.optimizer.step()

            print('total_it: %d (ep %d, it %d), lr %09f' %
                  (total_it, ep, it, model.gen_opt.param_groups[0]['lr']))
            total_it += 1

        # decay learning rate
        if opts.n_ep_decay > -1:
            model.update_lr()

        # save result image
        #saver.write_img(ep, model)
        if (ep + 1) % opts.n_ep == 0:
            print('save model')
            filename = os.path.join(opts.result_dir, opts.name)
            model.save('%s/final_model.pth' % (filename), ep, total_it)
            torch.save(subspace.state_dict(),
                       '%s/final_subspace.pth' % (filename))
        elif (ep + 1) % 10 == 0:
            print('save model')
            filename = os.path.join(opts.result_dir, opts.name)
            model.save('%s/%s_model.pth' % (filename, str(ep + 1)), ep,
                       total_it)
            torch.save(subspace.state_dict(),
                       '%s/%s_subspace.pth' % (filename, str(ep + 1)))

        if (ep + 1) % opts.model_save_freq == 0:
            a = None
            b = None
            c = None
            d = None
            subspace.val_start()
            for it, (images, captions, lengths, ids) in enumerate(test_loader):
                if it >= opts.val_iter:
                    break
                images = images.cuda(opts.gpu).detach()
                captions = captions.cuda(opts.gpu).detach()

                img_emb, cap_emb = subspace.forward_emb(images,
                                                        captions,
                                                        lengths,
                                                        volatile=True)

                img = img_emb.view(images.size(0), -1, 32, 32)
                cap = cap_emb.view(images.size(0), -1, 32, 32)
                image1, text1 = model.test_model2(img, cap)
                img2 = image1.view(images.size(0), -1)
                cap2 = text1.view(images.size(0), -1)

                if a is None:
                    a = np.zeros(
                        (opts.val_iter * opts.batch_size, img_emb.size(1)))
                    b = np.zeros(
                        (opts.val_iter * opts.batch_size, cap_emb.size(1)))

                    c = np.zeros(
                        (opts.val_iter * opts.batch_size, img2.size(1)))
                    d = np.zeros(
                        (opts.val_iter * opts.batch_size, cap2.size(1)))

                a[ids] = img_emb.data.cpu().numpy().copy()
                b[ids] = cap_emb.data.cpu().numpy().copy()

                c[ids] = img2.data.cpu().numpy().copy()
                d[ids] = cap2.data.cpu().numpy().copy()

            aa = torch.from_numpy(a)
            bb = torch.from_numpy(b)

            cc = torch.from_numpy(c)
            dd = torch.from_numpy(d)

            (r1, r5, r10, medr, meanr) = i2t(aa, bb, measure=opts.measure)
            print('test640: subspace: Med:{}, r1:{}, r5:{}, r10:{}'.format(
                medr, r1, r5, r10))

            (r1i, r5i, r10i, medri, meanr) = t2i(aa, bb, measure=opts.measure)
            print('test640: subspace: Med:{}, r1:{}, r5:{}, r10:{}'.format(
                medri, r1i, r5i, r10i))

            (r2, r3, r4, m1, m2) = i2t(cc, dd, measure=opts.measure)
            print('test640: encoder: Med:{}, r1:{}, r5:{}, r10:{}'.format(
                m1, r2, r3, r4))

            (r2i, r3i, r4i, m1i, m2i) = t2i(cc, dd, measure=opts.measure)
            print('test640: encoder: Med:{}, r1:{}, r5:{}, r10:{}'.format(
                m1i, r2i, r3i, r4i))

            curr = r2 + r3 + r4 + r2i + r3i + r4i

            if curr > score:
                score = curr
                print('save model')
                filename = os.path.join(opts.result_dir, opts.name)
                model.save('%s/best_model.pth' % (filename), ep, total_it)
                torch.save(subspace.state_dict(),
                           '%s/subspace.pth' % (filename))

            a = None
            b = None
            c = None
            d = None

            for it, (images, captions, lengths, ids) in enumerate(test_loader):

                images = images.cuda(opts.gpu).detach()
                captions = captions.cuda(opts.gpu).detach()

                img_emb, cap_emb = subspace.forward_emb(images,
                                                        captions,
                                                        lengths,
                                                        volatile=True)

                img = img_emb.view(images.size(0), -1, 32, 32)
                cap = cap_emb.view(images.size(0), -1, 32, 32)
                image1, text1 = model.test_model2(img, cap)
                img2 = image1.view(images.size(0), -1)
                cap2 = text1.view(images.size(0), -1)

                if a is None:
                    a = np.zeros((len(test_loader.dataset), img_emb.size(1)))
                    b = np.zeros((len(test_loader.dataset), cap_emb.size(1)))

                    c = np.zeros((len(test_loader.dataset), img2.size(1)))
                    d = np.zeros((len(test_loader.dataset), cap2.size(1)))

                a[ids] = img_emb.data.cpu().numpy().copy()
                b[ids] = cap_emb.data.cpu().numpy().copy()

                c[ids] = img2.data.cpu().numpy().copy()
                d[ids] = cap2.data.cpu().numpy().copy()

            aa = torch.from_numpy(a)
            bb = torch.from_numpy(b)

            cc = torch.from_numpy(c)
            dd = torch.from_numpy(d)

            (r1, r5, r10, medr, meanr) = i2t(aa, bb, measure=opts.measure)
            print('test5000: subspace: Med:{}, r1:{}, r5:{}, r10:{}'.format(
                medr, r1, r5, r10))

            (r1i, r5i, r10i, medri, meanr) = t2i(aa, bb, measure=opts.measure)
            print('test5000: subspace: Med:{}, r1:{}, r5:{}, r10:{}'.format(
                medri, r1i, r5i, r10i))

            (r2, r3, r4, m1, m2) = i2t(cc, dd, measure=opts.measure)
            print('test5000: encoder: Med:{}, r1:{}, r5:{}, r10:{}'.format(
                m1, r2, r3, r4))

            (r2i, r3i, r4i, m1i, m2i) = t2i(cc, dd, measure=opts.measure)
            print('test5000: encoder: Med:{}, r1:{}, r5:{}, r10:{}'.format(
                m1i, r2i, r3i, r4i))

    return
Exemple #16
0
        pbar = tqdm.tqdm(train_loader, desc=f'Epoch#{epoch}')
        for i, data in enumerate(pbar, 1):
            loss_terms, images = trainer.optimize_parameters(
                prepare_input(data), update_g=i % 5 == 0, update_d=True)
            pbar.set_postfix(history.add(loss_terms))
            logger.image(images, epoch=epoch, prefix='train_')
        ''' validate '''
        trainer.netG.eval()
        for data in val_loader:
            loss_terms, images = trainer.optimize_parameters(
                prepare_input(data), update_g=False, update_d=False)
            history.add(loss_terms, log_suffix='_val')
            logger.image(images, epoch=epoch, prefix='val_')

        logger.scalar(history.metric(), epoch)
        if epoch % opt.save_epoch_freq == 0:
            print(f'saving the model at the end of epoch {epoch}')
            trainer.save('latest')
            trainer.save(epoch)
        trainer.update_learning_rate()

        # clean the state of extensions
        history.clear()
        logger.clear()


if __name__ == '__main__':
    parser = TrainOptions()
    parser.parser.add_argument('--subjects', type=str, nargs='+')
    main(parser.parse())
Exemple #17
0
def main():
    # parse options
    parser = TrainOptions()
    opts = parser.parse()

    # data loader
    print('--- load parameter ---')
    # outer_iter = opts.outer_iter
    # fade_iter = max(1.0, float(outer_iter / 2))
    epochs = opts.epoch
    batchsize = opts.batchsize
    # datasize = opts.datasize
    # datarange = opts.datarange
    augementratio = opts.augementratio
    centercropratio = opts.centercropratio

    # model
    print('--- create model ---')
    tetGAN = TETGAN(gpu=(opts.gpu != 0))
    if opts.gpu != 0:
        tetGAN.cuda()
    tetGAN.init_networks(weights_init)

    num_params = 0
    for param in tetGAN.parameters():
        num_params += param.numel()
    print('Total number of parameters in TET-GAN: %.3f M' % (num_params / 1e6))

    print('--- training ---')
    texture_class = 'base_gray_texture' in opts.dataset_class or 'skeleton_gray_texture' in opts.dataset_class
    if texture_class:
        tetGAN.load_state_dict(torch.load(opts.model))
        dataset_path = os.path.join(opts.train_path, opts.dataset_class,
                                    'style')
        val_dataset_path = os.path.join(opts.train_path, opts.dataset_class,
                                        'val')
        if 'base_gray_texture' in opts.dataset_class:
            few_size = 6
            batchsize = 2
            epochs = 1500
        elif 'skeleton_gray_texture' in opts.dataset_class:
            few_size = 30
            batchsize = 10
            epochs = 300
        fnames = load_trainset_batchfnames_dualnet(dataset_path,
                                                   batchsize,
                                                   few_size=few_size)
        val_fnames = sorted(os.listdir(val_dataset_path))
        style_fnames = sorted(os.listdir(dataset_path)[:few_size])
    else:
        dataset_path = os.path.join(opts.train_path, opts.dataset_class,
                                    'train')
        fnames = load_trainset_batchfnames_dualnet(dataset_path, batchsize)

    tetGAN.train()

    train_size = os.listdir(dataset_path)
    print('List of %d styles:' % (len(train_size)))

    result_dir = os.path.join(opts.result_dir, opts.dataset_class)
    if not os.path.exists(result_dir):
        os.mkdir(result_dir)

    for epoch in range(epochs):
        for idx, fname in enumerate(fnames):
            x, y_real, y = prepare_batch(fname, 1, 1, centercropratio,
                                         augementratio, opts.gpu)
            losses = tetGAN.one_pass(x[0], None, y[0], None, y_real[0], None,
                                     1, 0)
            if (idx + 1) % 100 == 0:
                print('Epoch [%d/%d], Iter [%d/%d]' %
                      (epoch + 1, epochs, idx + 1, len(fnames)))
                print(
                    'Lrec: %.3f, Ldadv: %.3f, Ldesty: %.3f, Lsadv: %.3f, Lsty: %.3f'
                    % (losses[0], losses[1], losses[2], losses[3], losses[4]))
        if texture_class and ((epoch + 1) % (epochs / 20)) == 0:
            outname = 'save/' + 'val_epoch' + str(
                epoch +
                1) + '_' + opts.dataset_class + '_' + opts.save_model_name
            print('--- save model Epoch [%d/%d] ---' % (epoch + 1, epochs))
            torch.save(tetGAN.state_dict(), outname)

            print('--- validating model [%d/%d] ---' % (epoch + 1, epochs))
            for val_idx, val_fname in enumerate(val_fnames):
                v_fname = os.path.join(val_dataset_path, val_fname)
                random.shuffle(style_fnames)
                sty_fname = style_fnames[0]
                s_fname = os.path.join(dataset_path, sty_fname)
                with torch.no_grad():
                    val_content = load_image_dualnet(v_fname, load_type=1)
                    val_sty = load_image_dualnet(s_fname, load_type=0)
                    if opts.gpu != 0:
                        val_content = val_content.cuda()
                        val_sty = val_sty.cuda()
                    result = tetGAN(val_content, val_sty)
                    if opts.gpu != 0:
                        result = to_data(result)
                    result_filename = os.path.join(
                        result_dir,
                        str(epoch) + '_' + val_fname)
                    print(result_filename)
                    save_image(result[0], result_filename)
        elif not texture_class and ((epoch + 1) % 2) == 0:
            outname = 'save/' + 'epoch' + str(epoch +
                                              1) + '_' + opts.save_model_name
            print('--- save model ---')
            torch.save(tetGAN.state_dict(), outname)