Esempio n. 1
0
    def stylelistClick(self):
        # self.ui.label_style.setText(self.ui.list_style.currentItem().text())
        if self.src_img_path is None:
            return

        style_index = self.ui.list_style.currentIndex().row()

        model = self.style_mode[style_index]
        if model == None:
            parameter_dict = np.load(self.style_model_path[style_index]).item()
            model = Generator(in_chanel=3,
                              out_chanel=3,
                              parameter_dict=parameter_dict)

        self.process_src_img(model, self.src_img_path)
Esempio n. 2
0
def make_model(src_vocab,
               tar_vocab,
               N=6,
               d_model=512,
               d_ff=2014,
               h=8,
               dropout=0.1):
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    model = GeneralEncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        nn.Sequential(Embedding(d_model, src_vocab), c(position)),
        nn.Sequential(Embedding(d_model, tar_vocab), c(position)),
        Generator(d_model, tar_vocab))

    # 随机初始化参数,这非常重要
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model
loader_A = torch.utils.data.DataLoader(dataset=datasetA,
                                       batch_size=opt.batchSize,
                                       shuffle=True,
                                       num_workers=2)
loaderA = iter(loader_A)
loader_B = torch.utils.data.DataLoader(dataset=datasetB,
                                       batch_size=opt.batchSize,
                                       shuffle=True,
                                       num_workers=2)
loaderB = iter(loader_B)
###########   MODEL   ###########
ndf = opt.ndf
ngf = opt.ngf
nc = 3

G_AB = Generator(opt.input_nc, opt.output_nc, opt.ngf)
G_BA = Generator(opt.output_nc, opt.input_nc, opt.ngf)

if (opt.G_AB != ''):
    print('Warning! Loading pre-trained weights.')
    G_AB.load_state_dict(torch.load(opt.G_AB))
    G_BA.load_state_dict(torch.load(opt.G_BA))
else:
    print('ERROR! G_AB and G_BA must be provided!')

if (opt.cuda):
    G_AB.cuda()
    G_BA.cuda()

###########   GLOBAL VARIABLES   ###########
input_nc = opt.input_nc
Esempio n. 4
0
    nrow=6,
    normalize=True)

print(data_new.size())
dataset = Data.TensorDataset(data_tensor=data_new, target_tensor=train_y)
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=opt.batchSize,
                                     shuffle=True)

###############   MODEL   ####################
ndf = opt.ndf
ngf = opt.ngf
nc = 1

netD = Discriminator(nc, ndf)
netG = Generator(nc, ngf, opt.nz)
#if(opt.cuda):
netD.cuda()
netG.cuda()

###########   LOSS & OPTIMIZER   ##########
criterion = nn.BCELoss()
optimizerD = torch.optim.Adam(netD.parameters(),
                              lr=opt.lr,
                              betas=(opt.beta1, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(),
                              lr=opt.lr,
                              betas=(opt.beta1, 0.999))

##########   GLOBAL VARIABLES   ###########
#noise_all = torch.FloatTensor(20,opt.nz,1,1)
Esempio n. 5
0
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


ndf = opt.ndf
ngf = opt.ngf
nc = 3

netD = Discriminator(opt.input_nc, opt.output_nc, ndf)
netG = Generator(opt.input_nc, opt.output_nc, opt.ngf)
if (opt.cuda):
    netD.cuda()
    netG.cuda()

netG.apply(weights_init)
netD.apply(weights_init)
print(netD)
print(netG)

###########   LOSS & OPTIMIZER   ##########
criterion = nn.BCELoss()
criterionL1 = nn.L1Loss()
optimizerD = torch.optim.Adam(netD.parameters(),
                              lr=opt.lr,
                              betas=(opt.beta1, 0.999))
Esempio n. 6
0
def main(args):
    #===========================================================================
    # Set the file name format
    FILE_NAME_FORMAT = "{0}_{1}_{2:d}_{3:d}_{4:d}_{5:f}{6}".format(
        args.model, args.dataset, args.epochs, args.obj_step, args.batch_size,
        args.lr, args.flag)

    # Set the results file path
    RESULT_FILE_NAME = FILE_NAME_FORMAT + '_results.pkl'
    RESULT_FILE_PATH = os.path.join(RESULTS_PATH, RESULT_FILE_NAME)
    # Set the checkpoint file path
    CHECKPOINT_FILE_NAME = FILE_NAME_FORMAT + '.ckpt'
    CHECKPOINT_FILE_PATH = os.path.join(CHECKPOINT_PATH, CHECKPOINT_FILE_NAME)
    BEST_CHECKPOINT_FILE_NAME = FILE_NAME_FORMAT + '_best.ckpt'
    BEST_CHECKPOINT_FILE_PATH = os.path.join(CHECKPOINT_PATH,
                                             BEST_CHECKPOINT_FILE_NAME)

    # Set the random seed same for reproducibility
    random.seed(190811)
    torch.manual_seed(190811)
    torch.cuda.manual_seed_all(190811)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Step1 ====================================================================
    # Load dataset
    if args.dataset == 'CelebA':
        dataloader = CelebA_Dataloader()
    else:
        assert False, "Please select the proper dataset."

    train_loader = dataloader.get_train_loader(batch_size=args.batch_size,
                                               num_workers=args.num_workers)
    print('==> DataLoader ready.')

    # Step2 ====================================================================
    # Make the model
    if args.model in ['WGAN', 'DCGAN']:
        generator = Generator(BN=True)
        discriminator = Discriminator(BN=True)
    elif args.model in ['WGAN_noBN', 'DCGAN_noBN']:
        generator = Generator(BN=False)
        discriminator = Discriminator(BN=False)
    else:
        assert False, "Please select the proper model."

    # Check DataParallel available
    if torch.cuda.device_count() > 1:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)

    # Check CUDA available
    if torch.cuda.is_available():
        generator.cuda()
        discriminator.cuda()
    print('==> Model ready.')

    # Step3 ====================================================================
    # Set loss function and optimizer
    if args.model in ['DCGAN', 'DCGAN_noBN']:
        criterion = nn.BCELoss()
    else:
        criterion = None
    optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=args.lr)
    optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=args.lr)
    step_counter = StepCounter(args.obj_step)
    print('==> Criterion and optimizer ready.')

    # Step4 ====================================================================
    # Train and validate the model
    start_epoch = 0
    best_metric = float("inf")
    validate_noise = torch.randn(args.batch_size, 100, 1, 1)

    # Initialize the result lists
    train_loss_G = []
    train_loss_D = []
    train_distance = []

    if args.resume:
        assert os.path.exists(CHECKPOINT_FILE_PATH), 'No checkpoint file!'
        checkpoint = torch.load(CHECKPOINT_FILE_PATH)
        generator.load_state_dict(checkpoint['generator_state_dict'])
        discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
        optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
        start_epoch = checkpoint['epoch']
        step_counter.current_step = checkpoint['current_step']
        train_loss_G = checkpoint['train_loss_G']
        train_loss_D = checkpoint['train_loss_D']
        train_distance = checkpoint['train_distance']
        best_metric = checkpoint['best_metric']

    # Save the training information
    result_data = {}
    result_data['model'] = args.model
    result_data['dataset'] = args.dataset
    result_data['target_epoch'] = args.epochs
    result_data['batch_size'] = args.batch_size

    # Check the directory of the file path
    if not os.path.exists(os.path.dirname(RESULT_FILE_PATH)):
        os.makedirs(os.path.dirname(RESULT_FILE_PATH))
    if not os.path.exists(os.path.dirname(CHECKPOINT_FILE_PATH)):
        os.makedirs(os.path.dirname(CHECKPOINT_FILE_PATH))

    print('==> Train ready.')

    # Validate before training (step 0)
    val(generator, validate_noise, step_counter, FILE_NAME_FORMAT)

    for epoch in range(args.epochs):
        # strat after the checkpoint epoch
        if epoch < start_epoch:
            continue
        print("\n[Epoch: {:3d}/{:3d}]".format(epoch + 1, args.epochs))
        epoch_time = time.time()
        #=======================================================================
        # train the model (+ validate the model)
        tloss_G, tloss_D, tdist = train(generator, discriminator, train_loader,
                                        criterion, optimizer_G, optimizer_D,
                                        args.clipping, args.num_critic,
                                        step_counter, validate_noise,
                                        FILE_NAME_FORMAT)
        train_loss_G.extend(tloss_G)
        train_loss_D.extend(tloss_D)
        train_distance.extend(tdist)
        #=======================================================================
        current = time.time()

        # Calculate average loss
        avg_loss_G = sum(tloss_G) / len(tloss_G)
        avg_loss_D = sum(tloss_D) / len(tloss_D)
        avg_distance = sum(tdist) / len(tdist)

        # Save the current result
        result_data['current_epoch'] = epoch
        result_data['train_loss_G'] = train_loss_G
        result_data['train_loss_D'] = train_loss_D
        result_data['train_distance'] = train_distance

        # Save result_data as pkl file
        with open(RESULT_FILE_PATH, 'wb') as pkl_file:
            pickle.dump(result_data,
                        pkl_file,
                        protocol=pickle.HIGHEST_PROTOCOL)

        # Save the best checkpoint
        # if avg_distance < best_metric:
        #     best_metric = avg_distance
        #     torch.save({
        #         'epoch': epoch+1,
        #         'generator_state_dict': generator.state_dict(),
        #         'discriminator_state_dict': discriminator.state_dict(),
        #         'optimizer_G_state_dict': optimizer_G.state_dict(),
        #         'optimizer_D_state_dict': optimizer_D.state_dict(),
        #         'current_step': step_counter.current_step,
        #         'best_metric': best_metric,
        #         }, BEST_CHECKPOINT_FILE_PATH)

        # Save the current checkpoint
        torch.save(
            {
                'epoch': epoch + 1,
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                'current_step': step_counter.current_step,
                'train_loss_G': train_loss_G,
                'train_loss_D': train_loss_D,
                'train_distance': train_distance,
                'best_metric': best_metric,
            }, CHECKPOINT_FILE_PATH)

        # Print the information on the console
        print("model                : {}".format(args.model))
        print("dataset              : {}".format(args.dataset))
        print("batch_size           : {}".format(args.batch_size))
        print("current step         : {:d}".format(step_counter.current_step))
        print("current lrate        : {:f}".format(args.lr))
        print("gen/disc loss        : {:f}/{:f}".format(
            avg_loss_G, avg_loss_D))
        print("distance metric      : {:f}".format(avg_distance))
        print("epoch time           : {0:.3f} sec".format(current -
                                                          epoch_time))
        print("Current elapsed time : {0:.3f} sec".format(current - start))

        # If iteration step has been satisfied
        if step_counter.exit_signal:
            break

    print('==> Train done.')

    print(' '.join(['Results have been saved at', RESULT_FILE_PATH]))
    print(' '.join(['Checkpoints have been saved at', CHECKPOINT_FILE_PATH]))
Esempio n. 7
0
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.autograd import Variable

from model.Discriminator import Discriminator
from model.Generator import Generator

from loaders.MNISTLoader import MNIST
from util.ImageUtil import ImageUtil

# create model objects
discriminator = Discriminator()
generator = Generator()

# set data loader
dataLoader = MNIST()
train_loader, test_loader = dataLoader.train_loader, dataLoader.test_loader

# optimizer
D_optimizer = Adam(params=discriminator.parameters(), lr=0.001)
G_optimizer = Adam(params=generator.parameters(), lr=0.001)

# loss function
D_loss_function = nn.BCELoss()  # Binary Cross Entropy loss
G_loss_function = nn.BCELoss()  # Binary Cross Entropy loss

imageUtil = ImageUtil()

epoch_size = 10000
for epoch in range(epoch_size):
Esempio n. 8
0
def process_img(params=[16, "./source/", "as", 100, "./predict_img/", None]):
    args_len = len(params)

    if args_len > 0:
        if params[0] is not None:
            min_factor = int(params[0])
    if args_len > 1:
        if params[1] is not None:
            img_path = params[1]
    if args_len > 2:
        if params[2] is not None:
            use_style = params[2]
    if args_len > 3:
        if params[3] is not None:
            lerp_factor = params[3]
    if args_len > 4:
        if params[4] is not None:
            save_dir = params[4]
    if args_len > 5:
        if params[5] is not None:
            save_name = params[5]

    if min_factor % 16 != 0:
        min_factor = 16
    img_path_list = []
    if os.path.exists(img_path):
        if os.path.isfile(img_path):
            _, img_name = os.path.split(img_path)
            img_path_list.append(img_path)
        elif os.path.isdir(img_path):
            img_name = []
            for file in os.listdir(img_path):
                file_path = os.path.join(img_path, file)
                if os.path.isfile(file_path):
                    img_name.append(file)
                    img_path_list.append(file_path)
        else:
            raise TypeError("img_path must be file or dir")
    else:
        raise FileExistsError("img_path do not exist")
    ###########################载入模型###############################################################
    model_path_dict = {}
    for pair in loadDict.load_dict():
        key = pair['name']
        value = pair['model']
        model_path_dict[key] = value

    model_names_list = list(model_path_dict.keys())
    print(model_names_list)
    if use_style not in model_names_list:
        use_style = 'miyaziki'
    parameter_dict = np.load(model_path_dict[use_style]).item()
    model = Generator(in_chanel=3, out_chanel=3, parameter_dict=parameter_dict)

    if len(img_path_list) > 1:
        save_name = None
    if isinstance(img_name, str):
        img_name = [img_name]
    for i, path in enumerate(img_path_list):
        test_path = img_path_list[i]
        print("%s is under processing" % test_path)
        time_start = time.time()
        test_img = np.array(io.imread(test_path), dtype=np.float32) / 255.0 * 2.0 - 1
        h, w, c = np.shape(test_img)
        print(h, w)
        if c == 1:
            test_img = np.concatenate([test_img, test_img, test_img], axis=2)
        elif c > 3:
            test_img = test_img[:, :, :3]
            c = 3
            print('chanel num larger 3,not rgb')
        ###########填补操作####################
        h_pad = np.ceil(h / min_factor) * min_factor - h
        # h_pad += min_factor
        h_pad += 48
        w_pad = np.ceil(w / min_factor) * min_factor - w
        # w_pad += min_factor
        w_pad += 48
        h_pad_up = int(h_pad // 2)
        h_pad_down = int(h_pad - h_pad_up)
        w_pad_left = int(w_pad // 2)
        w_pad_right = int(w_pad - w_pad_left)
        test_img = np.reshape(test_img, [1, h, w, 3])
        test_img = tf.pad(test_img, [[0, 0], [h_pad_up, h_pad_down], [w_pad_left, w_pad_right], [0, 0]], mode="REFLECT")
        pre_img = model(test_img)
        pre_img = pre_img[0]
        pre_img = pre_img[h_pad_up:h_pad_up + h, w_pad_left:w_pad_left + w,:c]

        time_end = time.time()
        print('time cost', time_end - time_start, 's')
        #############################存图片#############################################################
        if save_name is None:
            save_path = save_dir + img_name[i]
        else:
            save_path = save_dir + save_name
        save_sample(pre_img, save_path)

        # 进行lerp操作
        img_add, ret = img_process.lerp_img(test_path, save_path, lerp_factor)
        if ret == 0:
            cv2.imwrite(save_path, img_add)
            file_dir, file_name = os.path.split(save_path)
            print("%s is save on %s dir" % (file_name, file_dir))
        else:
            print("generate failed when processing %s" % test_path)
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

ndf = 64
ngf = 64
input_nc = 3
output_nc = 3
fineSize = 128
batchSize = 1
cuda = 1

D_A = Discriminator(input_nc,ndf)
D_B = Discriminator(output_nc,ndf)
G_AB = Generator(input_nc, output_nc, ngf)
G_BA = Generator(output_nc, input_nc, ngf)

G_AB.apply(weights_init)
G_BA.apply(weights_init)

D_A.apply(weights_init)
D_B.apply(weights_init)

if(cuda):
    D_A.cuda()
    D_B.cuda()
    G_AB.cuda()
    G_BA.cuda()

###########   LOSS & OPTIMIZER   ##########
Esempio n. 10
0
def generate_fake_images(ckpt_list):
    #===========================================================================
    for ckpt_name in ckpt_list:
        #=======================================================================
        # Parsing the hyper-parameters
        parsing_list = ckpt_name.split('.')[0].split('_')

        # Setting constants
        model_type          = parsing_list[0]

        # Step1 ================================================================
        # Make the model
        if model_type in ['WGAN', 'DCGAN']:
            generator       = Generator(BN=True)
        elif model_type in ['WGAN_noBN', 'DCGAN_noBN']:
            generator       = Generator(BN=False)
        else:
            assert False, "Please select the proper model."

        # Check DataParallel available
        if torch.cuda.device_count() > 1:
            generator = nn.DataParallel(generator)

        # Check CUDA available
        if torch.cuda.is_available():
            generator.cuda()
        print('==> Model ready.')

        # Step2 ================================================================
        # Test the model
        checkpoint = torch.load(os.path.join(CHECKPOINT_PATH, ckpt_name))
        generator.load_state_dict(checkpoint['generator_state_dict'])
        train_step = checkpoint['current_step']

        generator.eval()
        device = next(generator.parameters()).device.index

        # Set save path
        FILE_NAME_FORMAT = os.path.splitext(ckpt_name)[0]
        SAVE_IMG_PATH = os.path.join(FIGURE_PATH, FILE_NAME_FORMAT)

        # Check the directory of the file path
        if not os.path.exists(SAVE_IMG_PATH):
            os.makedirs(SAVE_IMG_PATH)

        IMAGE_NAME = 'test_fake_images_step{0}.png'.format(train_step)

        # test the model
        #-----------------------------------------------------------------------
        # Make test noise
        test_noise = torch.randn(64, 100, 1, 1)
        test_noise = test_noise.cuda(device)

        # Generate fake images from noise
        fake_images = generator(test_noise)

        # Save the fake images
        fake_images = fake_images.detach().cpu()

        fig = plt.figure(figsize=(8,8)); plt.axis("off");
        plt.title("fake images (step: {0:d})".format(train_step));
        plt.imshow(np.transpose(utils.make_grid(fake_images,
                                                padding=2,
                                                normalize=True),
                                                (1,2,0)))
        fig.savefig(os.path.join(SAVE_IMG_PATH, IMAGE_NAME),
                bbox_inces='tight', pad_inches=0, dpi=150)
        plt.close()
        #-----------------------------------------------------------------------

        # Print the result on the console
        print("model                  : {}".format(model_type))
        print('-'*50)
    print('==> Image generation done.')
Esempio n. 11
0
    def __init__(self, wi_dict_path, iw_dict_path, train_data, val_data=None):
        super().__init__()

        self.vocab_size = 20
        self.emb_dim = 64
        self.hidden_dim = 64

        self.input_length = 8
        self.sequence_length = 32
        self.filter_size = [2, 3]
        self.num_filters = [100, 200]
        self.l2_reg_lambda = 0.2
        self.dropout_keep_prob = 0.75
        self.batch_size = 64
        self.generate_num = 256
        self.start_token = 0
        self.dis_embedding_dim = 64
        self.goal_size = 16

        self.save_path = 'save/model/LeakGan/LeakGan'
        self.model_path = 'save/model/LeakGan'
        self.best_path_pre = 'save/model/best-pre-gen/best-pre-gen'
        self.best_path = 'save/model/best-leak-gan/best-leak-gan'
        self.best_model_path = 'save/model/best-leak-gan'

        self.truth_file = 'save/truth.txt'
        self.generator_file = 'save/generator.txt'
        self.test_file = 'save/test_file.txt'

        self.trunc_train_file = 'save/trunc_train.txt'
        self.trunc_val_file = 'save/trunc_val.txt'
        trunc_data(train_data, self.trunc_train_file, self.input_length)
        trunc_data(val_data, self.trunc_val_file, self.input_length)

        if not os.path.isfile(wi_dict_path) or not os.path.isfile(
                iw_dict_path):
            print('Building word/index dictionaries...')
            self.sequence_length, self.vocab_size, word_index_dict, index_word_dict = text_precess(
                train_data, val_data)
            print('Vocab Size: %d' % self.vocab_size)
            print('Saving dictionaries to ' + wi_dict_path + ' ' +
                  iw_dict_path + '...')
            with open(wi_dict_path, 'wb') as f:
                pickle.dump(word_index_dict, f)
            with open(iw_dict_path, 'wb') as f:
                pickle.dump(index_word_dict, f)
        else:
            print('Loading word/index dectionaries...')
            with open(wi_dict_path, 'rb') as f:
                word_index_dict = pickle.load(f)
            with open(iw_dict_path, 'rb') as f:
                index_word_dict = pickle.load(f)
            self.vocab_size = len(word_index_dict) + 1
            print('Vocab Size: %d' % self.vocab_size)

        self.wi_dict = word_index_dict
        self.iw_dict = index_word_dict
        self.train_data = train_data
        self.val_data = val_data

        goal_out_size = sum(self.num_filters)
        self.discriminator = Discriminator(
            sequence_length=self.sequence_length,
            num_classes=2,
            vocab_size=self.vocab_size,
            dis_emb_dim=self.dis_embedding_dim,
            filter_sizes=self.filter_size,
            num_filters=self.num_filters,
            batch_size=self.batch_size,
            hidden_dim=self.hidden_dim,
            start_token=self.start_token,
            goal_out_size=goal_out_size,
            step_size=4,
            l2_reg_lambda=self.l2_reg_lambda)

        self.generator = Generator(num_classes=2,
                                   num_vocabulary=self.vocab_size,
                                   batch_size=self.batch_size,
                                   emb_dim=self.emb_dim,
                                   dis_emb_dim=self.dis_embedding_dim,
                                   goal_size=self.goal_size,
                                   hidden_dim=self.hidden_dim,
                                   sequence_length=self.sequence_length,
                                   input_length=self.input_length,
                                   filter_sizes=self.filter_size,
                                   start_token=self.start_token,
                                   num_filters=self.num_filters,
                                   goal_out_size=goal_out_size,
                                   D_model=self.discriminator,
                                   step_size=4)

        self.saver = tf.train.Saver()
        self.best_pre_saver = tf.train.Saver()
        self.best_saver = tf.train.Saver()

        self.val_bleu1 = Bleu(real_text=self.trunc_val_file, gram=1)
        self.val_bleu2 = Bleu(real_text=self.trunc_val_file, gram=2)
Esempio n. 12
0
    os.makedirs(opt.outf)
except OSError:
    pass

if opt.manualSeed is None:
    opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
if opt.cuda:
    torch.cuda.manual_seed_all(opt.manualSeed)

cudnn.benchmark = True

###########   Load netG   ###########
netG1 = Generator(opt.input_nc, opt.output_nc, opt.ngf)
netG1.load_state_dict(torch.load(opt.netG1))
netG2 = Generator(opt.input_nc, opt.output_nc, opt.ngf)
netG2.load_state_dict(torch.load(opt.netG2))

###########   Generate   ###########
text_dataset = TestDataset(opt.data_path, opt.size_w, opt.size_h)
train_loader = torch.utils.data.DataLoader(dataset=text_dataset,
                                           batch_size=opt.batch_size,
                                           shuffle=False,
                                           num_workers=6)
loader = iter(train_loader)
number = text_dataset.__len__()

if opt.cuda:
    netG1.cuda()