Exemple #1
0
def main(_):
    os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu_index

    # Initialize model and log folders
    if FLAGS.load_model is None:
        cur_time = datetime.now().strftime("%Y%m%d-%H%M")
    else:
        cur_time = FLAGS.load_model

    model_dir, log_dir, sample_dir, test_dir = utils.make_folders(
        is_train=FLAGS.is_train, cur_time=cur_time)

    # Logger
    logger = logging.getLogger(__name__)  # logger
    logger.setLevel(logging.INFO)
    utils.init_logger(logger=logger,
                      log_dir=log_dir,
                      is_train=FLAGS.is_train,
                      name='main')
    utils.print_main_parameters(logger, flags=FLAGS, is_train=FLAGS.is_train)

    # Initialize dataset
    data = Dataset(name=FLAGS.dataset,
                   is_train=FLAGS.is_train,
                   resized_factor=0.25,
                   log_dir=log_dir)

    # Initialize model
    if FLAGS.method.lower() == 'wgan-gp':
        model = WGAN_GP()
    else:
        model = DCGAN(image_shape=data.image_shape,
                      data_path=data(),
                      batch_size=FLAGS.batch_size,
                      z_dim=FLAGS.z_dim,
                      lr=FLAGS.learning_rate,
                      beta1=FLAGS.beta1,
                      total_iters=int(
                          np.ceil(FLAGS.epoch * data.num_images /
                                  FLAGS.batch_size)),
                      is_train=FLAGS.is_train,
                      log_dir=log_dir)

    # Intialize solver
    solver = Solver(model=model,
                    dataset_name=data.name,
                    batch_size=FLAGS.batch_size,
                    z_dim=FLAGS.z_dim,
                    log_dir=log_dir)

    # Initialize saver
    saver = tf.train.Saver(max_to_keep=1)

    if FLAGS.is_train:
        train(solver, data, saver, logger, sample_dir, model_dir, log_dir)
    else:
        test(solver, saver, test_dir, model_dir, log_dir)
def main(_):
    os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu_index

    # Initialize model and log folders:
    if FLAGS.load_model is None:
        cur_time = datetime.now().strftime("%Y%m%d-%H%M%S")
    else:
        cur_time = FLAGS.load_model

    model_dir, log_dir, sample_dir, _, test_dir = utils.make_folders(
        isTrain=FLAGS.is_train,
        curTime=cur_time,
        subfolder=os.path.join('generation', FLAGS.method))

    # Logger
    logger = logging.getLogger(__name__)  # logger
    logger.setLevel(logging.INFO)
    utils.init_logger(logger=logger,
                      logDir=log_dir,
                      isTrain=FLAGS.is_train,
                      name='egmain')
    print_main_parameters(logger, flags=FLAGS, is_train=FLAGS.is_train)

    # Initialize dataset
    data = Dataset(name=FLAGS.dataset,
                   track='Generative_Dataset',
                   isTrain=FLAGS.is_train,
                   resizedFactor=FLAGS.resize_factor,
                   logDir=log_dir)

    # Initialize model
    model = Pix2pix(decode_img_shape=data.decode_img_shape,
                    output_shape=data.single_img_shape,
                    num_classes=data.num_classes,
                    data_path=data(is_train=FLAGS.is_train),
                    batch_size=FLAGS.batch_size,
                    lr=FLAGS.learning_rate,
                    total_iters=FLAGS.iters,
                    is_train=FLAGS.is_train,
                    log_dir=log_dir,
                    resize_factor=FLAGS.resize_factor,
                    lambda_1=FLAGS.lambda_1)

    # Initialize solver
    solver = Solver(model=model, data=data, is_train=FLAGS.is_train)

    # Initialize saver
    saver = tf.compat.v1.train.Saver(max_to_keep=1)

    if FLAGS.is_train is True:
        train(solver, saver, logger, model_dir, log_dir, sample_dir)
    else:
        test(solver, saver, model_dir, test_dir)
def main(_):
    os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu_index

    # Initialize model and log folders:
    if FLAGS.load_model is None:
        cur_time = datetime.now().strftime("%Y%m%d-%H%M%S")
    else:
        cur_time = FLAGS.load_model

    model_dir, log_dir, sample_dir, _, test_dir = utils.make_folders(is_train=FLAGS.is_train,
                                                                     cur_time=cur_time,
                                                                     subfolder='generation')

    # Logger
    logger = logging.getLogger(__name__)  # logger
    logger.setLevel(logging.INFO)
    utils.init_logger(logger=logger, log_dir=log_dir, is_train=FLAGS.is_train, name='main')
    print_main_parameters(logger, flags=FLAGS, is_train=FLAGS.is_train)

    # Initialize Session
    sess = tf.compat.v1.Session()

    # Initialize dataset
    data = eg_dataset.Dataset(name='generation', resize_factor=FLAGS.resize_factor,
                              is_train=FLAGS.is_train, log_dir=log_dir,  is_debug=False)

    # Initialize model
    pix2pix = Pix2pix(input_img_shape=data.input_img_shape,
                      gen_mode=FLAGS.gen_mode,
                      iden_model_dir=FLAGS.load_iden_model,
                      session=sess,
                      lr=FLAGS.learning_rate,
                      total_iters=int(np.ceil((FLAGS.epoch * data.num_train_imgs) / FLAGS.batch_size)),
                      is_train=FLAGS.is_train,
                      log_dir=log_dir,
                      lambda_1=FLAGS.lambda_1,
                      num_class=data.num_seg_class)

    # Initialize solver
    solver = Solver(data=data, gen_model=pix2pix, session=sess, flags=FLAGS, log_dir=log_dir)

    if FLAGS.is_train is True:
        train(solver, logger, model_dir, log_dir, sample_dir)
    else:
        test(solver, model_dir, log_dir, test_dir)
Exemple #4
0
def main(_):
    os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu_index

    # Evaluation optimizers and dropout
    optimizer_options = [
        'SGDNesterov', 'Adagrad', 'RMSProp', 'AdaDelta', 'Adam'
    ]
    dropout_options = [False, True]

    # Initialize model and log folders
    if FLAGS.load_model is None:
        cur_time = datetime.now().strftime("%Y%m%d-%H%M")
    else:
        cur_time = FLAGS.load_model

    model_dir, log_dir = make_folders(is_train=FLAGS.is_train,
                                      base=FLAGS.model,
                                      cur_time=cur_time)
    init_logger(log_dir=log_dir, is_train=FLAGS.is_train)

    if FLAGS.model.lower() == 'logistic' or FLAGS.model.lower(
    ) == 'neural_network':
        # Initialize MNIST dataset and print info
        data = MNIST(log_dir=log_dir)
        data.info(use_logging=True if FLAGS.is_train else False,
                  show_img=False)  # print basic information
    elif FLAGS.model.lower() == 'cnn':
        # Initialize CIFAR10 dataset and print info
        data = CIFAR10(log_dir=log_dir, is_train=FLAGS.is_train)
        data.info(use_logging=True if FLAGS.is_train else False,
                  show_img=False,
                  smooth=True)
        data.preprocessing(use_whiten=FLAGS.is_whiten
                           )  # data preprocessing [whiten or subtract_mean]
    else:
        raise NotImplementedError

    if FLAGS.is_train:
        train(data, optimizer_options, dropout_options, model_dir, log_dir)
    else:
        test(data, optimizer_options, dropout_options, model_dir, log_dir)
Exemple #5
0
def main(_):
    os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu_index

    # Initialize model and log folders
    if FLAGS.load_model is None:
        cur_time = datetime.now().strftime("%Y%m%d-%H%M")
    else:
        cur_time = FLAGS.load_model

    model_dir, log_dir, sample_dir, test_dir = utils.make_folders(
        is_train=FLAGS.is_train, cur_time=cur_time)
    init_logger(log_dir=log_dir, is_train=FLAGS.is_train)

    # Initilize dataset
    data = Dataset(name=FLAGS.dataset, log_dir=log_dir)
    data.info(use_logging=True, log_dir=log_dir)

    # Initialize session
    sess = tf.Session()

    # Initilize model
    model = Model(input_shape=data.input_shape,
                  output_shape=data.output_shape,
                  lr=FLAGS.learning_rate,
                  weight_decay=FLAGS.weight_decay,
                  total_iters=FLAGS.iters,
                  is_train=FLAGS.is_train,
                  log_dir=log_dir,
                  name='U-Net')

    # Initilize solver
    solver = Solver(sess, model, data.mean_value)
    saver = tf.train.Saver(max_to_keep=1)

    if FLAGS.is_train:
        train(data, solver, saver, model_dir, log_dir, sample_dir)
    else:
        test(data, solver, saver, model_dir, test_dir)
Exemple #6
0
def main(_):
    os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu_index

    # Initialize model and log folders
    if FLAGS.load_model is None:
        curTime = datetime.now().strftime("%Y%m%d-%H%M%S")
    else:
        curTime = FLAGS.load_model

    modelDir, logDir, sampleDir, valDir, testDir = utils.make_folders(
        isTrain=FLAGS.is_train, curTime=curTime, subfolder=FLAGS.method)

    # Logger
    logger = logging.getLogger(__name__)  # logger
    logger.setLevel(logging.INFO)
    utils.init_logger(logger=logger,
                      logDir=logDir,
                      isTrain=FLAGS.is_train,
                      name='main')
    utils.print_main_parameters(logger, flags=FLAGS, isTrain=FLAGS.is_train)

    # Initialize dataset
    data = Dataset(name=FLAGS.dataset,
                   isTrain=FLAGS.is_train,
                   resizedFactor=FLAGS.resize_factor,
                   logDir=logDir)

    # Initialize model
    if not 'v5' in FLAGS.method:
        model = UNet(decodeImgShape=data.decodeImgShape,
                     outputShape=data.singleImgShape,
                     numClasses=data.numClasses,
                     dataPath=data(isTrain=FLAGS.is_train),
                     batchSize=FLAGS.batch_size,
                     lr=FLAGS.learning_rate,
                     weightDecay=FLAGS.weight_decay,
                     totalIters=FLAGS.iters,
                     isTrain=FLAGS.is_train,
                     logDir=logDir,
                     method=FLAGS.method,
                     multi_test=FLAGS.multi_test,
                     advanced_multi_test=FLAGS.advanced_multi_test,
                     resize_factor=FLAGS.resize_factor,
                     use_dice_loss=FLAGS.use_dice_loss,
                     lambda_one=FLAGS.lambda_one,
                     name='UNet')
    else:
        model = DenseUNet(decodeImgShape=data.decodeImgShape,
                          outputShape=data.singleImgShape,
                          numClasses=data.numClasses,
                          dataPath=data(isTrain=FLAGS.is_train),
                          batchSize=FLAGS.batch_size,
                          lr=FLAGS.learning_rate,
                          weightDecay=FLAGS.weight_decay,
                          totalIters=FLAGS.iters,
                          isTrain=FLAGS.is_train,
                          logDir=logDir,
                          method=FLAGS.method,
                          multi_test=FLAGS.multi_test,
                          resize_factor=FLAGS.resize_factor,
                          use_dice_loss=FLAGS.use_dice_loss,
                          use_batch_norm=FLAGS.use_batch_norm,
                          lambda_one=FLAGS.lambda_one,
                          name='DenseUNet')

    # Initialize solver
    solver = Solver(model=model,
                    data=data,
                    is_train=FLAGS.is_train,
                    multi_test=FLAGS.multi_test)

    # Initialize saver
    saver = tf.compat.v1.train.Saver(max_to_keep=1)

    if FLAGS.is_train is True:
        train(solver, saver, logger, modelDir, logDir, sampleDir)
    else:
        test(solver, saver, modelDir, valDir, testDir, data)
Exemple #7
0
def train():
    from benchmark import calc_fid, extract_feature_from_generator_fn, load_patched_inception_v3, real_image_loader, image_generator, image_generator_perm
    import lpips

    from config import IM_SIZE_GAN, BATCH_SIZE_GAN, NFC, NBR_CLS, DATALOADER_WORKERS, EPOCH_GAN, ITERATION_AE, GAN_CKECKPOINT
    from config import SAVE_IMAGE_INTERVAL, SAVE_MODEL_INTERVAL, LOG_INTERVAL, SAVE_FOLDER, TRIAL_NAME, DATA_NAME, MULTI_GPU
    from config import FID_INTERVAL, FID_BATCH_NBR, PRETRAINED_AE_PATH
    from config import data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3

    real_features = None
    inception = load_patched_inception_v3().cuda()
    inception.eval()

    percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True)

    saved_image_folder = saved_model_folder = None
    log_file_path = None
    if saved_image_folder is None:
        saved_image_folder, saved_model_folder = make_folders(
            SAVE_FOLDER, 'GAN_' + TRIAL_NAME)
        log_file_path = saved_image_folder + '/../gan_log.txt'
        log_file = open(log_file_path, 'w')
        log_file.close()

    dataset = PairedMultiDataset(data_root_colorful,
                                 data_root_sketch_1,
                                 data_root_sketch_2,
                                 data_root_sketch_3,
                                 im_size=IM_SIZE_GAN,
                                 rand_crop=True)
    print('the dataset contains %d images.' % len(dataset))
    dataloader = iter(
        DataLoader(dataset,
                   BATCH_SIZE_GAN,
                   sampler=InfiniteSamplerWrapper(dataset),
                   num_workers=DATALOADER_WORKERS,
                   pin_memory=True))

    from datasets import ImageFolder
    from datasets import trans_maker_augment as trans_maker

    dataset_rgb = ImageFolder(data_root_colorful, trans_maker(512))
    dataset_skt = ImageFolder(data_root_sketch_3, trans_maker(512))

    net_ae = AE(nfc=NFC, nbr_cls=NBR_CLS)

    if PRETRAINED_AE_PATH is None:
        PRETRAINED_AE_PATH = 'train_results/' + 'AE_' + TRIAL_NAME + '/models/%d.pth' % ITERATION_AE
    else:
        from config import PRETRAINED_AE_ITER
        PRETRAINED_AE_PATH = PRETRAINED_AE_PATH + '/models/%d.pth' % PRETRAINED_AE_ITER

    net_ae.load_state_dicts(PRETRAINED_AE_PATH)
    net_ae.cuda()
    net_ae.eval()

    RefineGenerator = None
    if DATA_NAME == 'celeba':
        from models import RefineGenerator_face as RefineGenerator
    elif DATA_NAME == 'art' or DATA_NAME == 'shoe':
        from models import RefineGenerator_art as RefineGenerator
    net_ig = RefineGenerator(nfc=NFC, im_size=IM_SIZE_GAN).cuda()
    net_id = Discriminator(nc=3).cuda(
    )  # we use the patch_gan, so the im_size for D should be 512 even if training image size is 1024

    if MULTI_GPU:
        net_ae = nn.DataParallel(net_ae)
        net_ig = nn.DataParallel(net_ig)
        net_id = nn.DataParallel(net_id)

    net_ig_ema = copy_G_params(net_ig)

    opt_ig = optim.Adam(net_ig.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_id = optim.Adam(net_id.parameters(), lr=2e-4, betas=(0.5, 0.999))

    if GAN_CKECKPOINT is not None:
        ckpt = torch.load(GAN_CKECKPOINT)
        net_ig.load_state_dict(ckpt['ig'])
        net_id.load_state_dict(ckpt['id'])
        net_ig_ema = ckpt['ig_ema']
        opt_ig.load_state_dict(ckpt['opt_ig'])
        opt_id.load_state_dict(ckpt['opt_id'])

    ## create a log file
    losses_g_img = AverageMeter()
    losses_d_img = AverageMeter()
    losses_mse = AverageMeter()
    losses_rec_s = AverageMeter()

    losses_rec_ae = AverageMeter()

    fixed_skt = fixed_rgb = fixed_perm = None

    fid = [[0, 0]]

    for epoch in range(EPOCH_GAN):
        for iteration in tqdm(range(10000)):
            rgb_img, skt_img_1, skt_img_2, skt_img_3 = next(dataloader)

            rgb_img = rgb_img.cuda()

            rd = random.randint(0, 3)
            if rd == 0:
                skt_img = skt_img_1.cuda()
            elif rd == 1:
                skt_img = skt_img_2.cuda()
            else:
                skt_img = skt_img_3.cuda()

            if iteration == 0:
                fixed_skt = skt_img_3[:8].clone().cuda()
                fixed_rgb = rgb_img[:8].clone()
                fixed_perm = true_randperm(fixed_rgb.shape[0], 'cuda')

            ### 1. train D
            gimg_ae, style_feats = net_ae(skt_img, rgb_img)
            g_image = net_ig(gimg_ae, style_feats)

            pred_r = net_id(rgb_img)
            pred_f = net_id(g_image.detach())

            loss_d = d_hinge_loss(pred_r, pred_f)

            net_id.zero_grad()
            loss_d.backward()
            opt_id.step()

            loss_rec_ae = F.mse_loss(gimg_ae, rgb_img) + F.l1_loss(
                gimg_ae, rgb_img)
            losses_rec_ae.update(loss_rec_ae.item(), BATCH_SIZE_GAN)

            ### 2. train G
            pred_g = net_id(g_image)
            loss_g = g_hinge_loss(pred_g)

            if DATA_NAME == 'shoe':
                loss_mse = 10 * (F.l1_loss(g_image, rgb_img) +
                                 F.mse_loss(g_image, rgb_img))
            else:
                loss_mse = 10 * percept(
                    F.adaptive_avg_pool2d(g_image, output_size=256),
                    F.adaptive_avg_pool2d(rgb_img, output_size=256)).sum()
            losses_mse.update(loss_mse.item() / BATCH_SIZE_GAN, BATCH_SIZE_GAN)

            loss_all = loss_g + loss_mse

            if DATA_NAME == 'shoe':
                ### the grey image reconstruction
                perm = true_randperm(BATCH_SIZE_GAN)
                img_ae_perm, style_feats_perm = net_ae(skt_img, rgb_img[perm])

                gimg_grey = net_ig(img_ae_perm, style_feats_perm)
                gimg_grey = gimg_grey.mean(dim=1, keepdim=True)
                real_grey = rgb_img.mean(dim=1, keepdim=True)
                loss_rec_grey = F.mse_loss(gimg_grey, real_grey)
                loss_all += 10 * loss_rec_grey

            net_ig.zero_grad()
            loss_all.backward()
            opt_ig.step()

            for p, avg_p in zip(net_ig.parameters(), net_ig_ema):
                avg_p.mul_(0.999).add_(p.data, alpha=0.001)

            ### 3. logging
            losses_g_img.update(pred_g.mean().item(), BATCH_SIZE_GAN)
            losses_d_img.update(pred_r.mean().item(), BATCH_SIZE_GAN)

            if iteration % SAVE_IMAGE_INTERVAL == 0:  #show the current images
                with torch.no_grad():

                    backup_para_g = copy_G_params(net_ig)
                    load_params(net_ig, net_ig_ema)

                    gimg_ae, style_feats = net_ae(fixed_skt, fixed_rgb)
                    gmatch = net_ig(gimg_ae, style_feats)

                    gimg_ae_perm, style_feats = net_ae(fixed_skt,
                                                       fixed_rgb[fixed_perm])
                    gmismatch = net_ig(gimg_ae_perm, style_feats)

                    gimg = torch.cat([
                        F.interpolate(fixed_rgb, IM_SIZE_GAN),
                        F.interpolate(fixed_skt.repeat(1, 3, 1, 1),
                                      IM_SIZE_GAN), gmatch,
                        F.interpolate(gimg_ae, IM_SIZE_GAN), gmismatch,
                        F.interpolate(gimg_ae_perm, IM_SIZE_GAN)
                    ])

                    vutils.save_image(
                        gimg,
                        f'{saved_image_folder}/img_iter_{epoch}_{iteration}.jpg',
                        normalize=True,
                        range=(-1, 1))
                    del gimg

                    make_matrix(
                        dataset_rgb, dataset_skt, net_ae, net_ig, 5,
                        f'{saved_image_folder}/img_iter_{epoch}_{iteration}_matrix.jpg'
                    )

                    load_params(net_ig, backup_para_g)

            if iteration % LOG_INTERVAL == 0:
                log_msg = 'Iter: [{0}/{1}] G: {losses_g_img.avg:.4f}  D: {losses_d_img.avg:.4f}  MSE: {losses_mse.avg:.4f}  Rec: {losses_rec_s.avg:.5f}  FID: {fid:.4f}'.format(
                    epoch,
                    iteration,
                    losses_g_img=losses_g_img,
                    losses_d_img=losses_d_img,
                    losses_mse=losses_mse,
                    losses_rec_s=losses_rec_s,
                    fid=fid[-1][0])

                print(log_msg)
                print('%.5f' % (losses_rec_ae.avg))

                if log_file_path is not None:
                    log_file = open(log_file_path, 'a')
                    log_file.write(log_msg + '\n')
                    log_file.close()

                losses_g_img.reset()
                losses_d_img.reset()
                losses_mse.reset()
                losses_rec_s.reset()
                losses_rec_ae.reset()

            if iteration % SAVE_MODEL_INTERVAL == 0 or iteration + 1 == 10000:
                print('Saving history model')
                torch.save(
                    {
                        'ig': net_ig.state_dict(),
                        'id': net_id.state_dict(),
                        'ae': net_ae.state_dict(),
                        'ig_ema': net_ig_ema,
                        'opt_ig': opt_ig.state_dict(),
                        'opt_id': opt_id.state_dict(),
                    }, '%s/%d.pth' % (saved_model_folder, epoch))

            if iteration % FID_INTERVAL == 0 and iteration > 1:
                print("calculating FID ...")
                fid_batch_images = FID_BATCH_NBR
                if real_features is None:
                    if os.path.exists('%s_fid_feats.npy' % (DATA_NAME)):
                        real_features = pickle.load(
                            open('%s_fid_feats.npy' % (DATA_NAME), 'rb'))
                    else:
                        real_features = extract_feature_from_generator_fn(
                            real_image_loader(dataloader,
                                              n_batches=fid_batch_images),
                            inception)
                        real_mean = np.mean(real_features, 0)
                        real_cov = np.cov(real_features, rowvar=False)
                        pickle.dump(
                            {
                                'feats': real_features,
                                'mean': real_mean,
                                'cov': real_cov
                            }, open('%s_fid_feats.npy' % (DATA_NAME), 'wb'))
                        real_features = pickle.load(
                            open('%s_fid_feats.npy' % (DATA_NAME), 'rb'))

                sample_features = extract_feature_from_generator_fn(
                    image_generator(dataset,
                                    net_ae,
                                    net_ig,
                                    n_batches=fid_batch_images),
                    inception,
                    total=fid_batch_images)
                cur_fid = calc_fid(sample_features,
                                   real_mean=real_features['mean'],
                                   real_cov=real_features['cov'])
                sample_features_perm = extract_feature_from_generator_fn(
                    image_generator_perm(dataset,
                                         net_ae,
                                         net_ig,
                                         n_batches=fid_batch_images),
                    inception,
                    total=fid_batch_images)
                cur_fid_perm = calc_fid(sample_features_perm,
                                        real_mean=real_features['mean'],
                                        real_cov=real_features['cov'])

                fid.append([cur_fid, cur_fid_perm])
                print('fid:', fid)
                if log_file_path is not None:
                    log_file = open(log_file_path, 'a')
                    log_msg = 'fid: %.5f, %.5f' % (fid[-1][0], fid[-1][1])
                    log_file.write(log_msg + '\n')
                    log_file.close()
Exemple #8
0
def train(netG, netD, opt_G, opt_D, opt_E):
	D_real = D_fake = D_z_kl = G_real = Z_recon = R_kl = 0
	fixed_z = torch.randn(64, Z_DIM).to(device)

	saved_image_folder, saved_model_folder = make_folders(SAVE_FOLDER, TRIAL_NAME)

	for n_iter in tqdm.tqdm(range(0, MAX_ITERATION+1)):

		if n_iter % SAVE_IMAGE_INTERVAL == 0:
			save_image_from_z(netG, fixed_z, pjoin(saved_image_folder, "z_%d.jpg"%n_iter))
			save_image_from_r(netG, R_DIM, pjoin(saved_image_folder, "r_%d.jpg"%n_iter))
		if n_iter % SAVE_MODEL_INTERVAL == 0:
			save_model(netG, netD, pjoin(saved_model_folder, "%d.pth"%n_iter))	
		
		### 0. prepare data
		real_image = next(dataloader)[0].to(device)

		z = torch.randn(BATCH_SIZE, Z_DIM).to(device)
		# e(r|z) as the likelihood of r given z
		r_sampler = netG.r_sampler(z)
		g_image = netG.generate(r_sampler.sample())

		### 1. Train Discriminator on real and generated data
		netD.zero_grad()
		pred_f = netD.discriminate(g_image.detach())
		pred_r, rec_z = netD(real_image)
		d_loss = loss_bce(torch.sigmoid(pred_r), torch.ones(pred_r.size()).to(device)) \
			+ loss_bce(torch.sigmoid(pred_f), torch.zeros(pred_f.size()).to(device))
		q_loss = KL_Loss(rec_z)
		#d_loss.backward()
		total_loss = d_loss + q_loss
		total_loss.backward()
		opt_D.step()

		# record the loss values
		D_real += torch.sigmoid(pred_r).mean().item()
		D_fake += torch.sigmoid(pred_f).mean().item()
		D_z_kl += q_loss.item()

		### 2. Train Generator
		netD.zero_grad()
		netG.zero_grad()
		# q(z|x) as the posterior of z given x
		pred_g, z_posterior = netD(g_image)
		# GAN loss for generator
		g_loss = LAMBDA_G * loss_bce(torch.sigmoid(pred_g), torch.ones(pred_g.size()).to(device))
		# reconstruction loss of z
		## TODO
		## question here: as stated in the paper-algorithm-1: this part should be a - log(q(z|x)) instead of mse
		recon_loss = loss_mse(z_posterior, z)
		# kl loss between e(r|z) || m(r) as a variational inference
		#kl_loss = BETA_KL * torch.distributions.kl.kl_divergence(r_likelihood, M_r).mean()
		kl_loss = BETA_KL * kl_divergence(r_sampler, M_r).mean()
		total_loss = g_loss + recon_loss + kl_loss
		total_loss.backward()
		opt_E.step()
		opt_G.step()

		# record the loss values
		G_real += torch.sigmoid(pred_g).mean().item()
		Z_recon += recon_loss.item()
		R_kl += kl_loss.item()

		if n_iter % LOG_INTERVAL == 0 and n_iter > 0:
			print("D(x): %.5f    D(G(z)): %.5f    D_kl: %.5f    G(z): %.5f    Z_rec: %.5f    R_kl: %.5f"%\
				(D_real/LOG_INTERVAL, D_fake/LOG_INTERVAL, D_z_kl/LOG_INTERVAL, G_real/LOG_INTERVAL, Z_recon/LOG_INTERVAL, R_kl/LOG_INTERVAL))
			D_real = D_fake = D_z_kl = G_real = Z_recon = R_kl = 0
def train():
    from config import IM_SIZE_AE, BATCH_SIZE_AE, NFC, NBR_CLS, DATALOADER_WORKERS, ITERATION_AE
    from config import SAVE_IMAGE_INTERVAL, SAVE_MODEL_INTERVAL, SAVE_FOLDER, TRIAL_NAME, LOG_INTERVAL
    from config import DATA_NAME
    from config import data_root_colorful, data_root_sketch_1, data_root_sketch_2, data_root_sketch_3

    dataset = PairedMultiDataset(data_root_colorful,
                                 data_root_sketch_1,
                                 data_root_sketch_2,
                                 data_root_sketch_3,
                                 im_size=IM_SIZE_AE,
                                 rand_crop=True)
    print(len(dataset))
    dataloader = iter(DataLoader(dataset, BATCH_SIZE_AE, \
        sampler=InfiniteSamplerWrapper(dataset), num_workers=DATALOADER_WORKERS, pin_memory=True))

    dataset_ss = SelfSupervisedDataset(data_root_colorful,
                                       data_root_sketch_3,
                                       im_size=IM_SIZE_AE,
                                       nbr_cls=NBR_CLS,
                                       rand_crop=True)
    print(len(dataset_ss), len(dataset_ss.frame))
    dataloader_ss = iter(DataLoader(dataset_ss, BATCH_SIZE_AE, \
        sampler=InfiniteSamplerWrapper(dataset_ss), num_workers=DATALOADER_WORKERS, pin_memory=True))

    style_encoder = StyleEncoder(nfc=NFC, nbr_cls=NBR_CLS).cuda()
    content_encoder = ContentEncoder(nfc=NFC).cuda()
    decoder = Decoder(nfc=NFC).cuda()

    opt_c = optim.Adam(content_encoder.parameters(),
                       lr=2e-4,
                       betas=(0.5, 0.999))
    opt_s = optim.Adam(style_encoder.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_d = optim.Adam(decoder.parameters(), lr=2e-4, betas=(0.5, 0.999))

    style_encoder.reset_cls()
    style_encoder.final_cls.cuda()

    from config import PRETRAINED_AE_PATH, PRETRAINED_AE_ITER
    if PRETRAINED_AE_PATH is not None:
        PRETRAINED_AE_PATH = PRETRAINED_AE_PATH + '/models/%d.pth' % PRETRAINED_AE_ITER
        ckpt = torch.load(PRETRAINED_AE_PATH)

        print(PRETRAINED_AE_PATH)

        style_encoder.load_state_dict(ckpt['s'])
        content_encoder.load_state_dict(ckpt['c'])
        decoder.load_state_dict(ckpt['d'])

        opt_c.load_state_dict(ckpt['opt_c'])
        opt_s.load_state_dict(ckpt['opt_s'])
        opt_d.load_state_dict(ckpt['opt_d'])
        print('loaded pre-trained AE')

    style_encoder.reset_cls()
    style_encoder.final_cls.cuda()
    opt_s_cls = optim.Adam(style_encoder.final_cls.parameters(),
                           lr=2e-4,
                           betas=(0.5, 0.999))

    saved_image_folder, saved_model_folder = make_folders(
        SAVE_FOLDER, 'AE_' + TRIAL_NAME)
    log_file_path = saved_image_folder + '/../ae_log.txt'
    log_file = open(log_file_path, 'w')
    log_file.close()
    ## for logging
    losses_sf_consist = AverageMeter()
    losses_cf_consist = AverageMeter()
    losses_cls = AverageMeter()
    losses_rec_rd = AverageMeter()
    losses_rec_org = AverageMeter()
    losses_rec_grey = AverageMeter()

    import lpips
    percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True)

    for iteration in tqdm(range(ITERATION_AE)):

        if iteration % (
            (NBR_CLS * 100) // BATCH_SIZE_AE) == 0 and iteration > 1:
            dataset_ss._next_set()
            dataloader_ss = iter(
                DataLoader(dataset_ss,
                           BATCH_SIZE_AE,
                           sampler=InfiniteSamplerWrapper(dataset_ss),
                           num_workers=DATALOADER_WORKERS,
                           pin_memory=True))
            style_encoder.reset_cls()
            opt_s_cls = optim.Adam(style_encoder.final_cls.parameters(),
                                   lr=2e-4,
                                   betas=(0.5, 0.999))

            opt_s.param_groups[0]['lr'] = 1e-4
            opt_d.param_groups[0]['lr'] = 1e-4

        ### 1. train the encoder with self-supervision methods
        rgb_img_rd, rgb_img_org, skt_org, skt_bold, skt_erased, skt_erased_bold, img_idx = next(
            dataloader_ss)
        rgb_img_rd = rgb_img_rd.cuda()
        rgb_img_org = rgb_img_org.cuda()
        img_idx = img_idx.cuda()

        skt_org = F.interpolate(skt_org, size=512).cuda()
        skt_bold = F.interpolate(skt_bold, size=512).cuda()
        skt_erased = F.interpolate(skt_erased, size=512).cuda()
        skt_erased_bold = F.interpolate(skt_erased_bold, size=512).cuda()

        style_encoder.zero_grad()
        decoder.zero_grad()
        content_encoder.zero_grad()

        style_vector_rd, pred_cls_rd = style_encoder(rgb_img_rd)
        style_vector_org, pred_cls_org = style_encoder(rgb_img_org)

        content_feats = content_encoder(skt_org)
        content_feats_bold = content_encoder(skt_bold)
        content_feats_erased = content_encoder(skt_erased)
        content_feats_eb = content_encoder(skt_erased_bold)

        rd = random.randint(0, 3)
        gimg_rd = None
        if rd == 0:
            gimg_rd = decoder(content_feats, style_vector_rd)
        elif rd == 1:
            gimg_rd = decoder(content_feats_bold, style_vector_rd)
        elif rd == 2:
            gimg_rd = decoder(content_feats_erased, style_vector_rd)
        elif rd == 3:
            gimg_rd = decoder(content_feats_eb, style_vector_rd)


        loss_cf_consist = loss_for_list_perm(F.mse_loss, content_feats_bold, content_feats) +\
                            loss_for_list_perm(F.mse_loss, content_feats_erased, content_feats) +\
                                loss_for_list_perm(F.mse_loss, content_feats_eb, content_feats)

        loss_sf_consist = 0
        for loss_idx in range(3):
            loss_sf_consist += -F.cosine_similarity(style_vector_rd[loss_idx], style_vector_org[loss_idx].detach()).mean() + \
                                    F.cosine_similarity(style_vector_rd[loss_idx], style_vector_org[loss_idx][torch.randperm(BATCH_SIZE_AE)].detach()).mean()

        loss_cls = F.cross_entropy(pred_cls_rd, img_idx) + F.cross_entropy(
            pred_cls_org, img_idx)
        loss_rec_rd = F.mse_loss(gimg_rd, rgb_img_org)
        if DATA_NAME != 'shoe':
            loss_rec_rd += percept(
                F.adaptive_avg_pool2d(gimg_rd, output_size=256),
                F.adaptive_avg_pool2d(rgb_img_org, output_size=256)).sum()
        else:
            loss_rec_rd += F.l1_loss(gimg_rd, rgb_img_org)

        loss_total = loss_cls + loss_sf_consist + loss_rec_rd + loss_cf_consist  #+ loss_kl_c + loss_kl_s
        loss_total.backward()

        opt_s.step()
        opt_s_cls.step()
        opt_c.step()
        opt_d.step()

        ### 2. train as AutoEncoder
        rgb_img, skt_img_1, skt_img_2, skt_img_3 = next(dataloader)

        rgb_img = rgb_img.cuda()

        rd = random.randint(0, 3)
        if rd == 0:
            skt_img = skt_img_1
        elif rd == 1:
            skt_img = skt_img_2
        else:
            skt_img = skt_img_3

        skt_img = F.interpolate(skt_img, size=512).cuda()

        style_encoder.zero_grad()
        decoder.zero_grad()
        content_encoder.zero_grad()

        style_vector, _ = style_encoder(rgb_img)
        content_feats = content_encoder(skt_img)
        gimg = decoder(content_feats, style_vector)

        loss_rec_org = F.mse_loss(gimg, rgb_img)
        if DATA_NAME != 'shoe':
            loss_rec_org += percept(
                F.adaptive_avg_pool2d(gimg, output_size=256),
                F.adaptive_avg_pool2d(rgb_img, output_size=256)).sum()
        #else:
        #    loss_rec_org += F.l1_loss(gimg, rgb_img)

        loss_rec = loss_rec_org
        if DATA_NAME == 'shoe':
            ### the grey image reconstruction
            perm = true_randperm(BATCH_SIZE_AE)
            gimg_perm = decoder(content_feats, [s[perm] for s in style_vector])
            gimg_grey = gimg_perm.mean(dim=1, keepdim=True)
            real_grey = rgb_img.mean(dim=1, keepdim=True)
            loss_rec_grey = F.mse_loss(gimg_grey, real_grey)
            loss_rec += loss_rec_grey
        loss_rec.backward()

        opt_s.step()
        opt_d.step()
        opt_c.step()

        ### Logging
        losses_cf_consist.update(loss_cf_consist.mean().item(), BATCH_SIZE_AE)
        losses_sf_consist.update(loss_sf_consist.mean().item(), BATCH_SIZE_AE)
        losses_cls.update(loss_cls.mean().item(), BATCH_SIZE_AE)
        losses_rec_rd.update(loss_rec_rd.item(), BATCH_SIZE_AE)
        losses_rec_org.update(loss_rec_org.item(), BATCH_SIZE_AE)
        if DATA_NAME == 'shoe':
            losses_rec_grey.update(loss_rec_grey.item(), BATCH_SIZE_AE)

        if iteration % LOG_INTERVAL == 0:
            log_msg = 'Train Stage 1: AE: \nrec_rd: %.4f  rec_org: %.4f  cls: %.4f  style_consist: %.4f  content_consist: %.4f  rec_grey: %.4f'%(losses_rec_rd.avg, \
                    losses_rec_org.avg, losses_cls.avg, losses_sf_consist.avg, losses_cf_consist.avg, losses_rec_grey.avg)

            print(log_msg)

            if log_file_path is not None:
                log_file = open(log_file_path, 'a')
                log_file.write(log_msg + '\n')
                log_file.close()

            losses_sf_consist.reset()
            losses_cls.reset()
            losses_rec_rd.reset()
            losses_rec_org.reset()
            losses_cf_consist.reset()
            losses_rec_grey.reset()

        if iteration % SAVE_IMAGE_INTERVAL == 0:
            vutils.save_image(torch.cat([
                rgb_img_rd,
                F.interpolate(skt_org.repeat(1, 3, 1, 1), size=512), gimg_rd
            ]),
                              '%s/rd_%d.jpg' % (saved_image_folder, iteration),
                              normalize=True,
                              range=(-1, 1))
            if DATA_NAME != 'shoe':
                with torch.no_grad():
                    perm = true_randperm(BATCH_SIZE_AE)
                    gimg_perm = decoder([c for c in content_feats],
                                        [s[perm] for s in style_vector])
            vutils.save_image(torch.cat([
                rgb_img,
                F.interpolate(skt_img.repeat(1, 3, 1, 1), size=512), gimg,
                gimg_perm
            ]),
                              '%s/org_%d.jpg' %
                              (saved_image_folder, iteration),
                              normalize=True,
                              range=(-1, 1))

        if iteration % SAVE_MODEL_INTERVAL == 0:
            print('Saving history model')
            torch.save(
                {
                    's': style_encoder.state_dict(),
                    'd': decoder.state_dict(),
                    'c': content_encoder.state_dict(),
                    'opt_c': opt_c.state_dict(),
                    'opt_s_cls': opt_s_cls.state_dict(),
                    'opt_s': opt_s.state_dict(),
                    'opt_d': opt_d.state_dict(),
                }, '%s/%d.pth' % (saved_model_folder, iteration))

    torch.save(
        {
            's': style_encoder.state_dict(),
            'd': decoder.state_dict(),
            'c': content_encoder.state_dict(),
            'opt_c': opt_c.state_dict(),
            'opt_s_cls': opt_s_cls.state_dict(),
            'opt_s': opt_s.state_dict(),
            'opt_d': opt_d.state_dict(),
        }, '%s/%d.pth' % (saved_model_folder, ITERATION_AE))
Exemple #10
0
def save_user_image(user_photo):
    make_folders()
    input_image_path = os.path.join('input', f'{user_photo.file_id}.jpg')
    output_image_path = os.path.join('output', f'{user_photo.file_id}.jpg')  
    user_photo.download(input_image_path)
    return input_image_path, output_image_path