コード例 #1
0
def main():
    args = get_args()

    # Training Generator/Discriminator
    if args.model == 'GAN':
        model = GAN()
    # elif args.model == 'LSGAN':
    #     model = LSGAN()
    # elif args.model == 'WGAN':
    #     model = WGAN()
    # elif args.model == 'WGAN_GP':
    #     model = WGAN_GP()
    # elif args.model == 'DRAGAN':
    #     model = DRAGAN()
    # elif args.model == 'EBGAN':
    #     model = EBGAN()
    # elif args.model == 'BEGAN':
    #     model = BEGAN()
    # elif args.model == 'SNGAN':
    #     model = SNGAN()
    elif args.model == 'AnoGAN':
        model = AnoGAN()
    model.train()

    # Anomaly Detection
    if args.model == 'AnoGAN':
        model.anomaly_detect()
コード例 #2
0
ファイル: simple_test.py プロジェクト: cheng-xie/pytorchGANs
def test_gaussian(mean, std, num_data, make_gif=False, use_gpu=True):
    num_iters = 30
    noise_size = 1
    sample_size = 1
    batch_size = 512
    # Sample some data
    data = torch.randn(num_data, sample_size) * std + mean
    data = data.resize_(num_data, sample_size)
    data_loader = DataLoader(data,
                             batch_size=batch_size,
                             shuffle=True,
                             pin_memory=use_gpu,
                             num_workers=4,
                             drop_last=True)
    # data_iter = iter(data_loader)
    # Construct a GAN
    gen = MLPGenerator(noise_size, sample_size)
    dis = MLPDiscriminator(sample_size)
    gan = GAN(gen, dis, data_loader, use_gpu=use_gpu)

    plt.ion()
    for ii in range(num_iters):
        gan.train(20, batch_size)
        # Sample and visualize

        # True Distribution
        x = data.numpy().reshape(num_data)
        # x = np.random.randn(20000) * std + mean
        y, bin_edges = np.histogram(x, bins=200, density=True)
        bin_centers = 0.5 * (bin_edges[1:] + bin_edges[:-1])
        plt.plot(bin_centers, y, '-')

        # Generator Approximation
        x = gan.sample_gen(20000).data.cpu().numpy()
        y, bin_edges = np.histogram(x, bins=200, density=True)
        bin_centers = 0.5 * (bin_edges[1:] + bin_edges[:-1])
        plt.plot(bin_centers, y, '-')

        # Discriminator Probability
        axes = plt.gca()
        x_lim = axes.get_xlim()
        x = torch.linspace(axes.get_xlim()[0],
                           axes.get_xlim()[1], 200).resize_(200, sample_size)
        if use_gpu:
            x = x.cuda()
        y = dis.forward(Variable(x))
        plt.plot(x.cpu().numpy(), y.data.cpu().numpy(), '-')

        if make_gif:
            plt.savefig('./figs/pic' + str(ii).zfill(3))

        plt.pause(0.01)
        plt.cla()

    if make_gif:
        subprocess.call([
            'convert', '-loop', '0', '-delay', '50', './figs/pic*.png',
            './figs/output.gif'
        ])
コード例 #3
0
ファイル: gridSearch.py プロジェクト: ibiscp/Synthetic-Plants
    def fit(self):

        for g in ParameterGrid(self.parameters):
            self.iter += 1
            print(
                '\nTraining:',
                str(self.iter) + '/' +
                str(len(ParameterGrid(self.parameters))), '- Parameters:', g)

            # Model
            model = GAN(self.name, self.train_dataset, self.test_dataset,
                        self.shape, **g)

            score = model.train()

            print(
                '\tScore: emd: %f\t fid: %f\t inception: %f\t knn: %f\t mmd: %f\t mode: %f'
                % (score.emd, score.fid, score.inception, score.knn, score.mmd,
                   score.mode))

            self.results.append({'score': score, 'params': g})

            # Write to results
            with open('resources/results.txt', "w+") as f:
                f.write(
                    '\nemd: %f\t fid: %f\t inception: %f\t knn: %f\t mmd: %f\t mode: %f'
                    % (score.emd, score.fid, score.inception, score.knn,
                       score.mmd, score.mode))
コード例 #4
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
      exit()

    # open session
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        gan = GAN(sess, args)

        # build graph
        gan.build_model()

        if args.phase == 'train' :
            gan.train()
            print(" [*] Training finished!")

        if args.phase == 'test' :
            gan.test()
            print(" [*] Test finished!")
コード例 #5
0
def main():

    for i in range(0, 1000):
        a = np.random.randint(10, 41)
        b = np.random.randint(a, 70)
        c = np.random.randint(10, b)

        args = {
            'attack_type': "portsweep",
            'max_epochs': 7000,
            'batch_size': 255,
            'sample_size': 500,
            'optimizer_learning_rate': 0.001,
            'generator_layers': [a, b, c]
        }
        for iter in range(0, 10):
            gan = GAN(**args)
            gan.train()
            print("GAN finished with layers:")
            print(str([a, b, c]))
コード例 #6
0
def main(_):
    assert sum([FLAGS.train, FLAGS.extract]) == 1

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.logs_dir):
        os.makedirs(FLAGS.logs_dir)

    with tf.Session() as sess:
        if FLAGS.train:
            gan = GAN(sess, FLAGS)
            gan.train()
        elif FLAGS.extract:
            df_X = pd.read_csv('dataset/kddcup.test.data.sub.csv')
            df_X.drop('label', axis=1, inplace=True)
            df_X = np.array(df_X)
            FLAGS.batch_size = df_X.shape[0]
            gan = GAN(sess, FLAGS)
            features = gan.extract_features(df_X)
            pd.DataFrame(features).to_csv('dataset/test.sub.features.csv',
                                          header=None,
                                          index=False)
コード例 #7
0
def main(argv):
    # Load configs from file

    config = json.load(open(FLAGS.config))
    # set_backend()

    # Set name
    #name = '{}_{}_'.format(config['INPUT_NAME'], config['TARGET_NAME'])
    #for l in config['LABELS']:
    #    name += str(l)
    #config['NAME'] += '_' + name

    if FLAGS.use_wandb:
        import wandb
        resume_wandb = True if FLAGS.wandb_resume_id is not None else False
        wandb.init(config=config,
                   resume=resume_wandb,
                   id=FLAGS.wandb_resume_id,
                   project='EchoGen',
                   name=FLAGS.wandb_run_name)

    # Initialize GAN
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    model = GAN(config, FLAGS.use_wandb, device, FLAGS.dataset_path)

    # load trained models if they exist
    if FLAGS.ckpt_load is not None:
        model.load(f'{FLAGS.ckpt_load}/generator_last_checkpoint.bin',
                   model='generator')
        model.load(f'{FLAGS.ckpt_load}/discriminator_last_checkpoint.bin',
                   model='discriminator')

    if FLAGS.test:
        model.test()
    else:
        model.train()
コード例 #8
0
def test_gan():

    #parameters
    file_name = "animals.txt"
    g_hidden_size = 10
    d_hidden_size = 10
    n_epochs = 1000
    g_epochs = 20
    d_epochs = 10
    g_initial_lr = 1
    d_initial_lr = 1
    g_multiplier = 0.9
    d_multiplier = 0.9
    g_batch_size = 100
    d_batch_size = 100

    # data
    char_list = dataloader.get_char_list(file_name)
    X_actual = dataloader.load_data(file_name)
    seq_len = X_actual.shape[1]

    # construct GAN
    gan = GAN(g_hidden_size, d_hidden_size, char_list)

    # train GAN
    gan.train(X_actual,
              seq_len,
              n_epochs,
              g_epochs,
              d_epochs,
              g_initial_lr,
              d_initial_lr,
              g_multiplier,
              d_multiplier,
              g_batch_size,
              d_batch_size,
              print_progress=True,
              num_displayed=3)
コード例 #9
0
def main():
  ''' This is a sample implementation of one generator training with three discriminators '''
  
  # Setting up the gan network system
  gan1 = GAN(generator=Generator(), discriminator=Discriminator(), nn=load_model("new_nn.h5"))
  gan2 = GAN(generator=gan1.G, discriminator=Discriminator(), nn=load_model("new_nn.h5"))
  gan3 = GAN(generator=gan1.G, discriminator=Discriminator(), nn=load_model("new_nn.h5"))
  
  # Set the number of training iterations for the network
  training_period = 5
  
  # Train the system
  for i in range(training_period):
    gan1.train(testid=i)
    gan2.train(testid=i)
    gan3.train(testid=i)
コード例 #10
0
ファイル: main.py プロジェクト: dronperminov/DCGAN
def main():
    dataset_path = "C:/Users/dronp/Desktop/ImageRecognition/cats/"
    images_path = "generated/"
    models_path = "models/"

    image_shape = (64, 64, 3)

    latent_dim = 256
    batch_size = 16
    iterations = 300000

    save_period = 1500
    save_loss_period = 500
    examples_count = 8

    init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    generator = keras.Sequential([
        layers.Reshape((1, 1, latent_dim), input_shape=(latent_dim, )),
        layers.Conv2DTranspose(512, (4, 4),
                               strides=(1, 1),
                               padding='valid',
                               kernel_initializer=init),
        layers.BatchNormalization(),
        layers.ReLU(),
        layers.Conv2DTranspose(256, (4, 4),
                               strides=(2, 2),
                               padding='same',
                               kernel_initializer=init),
        layers.BatchNormalization(),
        layers.ReLU(),
        layers.Conv2DTranspose(128, (4, 4),
                               strides=(2, 2),
                               padding='same',
                               kernel_initializer=init),
        layers.BatchNormalization(),
        layers.ReLU(),
        layers.Conv2DTranspose(64, (4, 4),
                               strides=(2, 2),
                               padding='same',
                               kernel_initializer=init),
        layers.BatchNormalization(),
        layers.ReLU(),
        layers.Conv2DTranspose(
            3, (4, 4), strides=(2, 2), padding='same', activation='tanh'),
    ],
                                 name="generator")

    discriminator = keras.Sequential([
        layers.Conv2D(64, (5, 5),
                      strides=(2, 2),
                      padding='same',
                      kernel_initializer=init,
                      input_shape=image_shape),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, (5, 5),
                      strides=(2, 2),
                      padding='same',
                      kernel_initializer=init),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(256, (5, 5),
                      strides=(2, 2),
                      padding='same',
                      kernel_initializer=init),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(512, (5, 5),
                      strides=(2, 2),
                      padding='same',
                      kernel_initializer=init),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(1, (4, 4),
                      strides=(2, 2),
                      padding='valid',
                      kernel_initializer=init),
        layers.Flatten()
    ],
                                     name="discriminator")

    g_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
    d_optimizer = keras.optimizers.Adam(learning_rate=0.00015, beta_1=0.5)

    images = load_data(dataset_path)

    if models_path and not os.path.exists(models_path):
        os.mkdir(models_path)

    if images_path and not os.path.exists(images_path):
        os.mkdir(images_path)

    loss_fn = losses.RaLSGANLoss()
    gan = GAN(latent_dim, discriminator, generator, d_optimizer, g_optimizer,
              loss_fn)
    gan.summary()
    gan.train(images, iterations, batch_size, models_path, images_path,
              examples_count, save_period, save_loss_period)
コード例 #11
0
    print('Plotting training samples ...')
    samples = X[np.random.choice(range(X.shape[0]), size=36)]
#    plot_samples(None, samples, scatter=True, symm_axis=symm_axis, s=1.5, alpha=.7, c='k', fname='samples')
    plot_samples(None, samples, scale=1.0, scatter=False, symm_axis=symm_axis, lw=1.2, alpha=.7, c='k', fname='samples')
    
    # Split training and test data
    X_train, X_test = train_test_plit(X, split=0.8)
    
    # Train
    directory = './trained_gan/{}_{}'.format(latent_dim, noise_dim)
    if args.model_id is not None:
        directory += '/{}'.format(args.model_id)
    model = GAN(latent_dim, noise_dim, X_train.shape[1], bezier_degree, bounds)
    if args.mode == 'train':
        timer = ElapsedTimer()
        model.train(X_train, batch_size=batch_size, train_steps=train_steps, save_interval=args.save_interval, directory=directory)
        elapsed_time = timer.elapsed_time()
        runtime_mesg = 'Wall clock time for training: %s' % elapsed_time
        print(runtime_mesg)
        runtime_file = open('{}/runtime.txt'.format(directory), 'w')
        runtime_file.write('%s\n' % runtime_mesg)
        runtime_file.close()
    else:
        model.restore(directory=directory)
    
    print('Plotting synthesized shapes ...')
    points_per_axis = 5
    plot_grid(points_per_axis, gen_func=model.synthesize, d=latent_dim, bounds=bounds, scale=1.0, scatter=False, symm_axis=symm_axis, 
              alpha=.7, lw=1.2, c='k', fname='{}/synthesized'.format(directory))
    def synthesize_noise(noise):
        return model.synthesize(0.5*np.ones((points_per_axis**2, latent_dim)), noise)
コード例 #12
0
        d_iter_count += 1
        if d_iter_count == config.max_d_iters:
            d_iter_count = 0
            mode = 'G'


iter_no = 0
max_iters = 100

mode = 'G'
g_iter_count = 0
d_iter_count = 0

batch_size = config.batch_size

train_writer = SummaryWriter(log_dir='../logs/train')
val_writer = SummaryWriter(log_dir='../logs/val')

with tqdm(total=max_iters) as pbar:
    for iter_no in range(max_iters):
        train_batch = train_loader.next_batch()

        gan.train()
        train_step(iter_no, train_batch)

        if iter_no % config.validation_interval == 0:
            val_batch = val_loader.next_batch()
            gan.eval()
            validate(val_batch, iter_no)
            pbar.update(1)
コード例 #13
0
ファイル: main.py プロジェクト: JustinCWeiler/python-tidbits
import numpy as np
from matplotlib import pyplot as plt
from keras.datasets import mnist

from gan import GAN

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape([-1, 28, 28, 1]) / 255
x_test = x_test.reshape([-1, 28, 28, 1]) / 255
x = np.concatenate([x_train, x_test])

gan = GAN()

gan.gan.summary()

gan.train(x, 100, batch_size=100)
gan.save('number_gan/gan.h5')

noise = np.random.normal(size=[10, 8 * 8])
images = gan.generator.predict(noise)
images = images.reshape([-1, 28, 28])

plt.gray()
for i in range(10):
    plt.subplot(2, 5, i + 1)
    plt.imshow(images[i])
plt.show()
コード例 #14
0
def build_gan_celeba(hyperparams, path):
    gan = GAN([128], [128], 100, hyperparams, 'celeba_bw')
    gan.train(path)

    return gan
コード例 #15
0
ファイル: main.py プロジェクト: AdamFull/GAN_video
from gan import GAN

width, heigth = 128, 128

self_path = os.path.dirname(os.path.abspath(__file__))
temp_path = os.path.join(self_path, "temp")



    # print("Video saved to %s" % final_path)

def main():
    comand = input("Select what you want( train, start_bot, clean): ")
    if comand == "train":
        datasets = os.path.join(self_path, "datasets")
        dirs = [dI for dI in os.listdir(datasets) if os.path.isdir(os.path.join(datasets,dI))]
        comand = input("Select dataset(%s):" % ','.join(dirs))
        dataset_data = dataset("file_path", temp_path, width, heigth)
        data = dataset_data.prep_imgs(os.path.join(self_path, "datasets/%s" % comand))
        #generative_net = GAN(buff_size=, batch_size=4, epochs=5000, imgs_size=(width, heigth))
        #generative_net.train(data)
        
        

if __name__ == "__main__":
    #file_path = input("write path to video:")
    dataset_data = dataset("file_path", temp_path, width, heigth)
    data = dataset_data.prep_imgs(os.path.join(self_path, "datasets/big_anime"))
    generative_net = GAN(buff_size=10035, batch_size=16, epochs=5000, imgs_size=(width, heigth))
    generative_net.train(data)
コード例 #16
0
    samples = X[np.random.choice(range(X.shape[0]), size=36)]
#    plot_samples(None, samples, scatter=True, symm_axis=symm_axis, s=1.5, alpha=.7, c='k', fname='samples')
    plot_samples(None, samples, scale=1.0, scatter=False, symm_axis=symm_axis, lw=1.2, alpha=.7, c='k', fname='samples')
    
    # Split training and test data
    test_split = 0.8
    N = X.shape[0]
    split = int(N*test_split)
    X_train = X[:split]
    X_test = X[split:]
    
    # Train
    model = GAN(latent_dim, noise_dim, X_train.shape[1], bezier_degree, bounds)
    if args.mode == 'startover':
        timer = ElapsedTimer()
        model.train(X_train, batch_size=batch_size, train_steps=train_steps, save_interval=args.save_interval)
        elapsed_time = timer.elapsed_time()
        runtime_mesg = 'Wall clock time for training: %s' % elapsed_time
        print(runtime_mesg)
        runtime_file = open('gan/runtime.txt', 'w')
        runtime_file.write('%s\n' % runtime_mesg)
        runtime_file.close()
    else:
        model.restore()
    
    print('Plotting synthesized shapes ...')
    plot_grid(5, gen_func=model.synthesize, d=latent_dim, bounds=bounds, scale=1.0, scatter=False, symm_axis=symm_axis, 
              alpha=.7, lw=1.2, c='k', fname='gan/synthesized')
    
    n_runs = 10
    
コード例 #17
0
ファイル: train.py プロジェクト: inidibininging/tiles-gan
opts.generator_frequency = 1
opts.generator_dropout = 0.3
opts.label_softness = 0.2
opts.batch_size = 128
opts.epoch_number = 300

#Note: if you are using jupyter notebook you might need to disable workers:
opts.workers_nbr = 0

#List of tranformations applied to each input image
opts.transforms = [
    transforms.Resize(int(opts.image_size), Image.BICUBIC),
    transforms.ToTensor(),  #do not forget to transform image into a tensor
    transforms.Normalize((.5, .5, .5), (.5, .5, .5))
    #RandomNoise(0.01) #adding noise might prevent the discriminator from over-fitting
]

#Build GAN
model = GAN(opts)

#Display GAN architecture (note: only work if cuda is enabled)
summary(model.generator.cuda(), input_size=(opts.latent_dim, 1, 1))
summary(model.discriminator.cuda(),
        input_size=(opts.channels_nbr, opts.image_size, opts.image_size))

#Start training
model.train()

#Save model
model.save(opts.output_path)
コード例 #18
0
 # Prepare save directory
 create_dir('./trained_gan')
 example_dir = './trained_gan/{}_{}'.format(args.data, args.func)
 create_dir(example_dir)
 save_dir = '{}/{}_{}'.format(example_dir, lambda0, lambda1)
 create_dir(save_dir)
 
 # Visualize data
 visualize_2d(data, func_obj, save_path='{}/data.svg'.format(example_dir), xlim=(-0.5,0.5), ylim=(-0.5,0.5))
 visualize_2d(data, func_obj, save_path='{}/data.png'.format(example_dir), xlim=(-0.5,0.5), ylim=(-0.5,0.5))
 
 # Train
 model = GAN(noise_dim, 2, lambda0, lambda1)
 if args.mode == 'train':
     timer = ElapsedTimer()
     model.train(data_obj, func_obj, batch_size=batch_size, train_steps=train_steps, 
                 disc_lr=disc_lr, gen_lr=gen_lr, save_interval=save_interval, save_dir=save_dir)
     elapsed_time = timer.elapsed_time()
     runtime_mesg = 'Wall clock time for training: %s' % elapsed_time
     print(runtime_mesg)
 else:
     model.restore(save_dir=save_dir)
 
 print('##########################################################################')
 print('Plotting generated samples ...')
     
 # Plot generated samples
 n = 1000
 gen_data = model.synthesize(n)
 visualize_2d(data[:n], func_obj, gen_data=gen_data, save_path='{}/synthesized.svg'.format(save_dir), 
              xlim=(-0.5,0.5), ylim=(-0.5,0.5), axis_off=False)
 visualize_2d(data[:n], func_obj, gen_data=gen_data, save_path='{}/synthesized.png'.format(save_dir), 
コード例 #19
0
from gan import GAN

dataset_id = "Irish"

gan = GAN(dataset_id)
gan.train(dataset_id, epochs=601, batch_size=32, save_interval=200)
コード例 #20
0
def build_gan(hyperparams):
    gan = GAN([128], [128], 100, hyperparams)
    gan.train()

    return gan
コード例 #21
0
    https://arxiv.org/pdf/1704.00028.pdf
    '''
    if objective == 'max':
        # We are coming from the discriminator.
        grad_penalty = lamb * (torch.norm(grad, dim=1) - 1).pow(2).mean()
        loss = Dx.mean() - DGy.mean() - grad_penalty
        loss = -1 * loss
    elif objective == 'min':
        # We are coming from the generator -- the other terms in the loss don't matter.
        loss = -DGy.mean()
    return loss


if __name__ == '__main__':
    # Get the dataloaders for SVHN
    print('Getting dataloaders...')
    batch_size = 64
    train_loader, valid_loader, test_loader = get_loaders(
        batch_size=batch_size)

    # Instantiate and train GAN
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    assert (torch.cuda.is_available())
    print("device: ", device)
    gan = GAN(device=device, batch_size=batch_size)
    print('Training GAN...')
    gan.train(test_loader, test_loader, loss_fn=wgan_gp)
    gan.log_learning_curves()
    gan.log_d_crossentropy()
    gan.save_model('gan.pt')
コード例 #22
0
ファイル: usage.py プロジェクト: numb3r3/tf2-sketches
import numpy as np
from keras.datasets import fashion_mnist

from image_helper import ImageHelper
from gan import GAN

(X, _), (_, _) = fashion_mnist.load_data()
X_train = X / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)

image_helper = ImageHelper()
generative_advarsial_network = GAN(X_train[0].shape, 100, image_helper)
generative_advarsial_network.train(30000, X_train, batch_size=32)
コード例 #23
0
def main():
    # parse arguments
    args = parse_args()

    if args is None:
        exit()

    if args.benchmark_mode:
        torch.backends.cudnn.benchmark = True

    # declare instance for GAN
    if args.gan_type == 'GAN':
        gan = GAN(args)
    elif args.gan_type == 'CGAN':
        gan = CGAN(args)
    elif args.gan_type == 'ACGAN':
        gan = ACGAN(args)
    elif args.gan_type == 'DSGAN':
        gan = DSGAN(args)
    elif args.gan_type == 'SNGAN':
        gan = SNGAN(args)
    else:
        raise Exception("[!] There is no option for " + args.gan_type)

    if args.mode == 'train':
        # launch the graph in a session
        gan.train()
        print(" [*] Training finished!")

        # visualize learned generator
        gan.visualize_results(args.epoch)
        print(" [*] Testing finished!")

    elif args.mode == 'evaluate':
        print(" [*] Compute the Lipschitz parameter")
        gan.get_lipschitz()
        print("")

        # print(" [*] Compute the inception score")
        # if args.dataset == 'mnist':
        #     model = SmallCNN()
        #     model.load_state_dict(torch.load('generative/pretrained/small_cnn/mnist.pt'))
        #     dataset = dset.MNIST(root='data/mnist/', train=False, 
        #                          download=True, transform=transforms.ToTensor())
        #     img_size = 28
        #     n_class = 10

        # elif args.dataset == 'fashion-mnist':
        #     model = SmallCNN()
        #     model.load_state_dict(torch.load('generative/pretrained/small_cnn/fashion-mnist.pt'))
        #     dataset = dset.FashionMNIST(root='data/fashion-mnist/', train=False, 
        #                                 download=True, transform=transforms.ToTensor())
        #     img_size = 28
        #     n_class = 10

        # elif args.dataset == 'cifar10':
        #     model = inception_v3(pretrained=True, transform_input=False)
        #     transform = transforms.Compose([transforms.ToTensor(),
        #                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        #     dataset = dset.CIFAR10(root='data/cifar10/', download=True, transform=transform)
        #     img_size = 299    
        #     n_class = 1000         

        # else:
        #     raise Exception("[!] There is no option for " + args.dataset)

        # if args.gpu_mode:
        #     model = model.cuda()
        # model.eval()

        # print("Calculating Inception Score for originial dataset...")
        # IS_origin = inception_score(IgnoreLabelDataset(dataset), model, cuda=args.gpu_mode, 
        #                             batch_size=32, img_size=img_size, n_class=n_class, resize=True, splits=10)
        # print(IS_origin[0])

        # # test_sample_path = 'data/'+args.dataset+'/'+args.gan_type+'/'+'samples_test.npy'
        # # test_label_path ='data/'+args.dataset+'/'+args.gan_type+'/'+'labels_test.npy'

        # test_path = 'data/'+args.dataset+'/'+args.gan_type+'/'+'test.npz'
        # dataset_acgan = CustomLabelDataset(test_path, args.input_size, 
        #                         args.input_size, args.channels, transform=transforms.ToTensor())

        # print ("Calculating Inception Score for ACGAN...")
        # IS_gan = inception_score(IgnoreLabelDataset(dataset_acgan), model, cuda=args.gpu_mode, 
        #                          batch_size=32, img_size=img_size, n_class=n_class, resize=True, splits=10)
        # print(IS_gan[0])

        # # save the inception score
        # IS_log = open(args.log_dir+'/'+args.dataset+'/'+args.gan_type+'/ACGAN_IS.txt', 'w')
        # print("%.4f, %.4f" % (IS_origin[0], IS_gan[0]), file=IS_log)

    elif args.mode == 'reconstruct':
        print(" [*] Reconstruct "+args.dataset+" dataset using "+args.gan_type)
        gan.reconstruct()
    
    else: 
        raise Exception("[!] There is no option for " + args.mode)
コード例 #24
0

# TODO: dropout
def discriminator(x):
    with tf.variable_scope("discriminator"):
        h1 = fc(x, 200, reuse=tf.AUTO_REUSE, scope="h1")
        h2 = fc(h1, 150, reuse=tf.AUTO_REUSE, scope="h2")
        h3 = fc(h2, 1, activation_fn = None, \
                reuse = tf.AUTO_REUSE, scope = "h3")
        o = tf.nn.sigmoid(h3)

        return o


mnist = tf.contrib.learn.datasets.load_dataset("mnist")
real_data = mnist.train.images

# copy by reference?
# TODO: read about Adam
g_optimizer = tf.train.AdamOptimizer(0.0001)
d_optimizer = g_optimizer

config = _config()
hook1 = Hook(1, show_result)

m = GAN(generator, discriminator, "wasserstein")
# TODO: cleanup code by placing session creation inside .train()
sess = tf.Session()
#sess = tf_debug.LocalCLIDebugWrapperSession(sess)
m.train(sess, g_optimizer, d_optimizer, real_data, config, hooks=[hook1])
コード例 #25
0
ファイル: mnist_gan.py プロジェクト: zadniprovskyy/gan
from gan import GAN
import tensorflow as tf


BATCH_SIZE = 256
BUFFER_SIZE = 60000
EPOCHS = 1

if __name__=='__main__':
    # Load training data
    (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
    # reshape training images
    train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
    # normalize training images to [-1, 1]
    train_images = (train_images - 127.5) / 127.5

    # Batch and shuffle the data
    train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

    # Initialize GAN model (create generator, discriminator and their optimizers)
    gan = GAN()
    # set batch size before training
    gan.set_batch_size(batch_size=BATCH_SIZE)
    gan.train(train_dataset, EPOCHS)