コード例 #1
0
def cal_fid_score(G, device, z_dim):
    stats_path = 'tflib/data/fid_stats_lsun_train.npz'
    inception_path = fid.check_or_download_inception('tflib/model')
    f = np.load(stats_path)
    mu_real, sigma_real = f['mu'][:], f['sigma'][:]
    f.close()
    fid.create_inception_graph(inception_path)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    all_samples = []
    samples = torch.randn(NUM_SAMPLES, z_dim, 1, 1)
    for i in range(0, NUM_SAMPLES, BATCH_SIZE):
        samples_100 = samples[i:i + BATCH_SIZE]
        samples_100 = samples_100.to(device=device)
        all_samples.append(G(samples_100).cpu().data.numpy())

    all_samples = np.concatenate(all_samples, axis=0)
    all_samples = np.multiply(np.add(np.multiply(all_samples, 0.5), 0.5),
                              255).astype('int32')
    all_samples = all_samples.reshape(
        (-1, N_CHANNEL, RESOLUTION, RESOLUTION)).transpose(0, 2, 3, 1)

    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        mu_gen, sigma_gen = fid.calculate_activation_statistics(
            all_samples, sess, batch_size=BATCH_SIZE)

    fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real,
                                               sigma_real)
    return fid_value
コード例 #2
0
ファイル: cegan_trainer.py プロジェクト: AlephZr/CE-GAN
    def __init__(self, args, model_constructor, env_constructor):
        self.args = args
        self.device = args.device
        self.env = env_constructor
        self.D_train_sample = args.batch_size * args.D_iters
        self.batch_sample1 = self.D_train_sample + args.batch_size if 'rsgan' in args.g_loss_mode else self.D_train_sample
        self.batch_sample2 = self.batch_sample1 + args.eval_size if 'fitness' in args.g_loss_mode else self.batch_sample1

        # PG Learner
        from algos.gan import GAN
        self.learner = GAN(args, model_constructor, env_constructor)

        # Evolution
        self.evolver = SSNE(self.args, self.learner.netG,
                            self.learner.optimizerG, self.learner.netD)

        self.no_FID = True
        self.no_IS = True
        self.sess, self.mu_real, self.sigma_real, self.get_inception_metrics = None, None, None, None
        if self.args.use_pytorch_scores and self.args.test_name:
            parallel = False
            if 'FID' in args.test_name:
                self.no_FID = False
            if 'IS' in args.test_name:
                self.no_IS = False
            self.get_inception_metrics = prepare_inception_metrics(
                args.dataset_name, parallel, self.no_IS, self.no_FID)
        else:
            if 'FID' in args.test_name:
                if self.args.dataset_name == 'CIFAR10':
                    STAT_FILE = './tflib/TTUR/stats/fid_stats_cifar10_train.npz'
                elif self.args.dataset_name == 'CelebA':
                    STAT_FILE = './tflib/TTUR/stats/fid_stats_celeba.npz'
                elif self.args.dataset_name == 'LSUN':
                    STAT_FILE = './tflib/TTUR/stats/fid_stats_lsun_train.npz'
                INCEPTION_PATH = './tflib/IS/imagenet'

                print("load train stats.. ")
                # load precalculated training set statistics
                f = np.load(STAT_FILE)
                self.mu_real, self.sigma_real = f['mu'][:], f['sigma'][:]
                f.close()
                print("ok")

                inception_path = fid.check_or_download_inception(
                    INCEPTION_PATH)  # download inception network
                fid.create_inception_graph(
                    inception_path)  # load the graph into the current TF graph

                config = tf.ConfigProto()
                config.gpu_options.allow_growth = True
                self.sess = tf.Session(config=config)
                self.sess.run(tf.global_variables_initializer())

                self.no_FID = False
            if 'IS' in args.test_name:
                self.no_IS = False
コード例 #3
0
ファイル: test.py プロジェクト: AlephZr/CE-GAN
    def __init__(self, args):
        self.args = args
        self.device = args.device
        self.D_train_sample = args.batch_size * args.D_iters
        self.batch_sample1 = self.D_train_sample + args.batch_size if 'rsgan' in args.g_loss_mode else self.D_train_sample
        self.batch_sample2 = self.batch_sample1 + args.eval_size if 'fitness' in args.g_loss_mode else self.batch_sample1

        self.no_FID = True
        self.no_IS = True
        self.sess, self.mu_real, self.sigma_real, self.get_inception_metrics = None, None, None, None
        if self.args.use_pytorch_scores and self.args.test_name:
            parallel = False
            if 'FID' in args.test_name:
                self.no_FID = False
            if 'IS' in args.test_name:
                self.no_IS = False
            self.get_inception_metrics = prepare_inception_metrics(
                args.dataset_name, parallel, self.no_IS, self.no_FID)
        else:
            if 'FID' in args.test_name:
                if self.args.dataset_name == 'CIFAR10':
                    STAT_FILE = './tflib/TTUR/stats/fid_stats_cifar10_train.npz'
                elif self.args.dataset_name == 'CelebA':
                    STAT_FILE = './tflib/TTUR/stats/fid_stats_celeba.npz'
                elif self.args.dataset_name == 'LSUN':
                    STAT_FILE = './tflib/TTUR/stats/fid_stats_lsun_train.npz'
                INCEPTION_PATH = './tflib/IS/imagenet'

                print("load train stats.. ")
                # load precalculated training set statistics
                f = np.load(STAT_FILE)
                self.mu_real, self.sigma_real = f['mu'][:], f['sigma'][:]
                f.close()
                print("ok")

                inception_path = fid.check_or_download_inception(
                    INCEPTION_PATH)  # download inception network
                fid.create_inception_graph(
                    inception_path)  # load the graph into the current TF graph

                config = tf.ConfigProto()
                config.gpu_options.allow_growth = True
                self.sess = tf.Session(config=config)
                self.sess.run(tf.global_variables_initializer())

                self.no_FID = False
            if 'IS' in args.test_name:
                self.no_IS = False
        # Define Tracker class to track scores
        self.test_tracker = utils.Tracker(
            self.args.savefolder, ['score_' + self.args.savetag],
            '.csv')  # Tracker class to log progress
コード例 #4
0
    def get_inception_score():
        all_samples = []
        samples = torch.randn(N_SAMPLES, N_LATENT)
        for i in range(0, N_SAMPLES, 100):
            batch_samples = samples[i:i + 100].cuda(0)
            all_samples.append(gen(batch_samples).cpu().data.numpy())

        all_samples = np.concatenate(all_samples, axis=0)
        return inception_score(torch.from_numpy(all_samples),
                               resize=True,
                               cuda=True)

    import tflib.fid as fid
    import tensorflow as tf
    stats_path = 'tflib/data/fid_stats_cifar10_train.npz'
    inception_path = fid.check_or_download_inception('tflib/model')
    f = np.load(stats_path)
    mu_real, sigma_real = f['mu'][:], f['sigma'][:]
    f.close()

    fid.create_inception_graph(
        inception_path)  # load the graph into the current TF graph

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    def get_fid_score():
        all_samples = []
        samples = torch.randn(N_SAMPLES, N_LATENT)
        for i in range(0, N_SAMPLES, BATCH_SIZE):
            samples_100 = samples[i:i + BATCH_SIZE]
コード例 #5
0
ファイル: measure_gan.py プロジェクト: drajwer/defensegan
def measure_gan(gan,
                rec_data_path=None,
                probe_size=10000,
                calc_real_data_is=True):
    """Based on MNIST tutorial from cleverhans.
    
    Args:
         gan: A `GAN` model.
         rec_data_path: A string to the directory.
    """
    FLAGS = tf.flags.FLAGS

    # Set logging level to see debug information.
    set_log_level(logging.WARNING)
    sess = gan.sess

    # FID init
    stats_path = 'data/fid_stats_celeba.npz'  # training set statistics
    inception_path = fid.check_or_download_inception(
        None)  # download inception network

    train_images, train_labels, test_images, test_labels = \
        get_cached_gan_data(gan, False)

    images = train_images[
        0:probe_size] * 255  # np.concatenate(train_images, test_images)

    # Inception Score for real data
    is_orig_mean, is_orig_stddev = (-1, -1)
    if calc_real_data_is:
        is_orig_mean, is_orig_stddev = get_inception_score(images)
        print(
            '\n[#] Inception Score for original data: mean = %f, stddev = %f\n'
            % (is_orig_mean, is_orig_stddev))

    rng = np.random.RandomState([11, 24, 1990])
    tf.set_random_seed(11241990)

    # Calculate Inception Score for GAN
    gan.batch_size = probe_size
    generated_images_tensor = gan.generator_fn()
    generated_images = sess.run(generated_images_tensor)
    generated_images = 255 * ((generated_images + 1) / 2)
    is_gen_mean, is_gen_stddev = get_inception_score(generated_images)
    print(
        '\n[#] Inception Score for generated data: mean = %f, stddev = %f\n' %
        (is_gen_mean, is_gen_stddev))

    # load precalculated training set statistics
    f = np.load(stats_path)
    mu_real, sigma_real = f['mu'][:], f['sigma'][:]
    f.close()

    fid.create_inception_graph(
        inception_path)  # load the graph into the current TF graph
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        mu_gen, sigma_gen = fid.calculate_activation_statistics(
            generated_images, sess, batch_size=100)

    fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real,
                                               sigma_real)
    print("FID: %s" % fid_value)

    return is_gen_mean, is_gen_stddev, is_orig_mean, is_orig_stddev, fid_value
コード例 #6
0
def run(mode="wgan-gp", dim_g=128, dim_d=128, critic_iters=5,
        n_gpus=1, normalization_g=True, normalization_d=False,
        batch_size=64, iters=100000, penalty_weight=2,
        output_dim=3072, lr=2e-4, data_dir='/usr/matematik/henning/code/datasets/cifar-10-batches-py',
        inception_frequency=1000,  log_dir='./logs', run_fid=False,
        penalty_mode="grad"):   # grad or pagan or ot
    # Download CIFAR-10 (Python version) at
    # https://www.cs.toronto.edu/~kriz/cifar.html and fill in the path to the
    # extracted files here!
    lib.plot.logdir = log_dir
    print("log dir set to {}".format(lib.plot.logdir))
    # dump locals() for settings
    with open("{}/settings.dict".format(log_dir), "w") as f:
        loca = locals().copy()
        f.write(str(loca))
        print("saved settings: {}".format(loca))

    DATA_DIR = data_dir
    if len(DATA_DIR) == 0:
        raise Exception('Please specify path to data directory in gan_cifar_resnet.py!')

    N_GPUS = n_gpus
    if N_GPUS not in [1,2]:
        raise Exception('Only 1 or 2 GPUs supported!')

    BATCH_SIZE = batch_size # Critic batch size
    GEN_BS_MULTIPLE = 2 # Generator batch size, as a multiple of BATCH_SIZE
    ITERS = iters # How many iterations to train for
    DIM_G = dim_g # Generator dimensionality
    DIM_D = dim_d # Critic dimensionality
    NORMALIZATION_G = normalization_g # Use batchnorm in generator?
    NORMALIZATION_D = normalization_d # Use batchnorm (or layernorm) in critic?
    OUTPUT_DIM = output_dim # Number of pixels in CIFAR10 (32*32*3)
    LR = lr # Initial learning rate
    DECAY = True # Whether to decay LR over learning
    N_CRITIC = critic_iters # Critic steps per generator steps
    INCEPTION_FREQUENCY = inception_frequency # How frequently to calculate Inception score


    LAMBDA = penalty_weight


    DEVICES = ['/gpu:{}'.format(i) for i in range(N_GPUS)]
    if len(DEVICES) == 1: # Hack because the code assumes 2 GPUs
        DEVICES = [DEVICES[0], DEVICES[0]]

    lib.print_model_settings(locals().copy())

    if run_fid:
        inception_path = "/tmp/inception"
        print("check for inception model..")
        inception_path = fid.check_or_download_inception(inception_path)  # download inception if necessary
        print("ok")

        print("create inception graph..")
        fid.create_inception_graph(inception_path)  # load the graph into the current TF graph
        print("ok")

        # region cifar FID
        if not os.path.exists("cifar.fid.stats.npz"):  # compute fid stats for CIFAR
            train_gen, dev_gen = lib.cifar10.load(BATCH_SIZE, DATA_DIR)

            # loads all images into memory (this might require a lot of RAM!)
            print("load images..")
            images = []
            for imagebatch, _ in dev_gen():
                images.append(imagebatch)

            _use_train_for_stats = True
            if _use_train_for_stats:
                print("using train data for FID stats")
                for imagebatch, _ in train_gen():
                    images.append(imagebatch)

            allimages = np.concatenate(images, axis=0)
            # allimages = ((allimages+1.)*(255.99/2)).astype('int32')
            allimages = allimages.reshape((-1, 3, 32, 32)).transpose(0,2,3,1)
            # images = list(allimages)
            images = allimages
            print("%d images found and loaded: {}" % len(images), images.shape)

            # embed()

            print("calculate FID stats..")
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                mu, sigma = fid.calculate_activation_statistics(images, sess, batch_size=100, verbose=True)
                np.savez_compressed("cifar.fid.stats", mu=mu, sigma=sigma)
            print("finished")

            # embed()

        f = np.load("cifar.fid.stats.npz")
        mu_real, sigma_real = f['mu'][:], f['sigma'][:]

    else:
        print("  NOT RUNNING FID  !!! -- ()()()")

    # embed()

    import tflib.inception_score

    # endregion

    # region model

    def nonlinearity(x):
        return tf.nn.relu(x)

    def Normalize(name, inputs, labels=None):
        """This is messy, but basically it chooses between batchnorm, layernorm,
        their conditional variants, or nothing, depending on the value of `name` and
        the global hyperparam flags."""

        labels = None

        if ('Discriminator' in name) and NORMALIZATION_D:
            assert(False)
            return lib.ops.layernorm.Layernorm(name,[1,2,3],inputs,labels=labels,n_labels=10)
        elif ('Generator' in name) and NORMALIZATION_G:
            if labels is not None:
                assert(False)
                return lib.ops.cond_batchnorm.Batchnorm(name,[0,2,3],inputs,labels=labels,n_labels=10)
            else:
                return lib.ops.batchnorm.Batchnorm(name,[0,2,3],inputs,fused=True)
        else:
            return inputs

    def ConvMeanPool(name, input_dim, output_dim, filter_size, inputs, he_init=True, biases=True):
        output = lib.ops.conv2d.Conv2D(name, input_dim, output_dim, filter_size, inputs, he_init=he_init, biases=biases)
        output = tf.add_n([output[:,:,::2,::2], output[:,:,1::2,::2], output[:,:,::2,1::2], output[:,:,1::2,1::2]]) / 4.
        return output

    def MeanPoolConv(name, input_dim, output_dim, filter_size, inputs, he_init=True, biases=True):
        output = inputs
        output = tf.add_n([output[:,:,::2,::2], output[:,:,1::2,::2], output[:,:,::2,1::2], output[:,:,1::2,1::2]]) / 4.
        output = lib.ops.conv2d.Conv2D(name, input_dim, output_dim, filter_size, output, he_init=he_init, biases=biases)
        return output

    def UpsampleConv(name, input_dim, output_dim, filter_size, inputs, he_init=True, biases=True):
        output = inputs
        output = tf.concat([output, output, output, output], axis=1)
        output = tf.transpose(output, [0,2,3,1])
        output = tf.depth_to_space(output, 2)
        output = tf.transpose(output, [0,3,1,2])
        output = lib.ops.conv2d.Conv2D(name, input_dim, output_dim, filter_size, output, he_init=he_init, biases=biases)
        return output

    def ResidualBlock(name, input_dim, output_dim, filter_size, inputs, resample=None, no_dropout=False, labels=None):
        """
        resample: None, 'down', or 'up'
        """
        if resample=='down':
            conv_1        = functools.partial(lib.ops.conv2d.Conv2D, input_dim=input_dim, output_dim=input_dim)
            conv_2        = functools.partial(ConvMeanPool, input_dim=input_dim, output_dim=output_dim)
            conv_shortcut = ConvMeanPool
        elif resample=='up':
            conv_1        = functools.partial(UpsampleConv, input_dim=input_dim, output_dim=output_dim)
            conv_shortcut = UpsampleConv
            conv_2        = functools.partial(lib.ops.conv2d.Conv2D, input_dim=output_dim, output_dim=output_dim)
        elif resample==None:
            conv_shortcut = lib.ops.conv2d.Conv2D
            conv_1        = functools.partial(lib.ops.conv2d.Conv2D, input_dim=input_dim, output_dim=output_dim)
            conv_2        = functools.partial(lib.ops.conv2d.Conv2D, input_dim=output_dim, output_dim=output_dim)
        else:
            raise Exception('invalid resample value')

        if output_dim==input_dim and resample==None:
            shortcut = inputs # Identity skip-connection
        else:
            shortcut = conv_shortcut(name+'.Shortcut', input_dim=input_dim, output_dim=output_dim, filter_size=1, he_init=False, biases=True, inputs=inputs)

        output = inputs
        output = Normalize(name+'.N1', output, labels=labels)
        output = nonlinearity(output)
        output = conv_1(name+'.Conv1', filter_size=filter_size, inputs=output)
        output = Normalize(name+'.N2', output, labels=labels)
        output = nonlinearity(output)
        output = conv_2(name+'.Conv2', filter_size=filter_size, inputs=output)

        return shortcut + output

    def OptimizedResBlockDisc1(inputs):
        conv_1        = functools.partial(lib.ops.conv2d.Conv2D, input_dim=3, output_dim=DIM_D)
        conv_2        = functools.partial(ConvMeanPool, input_dim=DIM_D, output_dim=DIM_D)
        conv_shortcut = MeanPoolConv
        shortcut = conv_shortcut('Discriminator.1.Shortcut', input_dim=3, output_dim=DIM_D, filter_size=1, he_init=False, biases=True, inputs=inputs)

        output = inputs
        output = conv_1('Discriminator.1.Conv1', filter_size=3, inputs=output)
        output = nonlinearity(output)
        output = conv_2('Discriminator.1.Conv2', filter_size=3, inputs=output)
        return shortcut + output

    def Generator(n_samples, labels, noise=None):
        if noise is None:
            noise = tf.random_normal([n_samples, 128])
        output = lib.ops.linear.Linear('Generator.Input', 128, 4*4*DIM_G, noise)
        output = tf.reshape(output, [-1, DIM_G, 4, 4])
        output = ResidualBlock('Generator.1', DIM_G, DIM_G, 3, output, resample='up', labels=labels)
        output = ResidualBlock('Generator.2', DIM_G, DIM_G, 3, output, resample='up', labels=labels)
        output = ResidualBlock('Generator.3', DIM_G, DIM_G, 3, output, resample='up', labels=labels)
        output = Normalize('Generator.OutputN', output)
        output = nonlinearity(output)
        output = lib.ops.conv2d.Conv2D('Generator.Output', DIM_G, 3, 3, output, he_init=False)
        output = tf.tanh(output)
        return tf.reshape(output, [-1, OUTPUT_DIM])

    def Features(inputs):
        output=inputs
        return output


    def Discriminator(inputs, labels=None):
        output = tf.reshape(inputs, [-1, 3, 32, 32])
        output = OptimizedResBlockDisc1(output)
        output = ResidualBlock('Discriminator.2', DIM_D, DIM_D, 3, output, resample='down', labels=labels)
        output = ResidualBlock('Discriminator.3', DIM_D, DIM_D, 3, output, resample=None, labels=labels)
        output = ResidualBlock('Discriminator.4', DIM_D, DIM_D, 3, output, resample=None, labels=labels)
        output = nonlinearity(output)
        output = tf.reduce_mean(output, axis=[2,3])
        output_wgan = lib.ops.linear.Linear('Discriminator.Output', DIM_D, 1, output)
        output_wgan = tf.reshape(output_wgan, [-1])
        return output_wgan

    # endregion

    with tf.Session() as session:

        _iteration_gan = tf.placeholder(tf.int32, shape=None)
        all_real_data_int = tf.placeholder(tf.int32, shape=[BATCH_SIZE, OUTPUT_DIM])
        all_real_labels = tf.placeholder(tf.int32, shape=[BATCH_SIZE])

        labels_splits = tf.split(all_real_labels, len(DEVICES), axis=0)

        fake_data_splits = []
        for i, device in enumerate(DEVICES):
            with tf.device(device):
                fake_data_splits.append(Generator(BATCH_SIZE/len(DEVICES), labels_splits[i]))

        all_real_data = tf.reshape(2*((tf.cast(all_real_data_int, tf.float32)/256.)-.5), [BATCH_SIZE, OUTPUT_DIM])
        all_real_data += tf.random_uniform(shape=[BATCH_SIZE,OUTPUT_DIM],minval=0.,maxval=1./128) # dequantize
        all_real_data_splits = tf.split(all_real_data, len(DEVICES), axis=0)

        DEVICES_B = DEVICES[:len(DEVICES)/2]
        DEVICES_A = DEVICES[len(DEVICES)/2:]

        disc_costs = []     # total disc cost
        # for separate logging
        scorediff_acc = []
        gps_all_acc = []
        gps_pos_acc = []
        gps_neg_acc = []


        for i, device in enumerate(DEVICES_A):
            with tf.device(device):
                real_and_fake_data = tf.concat([
                    all_real_data_splits[i],
                    all_real_data_splits[len(DEVICES_A)+i],
                    fake_data_splits[i],
                    fake_data_splits[len(DEVICES_A)+i]
                ], axis=0)
                real_and_fake_labels = tf.concat([
                    labels_splits[i],
                    labels_splits[len(DEVICES_A)+i],
                    labels_splits[i],
                    labels_splits[len(DEVICES_A)+i]
                ], axis=0)

                features_all = Features(real_and_fake_data)
                disc_all = Discriminator(features_all, real_and_fake_labels)

                features_real = features_all[:BATCH_SIZE/len(DEVICES_A)]
                features_fake = features_all[BATCH_SIZE/len(DEVICES_A):]

                disc_real = disc_all[:BATCH_SIZE/len(DEVICES_A)]
                disc_fake = disc_all[BATCH_SIZE/len(DEVICES_A):]

                disc_costs.append(tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real))
                scorediff_acc.append(tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real))


        for i, device in enumerate(DEVICES_B):
            with tf.device(device):
                real_data = tf.concat([all_real_data_splits[i], all_real_data_splits[len(DEVICES_A)+i]], axis=0)
                fake_data = tf.concat([fake_data_splits[i], fake_data_splits[len(DEVICES_A)+i]], axis=0)

                features_real = tf.reshape(Features(real_data),(batch_size,-1))
                features_fake = tf.reshape(Features(fake_data),(batch_size,-1))

                labels = tf.concat([
                    labels_splits[i],
                    labels_splits[len(DEVICES_A)+i],
                ], axis=0)




                no_of_interpolation_points = 3

                interpolating_coefficients = np.random.uniform(
                    low=0.0,
                    high=1.0,
                    size=[BATCH_SIZE/len(DEVICES_A), no_of_interpolation_points]).astype('float32')

                feature_differences = features_fake - features_real

                norm_differences_features = tf.sqrt(
                    tf.reduce_sum(tf.square(feature_differences), axis=-1))

                penalty =tf.constant(0.0)
                gp_pos = 0.0


                for j in range(0, no_of_interpolation_points):

                    iteration_interpolating_coefficients = interpolating_coefficients[:, j].reshape(batch_size, 1)

                    #alpha = tf.random_uniform(
                    #    shape=[BATCH_SIZE/len(DEVICES_A),1],
                    #    minval=0.,
                    #    maxval=1.
                    #)


                    # differences = fake_data - real_data
                    interpolating_feature_points = features_real + tf.multiply(feature_differences, iteration_interpolating_coefficients)

                    # interpolates = real_data + (alpha*differences)

                    gradient_vectors = tf.gradients(Discriminator(interpolating_feature_points), [interpolating_feature_points])[0]
                    slopes = tf.sqrt(tf.reduce_sum(tf.square(gradient_vectors), reduction_indices=[1]))

                    # gradients = tf.gradients(Discriminator(interpolates, labels)[0], [interpolates])[0]
                    # slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))

                    penalty = penalty +  1.0/no_of_interpolation_points* LAMBDA * tf.reduce_mean(
                        norm_differences_features * tf.clip_by_value(slopes - 1., 0., np.infty))

                    gp_pos = penalty * 1.  #LAMBDA *tf.reduce_mean(tf.clip_by_value(slopes - 1., 0, np.infty)**2)
                    gp_neg = tf.constant(0.)


                gradient_penalty = penalty



                disc_costs.append(gradient_penalty)
                gps_all_acc.append(gradient_penalty)
                gps_pos_acc.append(gp_pos)
                gps_neg_acc.append(gp_neg)



        disc_wgan = tf.add_n(disc_costs) / len(DEVICES_A)
        # for more logging
        scorediff = tf.add_n(scorediff_acc) / len(DEVICES_A)
        gps_all = tf.add_n(gps_all_acc) / len(DEVICES_B)
        gps_pos = tf.add_n(gps_pos_acc) / len(DEVICES_B)
        gps_neg = tf.add_n(gps_neg_acc) / len(DEVICES_B)


        disc_cost = disc_wgan

        disc_params = lib.params_with_name('Discriminator.')

        if DECAY:
            decay = tf.maximum(0., 1.-(tf.cast(_iteration_gan, tf.float32)/ITERS))

        else:
            decay = 1.

        gen_costs = []

        for device in DEVICES:
            with tf.device(device):
                n_samples = GEN_BS_MULTIPLE * BATCH_SIZE / len(DEVICES)
                fake_labels = tf.cast(tf.random_uniform([n_samples])*10, tf.int32)
                gen_costs.append(-tf.reduce_mean(Discriminator( Features(Generator(n_samples, fake_labels)), fake_labels)))




        gen_cost = (tf.add_n(gen_costs) / len(DEVICES))

        gen_opt = tf.train.AdamOptimizer(learning_rate=LR*decay, beta1=0., beta2=0.9)
        disc_opt = tf.train.AdamOptimizer(learning_rate=LR*decay, beta1=0., beta2=0.9)
        gen_gv = gen_opt.compute_gradients(gen_cost, var_list=lib.params_with_name('Generator'))
        disc_gv = disc_opt.compute_gradients(disc_cost, var_list=disc_params)
        gen_train_op = gen_opt.apply_gradients(gen_gv)
        disc_train_op = disc_opt.apply_gradients(disc_gv)

        # Function for generating samples
        frame_i = [0]
        fixed_noise = tf.constant(np.random.normal(size=(100, 128)).astype('float32'))
        fixed_labels = tf.constant(np.array([0,1,2,3,4,5,6,7,8,9]*10,dtype='int32'))
        fixed_noise_samples = Generator(100, fixed_labels, noise=fixed_noise)

        def generate_image(frame, true_dist):
            samples = session.run(fixed_noise_samples)
            samples = ((samples+1.)*(255./2)).astype('int32')
            lib.save_images.save_images(samples.reshape((100, 3, 32, 32)), '{}/samples_{}.png'.format(log_dir, frame))

        # Function for calculating inception score
        fake_labels_100 = tf.cast(tf.random_uniform([100])*10, tf.int32)
        samples_100 = Generator(100, fake_labels_100)

        def get_IS_and_FID(n):
            all_samples = []
            for i in xrange(n/100):
                all_samples.append(session.run(samples_100))
            all_samples = np.concatenate(all_samples, axis=0)
            all_samples = ((all_samples+1.)*(255.99/2)).astype('int32')
            all_samples = all_samples.reshape((-1, 3, 32, 32)).transpose(0,2,3,1)
            # print("getting IS and FID")
            # print(_iteration_gan)
            # embed()
            # with tf.Session() as _fid_session:
            #     _inception_score, _fid_score = fid.calc_IS_and_FID(all_samples, (mu_real, sigma_real), 100, verbose=True, session=_fid_session)
            # _inception, _inception_std = _inception_score
            _fid_score = 0
            _inception_score = lib.inception_score.get_inception_score(list(all_samples))
            print("calculated IS")
            # print(_inception_score, _inception_score_check)
            # embed()
            # assert(_inception_score[0] == _inception_score_check[0])
            # print("IS calculation same as old")
            if run_fid:
                mu_gen, sigma_gen = fid.calculate_activation_statistics(all_samples, session, 100, verbose=True)
                try:
                    _fid_score = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real, sigma_real)
                except Exception as e:
                    print(e)
                    _fid_score = 10e4
                print("calculated IS and FID")
                return _inception_score, _fid_score
            else:
                return _inception_score, 0

        train_gen, dev_gen = lib.cifar10.load(BATCH_SIZE, DATA_DIR)

        def inf_train_gen():
            while True:
                for images,_labels in train_gen():
                    yield images,_labels


        for name,grads_and_vars in [('G', gen_gv), ('D', disc_gv)]:
            print ("{} Params:".format(name))
            total_param_count = 0
            for g, v in grads_and_vars:
                shape = v.get_shape()
                shape_str = ",".join([str(x) for x in v.get_shape()])

                param_count = 1
                for dim in shape:
                    param_count *= int(dim)
                total_param_count += param_count

                if g == None:
                    print ("\t{} ({}) [no grad!]".format(v.name, shape_str))
                else:
                    print ("\t{} ({})".format(v.name, shape_str))

            print ("Total param count: {}".format(
                locale.format("%d", total_param_count, grouping=True)
            ))



        session.run(tf.initialize_all_variables())


        gen = inf_train_gen()


        for iteration in range(ITERS):
            # TRAINING

            start_time = time.time()

            # embed()

            _gen_cost = 0
            if iteration > 0:
                _gen_cost, _ = session.run([gen_cost, gen_train_op], feed_dict={_iteration_gan: iteration})

            for i in range(N_CRITIC):

                _data,_labels = gen.next()

                _disc_cost, _scorediff, _gps_all, _gps_pos, _gps_neg, _ = session.run(
                    [disc_cost,  scorediff,  gps_all,  gps_pos,  gps_neg,  disc_train_op],
                    feed_dict={all_real_data_int: _data, all_real_labels:_labels, _iteration_gan:iteration})


            lib.plot.plot('cost', _disc_cost)
            # extra plot
            lib.plot.plot('core', _scorediff)
            lib.plot.plot('gp', _gps_all)
            lib.plot.plot('gp_pos', _gps_pos)
            lib.plot.plot('gp_neg', _gps_neg)
            lib.plot.plot('gen_cost', _gen_cost)

            lib.plot.plot('time', time.time() - start_time)

            if iteration % INCEPTION_FREQUENCY == INCEPTION_FREQUENCY-1:
                inception_score, fid_score = get_IS_and_FID(50000)
                lib.plot.plot('inception', inception_score[0])
                lib.plot.plot('inception_std', inception_score[1])      # std of inception over 10 splits
                lib.plot.plot('fid', fid_score)

            # Calculate dev loss and generate samples every 100 iters
            # VALIDATION
            if iteration % 100 == 99:
                dev_disc_costs = []
                dev_scorediff = []
                dev_gp_all = []
                dev_gp_pos = []
                dev_gp_neg = []
                dev_gen_costs = []
                for images,_labels in dev_gen():
                    _dev_disc_cost, _dev_scorediff, _dev_gps_all, _dev_gps_pos, _dev_gps_neg, _dev_gen_cost \
                        = session.run(
                        [disc_cost,          scorediff,      gps_all,      gps_pos,      gps_neg,      gen_cost],
                        feed_dict={all_real_data_int: images,all_real_labels:_labels})

                    dev_disc_costs.append(_dev_disc_cost)
                    dev_scorediff.append(_dev_scorediff)
                    dev_gp_all.append(_dev_gps_all)
                    dev_gp_pos.append(_dev_gps_pos)
                    dev_gp_neg.append(_dev_gps_neg)
                    dev_gen_costs.append(_dev_gen_cost)

                lib.plot.plot('dev_cost', np.mean(dev_disc_costs))
                lib.plot.plot('dev_core', np.mean(dev_scorediff))
                lib.plot.plot('dev_gp', np.mean(dev_gp_all))
                lib.plot.plot('dev_gp_pos', np.mean(dev_gp_pos))
                lib.plot.plot('dev_gp_neg', np.mean(dev_gp_neg))
                lib.plot.plot('dev_gen_cost', np.mean(dev_gen_costs))

            if iteration % 500 == 0:
                generate_image(iteration, _data)

            if (iteration < 500) or (iteration % 1000 == 999):
                lib.plot.flush()

            lib.plot.tick()
コード例 #7
0
def run(mode="wgan-gp",
        dim_g=128,
        dim_d=128,
        critic_iters=5,
        n_gpus=1,
        normalization_g=True,
        normalization_d=False,
        batch_size=64,
        iters=100000,
        penalty_weight=10,
        one_sided=False,
        output_dim=3072,
        lr=2e-4,
        data_dir='/srv/denis/tfvision-datasets/cifar-10-batches-py',
        inception_frequency=1000,
        conditional=False,
        acgan=False,
        log_dir='default_log',
        run_fid=False,
        penalty_mode="grad"):  # grad or pagan or ot

    if mode == "wgan-gp" and penalty_mode == "grad":
        print("WGAN-LP" if one_sided else "WGAN-GP")

    # region start
    # Download CIFAR-10 (Python version) at
    # https://www.cs.toronto.edu/~kriz/cifar.html and fill in the path to the
    # extracted files here!
    lib.plot.logdir = log_dir
    print("log dir set to {}".format(lib.plot.logdir))
    # dump locals() for settings
    with open("{}/settings.dict".format(log_dir), "w") as f:
        loca = locals().copy()
        if penalty_mode != "grad":
            del loca["one_sided"]
        f.write(str(loca))
        print("saved settings: {}".format(loca))

    DATA_DIR = data_dir
    if len(DATA_DIR) == 0:
        raise Exception(
            'Please specify path to data directory in gan_cifar_resnet.py!')

    N_GPUS = n_gpus
    if N_GPUS not in [1, 2]:
        raise Exception('Only 1 or 2 GPUs supported!')

    BATCH_SIZE = batch_size  # Critic batch size
    GEN_BS_MULTIPLE = 2  # Generator batch size, as a multiple of BATCH_SIZE
    ITERS = iters  # How many iterations to train for
    DIM_G = dim_g  # Generator dimensionality
    DIM_D = dim_d  # Critic dimensionality
    NORMALIZATION_G = normalization_g  # Use batchnorm in generator?
    NORMALIZATION_D = normalization_d  # Use batchnorm (or layernorm) in critic?
    OUTPUT_DIM = output_dim  # Number of pixels in CIFAR10 (32*32*3)
    LR = lr  # Initial learning rate
    DECAY = True  # Whether to decay LR over learning
    N_CRITIC = critic_iters  # Critic steps per generator steps
    INCEPTION_FREQUENCY = inception_frequency  # How frequently to calculate Inception score

    CONDITIONAL = conditional  # Whether to train a conditional or unconditional model
    ACGAN = acgan  # If CONDITIONAL, whether to use ACGAN or "vanilla" conditioning
    ACGAN_SCALE = 1.  # How to scale the critic's ACGAN loss relative to WGAN loss
    ACGAN_SCALE_G = 0.1  # How to scale generator's ACGAN loss relative to WGAN loss

    ONE_SIDED = one_sided
    LAMBDA = penalty_weight

    if CONDITIONAL and (not ACGAN) and (not NORMALIZATION_D):
        print(
            "WARNING! Conditional model without normalization in D might be effectively unconditional!"
        )

    DEVICES = ['/gpu:{}'.format(i) for i in range(N_GPUS)]
    if len(DEVICES) == 1:  # Hack because the code assumes 2 GPUs
        DEVICES = [DEVICES[0], DEVICES[0]]

    lib.print_model_settings(locals().copy())

    if run_fid:
        inception_path = "/tmp/inception"
        print("check for inception model..")
        inception_path = fid.check_or_download_inception(
            inception_path)  # download inception if necessary
        print("ok")

        print("create inception graph..")
        fid.create_inception_graph(
            inception_path)  # load the graph into the current TF graph
        print("ok")

        # region cifar FID
        if not os.path.exists(
                "cifar.fid.stats.npz"):  # compute fid stats for CIFAR
            train_gen, dev_gen = lib.cifar10.load(BATCH_SIZE, DATA_DIR)

            # loads all images into memory (this might require a lot of RAM!)
            print("load images..")
            images = []
            for imagebatch, _ in dev_gen():
                images.append(imagebatch)

            _use_train_for_stats = True
            if _use_train_for_stats:
                print("using train data for FID stats")
                for imagebatch, _ in train_gen():
                    images.append(imagebatch)

            allimages = np.concatenate(images, axis=0)
            # allimages = ((allimages+1.)*(255.99/2)).astype('int32')
            allimages = allimages.reshape(
                (-1, 3, 32, 32)).transpose(0, 2, 3, 1)
            # images = list(allimages)
            images = allimages
            print("%d images found and loaded: {}" % len(images), images.shape)

            # embed()

            print("calculate FID stats..")
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                mu, sigma = fid.calculate_activation_statistics(images,
                                                                sess,
                                                                batch_size=100,
                                                                verbose=True)
                np.savez_compressed("cifar.fid.stats", mu=mu, sigma=sigma)
            print("finished")

            # embed()

        f = np.load("cifar.fid.stats.npz")
        mu_real, sigma_real = f['mu'][:], f['sigma'][:]

    else:
        print("  NOT RUNNING FID  !!! -- ()()()")

    # embed()

    import tflib.inception_score

    # endregion

    # region model

    def nonlinearity(x):
        return tf.nn.relu(x)

    def Normalize(name, inputs, labels=None):
        """This is messy, but basically it chooses between batchnorm, layernorm,
        their conditional variants, or nothing, depending on the value of `name` and
        the global hyperparam flags."""
        if not CONDITIONAL:
            labels = None
        if CONDITIONAL and ACGAN and ('Discriminator' in name):
            labels = None

        if ('Discriminator' in name) and NORMALIZATION_D:
            assert (False)
            return lib.ops.layernorm.Layernorm(name, [1, 2, 3],
                                               inputs,
                                               labels=labels,
                                               n_labels=10)
        elif ('Generator' in name) and NORMALIZATION_G:
            if labels is not None:
                assert (False)
                return lib.ops.cond_batchnorm.Batchnorm(name, [0, 2, 3],
                                                        inputs,
                                                        labels=labels,
                                                        n_labels=10)
            else:
                return lib.ops.batchnorm.Batchnorm(name, [0, 2, 3],
                                                   inputs,
                                                   fused=True)
        else:
            return inputs

    def ConvMeanPool(name,
                     input_dim,
                     output_dim,
                     filter_size,
                     inputs,
                     he_init=True,
                     biases=True):
        output = lib.ops.conv2d.Conv2D(name,
                                       input_dim,
                                       output_dim,
                                       filter_size,
                                       inputs,
                                       he_init=he_init,
                                       biases=biases)
        output = tf.add_n([
            output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
            output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]
        ]) / 4.
        return output

    def MeanPoolConv(name,
                     input_dim,
                     output_dim,
                     filter_size,
                     inputs,
                     he_init=True,
                     biases=True):
        output = inputs
        output = tf.add_n([
            output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
            output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]
        ]) / 4.
        output = lib.ops.conv2d.Conv2D(name,
                                       input_dim,
                                       output_dim,
                                       filter_size,
                                       output,
                                       he_init=he_init,
                                       biases=biases)
        return output

    def UpsampleConv(name,
                     input_dim,
                     output_dim,
                     filter_size,
                     inputs,
                     he_init=True,
                     biases=True):
        output = inputs
        output = tf.concat([output, output, output, output], axis=1)
        output = tf.transpose(output, [0, 2, 3, 1])
        output = tf.depth_to_space(output, 2)
        output = tf.transpose(output, [0, 3, 1, 2])
        output = lib.ops.conv2d.Conv2D(name,
                                       input_dim,
                                       output_dim,
                                       filter_size,
                                       output,
                                       he_init=he_init,
                                       biases=biases)
        return output

    def ResidualBlock(name,
                      input_dim,
                      output_dim,
                      filter_size,
                      inputs,
                      resample=None,
                      no_dropout=False,
                      labels=None):
        """
        resample: None, 'down', or 'up'
        """
        if resample == 'down':
            conv_1 = functools.partial(lib.ops.conv2d.Conv2D,
                                       input_dim=input_dim,
                                       output_dim=input_dim)
            conv_2 = functools.partial(ConvMeanPool,
                                       input_dim=input_dim,
                                       output_dim=output_dim)
            conv_shortcut = ConvMeanPool
        elif resample == 'up':
            conv_1 = functools.partial(UpsampleConv,
                                       input_dim=input_dim,
                                       output_dim=output_dim)
            conv_shortcut = UpsampleConv
            conv_2 = functools.partial(lib.ops.conv2d.Conv2D,
                                       input_dim=output_dim,
                                       output_dim=output_dim)
        elif resample == None:
            conv_shortcut = lib.ops.conv2d.Conv2D
            conv_1 = functools.partial(lib.ops.conv2d.Conv2D,
                                       input_dim=input_dim,
                                       output_dim=output_dim)
            conv_2 = functools.partial(lib.ops.conv2d.Conv2D,
                                       input_dim=output_dim,
                                       output_dim=output_dim)
        else:
            raise Exception('invalid resample value')

        if output_dim == input_dim and resample == None:
            shortcut = inputs  # Identity skip-connection
        else:
            shortcut = conv_shortcut(name + '.Shortcut',
                                     input_dim=input_dim,
                                     output_dim=output_dim,
                                     filter_size=1,
                                     he_init=False,
                                     biases=True,
                                     inputs=inputs)

        output = inputs
        output = Normalize(name + '.N1', output, labels=labels)
        output = nonlinearity(output)
        output = conv_1(name + '.Conv1',
                        filter_size=filter_size,
                        inputs=output)
        output = Normalize(name + '.N2', output, labels=labels)
        output = nonlinearity(output)
        output = conv_2(name + '.Conv2',
                        filter_size=filter_size,
                        inputs=output)

        return shortcut + output

    def OptimizedResBlockDisc1(inputs):
        conv_1 = functools.partial(lib.ops.conv2d.Conv2D,
                                   input_dim=3,
                                   output_dim=DIM_D)
        conv_2 = functools.partial(ConvMeanPool,
                                   input_dim=DIM_D,
                                   output_dim=DIM_D)
        conv_shortcut = MeanPoolConv
        shortcut = conv_shortcut('Discriminator.1.Shortcut',
                                 input_dim=3,
                                 output_dim=DIM_D,
                                 filter_size=1,
                                 he_init=False,
                                 biases=True,
                                 inputs=inputs)

        output = inputs
        output = conv_1('Discriminator.1.Conv1', filter_size=3, inputs=output)
        output = nonlinearity(output)
        output = conv_2('Discriminator.1.Conv2', filter_size=3, inputs=output)
        return shortcut + output

    def Generator(n_samples, labels, noise=None):
        if noise is None:
            noise = tf.random_normal([n_samples, 128])
        output = lib.ops.linear.Linear('Generator.Input', 128, 4 * 4 * DIM_G,
                                       noise)
        output = tf.reshape(output, [-1, DIM_G, 4, 4])
        output = ResidualBlock('Generator.1',
                               DIM_G,
                               DIM_G,
                               3,
                               output,
                               resample='up',
                               labels=labels)
        output = ResidualBlock('Generator.2',
                               DIM_G,
                               DIM_G,
                               3,
                               output,
                               resample='up',
                               labels=labels)
        output = ResidualBlock('Generator.3',
                               DIM_G,
                               DIM_G,
                               3,
                               output,
                               resample='up',
                               labels=labels)
        output = Normalize('Generator.OutputN', output)
        output = nonlinearity(output)
        output = lib.ops.conv2d.Conv2D('Generator.Output',
                                       DIM_G,
                                       3,
                                       3,
                                       output,
                                       he_init=False)
        output = tf.tanh(output)
        return tf.reshape(output, [-1, OUTPUT_DIM])

    def Discriminator(inputs, labels):
        output = tf.reshape(inputs, [-1, 3, 32, 32])
        output = OptimizedResBlockDisc1(oudisc_update_opstput)
        output = ResidualBlock('Discriminator.2',
                               DIM_D,
                               DIM_D,
                               3,
                               output,
                               resample='down',
                               labels=labels)
        output = ResidualBlock('Discriminator.3',
                               DIM_D,
                               DIM_D,
                               3,
                               output,
                               resample=None,
                               labels=labels)
        output = ResidualBlock('Discriminator.4',
                               DIM_D,
                               DIM_D,
                               3,
                               output,
                               resample=None,
                               labels=labels)
        output = nonlinearity(output)
        output = tf.reduce_mean(output, axis=[2, 3])
        output_wgan = lib.ops.linear.Linear('Discriminator.Output', DIM_D, 1,
                                            output)
        output_wgan = tf.reshape(output_wgan, [-1])
        if CONDITIONAL and ACGAN:
            output_acgan = lib.ops.linear.Linear('Discriminator.ACGANOutput',
                                                 DIM_D, 10, output)
            return output_wgan, output_acgan
        else:
            return output_wgan, None

    # endregion

    with tf.Session() as session:

        # discriminator = Disc(DIM_D).get_model(K.layers.Input(shape=(OUTPUT_DIM,)))
        # generator = Gen(DIM_G).get_model()

        K.backend.set_learning_phase(True)

        discriminator = Disc(DIM_D, normalize=normalization_d)
        generator = Gen(DIM_G, outdim=OUTPUT_DIM, normalize=normalization_g)
        #
        # discriminator = Discriminator
        # generator = Generator

        _iteration_gan = tf.placeholder(tf.int32, shape=None)
        all_real_data_int = tf.placeholder(tf.int32,
                                           shape=[BATCH_SIZE, OUTPUT_DIM])
        all_real_labels = tf.placeholder(tf.int32, shape=[BATCH_SIZE])

        labels_splits = tf.split(all_real_labels, len(DEVICES), axis=0)

        fake_data_splits = []
        for i, device in enumerate(DEVICES):
            with tf.device(device):
                fake_data_splits.append(
                    generator(BATCH_SIZE / len(DEVICES), labels_splits[i]))

        all_real_data = tf.reshape(
            2 * ((tf.cast(all_real_data_int, tf.float32) / 256.) - .5),
            [BATCH_SIZE, OUTPUT_DIM])
        all_real_data += tf.random_uniform(shape=[BATCH_SIZE, OUTPUT_DIM],
                                           minval=0.,
                                           maxval=1. / 128)  # dequantize
        all_real_data_splits = tf.split(all_real_data, len(DEVICES), axis=0)

        DEVICES_B = DEVICES[:len(DEVICES) / 2]
        DEVICES_A = DEVICES[len(DEVICES) / 2:]

        disc_costs = []  # total disc cost
        # for separate logging
        scorediff_acc = []
        gps_all_acc = []
        gps_pos_acc = []
        gps_neg_acc = []

        disc_acgan_costs = []
        disc_acgan_accs = []
        disc_acgan_fake_accs = []
        for i, device in enumerate(DEVICES_A):
            with tf.device(device):
                real_and_fake_data = tf.concat([
                    all_real_data_splits[i],
                    all_real_data_splits[len(DEVICES_A) + i],
                    fake_data_splits[i], fake_data_splits[len(DEVICES_A) + i]
                ],
                                               axis=0)
                real_and_fake_labels = tf.concat([
                    labels_splits[i], labels_splits[len(DEVICES_A) + i],
                    labels_splits[i], labels_splits[len(DEVICES_A) + i]
                ],
                                                 axis=0)
                disc_all, disc_all_acgan = discriminator(
                    real_and_fake_data, real_and_fake_labels)
                disc_real = disc_all[:BATCH_SIZE / len(DEVICES_A)]
                disc_fake = disc_all[BATCH_SIZE / len(DEVICES_A):]
                disc_costs.append(
                    tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real))
                scorediff_acc.append(
                    tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real))

                if CONDITIONAL and ACGAN:
                    disc_acgan_costs.append(
                        tf.reduce_mean(
                            tf.nn.sparse_softmax_cross_entropy_with_logits(
                                logits=disc_all_acgan[:BATCH_SIZE /
                                                      len(DEVICES_A)],
                                labels=real_and_fake_labels[:BATCH_SIZE /
                                                            len(DEVICES_A)])))
                    disc_acgan_accs.append(
                        tf.reduce_mean(
                            tf.cast(
                                tf.equal(
                                    tf.to_int32(
                                        tf.argmax(
                                            disc_all_acgan[:BATCH_SIZE /
                                                           len(DEVICES_A)],
                                            dimension=1)),
                                    real_and_fake_labels[:BATCH_SIZE /
                                                         len(DEVICES_A)]),
                                tf.float32)))
                    disc_acgan_fake_accs.append(
                        tf.reduce_mean(
                            tf.cast(
                                tf.equal(
                                    tf.to_int32(
                                        tf.argmax(
                                            disc_all_acgan[BATCH_SIZE /
                                                           len(DEVICES_A):],
                                            dimension=1)),
                                    real_and_fake_labels[BATCH_SIZE /
                                                         len(DEVICES_A):]),
                                tf.float32)))

        for i, device in enumerate(DEVICES_B):
            with tf.device(device):
                real_data = tf.concat([
                    all_real_data_splits[i],
                    all_real_data_splits[len(DEVICES_A) + i]
                ],
                                      axis=0)
                fake_data = tf.concat([
                    fake_data_splits[i], fake_data_splits[len(DEVICES_A) + i]
                ],
                                      axis=0)
                labels = tf.concat([
                    labels_splits[i],
                    labels_splits[len(DEVICES_A) + i],
                ],
                                   axis=0)
                if penalty_mode == "grad":
                    alpha = tf.random_uniform(
                        shape=[BATCH_SIZE / len(DEVICES_A), 1],
                        minval=0.,
                        maxval=1.)
                    differences = fake_data - real_data
                    interpolates = real_data + (alpha * differences)
                    gradients = tf.gradients(
                        discriminator(interpolates, labels)[0],
                        [interpolates])[0]
                    slopes = tf.sqrt(
                        tf.reduce_sum(tf.square(gradients),
                                      reduction_indices=[1]))
                    if ONE_SIDED is True:
                        gradient_penalty = LAMBDA * tf.reduce_mean(
                            tf.clip_by_value(slopes - 1., 0, np.infty)**2)
                        gp_pos = gradient_penalty * 1.  #LAMBDA *tf.reduce_mean(tf.clip_by_value(slopes - 1., 0, np.infty)**2)
                        gp_neg = tf.constant(0.)
                    else:
                        gradient_penalty = LAMBDA * tf.reduce_mean(
                            (slopes - 1.)**2)
                        gp_pos = LAMBDA * tf.reduce_mean(
                            tf.clip_by_value(slopes - 1., 0, np.infty)**2)
                        gp_neg = LAMBDA * tf.reduce_mean(
                            tf.clip_by_value(slopes - 1., -np.infty, 0)**2)
                elif penalty_mode == "pagan" or penalty_mode == "ot":
                    _EPS = 1e-6
                    _INTERP = False
                    print("SHAPES OF REAL AND FAKE DATA FOR PAGAN AND OT: ",
                          real_data.get_shape(), fake_data.get_shape())
                    print("TYPES: ", type(real_data), type(fake_data))
                    if _INTERP:  # ONLY EXP_PAGANY_1 is _INTERP !!! ???
                        _alpha = tf.random_uniform(
                            shape=[BATCH_SIZE / len(DEVICES_A), 1],
                            minval=0.,
                            maxval=1.)
                        _alpha2 = tf.random_uniform(
                            shape=[BATCH_SIZE / len(DEVICES_A), 1],
                            minval=0.,
                            maxval=1.)
                        alpha = tf.minimum(_alpha, _alpha2)
                        alpha2 = tf.maximum(_alpha, _alpha2)
                        differences = fake_data - real_data
                        interp_real = real_data + (alpha * differences
                                                   )  # points closer to real
                        interp_fake = real_data + (alpha2 * differences
                                                   )  # points closer to fake
                    else:
                        interp_real = real_data
                        interp_fake = fake_data
                    _D_real, _ = discriminator(
                        interp_real, labels
                    )  # TODO: check make sure labels don't have influence
                    _D_fake, _ = discriminator(interp_fake, labels)
                    real_fake_dist = tf.norm(
                        interp_real - interp_fake, ord=2, axis=1
                    )  # data are vectors: (64, 3072) from get_shape()
                    print("SCORES AND DIST SHAPES: ", _D_real.get_shape(),
                          _D_fake.get_shape(), real_fake_dist.get_shape())
                    print("TYPES: ", type(_D_real), type(_D_fake),
                          type(real_fake_dist))

                    if penalty_mode == "pagan":
                        penalty_vecs = tf.clip_by_value(
                            (tf.abs(_D_real - _D_fake) / tf.clip_by_value(
                                real_fake_dist, _EPS, np.infty)) - 1, 0,
                            np.infty)**2
                    elif penalty_mode == "ot":
                        penalty_vecs = tf.clip_by_value(
                            _D_real - _D_fake - real_fake_dist, 0, np.infty)**2
                    print("PENALTY VEC SHAPE: ", penalty_vecs.get_shape())
                    gradient_penalty = LAMBDA * tf.reduce_mean(penalty_vecs)
                    gp_pos = tf.constant(0.)
                    gp_neg = tf.constant(0.)
                else:
                    raise Exception(
                        "unknown penalty mode {}".format(penalty_mode))
                disc_costs.append(gradient_penalty)
                gps_all_acc.append(gradient_penalty)
                gps_pos_acc.append(gp_pos)
                gps_neg_acc.append(gp_neg)

        disc_wgan = tf.add_n(disc_costs) / len(DEVICES_A)
        # for more logging
        scorediff = tf.add_n(scorediff_acc) / len(DEVICES_A)
        gps_all = tf.add_n(gps_all_acc) / len(DEVICES_B)
        gps_pos = tf.add_n(gps_pos_acc) / len(DEVICES_B)
        gps_neg = tf.add_n(gps_neg_acc) / len(DEVICES_B)

        if CONDITIONAL and ACGAN:
            disc_acgan = tf.add_n(disc_acgan_costs) / len(DEVICES_A)
            disc_acgan_acc = tf.add_n(disc_acgan_accs) / len(DEVICES_A)
            disc_acgan_fake_acc = tf.add_n(disc_acgan_fake_accs) / len(
                DEVICES_A)
            disc_cost = disc_wgan + (ACGAN_SCALE * disc_acgan)
        else:
            disc_acgan = tf.constant(0.)
            disc_acgan_acc = tf.constant(0.)
            disc_acgan_fake_acc = tf.constant(0.)
            disc_cost = disc_wgan

        if hasattr(discriminator, "trainable_weights"):
            disc_params = discriminator.trainable_weights  # keras mode
        else:
            disc_params = lib.params_with_name('Discriminator.')  # original

        disc_update_ops = []
        if hasattr(discriminator, "updates"):
            disc_update_ops = discriminator.updates

        if DECAY:
            decay = tf.maximum(
                0., 1. - (tf.cast(_iteration_gan, tf.float32) / ITERS))
        else:
            decay = 1.

        gen_costs = []
        gen_acgan_costs = []
        for device in DEVICES:
            with tf.device(device):
                n_samples = GEN_BS_MULTIPLE * BATCH_SIZE / len(DEVICES)
                fake_labels = tf.cast(
                    tf.random_uniform([n_samples]) * 10, tf.int32)
                if CONDITIONAL and ACGAN:
                    disc_fake, disc_fake_acgan = Discriminator(
                        Generator(n_samples, fake_labels), fake_labels)
                    gen_costs.append(-tf.reduce_mean(disc_fake))
                    gen_acgan_costs.append(
                        tf.reduce_mean(
                            tf.nn.sparse_softmax_cross_entropy_with_logits(
                                logits=disc_fake_acgan, labels=fake_labels)))
                else:
                    gen_costs.append(-tf.reduce_mean(
                        discriminator(generator(n_samples, fake_labels),
                                      fake_labels)[0]))
        gen_cost = (tf.add_n(gen_costs) / len(DEVICES))
        if CONDITIONAL and ACGAN:
            gen_cost += (ACGAN_SCALE_G *
                         (tf.add_n(gen_acgan_costs) / len(DEVICES)))

        if hasattr(generator, "trainable_weights"):
            gen_params = generator.trainable_weights  # keras mode
        else:
            gen_params = lib.params_with_name('Generator')  # original

        gen_update_ops = []
        if hasattr(generator, "updates"):
            gen_update_ops = generator.updates

        with tf.control_dependencies(gen_update_ops):
            gen_opt = tf.train.AdamOptimizer(learning_rate=LR * decay,
                                             beta1=0.,
                                             beta2=0.9)
            gen_gv = gen_opt.compute_gradients(gen_cost, var_list=gen_params)
            gen_train_op = gen_opt.apply_gradients(gen_gv)

        with tf.control_dependencies(disc_update_ops):
            disc_opt = tf.train.AdamOptimizer(learning_rate=LR * decay,
                                              beta1=0.,
                                              beta2=0.9)
            disc_gv = disc_opt.compute_gradients(disc_cost,
                                                 var_list=disc_params)
            disc_train_op = disc_opt.apply_gradients(disc_gv)

        gen_train_ops = [gen_train_op]  # + gen_update_ops
        disc_train_ops = [disc_train_op]  # + disc_update_ops

        # K.backend.set_learning_phase(False)

        # Function for generating samples
        frame_i = [0]
        fixed_noise = tf.constant(
            np.random.normal(size=(100, 128)).astype('float32'))
        fixed_labels = tf.constant(
            np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9] * 10, dtype='int32'))
        fixed_noise_samples = generator(100, fixed_labels, noise=fixed_noise)

        def generate_image(frame, true_dist):
            samples = session.run(fixed_noise_samples)
            samples = ((samples + 1.) * (255. / 2)).astype('int32')
            lib.save_images.save_images(
                samples.reshape((100, 3, 32, 32)),
                '{}/samples_{}.png'.format(log_dir, frame))

        # Function for calculating inception score
        fake_labels_100 = tf.cast(tf.random_uniform([100]) * 10, tf.int32)
        samples_100 = generator(100, fake_labels_100)

        def get_IS_and_FID(n):
            all_samples = []
            for i in xrange(n / 100):
                all_samples.append(session.run(samples_100))
            all_samples = np.concatenate(all_samples, axis=0)
            all_samples = ((all_samples + 1.) * (255.99 / 2)).astype('int32')
            all_samples = all_samples.reshape(
                (-1, 3, 32, 32)).transpose(0, 2, 3, 1)
            # print("getting IS and FID")
            # print(_iteration_gan)
            # embed()
            # with tf.Session() as _fid_session:
            #     _inception_score, _fid_score = fid.calc_IS_and_FID(all_samples, (mu_real, sigma_real), 100, verbose=True, session=_fid_session)
            # _inception, _inception_std = _inception_score
            _fid_score = 0
            _inception_score = lib.inception_score.get_inception_score(
                list(all_samples))
            print("calculated IS")
            # print(_inception_score, _inception_score_check)
            # embed()
            # assert(_inception_score[0] == _inception_score_check[0])
            # print("IS calculation same as old")
            if run_fid:
                mu_gen, sigma_gen = fid.calculate_activation_statistics(
                    all_samples, session, 100, verbose=True)
                try:
                    _fid_score = fid.calculate_frechet_distance(
                        mu_gen, sigma_gen, mu_real, sigma_real)
                except Exception as e:
                    print(e)
                    _fid_score = 10e4
                print("calculated IS and FID")
                return _inception_score, _fid_score
            else:
                return _inception_score, 0

        train_gen, dev_gen = lib.cifar10.load(BATCH_SIZE, DATA_DIR)

        def inf_train_gen():
            while True:
                for images, _labels in train_gen():
                    yield images, _labels

        for name, grads_and_vars in [('G', gen_gv), ('D', disc_gv)]:
            print("{} Params:".format(name))
            total_param_count = 0
            for g, v in grads_and_vars:
                shape = v.get_shape()
                shape_str = ",".join([str(x) for x in v.get_shape()])

                param_count = 1
                for dim in shape:
                    param_count *= int(dim)
                total_param_count += param_count

                if g == None:
                    print("\t{} ({}) [no grad!]".format(v.name, shape_str))
                else:
                    print("\t{} ({})".format(v.name, shape_str))
            print("Total param count: {}".format(
                locale.format("%d", total_param_count, grouping=True)))

        session.run(tf.initialize_all_variables())

        gen = inf_train_gen()

        for iteration in range(ITERS):
            # TRAINING
            start_time = time.time()

            # embed()

            _gen_cost = 0
            if iteration > 0:
                _gen_cost, _ = session.run(
                    [gen_cost, gen_train_ops],
                    feed_dict={_iteration_gan: iteration})

            for i in range(N_CRITIC):
                _data, _labels = gen.next()
                if CONDITIONAL and ACGAN:
                    assert (False)
                    _disc_cost, _disc_wgan, _disc_acgan, _disc_acgan_acc, _disc_acgan_fake_acc, _ \
                        = session.run([disc_cost, disc_wgan, disc_acgan, disc_acgan_acc, disc_acgan_fake_acc, disc_train_op],
                            feed_dict={all_real_data_int: _data, all_real_labels:_labels, _iteration_gan:iteration})
                else:
                    _disc_cost, _scorediff, _gps_all, _gps_pos, _gps_neg, _ \
                        = session.run(
                    [disc_cost,  scorediff,  gps_all,  gps_pos,  gps_neg,  disc_train_ops],
                        feed_dict={all_real_data_int: _data, all_real_labels:_labels, _iteration_gan:iteration})

            lib.plot.plot('cost', _disc_cost)
            # extra plot
            lib.plot.plot('core', _scorediff)
            lib.plot.plot('gp', _gps_all)
            lib.plot.plot('gp_pos', _gps_pos)
            lib.plot.plot('gp_neg', _gps_neg)
            lib.plot.plot('gen_cost', _gen_cost)
            if CONDITIONAL and ACGAN:
                lib.plot.plot('wgan', _disc_wgan)
                lib.plot.plot('acgan', _disc_acgan)
                lib.plot.plot('acc_real', _disc_acgan_acc)
                lib.plot.plot('acc_fake', _disc_acgan_fake_acc)
            lib.plot.plot('time', time.time() - start_time)

            if iteration % INCEPTION_FREQUENCY == INCEPTION_FREQUENCY - 1:
                inception_score, fid_score = get_IS_and_FID(50000)
                lib.plot.plot('inception', inception_score[0])
                lib.plot.plot(
                    'inception_std',
                    inception_score[1])  # std of inception over 10 splits
                lib.plot.plot('fid', fid_score)

            # Calculate dev loss and generate samples every 100 iters
            # VALIDATION
            if iteration % 100 == 99:
                dev_disc_costs = []
                dev_scorediff = []
                dev_gp_all = []
                dev_gp_pos = []
                dev_gp_neg = []
                dev_gen_costs = []
                for images, _labels in dev_gen():
                    _dev_disc_cost, _dev_scorediff, _dev_gps_all, _dev_gps_pos, _dev_gps_neg, _dev_gen_cost \
                        = session.run(
                    [disc_cost,          scorediff,      gps_all,      gps_pos,      gps_neg,      gen_cost],
                        feed_dict={all_real_data_int: images,all_real_labels:_labels})

                    dev_disc_costs.append(_dev_disc_cost)
                    dev_scorediff.append(_dev_scorediff)
                    dev_gp_all.append(_dev_gps_all)
                    dev_gp_pos.append(_dev_gps_pos)
                    dev_gp_neg.append(_dev_gps_neg)
                    dev_gen_costs.append(_dev_gen_cost)

                lib.plot.plot('dev_cost', np.mean(dev_disc_costs))
                lib.plot.plot('dev_core', np.mean(dev_scorediff))
                lib.plot.plot('dev_gp', np.mean(dev_gp_all))
                lib.plot.plot('dev_gp_pos', np.mean(dev_gp_pos))
                lib.plot.plot('dev_gp_neg', np.mean(dev_gp_neg))
                lib.plot.plot('dev_gen_cost', np.mean(dev_gen_costs))

            if iteration % 500 == 0:
                generate_image(iteration, _data)

            if (iteration < 500) or (iteration % 1000 == 999):
                lib.plot.flush()

            lib.plot.tick()
コード例 #8
0
ファイル: gan_eval.py プロジェクト: arthurmensch/dseg
def score_checkpoint_gidel(input_path):
    CUDA = True
    BATCH_SIZE = 32
    N_CHANNEL = 3
    RESOLUTION = 32
    NUM_SAMPLES = 50000
    DEVICE = 'cpu'

    checkpoint = torch.load(input_path, map_location=DEVICE)

    N_LATENT = 128
    N_FILTERS_G = 128
    BATCH_NORM_G = True

    print "Init..."
    import tflib.fid as fid
    import tflib.inception_score as inception_score
    import tensorflow as tf
    stats_path = 'tflib/data/fid_stats_cifar10_train.npz'
    inception_path = fid.check_or_download_inception('tflib/model')
    f = np.load(stats_path)
    mu_real, sigma_real = f['mu'][:], f['sigma'][:]
    f.close()

    fid.create_inception_graph(
        inception_path)  # load the graph into the current TF graph

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    def get_fid_score():
        all_samples = []
        samples = torch.randn(NUM_SAMPLES, N_LATENT)
        for i in range(0, NUM_SAMPLES, BATCH_SIZE):
            samples_100 = samples[i:i + BATCH_SIZE]
            if CUDA:
                samples_100 = samples_100.cuda(0)
            all_samples.append(gen(samples_100).cpu().data.numpy())

        all_samples = np.concatenate(all_samples, axis=0)
        all_samples = np.multiply(np.add(np.multiply(all_samples, 0.5), 0.5),
                                  255).astype('int32')
        all_samples = all_samples.reshape(
            (-1, N_CHANNEL, RESOLUTION, RESOLUTION)).transpose(0, 2, 3, 1)

        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())
            mu_gen, sigma_gen = fid.calculate_activation_statistics(
                all_samples, sess, batch_size=BATCH_SIZE)

        fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real,
                                                   sigma_real)
        return fid_value

    def get_inception_score():
        all_samples = []
        samples = torch.randn(NUM_SAMPLES, 128)
        for i in range(0, NUM_SAMPLES, BATCH_SIZE):
            samples_100 = samples[i:i + BATCH_SIZE]
            if CUDA:
                samples_100 = samples_100.cuda(0)
            all_samples.append(gen(samples_100).cpu().data.numpy())

        all_samples = np.concatenate(all_samples, axis=0)
        all_samples = np.multiply(np.add(np.multiply(all_samples, 0.5), 0.5),
                                  255).astype('int32')
        all_samples = all_samples.reshape(
            (-1, N_CHANNEL, RESOLUTION, RESOLUTION)).transpose(0, 2, 3, 1)
        return inception_score.get_inception_score(list(all_samples))

    gen = ResNet32Generator(N_LATENT, N_CHANNEL, N_FILTERS_G, BATCH_NORM_G)

    print "Eval..."
    gen.load_state_dict(checkpoint['avg_net_G'], strict=False)
    if CUDA:
        gen.cuda(0)
    is_, std = get_inception_score()
    fid = get_fid_score()
    return is_, std, fid