예제 #1
0
    def __init__(
            self,
            classes,
            target_class_id,
            tag,
            # Set dratio_mode, and gratio_mode to 'rebalance' to bias the sampling toward the minority class
            # No relevant difference noted
            dratio_mode="uniform",
            gratio_mode="uniform",
            adam_lr=0.00005,
            latent_size=100,
            res_dir="./res-tmp",
            image_shape=[3, 32, 32],
            min_latent_res=8,
            N=10,
            batch_size=100,
            sigma=5,
            beta=5):
        self.beta = beta
        self.gratio_mode = gratio_mode
        self.dratio_mode = dratio_mode
        self.classes = classes
        self.target_class_id = target_class_id
        self.nclasses = len(classes)
        self.latent_size = latent_size
        self.res_dir = res_dir
        self.channels = image_shape[0]
        self.resolution = image_shape[1]
        if self.resolution != image_shape[2]:
            print(
                "Error: only squared images currently supported by balancingGAN"
            )
            exit(1)

        # self.min_latent_res = min_latent_res

        # Initialize learning variables
        self.adam_lr = adam_lr
        self.adam_beta_1 = 0.5

        # Initialize stats
        self.train_history = defaultdict(list)
        self.test_history = defaultdict(list)
        self.trained = False

        # Build final_latent
        self.build_latent_use_Gaussian(latent_size=latent_size,
                                       N=self.nclasses,
                                       batch_size=batch_size,
                                       sigma=sigma)
        self.final_latent.compile(optimizer=Adam(lr=self.adam_lr,
                                                 beta_1=self.adam_beta_1),
                                  loss='sparse_categorical_crossentropy')

        #Build generator
        self.build_generator(latent_size, N, init_resolution=min_latent_res)
        self.generator.compile(optimizer=Adam(lr=self.adam_lr,
                                              beta_1=self.adam_beta_1),
                               loss='sparse_categorical_crossentropy')

        # Build classifier
        self.build_classifier(min_latent_res=min_latent_res)
        self.classifier.compile(optimizer=Adam(lr=self.adam_lr,
                                               beta_1=self.adam_beta_1),
                                loss='sparse_categorical_crossentropy')
        # Define combined for training generator with final_latent.
        latent_gen = Input(batch_shape=(batch_size, latent_size))
        # class_label = Input(batch_shape=(batch_size,N))
        z_withclass = self.final_latent(latent_gen)
        fake = self.generator(z_withclass)
        aux = self.classifier(fake)

        self.final_latent.trainable = True
        self.classifier.trainable = False
        self.generator.trainable = True

        self.combined = Model(inputs=latent_gen, outputs=aux)
        # loss of regularization
        weights = self.combined.get_layer(index=1).weights
        mu_t = weights[0]
        sigma_t = weights[1]
        number = (self.nclasses * (self.nclasses - 1) / 2.0)
        # the value that R tends to
        beta_value = tf.constant(float(beta))
        loss_R = R_loss(mu_t, sigma_t, number, beta_value)
        self.combined.add_loss(loss_R)
        self.combined.compile(optimizer=Adam(lr=self.adam_lr,
                                             beta_1=self.adam_beta_1),
                              loss='sparse_categorical_crossentropy')
        self.tag = tag
        self.result_logger = ResultLogger(tag, self.res_dir, verbose=True)
예제 #2
0
class BalancingGAN:
    def __init__(
            self,
            classes,
            target_class_id,
            tag,
            # Set dratio_mode, and gratio_mode to 'rebalance' to bias the sampling toward the minority class
            # No relevant difference noted
            dratio_mode="uniform",
            gratio_mode="uniform",
            adam_lr=0.00005,
            latent_size=100,
            res_dir="./res-tmp",
            image_shape=[3, 32, 32],
            min_latent_res=8,
            N=10,
            batch_size=100,
            sigma=5,
            beta=5):
        self.beta = beta
        self.gratio_mode = gratio_mode
        self.dratio_mode = dratio_mode
        self.classes = classes
        self.target_class_id = target_class_id
        self.nclasses = len(classes)
        self.latent_size = latent_size
        self.res_dir = res_dir
        self.channels = image_shape[0]
        self.resolution = image_shape[1]
        if self.resolution != image_shape[2]:
            print(
                "Error: only squared images currently supported by balancingGAN"
            )
            exit(1)

        # self.min_latent_res = min_latent_res

        # Initialize learning variables
        self.adam_lr = adam_lr
        self.adam_beta_1 = 0.5

        # Initialize stats
        self.train_history = defaultdict(list)
        self.test_history = defaultdict(list)
        self.trained = False

        # Build final_latent
        self.build_latent_use_Gaussian(latent_size=latent_size,
                                       N=self.nclasses,
                                       batch_size=batch_size,
                                       sigma=sigma)
        self.final_latent.compile(optimizer=Adam(lr=self.adam_lr,
                                                 beta_1=self.adam_beta_1),
                                  loss='sparse_categorical_crossentropy')

        #Build generator
        self.build_generator(latent_size, N, init_resolution=min_latent_res)
        self.generator.compile(optimizer=Adam(lr=self.adam_lr,
                                              beta_1=self.adam_beta_1),
                               loss='sparse_categorical_crossentropy')

        # Build classifier
        self.build_classifier(min_latent_res=min_latent_res)
        self.classifier.compile(optimizer=Adam(lr=self.adam_lr,
                                               beta_1=self.adam_beta_1),
                                loss='sparse_categorical_crossentropy')
        # Define combined for training generator with final_latent.
        latent_gen = Input(batch_shape=(batch_size, latent_size))
        # class_label = Input(batch_shape=(batch_size,N))
        z_withclass = self.final_latent(latent_gen)
        fake = self.generator(z_withclass)
        aux = self.classifier(fake)

        self.final_latent.trainable = True
        self.classifier.trainable = False
        self.generator.trainable = True

        self.combined = Model(inputs=latent_gen, outputs=aux)
        # loss of regularization
        weights = self.combined.get_layer(index=1).weights
        mu_t = weights[0]
        sigma_t = weights[1]
        number = (self.nclasses * (self.nclasses - 1) / 2.0)
        # the value that R tends to
        beta_value = tf.constant(float(beta))
        loss_R = R_loss(mu_t, sigma_t, number, beta_value)
        self.combined.add_loss(loss_R)
        self.combined.compile(optimizer=Adam(lr=self.adam_lr,
                                             beta_1=self.adam_beta_1),
                              loss='sparse_categorical_crossentropy')
        self.tag = tag
        self.result_logger = ResultLogger(tag, self.res_dir, verbose=True)

    def generate_latent(self, num, latent_size=100, mode_z='uniform'):
        if mode_z == 'uniform':
            gen_z = np.random.uniform(-1.0, 1.0,
                                      [num, latent_size]).astype(np.float32)
        else:
            gen_z = np.random.normal(0.0, 1.0,
                                     [num, latent_size]).astype(np.float32)
        return gen_z

    def generate_image_labels(self, class_num=10, sample_num=[]):
        generated_images_labels = []
        for i in range(class_num):
            for j in range(sample_num[i]):
                generated_images_labels.append(i)
        return generated_images_labels

    def build_latent_use_Gaussian(self,
                                  latent_size=100,
                                  N=10,
                                  batch_size=100,
                                  sigma=5):
        latent = Input(batch_shape=(batch_size, latent_size))
        reparamter = MyLayer1(batch_size=batch_size,
                              output_dim=latent_size,
                              class_num=N,
                              sigma=sigma)
        reparamter_res = reparamter(latent)
        self.final_latent = Model(inputs=latent, outputs=reparamter_res)

    def build_generator(self, latent_size=100, N=10, init_resolution=8):
        resolution = self.resolution
        channels = self.channels

        cnn = Sequential()
        cnn.add(Dense(1024, input_dim=latent_size, use_bias=False))
        cnn.add(BatchNormalization(momentum=0.9))
        cnn.add(Activation('relu'))

        cnn.add(Dense(128 * init_resolution * init_resolution, use_bias=False))
        cnn.add(BatchNormalization(momentum=0.9))
        cnn.add(Activation('relu'))

        cnn.add(Reshape((128, init_resolution, init_resolution)))
        crt_res = init_resolution

        # upsample
        while crt_res != resolution:
            cnn.add(UpSampling2D(size=(2, 2)))
            if crt_res < resolution / 2:
                cnn.add(
                    Conv2D(256, (5, 5),
                           padding='same',
                           kernel_initializer='glorot_normal',
                           use_bias=False))
                cnn.add(BatchNormalization(momentum=0.9))
                cnn.add(Activation('relu'))
            else:
                cnn.add(
                    Conv2D(128, (5, 5),
                           padding='same',
                           kernel_initializer='glorot_normal',
                           use_bias=False))
                cnn.add(BatchNormalization(momentum=0.9))
                cnn.add(Activation('relu'))

            crt_res = crt_res * 2
            assert crt_res <= resolution, \
                "Error: final resolution [{}] must equal i*2^n. Initial resolution i is [{}]. n must be a natural number.".format(
                    resolution, init_resolution)
        # FIXME: sigmoid here
        cnn.add(
            Conv2D(channels, (2, 2),
                   padding='same',
                   activation='tanh',
                   kernel_initializer='glorot_normal',
                   use_bias=False))
        # This is the latent z space
        latent = Input(shape=(latent_size, ))

        fake_image_from_latent = cnn(latent)

        # The input-output interface
        self.generator = Model(inputs=latent, outputs=fake_image_from_latent)

    def _build_common_encoder(self, image, min_latent_res=8):
        resolution = self.resolution
        channels = self.channels

        # build a relatively standard conv net, with LeakyReLUs as suggested in ACGAN
        cnn = Sequential()

        cnn.add(
            Conv2D(32, (3, 3),
                   padding='same',
                   strides=(2, 2),
                   input_shape=(channels, resolution, resolution),
                   use_bias=True))
        cnn.add(LeakyReLU())
        cnn.add(Dropout(0.3))

        cnn.add(
            Conv2D(64, (3, 3), padding='same', strides=(1, 1), use_bias=True))
        cnn.add(LeakyReLU())
        cnn.add(Dropout(0.3))

        cnn.add(
            Conv2D(128, (3, 3), padding='same', strides=(2, 2), use_bias=True))
        cnn.add(LeakyReLU())
        cnn.add(Dropout(0.3))

        cnn.add(
            Conv2D(256, (3, 3), padding='same', strides=(1, 1), use_bias=True))
        cnn.add(LeakyReLU())
        cnn.add(Dropout(0.3))

        while cnn.output_shape[-1] > min_latent_res:
            cnn.add(
                Conv2D(256, (3, 3),
                       padding='same',
                       strides=(2, 2),
                       use_bias=True))
            cnn.add(LeakyReLU())
            cnn.add(Dropout(0.3))

            cnn.add(
                Conv2D(256, (3, 3),
                       padding='same',
                       strides=(1, 1),
                       use_bias=True))
            cnn.add(LeakyReLU())
            cnn.add(Dropout(0.3))

        cnn.add(Flatten())

        features = cnn(image)
        return features

    def build_classifier(self, min_latent_res=8):
        resolution = self.resolution
        channels = self.channels
        image = Input(shape=(channels, resolution, resolution))
        features = self._build_common_encoder(image, min_latent_res)
        aux = Dense(self.nclasses * 2, activation='softmax',
                    name='auxiliary')(features)
        aux1 = Dense(self.nclasses, activation='softmax',
                     name='classifier')(aux)
        self.classifier = Model(inputs=image, outputs=[aux, aux1])

    def get_batch_count(self, class_num, batch_size, gen_class_ration):
        sample = np.random.choice([i for i in range(class_num)],
                                  size=batch_size,
                                  replace=True,
                                  p=gen_class_ration)
        batch_num = [0] * class_num
        for x in sample:
            batch_num[x] += 1
        batch_num = np.array(batch_num)
        return batch_num

    def _train_one_epoch(self,
                         bg_train,
                         class_num,
                         batch_size=100,
                         latent_size=100,
                         mode_z='uniform',
                         gen_class_ration=[]):
        epoch_classifier_loss = []
        epoch_gen_loss = []

        for image_batch, label_batch in bg_train.next_batch():
            ################## Train Classifier ##################
            X = image_batch
            aux_y = label_batch
            noise_gen = self.generate_latent(batch_size, latent_size, mode_z)
            gen_counts = self.get_batch_count(class_num, batch_size,
                                              gen_class_ration)
            weights = self.final_latent.get_layer(index=1).get_weights()
            mu = weights[0]
            sigma = weights[1]
            final_mu = np.repeat(mu, gen_counts, axis=0)
            final_sigma = np.repeat(sigma, gen_counts, axis=0)
            final_z = final_mu + final_sigma * noise_gen
            fake_label = self.generate_image_labels(class_num, gen_counts)
            generated_images = self.generator.predict(final_z,
                                                      verbose=0,
                                                      batch_size=batch_size)
            X = np.concatenate((X, generated_images), axis=0)
            class_np = np.array([self.nclasses] * batch_size)
            fake_label_np = np.array(fake_label)
            fake_label_train = class_np + fake_label_np

            aux_y = np.concatenate((aux_y, fake_label_train), axis=0)
            aux1_y = np.concatenate((label_batch, fake_label_np), axis=0)
            epoch_classifier_loss.append(
                self.classifier.train_on_batch(X, [aux_y, aux1_y]))

            ################## Train Generator ##################
            noise_gen = self.generate_latent(batch_size, latent_size, mode_z)
            gen_label = self.generate_image_labels(
                class_num, [batch_size // class_num] * class_num)
            loss_gen = self.combined.train_on_batch(noise_gen,
                                                    [gen_label, gen_label])
            # loss_R = loss_gen[0] - loss_gen[1] - loss_gen[2]
            epoch_gen_loss.append(loss_gen)

        # return statistics: generator loss,
        return (np.mean(np.array(epoch_classifier_loss),
                        axis=0), np.mean(np.array(epoch_gen_loss), axis=0))

    def _get_lst_bck_name(self, element):
        # Find last bck name
        files = [
            f for f in os.listdir(self.res_dir) if re.match(
                r'bck_c_{}'.format(self.target_class_id) + "_" + element, f)
        ]
        if len(files) > 0:
            fname = files[0]
            e_str = os.path.splitext(fname)[0].split("_")[-1]
            epoch = int(e_str)
            return epoch, fname
        else:
            return 0, None

    def backup_point(self, epoch, epochs=100):
        if epoch % 50 == 0 or epoch == (epochs - 1) or epoch == (
                epochs - 2) or epoch == (epochs - 3):
            classifier_fname = "{}/bck_c_{}_classifier_e_{}.h5".format(
                self.res_dir, self.target_class_id, epoch)
            self.classifier.save(classifier_fname)
            generator_fname = "{}/bck_c_{}_generator_e_{}.h5".format(
                self.res_dir, self.target_class_id, epoch)
            self.generator.save(generator_fname)
        final_latent_fname = "{}/bck_c_{}_latent_e_{}.h5".format(
            self.res_dir, self.target_class_id, epoch)
        self.final_latent.save(final_latent_fname)

    def train(self,
              bg_train,
              bg_test,
              epochs=100,
              class_num=10,
              latent_size=100,
              mode_z='uniform',
              batch_size=100,
              gen_class_ration=[]):
        if not self.trained:
            # Class actual ratio
            self.class_aratio = bg_train.get_class_probability()
            fixed_latent = self.generate_latent(batch_size, latent_size,
                                                mode_z)

            # Train
            start_e = 0
            for e in range(start_e, epochs):
                start_time = time()
                # Train
                print('GAN train epoch: {}/{}'.format(e, epochs))
                train_classifier_loss, train_gen_loss = self._train_one_epoch(
                    bg_train,
                    class_num,
                    batch_size=batch_size,
                    mode_z=mode_z,
                    gen_class_ration=gen_class_ration)

                loss_R = train_gen_loss[0] - train_gen_loss[
                    1] - train_gen_loss[2]
                self.result_logger.add_training_metrics1(
                    float(train_gen_loss[0]), float(train_gen_loss[1]),
                    float(train_gen_loss[2]), float(loss_R),
                    float(train_classifier_loss[0]),
                    float(train_classifier_loss[1]),
                    float(train_classifier_loss[2]),
                    time() - start_time)
                # Test #
                test_loss = self.classifier.evaluate(
                    bg_test.dataset_x, [bg_test.dataset_y, bg_test.dataset_y],
                    verbose=False)
                self.result_logger.add_testing_metrics(test_loss[0],
                                                       test_loss[1],
                                                       test_loss[2])
                probs_0, probs_1 = self.classifier.predict(
                    bg_test.dataset_x, batch_size=batch_size, verbose=True)
                final_probs = probs_1
                predicts = np.argmax(final_probs, axis=-1)
                self.result_logger.save_prediction(e,
                                                   bg_test.dataset_y,
                                                   predicts,
                                                   probs_0,
                                                   probs_1,
                                                   epochs=epochs)
                self.result_logger.save_metrics()
                print("train_classifier_loss {},\ttrain_gen_loss {},\t".format(
                    train_classifier_loss, train_gen_loss))

                # Save sample images
                if e % 1 == 0:
                    final_latent = self.final_latent.predict(
                        fixed_latent, batch_size=batch_size)
                    generated_images = self.generator.predict(
                        final_latent, verbose=0, batch_size=batch_size)
                    img_samples = generated_images / 2. + 0.5  # 从[-1,1]恢复到[0,1]之间的值
                    save_image_array(img_samples,
                                     '{}/plot_epoch_{}.png'.format(
                                         self.res_dir, e),
                                     batch_size=batch_size,
                                     class_num=10)
                if e % 1 == 0:
                    self.backup_point(e, epochs)
            self.trained = True
예제 #3
0
파일: dgc.py 프로젝트: anonymous184/DGCMM
    def __init__(self, opts, tag):
        tf.reset_default_graph()
        logging.error('Building the Tensorflow Graph')
        gpu_options = tf.GPUOptions(allow_growth=True)
        config = tf.ConfigProto(gpu_options=gpu_options)
        self.sess = tf.Session(config=config)
        self.opts = opts

        assert opts['dataset'] in datashapes, 'Unknown dataset.'
        self.data_shape = datashapes[opts['dataset']]

        self.add_inputs_placeholders()

        self.add_training_placeholders()
        sample_size = tf.shape(self.sample_points)[0]

        enc_mean, enc_sigmas = encoder(opts,
                                       inputs=self.sample_points,
                                       is_training=self.is_training,
                                       y=self.labels)

        enc_sigmas = tf.clip_by_value(enc_sigmas, -50, 50)
        self.enc_mean, self.enc_sigmas = enc_mean, enc_sigmas

        eps = tf.random_normal((sample_size, opts['zdim']),
                               0.,
                               1.,
                               dtype=tf.float32)
        self.encoded = self.enc_mean + tf.multiply(
            eps, tf.sqrt(1e-8 + tf.exp(self.enc_sigmas)))
        # self.encoded = self.enc_mean + tf.multiply(
        #     eps, tf.exp(self.enc_sigmas / 2.))

        (self.reconstructed, self.reconstructed_logits), self.probs1 = \
            decoder(opts, noise=self.encoded,
                    is_training=self.is_training)
        self.correct_sum = tf.reduce_sum(
            tf.cast(tf.equal(tf.argmax(self.probs1, axis=1), self.labels),
                    tf.float32))
        # Decode the content of sample_noise
        (self.decoded,
         self.decoded_logits), _ = decoder(opts,
                                           reuse=True,
                                           noise=self.sample_noise,
                                           is_training=self.is_training)
        # -- Objectives, losses, penalties
        self.loss_cls = self.cls_loss(self.labels, self.probs1)
        self.loss_mmd = self.mmd_penalty(self.sample_noise, self.encoded)
        self.loss_recon = self.reconstruction_loss(self.opts,
                                                   self.sample_points,
                                                   self.reconstructed)
        self.mixup_loss = self.MIXUP_loss(opts, self.encoded, self.labels)
        self.gmmpara_init()
        self.loss_mixture = self.mixture_loss(self.encoded)

        self.objective = self.loss_recon + opts[
            'lambda_cls'] * self.loss_cls + opts['lambda_mixture'] * tf.cast(
                self.loss_mixture, dtype=tf.float32)
        self.objective_pre = self.loss_recon + opts[
            'lambda'] * self.loss_mmd + self.loss_cls

        self.result_logger = ResultLogger(tag, opts['work_dir'], verbose=True)
        self.tag = tag

        logpxy = []
        dimY = opts['n_classes']
        N = sample_size
        S = opts['sampling_size']
        x_rep = tf.tile(self.sample_points, [S, 1, 1, 1])
        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            for i in range(dimY):
                y = tf.fill((N, ), i)
                mu, log_sig = encoder(opts,
                                      inputs=self.sample_points,
                                      reuse=True,
                                      is_training=False,
                                      y=y)
                mu = tf.tile(mu, [S, 1])
                log_sig = tf.tile(log_sig, [S, 1])
                y = tf.tile(y, [S])
                eps2 = tf.random_normal((N * S, opts['zdim']),
                                        0.,
                                        1.,
                                        dtype=tf.float32)
                z = mu + tf.multiply(eps2, tf.sqrt(1e-8 + tf.exp(log_sig)))
                (mu_x, _), logit_y = decoder(opts,
                                             reuse=True,
                                             noise=z,
                                             is_training=False)
                logp = -tf.reduce_sum((x_rep - mu_x)**2, axis=[1, 2, 3])
                log_pyz = -tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=y, logits=logit_y)
                posterior = tf.log(
                    self.theta_p) - 0.5 * tf.log(2 * math.pi * self.lambda_p)
                self.u_p_1 = tf.expand_dims(self.u_p, 2)
                z_m = tf.expand_dims(tf.transpose(z), 1)
                aa = tf.square(z_m - self.u_p_1)
                self.lambda_p_1 = tf.expand_dims(self.lambda_p, 2)
                bb = aa / 2 * self.lambda_p_1
                posterior = tf.expand_dims(posterior, 2) - bb
                posterior_sum = tf.reduce_sum(tf.reduce_sum(posterior, axis=0),
                                              axis=0)
                bound = 0.5 * logp + opts['lambda_cls'] * log_pyz + opts[
                    'lambda_mixture'] * posterior_sum
                bound = tf.reshape(bound, [S, N])
                bound = self.logsumexp(bound) - tf.log(float(S))
                logpxy.append(tf.expand_dims(bound, 1))
            logpxy = tf.concat(logpxy, 1)
        y_pred = tf.nn.softmax(logpxy)

        self.eval_probs = y_pred
        self.test_a = 0.5 * logp
        self.test_b = log_pyz
        self.test_c = posterior_sum

        if opts['e_pretrain']:
            self.loss_pretrain = self.pretrain_loss()
        else:
            self.loss_pretrain = None

        self.add_optimizers()
        self.add_savers()
예제 #4
0
파일: dgc.py 프로젝트: anonymous184/DGCMM
class DGC(object):
    def __init__(self, opts, tag):
        tf.reset_default_graph()
        logging.error('Building the Tensorflow Graph')
        gpu_options = tf.GPUOptions(allow_growth=True)
        config = tf.ConfigProto(gpu_options=gpu_options)
        self.sess = tf.Session(config=config)
        self.opts = opts

        assert opts['dataset'] in datashapes, 'Unknown dataset.'
        self.data_shape = datashapes[opts['dataset']]

        self.add_inputs_placeholders()

        self.add_training_placeholders()
        sample_size = tf.shape(self.sample_points)[0]

        enc_mean, enc_sigmas = encoder(opts,
                                       inputs=self.sample_points,
                                       is_training=self.is_training,
                                       y=self.labels)

        enc_sigmas = tf.clip_by_value(enc_sigmas, -50, 50)
        self.enc_mean, self.enc_sigmas = enc_mean, enc_sigmas

        eps = tf.random_normal((sample_size, opts['zdim']),
                               0.,
                               1.,
                               dtype=tf.float32)
        self.encoded = self.enc_mean + tf.multiply(
            eps, tf.sqrt(1e-8 + tf.exp(self.enc_sigmas)))
        # self.encoded = self.enc_mean + tf.multiply(
        #     eps, tf.exp(self.enc_sigmas / 2.))

        (self.reconstructed, self.reconstructed_logits), self.probs1 = \
            decoder(opts, noise=self.encoded,
                    is_training=self.is_training)
        self.correct_sum = tf.reduce_sum(
            tf.cast(tf.equal(tf.argmax(self.probs1, axis=1), self.labels),
                    tf.float32))
        # Decode the content of sample_noise
        (self.decoded,
         self.decoded_logits), _ = decoder(opts,
                                           reuse=True,
                                           noise=self.sample_noise,
                                           is_training=self.is_training)
        # -- Objectives, losses, penalties
        self.loss_cls = self.cls_loss(self.labels, self.probs1)
        self.loss_mmd = self.mmd_penalty(self.sample_noise, self.encoded)
        self.loss_recon = self.reconstruction_loss(self.opts,
                                                   self.sample_points,
                                                   self.reconstructed)
        self.mixup_loss = self.MIXUP_loss(opts, self.encoded, self.labels)
        self.gmmpara_init()
        self.loss_mixture = self.mixture_loss(self.encoded)

        self.objective = self.loss_recon + opts[
            'lambda_cls'] * self.loss_cls + opts['lambda_mixture'] * tf.cast(
                self.loss_mixture, dtype=tf.float32)
        self.objective_pre = self.loss_recon + opts[
            'lambda'] * self.loss_mmd + self.loss_cls

        self.result_logger = ResultLogger(tag, opts['work_dir'], verbose=True)
        self.tag = tag

        logpxy = []
        dimY = opts['n_classes']
        N = sample_size
        S = opts['sampling_size']
        x_rep = tf.tile(self.sample_points, [S, 1, 1, 1])
        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            for i in range(dimY):
                y = tf.fill((N, ), i)
                mu, log_sig = encoder(opts,
                                      inputs=self.sample_points,
                                      reuse=True,
                                      is_training=False,
                                      y=y)
                mu = tf.tile(mu, [S, 1])
                log_sig = tf.tile(log_sig, [S, 1])
                y = tf.tile(y, [S])
                eps2 = tf.random_normal((N * S, opts['zdim']),
                                        0.,
                                        1.,
                                        dtype=tf.float32)
                z = mu + tf.multiply(eps2, tf.sqrt(1e-8 + tf.exp(log_sig)))
                (mu_x, _), logit_y = decoder(opts,
                                             reuse=True,
                                             noise=z,
                                             is_training=False)
                logp = -tf.reduce_sum((x_rep - mu_x)**2, axis=[1, 2, 3])
                log_pyz = -tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=y, logits=logit_y)
                posterior = tf.log(
                    self.theta_p) - 0.5 * tf.log(2 * math.pi * self.lambda_p)
                self.u_p_1 = tf.expand_dims(self.u_p, 2)
                z_m = tf.expand_dims(tf.transpose(z), 1)
                aa = tf.square(z_m - self.u_p_1)
                self.lambda_p_1 = tf.expand_dims(self.lambda_p, 2)
                bb = aa / 2 * self.lambda_p_1
                posterior = tf.expand_dims(posterior, 2) - bb
                posterior_sum = tf.reduce_sum(tf.reduce_sum(posterior, axis=0),
                                              axis=0)
                bound = 0.5 * logp + opts['lambda_cls'] * log_pyz + opts[
                    'lambda_mixture'] * posterior_sum
                bound = tf.reshape(bound, [S, N])
                bound = self.logsumexp(bound) - tf.log(float(S))
                logpxy.append(tf.expand_dims(bound, 1))
            logpxy = tf.concat(logpxy, 1)
        y_pred = tf.nn.softmax(logpxy)

        self.eval_probs = y_pred
        self.test_a = 0.5 * logp
        self.test_b = log_pyz
        self.test_c = posterior_sum

        if opts['e_pretrain']:
            self.loss_pretrain = self.pretrain_loss()
        else:
            self.loss_pretrain = None

        self.add_optimizers()
        self.add_savers()

    def log_gaussian_prob(self, x, mu=0.0, log_sig=0.0):
        logprob = -(0.5 * np.log(2 * np.pi) + log_sig) \
                  - 0.5 * ((x - mu) / tf.exp(log_sig)) ** 2
        ind = list(range(1, len(x.get_shape().as_list())))
        return tf.reduce_sum(logprob, ind)

    def logsumexp(self, x):
        x_max = tf.reduce_max(x, 0)
        x_ = x - x_max
        tmp = tf.log(
            tf.clip_by_value(tf.reduce_sum(tf.exp(x_), 0), 1e-20, np.inf))
        return tmp + x_max

    def add_inputs_placeholders(self):
        opts = self.opts
        shape = self.data_shape
        data = tf.placeholder(tf.float32, [None] + shape,
                              name='real_points_ph')
        label = tf.placeholder(tf.int64, shape=[None], name='label_ph')
        noise = tf.placeholder(tf.float32, [None] + [opts['zdim']],
                               name='noise_ph')

        self.sample_points = data
        self.sample_noise = noise
        self.labels = label

    def add_training_placeholders(self):
        decay = tf.placeholder(tf.float32, name='rate_decay_ph')
        is_training = tf.placeholder(tf.bool, name='is_training_ph')
        self.lr_decay = decay
        self.is_training = is_training

    def pretrain_loss(self):
        opts = self.opts
        mean_pz = tf.reduce_mean(self.sample_noise, axis=0, keepdims=True)
        mean_qz = tf.reduce_mean(self.encoded, axis=0, keepdims=True)
        mean_loss = tf.reduce_mean(tf.square(mean_pz - mean_qz))
        cov_pz = tf.matmul(self.sample_noise - mean_pz,
                           self.sample_noise - mean_pz,
                           transpose_a=True)
        cov_pz /= opts['e_pretrain_sample_size'] - 1.
        cov_qz = tf.matmul(self.encoded - mean_qz,
                           self.encoded - mean_qz,
                           transpose_a=True)
        cov_qz /= opts['e_pretrain_sample_size'] - 1.
        cov_loss = tf.reduce_mean(tf.square(cov_pz - cov_qz))
        return mean_loss + cov_loss

    def add_savers(self):
        saver = tf.train.Saver(max_to_keep=11)
        tf.add_to_collection('real_points_ph', self.sample_points)
        tf.add_to_collection('noise_ph', self.sample_noise)
        tf.add_to_collection('is_training_ph', self.is_training)
        if self.enc_mean is not None:
            tf.add_to_collection('encoder_mean', self.enc_mean)
            tf.add_to_collection('encoder_var', self.enc_sigmas)
        tf.add_to_collection('encoder', self.encoded)
        tf.add_to_collection('decoder', self.decoded)

        self.saver = saver

    def cls_loss(self, labels, logits):
        return tf.reduce_mean(
            tf.reduce_sum(
                tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
                                                               logits=logits)))

    def MIXUP_loss(self, opts, z_tilde, y):
        alpha = 1.0
        batch_size, z_dim = z_tilde.get_shape().as_list()

        def loss_func(z_tilde):
            lam = np.random.beta(alpha, alpha)
            index = np.random.permutation(len(z_tilde))
            mixed_z = lam * z_tilde + (1.0 - lam) * z_tilde[index]
            return mixed_z, index, lam

        mixed_z, index, lam = tf.py_func(loss_func, [z_tilde],
                                         [tf.float32, tf.int64, tf.float64])
        mixed_z.set_shape(z_tilde.get_shape())
        index.set_shape([
            batch_size,
        ])
        lam.set_shape(None)
        lam = tf.cast(lam, dtype=tf.float32)
        (_, _), pred_y = \
            decoder(opts, noise=mixed_z, is_training=self.is_training, reuse=True)

        y_a, y_b = y, tf.gather(y, index)
        soft1 = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_a,
                                                               logits=pred_y)
        soft2 = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_b,
                                                               logits=pred_y)
        loss = tf.reduce_sum(lam * soft1 + (1 - lam) * soft2, axis=-1)
        loss = tf.reduce_mean(loss)
        return loss

    def save_aug_data(self, x, y):
        filename = self.tag + "_aug.hdf5"
        with h5py.File(self.opts['work_dir'] + os.sep + filename, "w") as f:
            f.create_dataset("x", data=x)
            f.create_dataset("y", data=y)

    def augment_data(self, data, restore=False):
        if restore:
            self.saver.restore(
                self.sess,
                tf.train.latest_checkpoint(
                    os.path.join(self.opts['work_dir'], 'checkpoints')))
        x, y = data
        class_cnt = self.class_cnt
        x_aug_list = []
        y_aug_list = []
        batch_size = self.opts['batch_size']
        aug_num = [
            max(class_cnt) - class_cnt[i] for i in range(len(class_cnt))
        ]
        for i, num in enumerate(aug_num):
            if num <= 0:
                continue
            x_c = x[y == i]
            y_c = y[y == i]
            rand_idx = np.random.choice(len(x_c), num)
            x_raw = x_c[rand_idx]
            y_aug = y_c[rand_idx]
            x_aug_batches = []
            batches_num = math.ceil(len(y_aug) / batch_size)
            for it in tqdm(range(batches_num)):
                start_idx = it * batch_size
                end_idx = start_idx + batch_size
                x_aug_batch = self.sess.run(self.reconstructed,
                                            feed_dict={
                                                self.sample_points:
                                                x_raw[start_idx:end_idx],
                                                self.labels:
                                                y_aug[start_idx:end_idx],
                                                self.is_training:
                                                False
                                            })
                x_aug_batches.append(x_aug_batch)
            x_aug = np.concatenate(x_aug_batches, axis=0)
            x_aug_list.append(x_aug)
            y_aug_list.append(y_aug)
        x_augs = np.concatenate(x_aug_list, axis=0)
        y_augs = np.concatenate(y_aug_list, axis=0)
        x = np.concatenate((x, x_augs), axis=0)
        y = np.concatenate((y, y_augs), axis=0)
        self.save_aug_data(x, y)

    def cal_dis(self, opts, z_tilde, max_iter=20):
        nx = z_tilde
        out = self.probs1
        n_class = opts['n_classes']
        py = tf.get_variable(name='py',
                             shape=[out.shape[0], opts['n_classes']],
                             initializer=tf.zeros_initializer())
        py.assign(tf.argmax(out, 1))
        ny = tf.argmax(out, 1)
        i_iter = tf.Variable(0, name='i', dtype=tf.int64)
        eta = tf.Variable(tf.zeros([
            opts['zdim'],
        ]))
        value_l = tf.Variable(np.inf, name='value_l')

        def cond1(out, nx, ny, py, eta, i_iter, max_iter):
            return tf.equal(py, ny) and tf.less(i_iter, max_iter)

        def body1(out, nx, ny, py, eta, i_iter, max_iter):
            grad_np = tf.gradients(out[py], nx)[0]
            ri = None
            j_iter = tf.Variable(0, name='j', dtype=tf.int64)
            r_i = tf.while_loop(cond2, body2,
                                [grad_np, ri, value_l, py, j_iter, n_class])
            eta.assign_add(r_i)
            (_, _), out = \
                decoder(opts, noise=nx + eta, is_training=self.is_training, reuse=True)
            py = tf.argmax(out, 1)
            i_iter.assign_add(1)
            return (eta * eta).sum()

        def cond2(grad_np, ri, value_l, py, i, n_class):
            return i < n_class

        def body2(grad_np, ri, value_l, py, i, n_class):
            if tf.not_equal(i, py):
                grad_i = tf.gradients(out[0, i], nx)[0]
                wi = grad_i - grad_np
                fi = out[0, i] - out[0, py]
                value_i = np.abs(fi.item()) / np.linalg.norm(
                    wi.numpy().flatten())
                if value_i < value_l:
                    ri = value_i / np.linalg.norm(wi.numpy().flatten()) * wi
            i = i + 1
            return ri

        r_i = tf.while_loop(cond1, body1,
                            [out, nx, ny, py, eta, i_iter, max_iter])

    def mmd_penalty(self, sample_pz, sample_qz):
        opts = self.opts
        sigma2_p = opts['pz_scale']**2
        kernel = opts['mmd_kernel']
        n = utils.get_batch_size(sample_qz)
        n = tf.cast(n, tf.int32)
        nf = tf.cast(n, tf.float32)
        half_size = (n * n - n) / 2

        norms_pz = tf.reduce_sum(tf.square(sample_pz), axis=1, keepdims=True)
        dotprods_pz = tf.matmul(sample_pz, sample_pz, transpose_b=True)
        distances_pz = norms_pz + tf.transpose(norms_pz) - 2. * dotprods_pz

        norms_qz = tf.reduce_sum(tf.square(sample_qz), axis=1, keepdims=True)
        dotprods_qz = tf.matmul(sample_qz, sample_qz, transpose_b=True)
        distances_qz = norms_qz + tf.transpose(norms_qz) - 2. * dotprods_qz

        dotprods = tf.matmul(sample_qz, sample_pz, transpose_b=True)
        distances = norms_qz + tf.transpose(norms_pz) - 2. * dotprods

        if kernel == 'RBF':
            # Median heuristic for the sigma^2 of Gaussian kernel
            sigma2_k = tf.nn.top_k(tf.reshape(distances, [-1]),
                                   half_size).values[half_size - 1]
            sigma2_k += tf.nn.top_k(tf.reshape(distances_qz, [-1]),
                                    half_size).values[half_size - 1]

            if opts['verbose']:
                sigma2_k = tf.Print(sigma2_k, [sigma2_k], 'Kernel width:')
            res1 = tf.exp(-distances_qz / 2. / sigma2_k)
            res1 += tf.exp(-distances_pz / 2. / sigma2_k)
            res1 = tf.multiply(res1, 1. - tf.eye(n))
            res1 = tf.reduce_sum(res1) / (nf * nf - nf)
            res2 = tf.exp(-distances / 2. / sigma2_k)
            res2 = tf.reduce_sum(res2) * 2. / (nf * nf)
            stat = res1 - res2
        elif kernel == 'IMQ':
            # k(x, y) = C / (C + ||x - y||^2)
            # C = tf.nn.top_k(tf.reshape(distances, [-1]), half_size).values[half_size - 1]
            # C += tf.nn.top_k(tf.reshape(distances_qz, [-1]), half_size).values[half_size - 1]
            if opts['pz'] == 'normal':
                Cbase = 2. * opts['zdim'] * sigma2_p
            elif opts['pz'] == 'sphere':
                Cbase = 2.
            elif opts['pz'] == 'uniform':
                # E ||x - y||^2 = E[sum (xi - yi)^2]
                #               = zdim E[(xi - yi)^2]
                #               = const * zdim
                Cbase = opts['zdim']
            stat = 0.
            for scale in [.1, .2, .5, 1., 2., 5., 10.]:
                C = Cbase * scale
                res1 = C / (C + distances_qz)
                res1 += C / (C + distances_pz)
                res1 = tf.multiply(res1, 1. - tf.eye(n))
                res1 = tf.reduce_sum(res1) / (nf * nf - nf)
                res2 = C / (C + distances)
                res2 = tf.reduce_sum(res2) * 2. / (nf * nf)
                stat += res1 - res2
        return stat

    def gmmpara_init(self):
        self.theta_p = tf.get_variable(
            "theta_p", [self.opts['n_classes']], tf.float32,
            tf.constant_initializer(1.0 / self.opts['n_classes']))
        self.u_p = tf.get_variable("u_p",
                                   [self.opts['zdim'], self.opts['n_classes']],
                                   tf.float32,
                                   initializer=tf.constant_initializer(0.0))
        self.lambda_p = tf.get_variable(
            "lambda_p", [self.opts['zdim'], self.opts['n_classes']],
            tf.float32,
            initializer=tf.constant_initializer(1.0))

    def mixture_loss(self, z_tilde):
        z_mean_t = tf.transpose(
            tf.tile(tf.expand_dims(self.enc_mean, dim=1),
                    [1, self.opts['n_classes'], 1]), [0, 2, 1])
        z_log_var_t = tf.transpose(
            tf.tile(tf.expand_dims(self.enc_sigmas, dim=1),
                    [1, self.opts['n_classes'], 1]), [0, 2, 1])
        Z = tf.transpose(
            tf.tile(tf.expand_dims(z_tilde, dim=1),
                    [1, self.opts['n_classes'], 1]), [0, 2, 1])
        u_tensor3 = self.u_p
        lambda_tensor3 = self.lambda_p
        theta_tensor3 = self.theta_p
        a = tf.log(theta_tensor3) - 0.5 * tf.log(2 * math.pi * lambda_tensor3)
        b = tf.square(Z - u_tensor3)
        c = (2 * lambda_tensor3)
        p_c_z = tf.exp(tf.reduce_sum((a - b / c), axis=1)) + 1e-10
        gamma = p_c_z / tf.reduce_sum(p_c_z, axis=-1, keepdims=True)
        gamma_t = tf.tile(tf.expand_dims(gamma, dim=1),
                          [1, self.opts['zdim'], 1])
        loss = tf.reduce_sum(
            0.5 * gamma_t *
            (self.opts['zdim'] * tf.log(math.pi * 2) + tf.log(lambda_tensor3) +
             tf.exp(z_log_var_t) / lambda_tensor3 +
             tf.square(z_mean_t - u_tensor3) / lambda_tensor3),
            axis=(1, 2))
        loss = loss - 0.5 * tf.reduce_sum(self.enc_sigmas + 1, axis=-1)
        loss = loss - tf.reduce_sum(tf.log(self.theta_p) * gamma, axis=-1) \
               + tf.reduce_sum(tf.log(gamma) * gamma, axis=-1)
        loss = tf.reduce_mean(loss)
        return loss

    def reconstruction_loss(self, opts, real, reconstr):
        if opts['cost'] == 'l2':
            # c(x,y) = ||x - y||_2
            loss = tf.reduce_sum(tf.square(real - reconstr), axis=[1, 2, 3])
            loss = 0.2 * tf.reduce_mean(tf.sqrt(1e-08 + loss))
        elif opts['cost'] == 'l2sq':
            # c(x,y) = ||x - y||_2^2
            loss = tf.reduce_sum(tf.square(real - reconstr), axis=[1, 2, 3])
            loss = 0.5 * tf.reduce_mean(loss)
        elif opts['cost'] == 'l1':
            # c(x,y) = ||x - y||_1
            loss = tf.reduce_sum(tf.abs(real - reconstr), axis=[1, 2, 3])
            loss = 0.02 * tf.reduce_mean(loss)
        else:
            assert False, 'Unknown cost function %s' % opts['cost']
        return loss

    def optimizer(self, lr, decay=1.):
        opts = self.opts
        lr *= decay
        return tf.train.AdamOptimizer(lr, beta1=opts["adam_beta1"])

    def add_optimizers(self):
        opts = self.opts
        lr = opts['lr']
        encoder_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                         scope='encoder')
        decoder_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                         scope='generator')
        ae_vars = encoder_vars + decoder_vars

        # Auto-encoder optimizer
        opt = self.optimizer(lr, self.lr_decay)
        self.ae_opt = opt.minimize(loss=self.objective,
                                   var_list=encoder_vars + decoder_vars)

        # Encoder optimizer
        if opts['e_pretrain']:
            opt = self.optimizer(lr)
            self.pretrain_opt = opt.minimize(loss=self.loss_pretrain,
                                             var_list=encoder_vars)
        else:
            self.pretrain_opt = None
        if opts['LVO']:
            self.lvo_opt = opt.minimize(loss=self.objective,
                                        var_list=encoder_vars)

    def get_z_dist(self, data):
        opts = self.opts
        covariances = []
        means = []
        for c in range(opts['n_classes']):
            imgs = data.data[data.labels == c]
            labels = data.labels[data.labels == c]
            batch_size = 128
            num_c = imgs.shape[0]
            latent_np = self.sess.run(self.encoded,
                                      feed_dict={
                                          self.sample_points:
                                          imgs[0:batch_size],
                                          self.labels: labels[0:batch_size],
                                          self.is_training: False
                                      })
            for i in range(1, num_c // batch_size):
                latent_ele = self.sess.run(
                    self.encoded,
                    feed_dict={
                        self.sample_points:
                        imgs[(i * batch_size):((i + 1) * batch_size)],
                        self.labels:
                        labels[(i * batch_size):((i + 1) * batch_size)],
                        self.is_training:
                        False
                    })
                latent_np = np.concatenate((latent_np, latent_ele), axis=0)
            covariances.append(np.cov(np.transpose(latent_np)))
            means.append(np.mean(latent_np, axis=0))

        covariances = np.array(covariances)
        means = np.array(means)

        cfname = "{}covariances.npy".format(opts['work_dir'])
        mfname = "{}means.npy".format(opts['work_dir'])
        # print("saving multivariate: ", cfname, mfname)
        np.save(cfname, covariances)
        np.save(mfname, means)

        return means, covariances

    def sample_pz(self, num=100, z_dist=None, labels=None):
        opts = self.opts
        noise = None
        distr = opts['pz']
        if z_dist is None:
            if distr == 'uniform':
                noise = np.random.uniform(-1, 1, [num, opts["zdim"]]).astype(
                    np.float32)
            elif distr in ('normal', 'sphere'):
                mean = np.zeros(opts["zdim"])
                cov = np.identity(opts["zdim"])
                noise = np.random.multivariate_normal(mean, cov,
                                                      num).astype(np.float32)
                if distr == 'sphere':
                    noise = noise / np.sqrt(np.sum(noise * noise,
                                                   axis=1))[:, np.newaxis]
            return opts['pz_scale'] * noise
        else:
            assert labels is not None
            means, covariances = z_dist
            noise = np.array([
                np.random.multivariate_normal(means[e], covariances[e])
                for e in labels
            ])
            return noise

    def pre_train(self, data):
        opts = self.opts
        batches_num = data.num_points // opts['batch_size']
        self.num_pics = opts['plot_num_pics']
        decay = 1.
        for epoch in range(2):
            # Update learning rate if necessary
            for it in range(batches_num):
                start_idx = it * opts['batch_size']
                end_idx = start_idx + opts['batch_size']
                batch_images = data.data[start_idx:end_idx].astype(np.float)
                batch_labels = data.labels[start_idx:end_idx]
                batch_noise = self.sample_pz(num=opts['batch_size'],
                                             labels=batch_labels)
                feed_d = {
                    self.sample_points: batch_images,
                    self.sample_noise: batch_noise,
                    self.labels: batch_labels,
                    self.lr_decay: decay,
                    self.is_training: True
                }
                [_, loss, loss_rec, loss_cls,
                 train_prob] = self.sess.run([
                     self.ae_opt, self.objective_pre, self.loss_recon,
                     self.loss_cls, self.probs1
                 ],
                                             feed_dict=feed_d)
        z_dist = self.get_z_dist(data)
        return z_dist

    def pretrain_encoder(self, data):
        opts = self.opts
        steps_max = 200
        batch_size = opts['e_pretrain_sample_size']
        for step in range(steps_max):
            train_size = data.num_points
            data_ids = np.random.choice(train_size,
                                        min(train_size, batch_size),
                                        replace=False)
            batch_images = data.data[data_ids].astype(np.float)
            batch_labels = data.labels[data_ids].astype(np.int64)
            batch_noise = self.sample_pz(batch_size)

            [_, loss_pretrain] = self.sess.run(
                [self.pretrain_opt, self.loss_pretrain],
                feed_dict={
                    self.sample_points: batch_images,
                    self.labels: batch_labels,
                    self.sample_noise: batch_noise,
                    self.is_training: True
                })

            if opts['verbose']:
                logging.error('Step %d/%d, loss=%f' %
                              (step, steps_max, loss_pretrain))

            if loss_pretrain < 0.1:
                break

    def augment_batch(self, x, y):
        class_cnt = self.class_cnt

        max_class_cnt = max(class_cnt)
        n_classes = len(class_cnt)
        x_aug_list = []
        y_aug_list = []
        aug_rate = self.opts['aug_rate']
        if aug_rate <= 0:
            return x, y
        aug_nums = [
            aug_rate * (max_class_cnt - class_cnt[i]) for i in range(n_classes)
        ]
        rep_nums = [
            aug_num / class_cnt[i] for i, aug_num in enumerate(aug_nums)
        ]
        for i in range(n_classes):
            idx = (y == i)
            if rep_nums[i] <= 0.:
                x_aug_list.append(x[idx])
                y_aug_list.append(y[idx])
                continue
            n_c = np.count_nonzero(idx)
            if n_c == 0:
                continue
            x_aug_list.append(
                np.repeat(x[idx], repeats=math.ceil(1 + rep_nums[i]),
                          axis=0)[:math.floor(n_c * (1 + rep_nums[i]))])
            y_aug_list.append(
                np.repeat(y[idx], repeats=math.ceil(1 + rep_nums[i]),
                          axis=0)[:math.floor(n_c * (1 + rep_nums[i]))])
        if len(x_aug_list) == 0:
            return x, y
        x_aug = np.concatenate(x_aug_list, axis=0)
        y_aug = np.concatenate(y_aug_list, axis=0)
        return x_aug, y_aug

    def train(self, data):
        opts = self.opts
        self.class_cnt = [
            np.count_nonzero(data.labels == n)
            for n in range(opts['n_classes'])
        ]
        if opts['verbose']:
            logging.error(opts)
        losses = []
        losses_rec = []
        losses_match = []
        losses_cls = []

        batches_num = math.ceil(data.num_points / opts['batch_size'])
        self.num_pics = opts['plot_num_pics']
        self.sess.run(tf.global_variables_initializer())

        if opts['e_pretrain']:
            logging.error('Pretraining the encoder')
            self.pretrain_encoder(data)
            logging.error('Pretraining the encoder done.')

        self.start_time = time.time()
        counter = 0
        decay = 1.
        wait = 0
        z_dist = self.pre_train(data)
        for epoch in range(opts["epoch_num"]):
            # Update learning rate if necessary
            start_time = time.time()
            if opts['lr_schedule'] == "manual":
                if epoch == 30:
                    decay = decay / 2.
                if epoch == 50:
                    decay = decay / 5.
                if epoch == 100:
                    decay = decay / 10.
            elif opts['lr_schedule'] == "manual_smooth":
                enum = opts['epoch_num']
                decay_t = np.exp(np.log(100.) / enum)
                decay = decay / decay_t

            elif opts['lr_schedule'] != "plateau":
                assert type(opts['lr_schedule']) == float
                decay = 1.0 * 10**(-epoch / float(opts['lr_schedule']))

            # Save the model
            if epoch > 0 and epoch % opts['save_every_epoch'] == 0:
                self.saver.save(self.sess,
                                os.path.join(opts['work_dir'], 'checkpoints',
                                             'trained'),
                                global_step=counter)

            acc_total = 0.
            loss_total = 0.

            z_list = []
            y_list = []
            mu_list = []
            logsigma_list = []

            for it in tqdm(range(batches_num)):
                start_idx = it * opts['batch_size']
                end_idx = start_idx + opts['batch_size']
                batch_images = data.data[start_idx:end_idx]
                batch_labels = data.labels[start_idx:end_idx]
                orig_batch_labels = batch_labels
                orig_batch_images = batch_images
                if opts['augment_z'] is True:
                    batch_images, batch_labels = self.augment_batch(
                        batch_images, batch_labels)
                train_size = len(batch_labels)
                # print(train_size, len(orig_batch_labels))
                batch_noise = self.sample_pz(len(batch_images),
                                             z_dist=z_dist,
                                             labels=batch_labels)
                if opts['LVO'] is True:
                    _ = self.sess.run(self.lvo_opt,
                                      feed_dict={
                                          self.sample_points: batch_images,
                                          self.sample_noise: batch_noise,
                                          self.labels: batch_labels,
                                          self.lr_decay: decay,
                                          self.is_training: True
                                      })

                feed_d = {
                    self.sample_points: batch_images,
                    self.sample_noise: batch_noise,
                    self.labels: batch_labels,
                    self.lr_decay: decay,
                    self.is_training: True
                }

                (_, loss, loss_rec, loss_cls, loss_match, correct,
                 theta_p_final, u_p_final, lambda_p_final, mu,
                 logsigma) = self.sess.run([
                     self.ae_opt, self.objective, self.loss_recon,
                     self.loss_cls, self.loss_mixture, self.correct_sum,
                     self.theta_p, self.u_p, self.lambda_p, self.enc_mean,
                     self.enc_sigmas
                 ],
                                           feed_dict=feed_d)
                acc_total += correct / train_size

                loss_total += loss

                if opts['lr_schedule'] == "plateau":
                    if epoch >= 30:
                        if loss < min(losses[-20 * batches_num:]):
                            wait = 0
                        else:
                            wait += 1
                        if wait > 10 * batches_num:
                            decay = max(decay / 1.4, 1e-6)
                            logging.error('Reduction in lr: %f' % decay)
                            wait = 0

                feed_d = {
                    self.sample_points:
                    orig_batch_images,
                    # self.sample_noise: batch_noise,
                    # self.labels: batch_labels,
                    self.is_training:
                    False
                }

                z_final = self.sess.run(self.encoded, feed_dict=feed_d)

                # print('z_final',z_final.shape)

                losses.append(loss)
                losses_rec.append(loss_rec)
                losses_match.append(loss_match)
                losses_cls.append(loss_cls)

                counter += 1

                if epoch >= 0 and epoch % opts['save_every_epoch'] == 0:
                    z_list.append(z_final)
                    y_list.append(orig_batch_labels)
                    mu_list.append(mu)
                    logsigma_list.append(logsigma)
                    # train_prob_list.append(train_prob)

            if epoch >= 0 and epoch % opts['save_every_epoch'] == 0:
                mus = np.concatenate(mu_list, axis=0)
                logsigmas = np.concatenate(logsigma_list, axis=0)
                # print('epoch-calculating zs ys', epoch)
                zs = np.concatenate(z_list, axis=0)
                ys = np.concatenate(y_list, axis=0)
                self.result_logger.save_latent_code_new(
                    epoch, zs, ys, mus, logsigmas, theta_p_final, u_p_final,
                    lambda_p_final)

            # Print debug info
            now = time.time()
            # Auto-encoding test images
            [loss_rec_test, loss_cls_test] = self.sess.run(
                [self.loss_recon, self.loss_cls],
                feed_dict={
                    self.sample_points: data.test_data[:self.num_pics],
                    self.labels: data.test_labels[:self.num_pics],
                    self.is_training: False
                })

            debug_str = 'EPOCH: %d/%d, BATCH/SEC:%.2f' % (
                epoch + 1, opts['epoch_num'], float(counter) /
                (now - self.start_time))
            debug_str += ' (TOTAL_LOSS=%.5f, RECON_LOSS=%.5f, ' \
                         'MATCH_LOSS=%.5f, ' \
                         'CLS_LOSS=%.5f, ' \
                         'RECON_LOSS_TEST=%.5f, ' \
                         'CLS_LOSS_TEST=%.5f, ' % (
                             losses[-1], losses_rec[-1],
                             losses_match[-1], losses_cls[-1], loss_rec_test, loss_cls_test)
            logging.error(debug_str)

            training_acc = acc_total / batches_num
            avg_loss = loss_total / batches_num
            self.result_logger.add_training_metrics(avg_loss, training_acc,
                                                    time.time() - start_time)

            # if (self.opts['eval_strategy'] == 1 and (epoch + 1) % 5 == 0) or self.opts['eval_strategy'] == 2 and (
            #         (0 < epoch <= 20) or (epoch > 20 and epoch % 3 == 0)):
            self.evaluate(data, epoch)

            if epoch > 0:  # and epoch % 10 == 0:
                self.saver.save(self.sess,
                                os.path.join(opts['work_dir'], 'checkpoints',
                                             'trained-final'),
                                global_step=epoch)
            self.viz_img(data, epoch)

        self.result_logger.save_metrics()
        # For FID
        # self.augment_data((data.data, data.labels))

    def evaluate(self, data, epoch):
        batch_size = self.opts['batch_size'] // 10
        batches_num = math.ceil(len(data.test_data) / batch_size)
        probs = []
        start_time = time.time()
        for it in tqdm(range(batches_num)):
            start_idx = it * batch_size
            end_idx = start_idx + batch_size
            [prob, tst_a, tst_b, tst_c] = self.sess.run(
                [self.eval_probs, self.test_a, self.test_b, self.test_c],
                feed_dict={
                    self.sample_points: data.test_data[start_idx:end_idx],
                    self.is_training: False
                })
            probs.append(prob)
            # if it==1:
            # print('tst', tst_b, tst_c)
        probs = np.concatenate(probs, axis=0)
        predicts = np.argmax(probs, axis=-1)
        self.result_logger.save_prediction(epoch, data.test_labels, predicts,
                                           probs,
                                           time.time() - start_time)
        self.result_logger.save_metrics()

    def viz_img(self, data, epoch):
        x = data.data
        y = data.labels
        n_classes = self.opts['n_classes']
        batch_size = self.opts['batch_size']
        x_aug_list = []
        # y_aug_list = []
        for i in range(n_classes):
            x_c = x[y == i]
            y_c = y[y == i]
            rand_idx = np.random.choice(len(x_c), 50)
            x_raw = x_c[rand_idx]
            y_aug = y_c[rand_idx]
            x_aug_batches = []
            batches_num = math.ceil(len(x_raw) / batch_size)
            for it in (range(batches_num)):
                start_idx = it * batch_size
                end_idx = start_idx + batch_size
                x_aug_batch = self.sess.run(self.reconstructed,
                                            feed_dict={
                                                self.sample_points:
                                                x_raw[start_idx:end_idx],
                                                self.labels:
                                                y_aug[start_idx:end_idx],
                                                self.is_training:
                                                False
                                            })
                x_aug_batches.append(x_aug_batch)
            x_aug = np.concatenate(x_aug_batches, axis=0)
            x_aug_list.append(x_aug)
            # y_aug_list.append(y_aug)
        x_aug = np.concatenate(x_aug_list, axis=0)

        import torch
        from torchvision.utils import save_image
        filename = os.path.join(self.opts['work_dir'], "epoch%d.png" % epoch)
        save_image(torch.from_numpy(x_aug).permute(0, 3, 1, 2),
                   filename,
                   nrow=n_classes,
                   padding=0)
예제 #5
0
파일: caleg.py 프로젝트: anonymous184/CaLeG
    def __init__(self, opts, tag):
        tf.reset_default_graph()
        gpu_options = tf.GPUOptions(allow_growth=True)
        config = tf.ConfigProto(gpu_options=gpu_options)
        self.sess = tf.Session(config=config)
        self.opts = opts
        self.tag = tag
        assert opts['dataset'] in datashapes, 'Unknown dataset.'
        shape = datashapes[opts['dataset']]

        # Placeholders
        self.sample_points = tf.placeholder(tf.float32, [None] + shape, name='real_points_ph')
        self.labels = tf.placeholder(tf.int32, shape=[None], name='label_ph')
        self.sample_noise = tf.placeholder(tf.float32, [None] + [opts['zdim']], name='noise_ph')
        self.fixed_sample_labels = tf.placeholder(tf.int32, shape=[None], name='fixed_sample_label_ph')
        self.lr_decay = tf.placeholder(tf.float32, name='rate_decay_ph')
        self.is_training = tf.placeholder(tf.bool, name='is_training_ph')

        # Ops
        self.encoded = encoder(opts, inputs=self.sample_points, is_training=self.is_training)
        self.reconstructed, self.probs1 = decoder(opts, noise=self.encoded, is_training=self.is_training)
        self.prob1_softmaxed = tf.nn.softmax(self.probs1, axis=-1)
        self.correct_sum = tf.reduce_sum(
            tf.cast(tf.equal(tf.argmax(self.prob1_softmaxed, axis=1, output_type=tf.int32), self.labels), tf.float32))
        self.decoded, self.probs2 = decoder(opts, reuse=True, noise=self.sample_noise, is_training=self.is_training)

        self.De_pro_tilde_logits, self.De_pro_tilde_wdistance = self.discriminate(self.reconstructed)
        self.D_pro_logits, self.D_pro_logits_wdistance = self.discriminate(self.sample_points)
        self.G_pro_logits, self.G_pro_logits_wdistance = self.discriminate(self.decoded)
        self.predict_as_real_mask = tf.equal(tf.argmax(self.G_pro_logits, axis=1, output_type=tf.int32),
                                             self.fixed_sample_labels)

        # Objectives, losses, penalties
        self.loss_cls = self.cls_loss(self.labels, self.probs1)
        self.penalty = self.mmd_penalty(self.encoded)
        self.loss_reconstruct = self.reconstruction_loss(self.opts, self.sample_points, self.reconstructed)
        self.wgan_d_loss = tf.reduce_mean(self.De_pro_tilde_wdistance) + tf.reduce_mean(
            self.G_pro_logits_wdistance) - 2 * tf.reduce_mean(self.D_pro_logits_wdistance)
        self.wgan_g_loss = -(tf.reduce_mean(self.De_pro_tilde_wdistance) + tf.reduce_mean(self.G_pro_logits_wdistance))
        self.wgan_d_penalty1 = self.gradient_penalty(self.sample_points, self.reconstructed)
        self.wgan_d_penalty2 = self.gradient_penalty(self.sample_points, self.decoded)
        self.wgan_d_penalty = 0.5 * (self.wgan_d_penalty1 + self.wgan_d_penalty2)
        #  G_additional loss
        self.G_fake_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.fixed_sample_labels,
                                                           logits=self.G_pro_logits))
        self.G_tilde_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.labels,
                                                           logits=self.De_pro_tilde_logits))
        #  D loss
        self.D_fake_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.fixed_sample_labels,
                                                           logits=self.G_pro_logits))
        self.D_real_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.labels,
                                                           logits=self.D_pro_logits))
        self.D_tilde_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.labels, logits=self.De_pro_tilde_logits))

        self.encoder_objective = self.loss_reconstruct + opts['lambda'] * self.penalty + self.loss_cls
        self.decoder_objective = self.loss_reconstruct + self.G_fake_loss + self.G_tilde_loss + self.wgan_g_loss
        self.disc_objective = self.D_real_loss + self.D_fake_loss + \
                              self.D_tilde_loss + self.wgan_d_loss + self.wgan_d_penalty

        self.total_loss = self.loss_reconstruct + opts['lambda'] * self.penalty + self.loss_cls
        self.loss_pretrain = self.pretrain_loss() if opts['e_pretrain'] else None

        # Optimizers, savers, etc
        opts = self.opts
        lr = opts['lr']
        encoder_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='encoder')
        decoder_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
        discriminator_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
        optim = self.optimizer(lr, self.lr_decay)
        self.encoder_opt = optim.minimize(loss=self.encoder_objective, var_list=encoder_vars)
        self.decoder_opt = optim.minimize(loss=self.decoder_objective, var_list=decoder_vars)
        self.disc_opt = optim.minimize(loss=self.disc_objective, var_list=discriminator_vars)
        self.ae_opt = optim.minimize(loss=self.total_loss, var_list=encoder_vars + decoder_vars)
        self.pretrain_opt = self.optimizer(lr).minimize(loss=self.loss_pretrain,
                                                        var_list=encoder_vars) if opts['e_pretrain'] else None

        self.saver = tf.train.Saver(max_to_keep=10)
        tf.add_to_collection('real_points_ph', self.sample_points)
        tf.add_to_collection('noise_ph', self.sample_noise)
        tf.add_to_collection('is_training_ph', self.is_training)
        tf.add_to_collection('encoder', self.encoded)
        tf.add_to_collection('decoder', self.decoded)

        self.init = tf.global_variables_initializer()
        self.result_logger = ResultLogger(tag, opts['work_dir'], verbose=True)
예제 #6
0
파일: caleg.py 프로젝트: anonymous184/CaLeG
class CaLeG(object):

    def __init__(self, opts, tag):
        tf.reset_default_graph()
        gpu_options = tf.GPUOptions(allow_growth=True)
        config = tf.ConfigProto(gpu_options=gpu_options)
        self.sess = tf.Session(config=config)
        self.opts = opts
        self.tag = tag
        assert opts['dataset'] in datashapes, 'Unknown dataset.'
        shape = datashapes[opts['dataset']]

        # Placeholders
        self.sample_points = tf.placeholder(tf.float32, [None] + shape, name='real_points_ph')
        self.labels = tf.placeholder(tf.int32, shape=[None], name='label_ph')
        self.sample_noise = tf.placeholder(tf.float32, [None] + [opts['zdim']], name='noise_ph')
        self.fixed_sample_labels = tf.placeholder(tf.int32, shape=[None], name='fixed_sample_label_ph')
        self.lr_decay = tf.placeholder(tf.float32, name='rate_decay_ph')
        self.is_training = tf.placeholder(tf.bool, name='is_training_ph')

        # Ops
        self.encoded = encoder(opts, inputs=self.sample_points, is_training=self.is_training)
        self.reconstructed, self.probs1 = decoder(opts, noise=self.encoded, is_training=self.is_training)
        self.prob1_softmaxed = tf.nn.softmax(self.probs1, axis=-1)
        self.correct_sum = tf.reduce_sum(
            tf.cast(tf.equal(tf.argmax(self.prob1_softmaxed, axis=1, output_type=tf.int32), self.labels), tf.float32))
        self.decoded, self.probs2 = decoder(opts, reuse=True, noise=self.sample_noise, is_training=self.is_training)

        self.De_pro_tilde_logits, self.De_pro_tilde_wdistance = self.discriminate(self.reconstructed)
        self.D_pro_logits, self.D_pro_logits_wdistance = self.discriminate(self.sample_points)
        self.G_pro_logits, self.G_pro_logits_wdistance = self.discriminate(self.decoded)
        self.predict_as_real_mask = tf.equal(tf.argmax(self.G_pro_logits, axis=1, output_type=tf.int32),
                                             self.fixed_sample_labels)

        # Objectives, losses, penalties
        self.loss_cls = self.cls_loss(self.labels, self.probs1)
        self.penalty = self.mmd_penalty(self.encoded)
        self.loss_reconstruct = self.reconstruction_loss(self.opts, self.sample_points, self.reconstructed)
        self.wgan_d_loss = tf.reduce_mean(self.De_pro_tilde_wdistance) + tf.reduce_mean(
            self.G_pro_logits_wdistance) - 2 * tf.reduce_mean(self.D_pro_logits_wdistance)
        self.wgan_g_loss = -(tf.reduce_mean(self.De_pro_tilde_wdistance) + tf.reduce_mean(self.G_pro_logits_wdistance))
        self.wgan_d_penalty1 = self.gradient_penalty(self.sample_points, self.reconstructed)
        self.wgan_d_penalty2 = self.gradient_penalty(self.sample_points, self.decoded)
        self.wgan_d_penalty = 0.5 * (self.wgan_d_penalty1 + self.wgan_d_penalty2)
        #  G_additional loss
        self.G_fake_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.fixed_sample_labels,
                                                           logits=self.G_pro_logits))
        self.G_tilde_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.labels,
                                                           logits=self.De_pro_tilde_logits))
        #  D loss
        self.D_fake_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.fixed_sample_labels,
                                                           logits=self.G_pro_logits))
        self.D_real_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.labels,
                                                           logits=self.D_pro_logits))
        self.D_tilde_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.labels, logits=self.De_pro_tilde_logits))

        self.encoder_objective = self.loss_reconstruct + opts['lambda'] * self.penalty + self.loss_cls
        self.decoder_objective = self.loss_reconstruct + self.G_fake_loss + self.G_tilde_loss + self.wgan_g_loss
        self.disc_objective = self.D_real_loss + self.D_fake_loss + \
                              self.D_tilde_loss + self.wgan_d_loss + self.wgan_d_penalty

        self.total_loss = self.loss_reconstruct + opts['lambda'] * self.penalty + self.loss_cls
        self.loss_pretrain = self.pretrain_loss() if opts['e_pretrain'] else None

        # Optimizers, savers, etc
        opts = self.opts
        lr = opts['lr']
        encoder_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='encoder')
        decoder_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
        discriminator_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
        optim = self.optimizer(lr, self.lr_decay)
        self.encoder_opt = optim.minimize(loss=self.encoder_objective, var_list=encoder_vars)
        self.decoder_opt = optim.minimize(loss=self.decoder_objective, var_list=decoder_vars)
        self.disc_opt = optim.minimize(loss=self.disc_objective, var_list=discriminator_vars)
        self.ae_opt = optim.minimize(loss=self.total_loss, var_list=encoder_vars + decoder_vars)
        self.pretrain_opt = self.optimizer(lr).minimize(loss=self.loss_pretrain,
                                                        var_list=encoder_vars) if opts['e_pretrain'] else None

        self.saver = tf.train.Saver(max_to_keep=10)
        tf.add_to_collection('real_points_ph', self.sample_points)
        tf.add_to_collection('noise_ph', self.sample_noise)
        tf.add_to_collection('is_training_ph', self.is_training)
        tf.add_to_collection('encoder', self.encoded)
        tf.add_to_collection('decoder', self.decoded)

        self.init = tf.global_variables_initializer()
        self.result_logger = ResultLogger(tag, opts['work_dir'], verbose=True)

    def discriminate(self, image):
        res_logits, res_wdistance = discriminator(self.opts, inputs=image, is_training=self.is_training)
        return res_logits, res_wdistance

    def pretrain_loss(self):
        mean_pz = tf.reduce_mean(self.sample_noise, axis=0, keepdims=True)
        mean_qz = tf.reduce_mean(self.encoded, axis=0, keepdims=True)
        mean_loss = tf.reduce_mean(tf.square(mean_pz - mean_qz))
        cov_pz = tf.matmul(self.sample_noise - mean_pz,
                           self.sample_noise - mean_pz, transpose_a=True)
        cov_pz /= tf.cast(tf.shape(self.sample_noise)[0], tf.float32) - 1.
        cov_qz = tf.matmul(self.encoded - mean_qz,
                           self.encoded - mean_qz, transpose_a=True)
        cov_qz /= tf.cast(tf.shape(self.encoded)[0], tf.float32) - 1.
        cov_loss = tf.reduce_mean(tf.square(cov_pz - cov_qz))
        return mean_loss + cov_loss

    def cls_loss(self, labels, logits):
        return tf.reduce_mean(tf.reduce_sum(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)))

    def mmd_penalty(self, sample_qz, mean=0., std=1.):
        opts = self.opts
        sample_pz = tf.random_normal(tf.shape(sample_qz), mean, std)
        sigma2_p = 1.
        n = utils.get_batch_size(sample_qz)
        n = tf.cast(n, tf.int32)
        nf = tf.cast(n, tf.float32)

        norms_pz = tf.reduce_sum(tf.square(sample_pz), axis=1, keepdims=True)
        dotprods_pz = tf.matmul(sample_pz, sample_pz, transpose_b=True)
        distances_pz = norms_pz + tf.transpose(norms_pz) - 2. * dotprods_pz

        norms_qz = tf.reduce_sum(tf.square(sample_qz), axis=1, keepdims=True)
        dotprods_qz = tf.matmul(sample_qz, sample_qz, transpose_b=True)
        distances_qz = norms_qz + tf.transpose(norms_qz) - 2. * dotprods_qz

        dotprods = tf.matmul(sample_qz, sample_pz, transpose_b=True)
        distances = norms_qz + tf.transpose(norms_pz) - 2. * dotprods
        Cbase = 2. * opts['zdim'] * sigma2_p
        loss_match = 0.
        for scale in [.1, .2, .5, 1., 2., 5., 10.]:
            C = Cbase * scale
            res1 = C / (C + distances_qz)
            res1 += C / (C + distances_pz)
            res1 = tf.multiply(res1, 1. - tf.eye(n))
            res1 = tf.reduce_sum(res1) / (nf * nf - nf)
            res2 = C / (C + distances)
            res2 = tf.reduce_sum(res2) * 2. / (nf * nf)
            loss_match += res1 - res2
        return loss_match

    def gradient_penalty(self, real, generated):
        shape = tf.shape(generated)[0]
        if self.opts['aug_rate'] > 1.0:
            idxs = tf.range(shape)
            ridxs = tf.random_shuffle(idxs)
            real = tf.gather(real, ridxs)
        elif self.opts['aug_rate'] < 1.0:
            real_shape = tf.shape(real)[0]
            idxs = tf.range(real_shape)
            ridxs = tf.random_shuffle(idxs)[:shape]
            real = tf.gather(real, ridxs)
        alpha = tf.random_uniform(shape=[shape, 1, 1, 1], minval=0., maxval=1.)
        # alpha = tf.random_uniform(shape=[self.opts['batch_size'], 1, 1, 1], minval=0., maxval=1.)
        differences = generated - real
        interpolates = real + (alpha * differences)
        gradients = \
            tf.gradients(discriminator(self.opts, interpolates, is_training=self.is_training)[1], [interpolates])[0]
        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2, 3]))
        gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
        return gradient_penalty

    def reconstruction_loss(self, opts, real, reconstr):
        if opts['cost'] == 'l2':
            # c(x,y) = ||x - y||_2
            loss = tf.reduce_sum(tf.square(real - reconstr), axis=[1, 2, 3])
            loss = 0.2 * tf.reduce_mean(tf.sqrt(1e-08 + loss))
        elif opts['cost'] == 'l2sq':
            # c(x,y) = ||x - y||_2^2
            loss = tf.reduce_sum(tf.square(real - reconstr), axis=[1, 2, 3])
            loss = 0.05 * tf.reduce_mean(loss)
        elif opts['cost'] == 'l1':
            # c(x,y) = ||x - y||_1
            loss = tf.reduce_sum(tf.abs(real - reconstr), axis=[1, 2, 3])
            loss = 0.02 * tf.reduce_mean(loss)
        else:
            assert False, 'Unknown cost function %s' % opts['cost']
        return loss

    def optimizer(self, lr, decay=1.):
        opts = self.opts
        lr *= decay
        if opts["optimizer"] == "sgd":
            return tf.train.GradientDescentOptimizer(lr)
        elif opts["optimizer"] == "adam":
            return tf.train.AdamOptimizer(lr, beta1=opts["adam_beta1"])
        else:
            assert False, 'Unknown optimizer.'

    def sample_pz(self, num=100, z_dist=None, labels=None, label_nums=None):
        opts = self.opts
        if z_dist is None:
            mean = np.zeros(opts["zdim"])
            cov = np.identity(opts["zdim"])
            noise = np.random.multivariate_normal(
                mean, cov, num).astype(np.float32)
            return noise
        assert labels is not None or label_nums is not None
        means, covariances = z_dist
        if labels is not None:
            return np.array([np.random.multivariate_normal(means[e], covariances[e]) for e in labels])
        noises = []
        for i, cnt in enumerate(label_nums):
            if cnt > 0:
                noises.append(np.random.multivariate_normal(means[i], covariances[i], cnt))
        return np.concatenate(noises, axis=0)

    def pretrain_encoder(self, data):
        opts = self.opts
        steps_max = 200
        batch_size = opts['e_pretrain_sample_size']
        for step in range(steps_max):
            train_size = data.num_points
            data_ids = np.random.choice(train_size, min(train_size, batch_size),
                                        replace=False)
            batch_images = data.data[data_ids].astype(np.float)
            batch_noise = self.sample_pz(batch_size)

            [_, loss_pretrain] = self.sess.run([self.pretrain_opt, self.loss_pretrain],
                                               feed_dict={self.sample_points: batch_images,
                                                          self.sample_noise: batch_noise, self.is_training: True})

            if opts['verbose']: logging.error('Step %d/%d, loss=%f' % (step, steps_max, loss_pretrain))
            if loss_pretrain < 0.1:
                break

    def adjust_decay(self, epoch, decay):
        opts = self.opts
        if opts['lr_schedule'] == "none":
            pass
        elif opts['lr_schedule'] == "manual":
            if epoch == 30:
                decay = decay / 2.
            if epoch == 50:
                decay = decay / 5.
            if epoch == 100:
                decay = decay / 10.
        elif opts['lr_schedule'] == "manual_smooth":
            enum = opts['epoch_num']
            decay_t = np.exp(np.log(100.) / enum)
            decay = decay / decay_t
        elif opts['lr_schedule'] != "plateau":
            assert type(opts['lr_schedule']) == float
            decay = 1.0 * 10 ** (-epoch / float(opts['lr_schedule']))
        return decay

    def train(self, data):
        self.set_class_ratios(data)
        opts = self.opts
        if opts['verbose']:
            logging.error(opts)
        rec_losses, match_losses, encoder_losses, decoder_losses, disc_losses = [], [], [], [], []
        z_dist = None
        batches_num = math.ceil(data.num_points / opts['batch_size'])
        self.sess.run(self.init)

        if opts['e_pretrain']:
            logging.error('Pretraining the encoder')
            self.pretrain_encoder(data)
            logging.error('Pretraining the encoder done.')

        start_time = time.time()
        decay = 1.
        for epoch in range(opts["epoch_num"]):
            print('Epoch %d:' % epoch)
            # Update learning rate if necessary
            decay = self.adjust_decay(epoch, decay)
            # Iterate over batches
            encoder_loss_total = 0.
            decoder_loss_total = 0.
            disc_loss_total = 0.
            correct_total = 0
            update_z_dist_flag = ((epoch + 1) % 5 == 0)
            z_list, z_new_list, y_new_list = [], [], []
            for it in tqdm(range(batches_num)):
                start_idx = it * opts['batch_size']
                end_idx = (it + 1) * opts['batch_size']
                batch_images = data.data[start_idx:end_idx].astype(np.float32)
                batch_labels = data.labels[start_idx:end_idx].astype(np.int32)
                if z_dist is None:
                    batch_noise = self.sample_pz(len(batch_labels))
                    feed_d = {self.sample_points: batch_images, self.sample_noise: batch_noise,
                              self.labels: batch_labels, self.lr_decay: decay, self.is_training: True}
                    [_, z] = self.sess.run([self.ae_opt, self.encoded], feed_dict=feed_d)

                else:
                    batch_labels_new = self.biased_sample_labels(len(batch_labels))
                    batch_z = self.sample_pz(len(batch_labels_new), z_dist, batch_labels_new)
                    feed_d = {self.sample_points: batch_images, self.sample_noise: batch_z, self.labels: batch_labels,
                              self.fixed_sample_labels: batch_labels_new, self.lr_decay: decay, self.is_training: True}

                    [_, encoder_loss, rec_loss, match_loss, z, correct] = self.sess.run(
                        [self.encoder_opt,
                         self.encoder_objective,
                         self.loss_reconstruct,
                         self.penalty,
                         self.encoded, self.correct_sum
                         ],
                        feed_dict=feed_d)
                    [_, decoder_loss] = self.sess.run(
                        [self.decoder_opt,
                         self.decoder_objective,
                         ],
                        feed_dict=feed_d)
                    [_, disc_loss, mask] = self.sess.run(
                        [self.disc_opt,
                         self.disc_objective,
                         self.predict_as_real_mask
                         ],
                        feed_dict=feed_d)

                    rec_losses.append(rec_loss)
                    match_losses.append(match_loss)
                    encoder_losses.append(encoder_loss)
                    decoder_losses.append(decoder_loss)
                    disc_losses.append(disc_loss)
                    correct_total += correct
                    encoder_loss_total += encoder_loss
                    decoder_loss_total += decoder_loss
                    disc_loss_total += disc_loss

                    if np.any(mask):
                        z_new_list.append(batch_z[mask])
                        y_new_list.append(batch_labels_new[mask])
                z_list.append(z)

            training_acc = correct_total / data.num_points
            avg_encoder_loss = encoder_loss_total / batches_num
            avg_decoder_loss = decoder_loss_total / batches_num
            avg_disc_loss = disc_loss_total / batches_num
            self.result_logger.add_training_metrics(avg_encoder_loss, avg_decoder_loss, avg_disc_loss,
                                                    training_acc, time.time() - start_time)
            z = np.concatenate(z_list, axis=0)
            self.result_logger.save_latent_code(epoch, z, data.labels)
            print('Evaluating...')
            self.evaluate(data, epoch)
            if update_z_dist_flag:
                if len(z_new_list) > 0:
                    z_new = np.concatenate(z_new_list, axis=0)
                    y_new = np.concatenate(y_new_list, axis=0)
                    print('Length of z_gen: %d' % len(z_new))
                else:
                    z_new = None
                    y_new = None
                z_dist = means, covariances = self.get_z_dist(z, data.labels, z_new, y_new)
                np.save(opts['work_dir'] + os.sep + 'means_epoch%02d.npy' % epoch, means)
                np.save(opts['work_dir'] + os.sep + 'covariances_epoch%02d.npy' % epoch, covariances)
                self.viz_img_from_z_dist(z_dist, epoch)

            if (epoch + 1) % opts['save_every_epoch'] == 0:
                print('Saving checkpoint...')
                self.saver.save(self.sess, os.path.join(opts['work_dir'], 'checkpoints', 'trained-caleg'),
                                global_step=epoch)

    def viz_img_from_z_dist(self, z_dist, epoch):
        n_classes = self.opts['n_classes']
        y = np.repeat(np.arange(0, n_classes).reshape(n_classes, 1), 10)
        z = self.sample_pz(len(y), z_dist, y)

        sample_gen = self.sess.run(
            self.decoded,
            feed_dict={self.sample_noise: z,
                       self.is_training: False})
        sample_gen = sample_gen.reshape(self.opts['n_classes'],
                                        sample_gen.shape[0] // self.opts['n_classes'],
                                        sample_gen.shape[-3],
                                        sample_gen.shape[-2],
                                        sample_gen.shape[-1])
        utils.save_image_array(sample_gen, self.opts['work_dir'] + os.sep + "img_from_z_dist_{}.png".format(epoch))

    def get_z_dist(self, z, y, z_new=None, y_new=None):
        opts = self.opts
        print("Computing means and covariances...")
        if z_new is None:
            assert y_new is None
        covariances = []
        means = []
        for c in tqdm(range(opts['n_classes'])):
            z_c = np.concatenate([z[y == c], z_new[y_new == c]], axis=0) if z_new is not None else z[y == c]
            covariances.append(np.cov(np.transpose(z_c)))
            means.append(np.mean(z_c, axis=0))
        covariances = np.array(covariances)
        means = np.array(means)

        return means, covariances

    def evaluate(self, data, epoch):
        batch_size = self.opts['batch_size']
        batches_num = int(math.ceil(len(data.test_data) / batch_size))

        probs = []

        start_time = time.time()
        for it in range(batches_num):
            start_idx = it * batch_size
            end_idx = start_idx + batch_size
            prob = self.sess.run(
                self.probs1,
                feed_dict={self.sample_points: data.test_data[start_idx:end_idx], self.is_training: False})
            probs.append(prob)
        probs = np.concatenate(probs, axis=0)
        predicts = np.argmax(probs, axis=-1)
        assert probs.shape[1] == self.opts['n_classes']
        self.result_logger.save_prediction(epoch, data.test_labels, predicts, probs, time.time() - start_time)
        self.result_logger.save_metrics()

    def set_class_ratios(self, data):
        self.gratio_mode = self.opts['gratio_mode']
        self.dratio_mode = self.opts['dratio_mode']
        class_count = [np.count_nonzero(data.labels == n) for n in range(self.opts['n_classes'])]
        class_cnt = np.array(class_count)
        max_class_cnt = np.max(class_cnt)
        total_aug_nums = (max_class_cnt - class_cnt)
        self.aug_class_rate = total_aug_nums / np.sum(total_aug_nums)

        self.class_aratio = [per_count / sum(class_count) for per_count in class_count]

        n_classes = self.opts['n_classes']
        self.class_dratio = np.full(n_classes, 0.0)
        # Set uniform
        target = 1 / n_classes
        self.class_uratio = np.full(n_classes, target)
        # Set gratio
        self.class_gratio = np.full(n_classes, 0.0)
        for c in range(n_classes):
            if self.gratio_mode == "uniform":
                self.class_gratio[c] = target
            elif self.gratio_mode == "rebalance":
                self.class_gratio[c] = 2 * target - self.class_aratio[c]
            else:
                print("Error while training bgan, unknown gmode " + self.gratio_mode)
                exit()

        # Set dratio
        self.class_dratio = np.full(n_classes, 0.0)
        for c in range(n_classes):
            if self.dratio_mode == "uniform":
                self.class_dratio[c] = target
            elif self.dratio_mode == "rebalance":
                self.class_dratio[c] = 2 * target - self.class_aratio[c]
            else:
                print("Error while training bgan, unknown dmode " + self.dratio_mode)
                exit()

        # if very unbalanced, the gratio might be negative for some classes.
        # In this case, we adjust..
        if self.gratio_mode == "rebalance":
            self.class_gratio[self.class_gratio < 0] = 0
            self.class_gratio = self.class_gratio / sum(self.class_gratio)

        # if very unbalanced, the dratio might be negative for some classes.
        # In this case, we adjust..
        if self.dratio_mode == "rebalance":
            self.class_dratio[self.class_dratio < 0] = 0
            self.class_dratio = self.class_dratio / sum(self.class_dratio)

    def biased_sample_labels_old(self, num_samples, target_distribution="d"):
        distribution = self.class_uratio
        if target_distribution == "d":
            distribution = self.class_dratio
        elif target_distribution == "g":
            distribution = self.class_gratio

        sampled_labels = np.zeros(num_samples, dtype=np.int64)
        sampled_labels_p = np.random.uniform(0, 1, num_samples)
        for i in range(self.opts['n_classes']):
            mask = np.logical_and((sampled_labels_p > 0), (sampled_labels_p <= distribution[i]))
            sampled_labels[mask] = i
            sampled_labels_p = sampled_labels_p - distribution[i]

        return sampled_labels

    def biased_sample_labels(self, num_samples):
        num_samples = math.ceil(num_samples * self.opts['aug_rate'])
        if num_samples == 0:
            return np.full([1], self.opts['n_classes'] - 1, dtype=np.int32)
        aug_num = np.round(self.aug_class_rate * num_samples).astype(np.int32)
        sampled_labels = np.zeros(np.sum(aug_num), dtype=np.int32)
        start = 0
        for i in range(self.opts['n_classes']):
            end = start + aug_num[i]
            sampled_labels[start:end] = i
            start = end
        return sampled_labels

    def load_ckpt(self, epoch):
        self.saver.restore(self.sess, os.path.join(self.opts['work_dir'], 'checkpoints', 'trained-caleg-%d' % epoch))