コード例 #1
0
    def stylelistClick(self):
        # self.ui.label_style.setText(self.ui.list_style.currentItem().text())
        if self.src_img_path is None:
            return

        style_index = self.ui.list_style.currentIndex().row()

        model = self.style_mode[style_index]
        if model == None:
            parameter_dict = np.load(self.style_model_path[style_index]).item()
            model = Generator(in_chanel=3,
                              out_chanel=3,
                              parameter_dict=parameter_dict)

        self.process_src_img(model, self.src_img_path)
コード例 #2
0
ファイル: utils.py プロジェクト: HeartFu/SimpleTransformer
def make_model(src_vocab,
               tar_vocab,
               N=6,
               d_model=512,
               d_ff=2014,
               h=8,
               dropout=0.1):
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    model = GeneralEncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        nn.Sequential(Embedding(d_model, src_vocab), c(position)),
        nn.Sequential(Embedding(d_model, tar_vocab), c(position)),
        Generator(d_model, tar_vocab))

    # 随机初始化参数,这非常重要
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model
コード例 #3
0
loader_A = torch.utils.data.DataLoader(dataset=datasetA,
                                       batch_size=opt.batchSize,
                                       shuffle=True,
                                       num_workers=2)
loaderA = iter(loader_A)
loader_B = torch.utils.data.DataLoader(dataset=datasetB,
                                       batch_size=opt.batchSize,
                                       shuffle=True,
                                       num_workers=2)
loaderB = iter(loader_B)
###########   MODEL   ###########
ndf = opt.ndf
ngf = opt.ngf
nc = 3

G_AB = Generator(opt.input_nc, opt.output_nc, opt.ngf)
G_BA = Generator(opt.output_nc, opt.input_nc, opt.ngf)

if (opt.G_AB != ''):
    print('Warning! Loading pre-trained weights.')
    G_AB.load_state_dict(torch.load(opt.G_AB))
    G_BA.load_state_dict(torch.load(opt.G_BA))
else:
    print('ERROR! G_AB and G_BA must be provided!')

if (opt.cuda):
    G_AB.cuda()
    G_BA.cuda()

###########   GLOBAL VARIABLES   ###########
input_nc = opt.input_nc
コード例 #4
0
ファイル: dcgan.py プロジェクト: Angelowin/deeplearning
    nrow=6,
    normalize=True)

print(data_new.size())
dataset = Data.TensorDataset(data_tensor=data_new, target_tensor=train_y)
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=opt.batchSize,
                                     shuffle=True)

###############   MODEL   ####################
ndf = opt.ndf
ngf = opt.ngf
nc = 1

netD = Discriminator(nc, ndf)
netG = Generator(nc, ngf, opt.nz)
#if(opt.cuda):
netD.cuda()
netG.cuda()

###########   LOSS & OPTIMIZER   ##########
criterion = nn.BCELoss()
optimizerD = torch.optim.Adam(netD.parameters(),
                              lr=opt.lr,
                              betas=(opt.beta1, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(),
                              lr=opt.lr,
                              betas=(opt.beta1, 0.999))

##########   GLOBAL VARIABLES   ###########
#noise_all = torch.FloatTensor(20,opt.nz,1,1)
コード例 #5
0
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


ndf = opt.ndf
ngf = opt.ngf
nc = 3

netD = Discriminator(opt.input_nc, opt.output_nc, ndf)
netG = Generator(opt.input_nc, opt.output_nc, opt.ngf)
if (opt.cuda):
    netD.cuda()
    netG.cuda()

netG.apply(weights_init)
netD.apply(weights_init)
print(netD)
print(netG)

###########   LOSS & OPTIMIZER   ##########
criterion = nn.BCELoss()
criterionL1 = nn.L1Loss()
optimizerD = torch.optim.Adam(netD.parameters(),
                              lr=opt.lr,
                              betas=(opt.beta1, 0.999))
コード例 #6
0
ファイル: main.py プロジェクト: jungwon-choi/WGAN-pytorch
def main(args):
    #===========================================================================
    # Set the file name format
    FILE_NAME_FORMAT = "{0}_{1}_{2:d}_{3:d}_{4:d}_{5:f}{6}".format(
        args.model, args.dataset, args.epochs, args.obj_step, args.batch_size,
        args.lr, args.flag)

    # Set the results file path
    RESULT_FILE_NAME = FILE_NAME_FORMAT + '_results.pkl'
    RESULT_FILE_PATH = os.path.join(RESULTS_PATH, RESULT_FILE_NAME)
    # Set the checkpoint file path
    CHECKPOINT_FILE_NAME = FILE_NAME_FORMAT + '.ckpt'
    CHECKPOINT_FILE_PATH = os.path.join(CHECKPOINT_PATH, CHECKPOINT_FILE_NAME)
    BEST_CHECKPOINT_FILE_NAME = FILE_NAME_FORMAT + '_best.ckpt'
    BEST_CHECKPOINT_FILE_PATH = os.path.join(CHECKPOINT_PATH,
                                             BEST_CHECKPOINT_FILE_NAME)

    # Set the random seed same for reproducibility
    random.seed(190811)
    torch.manual_seed(190811)
    torch.cuda.manual_seed_all(190811)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Step1 ====================================================================
    # Load dataset
    if args.dataset == 'CelebA':
        dataloader = CelebA_Dataloader()
    else:
        assert False, "Please select the proper dataset."

    train_loader = dataloader.get_train_loader(batch_size=args.batch_size,
                                               num_workers=args.num_workers)
    print('==> DataLoader ready.')

    # Step2 ====================================================================
    # Make the model
    if args.model in ['WGAN', 'DCGAN']:
        generator = Generator(BN=True)
        discriminator = Discriminator(BN=True)
    elif args.model in ['WGAN_noBN', 'DCGAN_noBN']:
        generator = Generator(BN=False)
        discriminator = Discriminator(BN=False)
    else:
        assert False, "Please select the proper model."

    # Check DataParallel available
    if torch.cuda.device_count() > 1:
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)

    # Check CUDA available
    if torch.cuda.is_available():
        generator.cuda()
        discriminator.cuda()
    print('==> Model ready.')

    # Step3 ====================================================================
    # Set loss function and optimizer
    if args.model in ['DCGAN', 'DCGAN_noBN']:
        criterion = nn.BCELoss()
    else:
        criterion = None
    optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=args.lr)
    optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=args.lr)
    step_counter = StepCounter(args.obj_step)
    print('==> Criterion and optimizer ready.')

    # Step4 ====================================================================
    # Train and validate the model
    start_epoch = 0
    best_metric = float("inf")
    validate_noise = torch.randn(args.batch_size, 100, 1, 1)

    # Initialize the result lists
    train_loss_G = []
    train_loss_D = []
    train_distance = []

    if args.resume:
        assert os.path.exists(CHECKPOINT_FILE_PATH), 'No checkpoint file!'
        checkpoint = torch.load(CHECKPOINT_FILE_PATH)
        generator.load_state_dict(checkpoint['generator_state_dict'])
        discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
        optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
        start_epoch = checkpoint['epoch']
        step_counter.current_step = checkpoint['current_step']
        train_loss_G = checkpoint['train_loss_G']
        train_loss_D = checkpoint['train_loss_D']
        train_distance = checkpoint['train_distance']
        best_metric = checkpoint['best_metric']

    # Save the training information
    result_data = {}
    result_data['model'] = args.model
    result_data['dataset'] = args.dataset
    result_data['target_epoch'] = args.epochs
    result_data['batch_size'] = args.batch_size

    # Check the directory of the file path
    if not os.path.exists(os.path.dirname(RESULT_FILE_PATH)):
        os.makedirs(os.path.dirname(RESULT_FILE_PATH))
    if not os.path.exists(os.path.dirname(CHECKPOINT_FILE_PATH)):
        os.makedirs(os.path.dirname(CHECKPOINT_FILE_PATH))

    print('==> Train ready.')

    # Validate before training (step 0)
    val(generator, validate_noise, step_counter, FILE_NAME_FORMAT)

    for epoch in range(args.epochs):
        # strat after the checkpoint epoch
        if epoch < start_epoch:
            continue
        print("\n[Epoch: {:3d}/{:3d}]".format(epoch + 1, args.epochs))
        epoch_time = time.time()
        #=======================================================================
        # train the model (+ validate the model)
        tloss_G, tloss_D, tdist = train(generator, discriminator, train_loader,
                                        criterion, optimizer_G, optimizer_D,
                                        args.clipping, args.num_critic,
                                        step_counter, validate_noise,
                                        FILE_NAME_FORMAT)
        train_loss_G.extend(tloss_G)
        train_loss_D.extend(tloss_D)
        train_distance.extend(tdist)
        #=======================================================================
        current = time.time()

        # Calculate average loss
        avg_loss_G = sum(tloss_G) / len(tloss_G)
        avg_loss_D = sum(tloss_D) / len(tloss_D)
        avg_distance = sum(tdist) / len(tdist)

        # Save the current result
        result_data['current_epoch'] = epoch
        result_data['train_loss_G'] = train_loss_G
        result_data['train_loss_D'] = train_loss_D
        result_data['train_distance'] = train_distance

        # Save result_data as pkl file
        with open(RESULT_FILE_PATH, 'wb') as pkl_file:
            pickle.dump(result_data,
                        pkl_file,
                        protocol=pickle.HIGHEST_PROTOCOL)

        # Save the best checkpoint
        # if avg_distance < best_metric:
        #     best_metric = avg_distance
        #     torch.save({
        #         'epoch': epoch+1,
        #         'generator_state_dict': generator.state_dict(),
        #         'discriminator_state_dict': discriminator.state_dict(),
        #         'optimizer_G_state_dict': optimizer_G.state_dict(),
        #         'optimizer_D_state_dict': optimizer_D.state_dict(),
        #         'current_step': step_counter.current_step,
        #         'best_metric': best_metric,
        #         }, BEST_CHECKPOINT_FILE_PATH)

        # Save the current checkpoint
        torch.save(
            {
                'epoch': epoch + 1,
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                'current_step': step_counter.current_step,
                'train_loss_G': train_loss_G,
                'train_loss_D': train_loss_D,
                'train_distance': train_distance,
                'best_metric': best_metric,
            }, CHECKPOINT_FILE_PATH)

        # Print the information on the console
        print("model                : {}".format(args.model))
        print("dataset              : {}".format(args.dataset))
        print("batch_size           : {}".format(args.batch_size))
        print("current step         : {:d}".format(step_counter.current_step))
        print("current lrate        : {:f}".format(args.lr))
        print("gen/disc loss        : {:f}/{:f}".format(
            avg_loss_G, avg_loss_D))
        print("distance metric      : {:f}".format(avg_distance))
        print("epoch time           : {0:.3f} sec".format(current -
                                                          epoch_time))
        print("Current elapsed time : {0:.3f} sec".format(current - start))

        # If iteration step has been satisfied
        if step_counter.exit_signal:
            break

    print('==> Train done.')

    print(' '.join(['Results have been saved at', RESULT_FILE_PATH]))
    print(' '.join(['Checkpoints have been saved at', CHECKPOINT_FILE_PATH]))
コード例 #7
0
ファイル: main.py プロジェクト: djkim1991/DCGAN
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.autograd import Variable

from model.Discriminator import Discriminator
from model.Generator import Generator

from loaders.MNISTLoader import MNIST
from util.ImageUtil import ImageUtil

# create model objects
discriminator = Discriminator()
generator = Generator()

# set data loader
dataLoader = MNIST()
train_loader, test_loader = dataLoader.train_loader, dataLoader.test_loader

# optimizer
D_optimizer = Adam(params=discriminator.parameters(), lr=0.001)
G_optimizer = Adam(params=generator.parameters(), lr=0.001)

# loss function
D_loss_function = nn.BCELoss()  # Binary Cross Entropy loss
G_loss_function = nn.BCELoss()  # Binary Cross Entropy loss

imageUtil = ImageUtil()

epoch_size = 10000
for epoch in range(epoch_size):
コード例 #8
0
def process_img(params=[16, "./source/", "as", 100, "./predict_img/", None]):
    args_len = len(params)

    if args_len > 0:
        if params[0] is not None:
            min_factor = int(params[0])
    if args_len > 1:
        if params[1] is not None:
            img_path = params[1]
    if args_len > 2:
        if params[2] is not None:
            use_style = params[2]
    if args_len > 3:
        if params[3] is not None:
            lerp_factor = params[3]
    if args_len > 4:
        if params[4] is not None:
            save_dir = params[4]
    if args_len > 5:
        if params[5] is not None:
            save_name = params[5]

    if min_factor % 16 != 0:
        min_factor = 16
    img_path_list = []
    if os.path.exists(img_path):
        if os.path.isfile(img_path):
            _, img_name = os.path.split(img_path)
            img_path_list.append(img_path)
        elif os.path.isdir(img_path):
            img_name = []
            for file in os.listdir(img_path):
                file_path = os.path.join(img_path, file)
                if os.path.isfile(file_path):
                    img_name.append(file)
                    img_path_list.append(file_path)
        else:
            raise TypeError("img_path must be file or dir")
    else:
        raise FileExistsError("img_path do not exist")
    ###########################载入模型###############################################################
    model_path_dict = {}
    for pair in loadDict.load_dict():
        key = pair['name']
        value = pair['model']
        model_path_dict[key] = value

    model_names_list = list(model_path_dict.keys())
    print(model_names_list)
    if use_style not in model_names_list:
        use_style = 'miyaziki'
    parameter_dict = np.load(model_path_dict[use_style]).item()
    model = Generator(in_chanel=3, out_chanel=3, parameter_dict=parameter_dict)

    if len(img_path_list) > 1:
        save_name = None
    if isinstance(img_name, str):
        img_name = [img_name]
    for i, path in enumerate(img_path_list):
        test_path = img_path_list[i]
        print("%s is under processing" % test_path)
        time_start = time.time()
        test_img = np.array(io.imread(test_path), dtype=np.float32) / 255.0 * 2.0 - 1
        h, w, c = np.shape(test_img)
        print(h, w)
        if c == 1:
            test_img = np.concatenate([test_img, test_img, test_img], axis=2)
        elif c > 3:
            test_img = test_img[:, :, :3]
            c = 3
            print('chanel num larger 3,not rgb')
        ###########填补操作####################
        h_pad = np.ceil(h / min_factor) * min_factor - h
        # h_pad += min_factor
        h_pad += 48
        w_pad = np.ceil(w / min_factor) * min_factor - w
        # w_pad += min_factor
        w_pad += 48
        h_pad_up = int(h_pad // 2)
        h_pad_down = int(h_pad - h_pad_up)
        w_pad_left = int(w_pad // 2)
        w_pad_right = int(w_pad - w_pad_left)
        test_img = np.reshape(test_img, [1, h, w, 3])
        test_img = tf.pad(test_img, [[0, 0], [h_pad_up, h_pad_down], [w_pad_left, w_pad_right], [0, 0]], mode="REFLECT")
        pre_img = model(test_img)
        pre_img = pre_img[0]
        pre_img = pre_img[h_pad_up:h_pad_up + h, w_pad_left:w_pad_left + w,:c]

        time_end = time.time()
        print('time cost', time_end - time_start, 's')
        #############################存图片#############################################################
        if save_name is None:
            save_path = save_dir + img_name[i]
        else:
            save_path = save_dir + save_name
        save_sample(pre_img, save_path)

        # 进行lerp操作
        img_add, ret = img_process.lerp_img(test_path, save_path, lerp_factor)
        if ret == 0:
            cv2.imwrite(save_path, img_add)
            file_dir, file_name = os.path.split(save_path)
            print("%s is save on %s dir" % (file_name, file_dir))
        else:
            print("generate failed when processing %s" % test_path)
コード例 #9
0
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

ndf = 64
ngf = 64
input_nc = 3
output_nc = 3
fineSize = 128
batchSize = 1
cuda = 1

D_A = Discriminator(input_nc,ndf)
D_B = Discriminator(output_nc,ndf)
G_AB = Generator(input_nc, output_nc, ngf)
G_BA = Generator(output_nc, input_nc, ngf)

G_AB.apply(weights_init)
G_BA.apply(weights_init)

D_A.apply(weights_init)
D_B.apply(weights_init)

if(cuda):
    D_A.cuda()
    D_B.cuda()
    G_AB.cuda()
    G_BA.cuda()

###########   LOSS & OPTIMIZER   ##########
コード例 #10
0
def generate_fake_images(ckpt_list):
    #===========================================================================
    for ckpt_name in ckpt_list:
        #=======================================================================
        # Parsing the hyper-parameters
        parsing_list = ckpt_name.split('.')[0].split('_')

        # Setting constants
        model_type          = parsing_list[0]

        # Step1 ================================================================
        # Make the model
        if model_type in ['WGAN', 'DCGAN']:
            generator       = Generator(BN=True)
        elif model_type in ['WGAN_noBN', 'DCGAN_noBN']:
            generator       = Generator(BN=False)
        else:
            assert False, "Please select the proper model."

        # Check DataParallel available
        if torch.cuda.device_count() > 1:
            generator = nn.DataParallel(generator)

        # Check CUDA available
        if torch.cuda.is_available():
            generator.cuda()
        print('==> Model ready.')

        # Step2 ================================================================
        # Test the model
        checkpoint = torch.load(os.path.join(CHECKPOINT_PATH, ckpt_name))
        generator.load_state_dict(checkpoint['generator_state_dict'])
        train_step = checkpoint['current_step']

        generator.eval()
        device = next(generator.parameters()).device.index

        # Set save path
        FILE_NAME_FORMAT = os.path.splitext(ckpt_name)[0]
        SAVE_IMG_PATH = os.path.join(FIGURE_PATH, FILE_NAME_FORMAT)

        # Check the directory of the file path
        if not os.path.exists(SAVE_IMG_PATH):
            os.makedirs(SAVE_IMG_PATH)

        IMAGE_NAME = 'test_fake_images_step{0}.png'.format(train_step)

        # test the model
        #-----------------------------------------------------------------------
        # Make test noise
        test_noise = torch.randn(64, 100, 1, 1)
        test_noise = test_noise.cuda(device)

        # Generate fake images from noise
        fake_images = generator(test_noise)

        # Save the fake images
        fake_images = fake_images.detach().cpu()

        fig = plt.figure(figsize=(8,8)); plt.axis("off");
        plt.title("fake images (step: {0:d})".format(train_step));
        plt.imshow(np.transpose(utils.make_grid(fake_images,
                                                padding=2,
                                                normalize=True),
                                                (1,2,0)))
        fig.savefig(os.path.join(SAVE_IMG_PATH, IMAGE_NAME),
                bbox_inces='tight', pad_inches=0, dpi=150)
        plt.close()
        #-----------------------------------------------------------------------

        # Print the result on the console
        print("model                  : {}".format(model_type))
        print('-'*50)
    print('==> Image generation done.')
コード例 #11
0
ファイル: LeakGan.py プロジェクト: Liugawa/GAN_Poem_Generate
    def __init__(self, wi_dict_path, iw_dict_path, train_data, val_data=None):
        super().__init__()

        self.vocab_size = 20
        self.emb_dim = 64
        self.hidden_dim = 64

        self.input_length = 8
        self.sequence_length = 32
        self.filter_size = [2, 3]
        self.num_filters = [100, 200]
        self.l2_reg_lambda = 0.2
        self.dropout_keep_prob = 0.75
        self.batch_size = 64
        self.generate_num = 256
        self.start_token = 0
        self.dis_embedding_dim = 64
        self.goal_size = 16

        self.save_path = 'save/model/LeakGan/LeakGan'
        self.model_path = 'save/model/LeakGan'
        self.best_path_pre = 'save/model/best-pre-gen/best-pre-gen'
        self.best_path = 'save/model/best-leak-gan/best-leak-gan'
        self.best_model_path = 'save/model/best-leak-gan'

        self.truth_file = 'save/truth.txt'
        self.generator_file = 'save/generator.txt'
        self.test_file = 'save/test_file.txt'

        self.trunc_train_file = 'save/trunc_train.txt'
        self.trunc_val_file = 'save/trunc_val.txt'
        trunc_data(train_data, self.trunc_train_file, self.input_length)
        trunc_data(val_data, self.trunc_val_file, self.input_length)

        if not os.path.isfile(wi_dict_path) or not os.path.isfile(
                iw_dict_path):
            print('Building word/index dictionaries...')
            self.sequence_length, self.vocab_size, word_index_dict, index_word_dict = text_precess(
                train_data, val_data)
            print('Vocab Size: %d' % self.vocab_size)
            print('Saving dictionaries to ' + wi_dict_path + ' ' +
                  iw_dict_path + '...')
            with open(wi_dict_path, 'wb') as f:
                pickle.dump(word_index_dict, f)
            with open(iw_dict_path, 'wb') as f:
                pickle.dump(index_word_dict, f)
        else:
            print('Loading word/index dectionaries...')
            with open(wi_dict_path, 'rb') as f:
                word_index_dict = pickle.load(f)
            with open(iw_dict_path, 'rb') as f:
                index_word_dict = pickle.load(f)
            self.vocab_size = len(word_index_dict) + 1
            print('Vocab Size: %d' % self.vocab_size)

        self.wi_dict = word_index_dict
        self.iw_dict = index_word_dict
        self.train_data = train_data
        self.val_data = val_data

        goal_out_size = sum(self.num_filters)
        self.discriminator = Discriminator(
            sequence_length=self.sequence_length,
            num_classes=2,
            vocab_size=self.vocab_size,
            dis_emb_dim=self.dis_embedding_dim,
            filter_sizes=self.filter_size,
            num_filters=self.num_filters,
            batch_size=self.batch_size,
            hidden_dim=self.hidden_dim,
            start_token=self.start_token,
            goal_out_size=goal_out_size,
            step_size=4,
            l2_reg_lambda=self.l2_reg_lambda)

        self.generator = Generator(num_classes=2,
                                   num_vocabulary=self.vocab_size,
                                   batch_size=self.batch_size,
                                   emb_dim=self.emb_dim,
                                   dis_emb_dim=self.dis_embedding_dim,
                                   goal_size=self.goal_size,
                                   hidden_dim=self.hidden_dim,
                                   sequence_length=self.sequence_length,
                                   input_length=self.input_length,
                                   filter_sizes=self.filter_size,
                                   start_token=self.start_token,
                                   num_filters=self.num_filters,
                                   goal_out_size=goal_out_size,
                                   D_model=self.discriminator,
                                   step_size=4)

        self.saver = tf.train.Saver()
        self.best_pre_saver = tf.train.Saver()
        self.best_saver = tf.train.Saver()

        self.val_bleu1 = Bleu(real_text=self.trunc_val_file, gram=1)
        self.val_bleu2 = Bleu(real_text=self.trunc_val_file, gram=2)
コード例 #12
0
ファイル: generate.py プロジェクト: Emmmmmaa/STIG
    os.makedirs(opt.outf)
except OSError:
    pass

if opt.manualSeed is None:
    opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
if opt.cuda:
    torch.cuda.manual_seed_all(opt.manualSeed)

cudnn.benchmark = True

###########   Load netG   ###########
netG1 = Generator(opt.input_nc, opt.output_nc, opt.ngf)
netG1.load_state_dict(torch.load(opt.netG1))
netG2 = Generator(opt.input_nc, opt.output_nc, opt.ngf)
netG2.load_state_dict(torch.load(opt.netG2))

###########   Generate   ###########
text_dataset = TestDataset(opt.data_path, opt.size_w, opt.size_h)
train_loader = torch.utils.data.DataLoader(dataset=text_dataset,
                                           batch_size=opt.batch_size,
                                           shuffle=False,
                                           num_workers=6)
loader = iter(train_loader)
number = text_dataset.__len__()

if opt.cuda:
    netG1.cuda()