예제 #1
0
 def test_dataset(self):
     dataset = GANDataset(np.random.normal(size=(3, 28, 28, 1)), 100, 1)
     self.assertEqual(dataset.next_batch().shape, (1, 28, 28, 1))
     self.assertTrue(dataset.has_more_than(1))
     self.assertFalse(dataset.has_more_than(2))
     dataset.reset()
     self.assertTrue(dataset.has_more_than(2))
예제 #2
0
def main(_):
    model = GANModel()
    mnist_data = mnist.input_data.read_data_sets('./dataset/mnist', validation_size=0)
    dataset = GANDataset(np.reshape(mnist_data.train.images, (-1, 28, 28, 1)), 100, 32)
    with tf.Session() as session:
        session.run([tf.global_variables_initializer()])
        model.fit(session, dataset, 20, 1)
예제 #3
0
 def __load_file(self, train_filepath, train_csvfile, test_filepath, test_csvfile):
     self.train_dataset = GANDataset(train_filepath,
                                     train_csvfile,
                                     test_filepath,
                                     test_csvfile,
                                     transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
                                     ]))
     self.train_data_loader = DataLoader(dataset=self.train_dataset,
                                         batch_size=self.args.batch_size,
                                         shuffle=True)
예제 #4
0
 def test_fit(self):
     model = GANModel(100)
     dataset = GANDataset(np.random.normal(size=(3, 28, 28, 1)), 100, 1)
     with self.test_session() as session:
         session.run(tf.global_variables_initializer())
         model.fit(session, dataset, 1, 2)
예제 #5
0
import torch
import torch.nn as nn
import numpy as np
import sys
import os
from skimage.color import lab2rgb
from torch.autograd import Variable
import matplotlib.pyplot as plt
import getopt
from colorutils import NNEncode

from pix2pix import Generator
from dataset import GANDataset
from colorutils import modelimg2cvimg

test_dataset = GANDataset(root='./SUN2012', train=False)

batch_size = 100
location = 'cpu'
test_cases = np.floor(np.random.rand(5) * len(test_dataset)).astype(int)
try:
    opts, args = getopt.getopt(sys.argv[1:], 'hl:c:',
                               ['location=', 'testcases='])
except getopt.GetoptError:
    print('python test.py -l <location> -c <testcases>')
    sys.exit(2)

for opt, arg in opts:
    if opt == '-h':
        print('python test.py -l <location> -c <testcases>')
        sys.exit(0)
예제 #6
0
    opts, args = getopt.getopt(sys.argv[1:], 'hl:c', [
                               'location=', 'continue='])
except getopt.GetoptError:
    print('python train.py -l <location> -c')
    sys.exit(2)

for opt, arg in opts:
    if opt == '-h':
        print('python train.py -l <location> -c <testcases>')
        sys.exit(0)
    elif opt in ('-l', '--location'):
        location = arg
    elif opt in ('-c', '--continue'):
        continue_training = True

train_dset = GANDataset(root=dset_root, train=True)
train_loader = torch.utils.data.DataLoader(
    dataset=train_dset, batch_size=batch_size, shuffle=True)

generator_G = Generator(input_channel, output_channel)
discriminator_D = Discriminator(input_channel, output_channel)
loss_L1 = nn.L1Loss()
loss_binaryCrossEntropy = nn.BCELoss()
optimizer_G = torch.optim.Adam(generator_G.parameters(
), lr=learning_rate, betas=(0.5, 0.999), weight_decay=0.00001)
optimizer_D = torch.optim.Adam(discriminator_D.parameters(
), lr=learning_rate, betas=(0.5, 0.999), weight_decay=0.00001)

if continue_training and os.path.isfile('generator.pkl') and os.path.isfile('discriminator.pkl'):
    generator_G.load_state_dict(torch.load(
        'generator.pkl', map_location=location))