Пример #1
0
    def test_epoch(self):
        self.logger.warn("Testing evaluation...")
        scores_im1 = []
        scores_im2 = []
        scores_z1 = []
        scores_z2 = []
        inference_time = []
        true_labels = []
        summaries = []
        # Create the scores
        test_loop = tqdm(range(self.config.data_loader.num_iter_per_test))
        cur_epoch = self.model.cur_epoch_tensor.eval(self.sess)
        for _ in test_loop:
            test_batch_begin = time()
            test_batch, test_labels = self.sess.run(
                [self.data.test_image, self.data.test_label])
            test_loop.refresh()  # to show immediately the update
            sleep(0.01)
            noise = np.random.normal(
                loc=0.0,
                scale=1.0,
                size=[self.config.data_loader.test_batch, self.noise_dim])
            feed_dict = {
                self.model.image_input: test_batch,
                self.model.noise_tensor: noise,
                self.model.is_training: False,
            }
            scores_im1 += self.sess.run(self.model.img_score_l1,
                                        feed_dict=feed_dict).tolist()
            scores_im2 += self.sess.run(self.model.img_score_l2,
                                        feed_dict=feed_dict).tolist()
            scores_z1 += self.sess.run(self.model.z_score_l1,
                                       feed_dict=feed_dict).tolist()
            scores_z2 += self.sess.run(self.model.z_score_l2,
                                       feed_dict=feed_dict).tolist()
            summaries += self.sess.run([self.model.sum_op_im_test],
                                       feed_dict=feed_dict)
            inference_time.append(time() - test_batch_begin)
            true_labels += test_labels.tolist()
        self.summarizer.add_tensorboard(step=cur_epoch,
                                        summaries=summaries,
                                        summarizer="test")
        scores_im1 = np.asarray(scores_im1)
        scores_im2 = np.asarray(scores_im2)
        scores_z1 = np.asarray(scores_z1)
        scores_z2 = np.asarray(scores_z2)

        true_labels = np.asarray(true_labels)
        inference_time = np.mean(inference_time)
        self.logger.info(
            "Testing: Mean inference time is {:4f}".format(inference_time))
        step = self.sess.run(self.model.global_step_tensor)
        percentiles = np.asarray(self.config.trainer.percentiles)
        save_results(
            self.config.log.result_dir,
            scores_im1,
            true_labels,
            self.config.model.name,
            self.config.data_loader.dataset_name,
            "im1",
            "paper",
            self.config.trainer.label,
            self.config.data_loader.random_seed,
            self.logger,
            step,
            percentile=percentiles,
        )
        save_results(
            self.config.log.result_dir,
            scores_im2,
            true_labels,
            self.config.model.name,
            self.config.data_loader.dataset_name,
            "im2",
            "paper",
            self.config.trainer.label,
            self.config.data_loader.random_seed,
            self.logger,
            step,
            percentile=percentiles,
        )
        save_results(
            self.config.log.result_dir,
            scores_z1,
            true_labels,
            self.config.model.name,
            self.config.data_loader.dataset_name,
            "z1",
            "paper",
            self.config.trainer.label,
            self.config.data_loader.random_seed,
            self.logger,
            step,
            percentile=percentiles,
        )
        save_results(
            self.config.log.result_dir,
            scores_z2,
            true_labels,
            self.config.model.name,
            self.config.data_loader.dataset_name,
            "z2",
            "paper",
            self.config.trainer.label,
            self.config.data_loader.random_seed,
            self.logger,
            step,
            percentile=percentiles,
        )
Пример #2
0
def train_and_test(dataset, nb_epochs, weight, degree, random_seed, label,
                   enable_sm, score_method, enable_early_stop):
    """ Runs the AnoGAN on the specified dataset

    Note:
        Saves summaries on tensorboard. To display them, please use cmd line
        tensorboard --logdir=model.training_logdir() --port=number
    Args:
        dataset (str): name of the dataset
        nb_epochs (int): number of epochs
        weight (float): weight in the inverting loss function
        degree (int): degree of the norm in the feature matching
        random_seed (int): trying different seeds for averaging the results
        label (int): label which is normal for image experiments
        enable_sm (bool): allow TF summaries for monitoring the training
        score_method (str): which metric to use for the ablation study
        enable_early_stop (bool): allow early stopping for determining the number of epochs
    """

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    logger = logging.getLogger("AnoGAN.run.{}.{}".format(dataset, label))

    # Import model and data
    network = importlib.import_module('anogan.{}_utilities'.format(dataset))
    data = importlib.import_module("data.{}".format(dataset))

    # Parameters
    starting_lr = network.learning_rate
    batch_size = network.batch_size
    latent_dim = network.latent_dim
    ema_decay = 0.999

    # Data
    logger.info('Data loading...')
    trainx, trainy = data.get_train(label)
    if enable_early_stop:
        validx, validy = data.get_valid(label)
        nr_batches_valid = int(validx.shape[0] / batch_size)
    trainx_copy = trainx.copy()
    testx, testy = data.get_test(label)

    global_step = tf.Variable(0, name='global_step', trainable=False)
    # Placeholders
    x_pl = tf.placeholder(tf.float32,
                          shape=(None, trainx.shape[1]),
                          name="input")
    is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl')
    learning_rate = tf.placeholder(tf.float32, shape=(), name="lr_pl")

    rng = np.random.RandomState(random_seed)
    nr_batches_train = int(trainx.shape[0] / batch_size)
    nr_batches_test = int(testx.shape[0] / batch_size)

    logger.info('Building graph...')
    logger.warn("The GAN is training with the following parameters:")
    display_parameters(batch_size, starting_lr, ema_decay, weight, degree,
                       label, trainx.shape[1])

    gen = network.generator
    dis = network.discriminator

    # Sample noise from random normal distribution
    random_z = tf.random_normal([batch_size, latent_dim],
                                mean=0.0,
                                stddev=1.0,
                                name='random_z')
    # Generate images with generator
    x_gen = gen(x_pl, random_z, is_training=is_training_pl)

    real_d, inter_layer_real = dis(x_pl, is_training=is_training_pl)
    fake_d, inter_layer_fake = dis(x_gen,
                                   is_training=is_training_pl,
                                   reuse=True)

    with tf.name_scope('loss_functions'):
        # Calculate seperate losses for discriminator with real and fake images
        real_discriminator_loss = tf.losses.sigmoid_cross_entropy(
            tf.ones_like(real_d), real_d, scope='real_discriminator_loss')
        fake_discriminator_loss = tf.losses.sigmoid_cross_entropy(
            tf.zeros_like(fake_d), fake_d, scope='fake_discriminator_loss')
        # Add discriminator losses
        loss_discriminator = real_discriminator_loss + fake_discriminator_loss
        # Calculate loss for generator by flipping label on discriminator output
        loss_generator = tf.losses.sigmoid_cross_entropy(
            tf.ones_like(fake_d), fake_d, scope='generator_loss')

    with tf.name_scope('optimizers'):
        # control op dependencies for batch norm and trainable variables
        dvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                  scope='discriminator')
        gvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                  scope='generator')

        update_ops_gen = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           scope='generator')
        update_ops_dis = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           scope='discriminator')

        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                           beta1=0.5)

        with tf.control_dependencies(update_ops_gen):
            gen_op = optimizer.minimize(loss_generator,
                                        var_list=gvars,
                                        global_step=global_step)

        with tf.control_dependencies(update_ops_dis):
            dis_op = optimizer.minimize(loss_discriminator, var_list=dvars)

        # Exponential Moving Average for inference
        def train_op_with_ema_dependency(vars, op):
            ema = tf.train.ExponentialMovingAverage(decay=ema_decay)
            maintain_averages_op = ema.apply(vars)
            with tf.control_dependencies([op]):
                train_op = tf.group(maintain_averages_op)
            return train_op, ema

        train_gen_op, gen_ema = train_op_with_ema_dependency(gvars, gen_op)
        train_dis_op, dis_ema = train_op_with_ema_dependency(dvars, dis_op)

    ### Testing ###
    with tf.variable_scope("latent_variable"):
        z_optim = tf.get_variable(
            name='z_optim',
            shape=[batch_size, latent_dim],
            initializer=tf.truncated_normal_initializer())
        reinit_z = z_optim.initializer

    # EMA
    x_gen_ema = gen(x_pl,
                    random_z,
                    is_training=is_training_pl,
                    getter=get_getter(gen_ema),
                    reuse=True)
    rec_x_ema = gen(x_pl,
                    z_optim,
                    is_training=is_training_pl,
                    getter=get_getter(gen_ema),
                    reuse=True)
    # Pass real and fake images into discriminator separately
    real_d_ema, inter_layer_real_ema = dis(x_pl,
                                           is_training=is_training_pl,
                                           getter=get_getter(gen_ema),
                                           reuse=True)
    fake_d_ema, inter_layer_fake_ema = dis(rec_x_ema,
                                           is_training=is_training_pl,
                                           getter=get_getter(gen_ema),
                                           reuse=True)

    with tf.name_scope('Testing'):
        with tf.variable_scope('Reconstruction_loss'):
            delta = x_pl - rec_x_ema
            delta_flat = tf.contrib.layers.flatten(delta)
            reconstruction_score = tf.norm(delta_flat,
                                           ord=degree,
                                           axis=1,
                                           keep_dims=False,
                                           name='epsilon')

        with tf.variable_scope('Discriminator_scores'):

            if score_method == 'cross-e':
                dis_score = tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.ones_like(fake_d_ema), logits=fake_d_ema)

            else:
                fm = inter_layer_real_ema - inter_layer_fake_ema
                fm = tf.contrib.layers.flatten(fm)
                dis_score = tf.norm(fm,
                                    ord=degree,
                                    axis=1,
                                    keep_dims=False,
                                    name='d_loss')

            dis_score = tf.squeeze(dis_score)

        with tf.variable_scope('Score'):
            loss_invert = weight * reconstruction_score \
                                  + (1 - weight) * dis_score

    rec_error_valid = tf.reduce_mean(loss_invert)

    with tf.variable_scope("Test_learning_rate"):
        step_lr = tf.Variable(0, trainable=False)
        learning_rate_invert = 0.001
        reinit_lr = tf.variables_initializer(
            tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                              scope="Test_learning_rate"))

    with tf.name_scope('Test_optimizer'):
        invert_op = tf.train.AdamOptimizer(learning_rate_invert).\
            minimize(loss_invert,global_step=step_lr, var_list=[z_optim],
                     name='optimizer')
        reinit_optim = tf.variables_initializer(
            tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                              scope='Test_optimizer'))

    reinit_test_graph_op = [reinit_z, reinit_lr, reinit_optim]

    with tf.name_scope("Scores"):
        list_scores = loss_invert

    if enable_sm:
        with tf.name_scope('training_summary'):
            with tf.name_scope('dis_summary'):
                tf.summary.scalar('real_discriminator_loss',
                                  real_discriminator_loss, ['dis'])
                tf.summary.scalar('fake_discriminator_loss',
                                  fake_discriminator_loss, ['dis'])
                tf.summary.scalar('discriminator_loss', loss_discriminator,
                                  ['dis'])

            with tf.name_scope('gen_summary'):
                tf.summary.scalar('loss_generator', loss_generator, ['gen'])

            with tf.name_scope('img_summary'):
                heatmap_pl_latent = tf.placeholder(tf.float32,
                                                   shape=(1, 480, 640, 3),
                                                   name="heatmap_pl_latent")
                sum_op_latent = tf.summary.image('heatmap_latent',
                                                 heatmap_pl_latent)

            with tf.name_scope('validation_summary'):
                tf.summary.scalar('valid', rec_error_valid, ['v'])

            if dataset in IMAGES_DATASETS:
                with tf.name_scope('image_summary'):
                    tf.summary.image('reconstruct', x_gen, 8, ['image'])
                    tf.summary.image('input_images', x_pl, 8, ['image'])

            else:
                heatmap_pl_rec = tf.placeholder(tf.float32,
                                                shape=(1, 480, 640, 3),
                                                name="heatmap_pl_rec")
                with tf.name_scope('image_summary'):
                    tf.summary.image('heatmap_rec', heatmap_pl_rec, 1,
                                     ['image'])

            sum_op_dis = tf.summary.merge_all('dis')
            sum_op_gen = tf.summary.merge_all('gen')
            sum_op = tf.summary.merge([sum_op_dis, sum_op_gen])
            sum_op_im = tf.summary.merge_all('image')
            sum_op_valid = tf.summary.merge_all('v')

    logdir = create_logdir(dataset, weight, label, random_seed)

    sv = tf.train.Supervisor(logdir=logdir,
                             save_summaries_secs=None,
                             save_model_secs=None)

    logger.info('Start training...')
    with sv.managed_session(config=config) as sess:

        logger.info('Initialization done')

        writer = tf.summary.FileWriter(logdir, sess.graph)

        train_batch = 0
        epoch = 0
        best_valid_loss = 0

        while not sv.should_stop() and epoch < nb_epochs:

            lr = starting_lr
            begin = time.time()

            trainx = trainx[rng.permutation(
                trainx.shape[0])]  # shuffling unl dataset
            trainx_copy = trainx_copy[rng.permutation(trainx.shape[0])]

            train_loss_dis, train_loss_gen = [0, 0]
            # training
            for t in range(nr_batches_train):
                display_progression_epoch(t, nr_batches_train)

                # construct randomly permuted minibatches
                ran_from = t * batch_size
                ran_to = (t + 1) * batch_size

                # train discriminator
                feed_dict = {
                    x_pl: trainx[ran_from:ran_to],
                    is_training_pl: True,
                    learning_rate: lr
                }
                _, ld, step = sess.run(
                    [train_dis_op, loss_discriminator, global_step],
                    feed_dict=feed_dict)
                train_loss_dis += ld

                # train generator
                feed_dict = {
                    x_pl: trainx_copy[ran_from:ran_to],
                    is_training_pl: True,
                    learning_rate: lr
                }
                _, lg = sess.run([train_gen_op, loss_generator],
                                 feed_dict=feed_dict)
                train_loss_gen += lg

                if enable_sm:
                    sm = sess.run(sum_op, feed_dict=feed_dict)
                    writer.add_summary(sm, step)

                    if t % FREQ_PRINT == 0:  # inspect reconstruction
                        # t = np.random.randint(0,400)
                        # ran_from = t
                        # ran_to = t + batch_size
                        # sm = sess.run(sum_op_im, feed_dict={x_pl: trainx[ran_from:ran_to],is_training_pl: False})
                        # writer.add_summary(sm, train_batch)

                        # data = sess.run(z_gen, feed_dict={
                        #     x_pl: trainx[ran_from:ran_to],
                        #     is_training_pl: False})
                        # data = np.expand_dims(heatmap(data), axis=0)
                        # sml = sess.run(sum_op_latent, feed_dict={
                        #     heatmap_pl_latent: data,
                        #     is_training_pl: False})
                        #
                        # writer.add_summary(sml, train_batch)

                        if dataset in IMAGES_DATASETS:
                            sm = sess.run(sum_op_im,
                                          feed_dict={
                                              x_pl: trainx[ran_from:ran_to],
                                              is_training_pl: False
                                          })
                            #
                            # else:
                            #     data = sess.run(z_gen, feed_dict={
                            #         x_pl: trainx[ran_from:ran_to],
                            #         z_pl: np.random.normal(
                            #             size=[batch_size, latent_dim]),
                            #         is_training_pl: False})
                            #     data = np.expand_dims(heatmap(data), axis=0)
                            #     sm = sess.run(sum_op_im, feed_dict={
                            #         heatmap_pl_rec: data,
                            #         is_training_pl: False})
                            writer.add_summary(sm, step)  #train_batch)
                train_batch += 1

            train_loss_gen /= nr_batches_train
            train_loss_dis /= nr_batches_train

            # logger.info('Epoch terminated')
            # print("Epoch %d | time = %ds | loss gen = %.4f | loss dis = %.4f "
            #       % (epoch, time.time() - begin, train_loss_gen, train_loss_dis))

            ##EARLY STOPPING
            if (epoch + 1) % FREQ_EV == 0 and enable_early_stop:
                logger.info('Validation...')

                inds = rng.permutation(validx.shape[0])
                validx = validx[inds]  # shuffling  dataset
                validy = validy[inds]  # shuffling  dataset

                valid_loss = 0

                # Create scores
                for t in range(nr_batches_valid):
                    # construct randomly permuted minibatches
                    display_progression_epoch(t, nr_batches_valid)
                    ran_from = t * batch_size
                    ran_to = (t + 1) * batch_size

                    feed_dict = {
                        x_pl: validx[ran_from:ran_to],
                        is_training_pl: False
                    }
                    for _ in range(STEPS_NUMBER):
                        _ = sess.run(invert_op, feed_dict=feed_dict)
                    vl = sess.run(rec_error_valid, feed_dict=feed_dict)
                    valid_loss += vl
                    sess.run(reinit_test_graph_op)

                valid_loss /= nr_batches_valid
                sess.run(reinit_test_graph_op)

                if enable_sm:
                    sm = sess.run(sum_op_valid, feed_dict=feed_dict)
                    writer.add_summary(sm, step)  # train_batch)

                logger.info('Validation: valid loss {:.4f}'.format(valid_loss))

                if valid_loss < best_valid_loss or epoch == FREQ_EV - 1:

                    best_valid_loss = valid_loss
                    logger.info(
                        "Best model - valid loss = {:.4f} - saving...".format(
                            best_valid_loss))
                    sv.saver.save(sess,
                                  logdir + '/model.ckpt',
                                  global_step=step)
                    nb_without_improvements = 0
                else:
                    nb_without_improvements += FREQ_EV

                if nb_without_improvements > PATIENCE:
                    sv.request_stop()
                    logger.warning(
                        "Early stopping at epoch {} with weights from epoch {}"
                        .format(epoch, epoch - nb_without_improvements))

            epoch += 1

        logger.warn('Testing evaluation...')
        step = sess.run(global_step)
        sv.saver.save(sess, logdir + '/model.ckpt', global_step=step)

        rect_x, rec_error, latent, scores = [], [], [], []
        inference_time = []

        # Create scores
        for t in range(nr_batches_test):
            # construct randomly permuted minibatches
            display_progression_epoch(t, nr_batches_test)
            ran_from = t * batch_size
            ran_to = (t + 1) * batch_size
            begin_val_batch = time.time()

            feed_dict = {x_pl: testx[ran_from:ran_to], is_training_pl: False}

            for _ in range(STEPS_NUMBER):
                _ = sess.run(invert_op, feed_dict=feed_dict)

            brect_x, brec_error, bscores, blatent = sess.run(
                [rec_x_ema, reconstruction_score, loss_invert, z_optim],
                feed_dict=feed_dict)
            rect_x.append(brect_x)
            rec_error.append(brec_error)
            scores.append(bscores)
            latent.append(blatent)
            sess.run(reinit_test_graph_op)

            inference_time.append(time.time() - begin_val_batch)

        logger.info('Testing : mean inference time is %.4f' %
                    (np.mean(inference_time)))

        if testx.shape[0] % batch_size != 0:
            batch, size = batch_fill(testx, batch_size)
            feed_dict = {x_pl: batch, is_training_pl: False}
            for _ in range(STEPS_NUMBER):
                _ = sess.run(invert_op, feed_dict=feed_dict)
            brect_x, brec_error, bscores, blatent = sess.run(
                [rec_x_ema, reconstruction_score, loss_invert, z_optim],
                feed_dict=feed_dict)
            rect_x.append(brect_x[:size])
            rec_error.append(brec_error[:size])
            scores.append(bscores[:size])
            latent.append(blatent[:size])
            sess.run(reinit_test_graph_op)

        rect_x = np.concatenate(rect_x, axis=0)
        rec_error = np.concatenate(rec_error, axis=0)
        scores = np.concatenate(scores, axis=0)
        latent = np.concatenate(latent, axis=0)
        save_results(scores, testy, 'anogan', dataset, score_method, weight,
                     label, random_seed)
Пример #3
0
    def test_epoch(self):
        self.logger.warn("Testing evaluation...")
        scores_im1 = []
        scores_im2 = []
        scores_comb_im = []
        scores_comb_z = []
        scores_final_1 = []
        scores_final_2 = []
        scores_final_3 = []
        summaries = []
        if self.config.trainer.enable_disc_xx:
            scores_final_4 = []
        if self.config.trainer.enable_disc_zz:
            scores_final_5 = []
            scores_final_6 = []
        inference_time = []
        true_labels = []
        # Create the scores
        test_loop = tqdm(range(self.config.data_loader.num_iter_per_test))
        cur_epoch = self.model.cur_epoch_tensor.eval(self.sess)
        factor_list = self.config.trainer.feature_match_weight_2
        for f in factor_list:
            scores_im1 = []
            scores_im2 = []
            scores_comb_im = []
            scores_comb_z = []
            scores_final_1 = []
            scores_final_2 = []
            scores_final_3 = []
            inference_time = []
            true_labels = []
            for _ in test_loop:
                test_batch_begin = time()
                test_batch, test_labels = self.sess.run([self.data.test_image, self.data.test_label])
                test_loop.refresh()  # to show immediately the update
                sleep(0.01)
                noise = np.random.normal(
                    loc=0.0, scale=1.0, size=[self.config.data_loader.test_batch, self.noise_dim]
                )
                feature_match1 = f
                feature_match2 = self.config.trainer.feature_match_weight
                feed_dict = {
                    self.model.image_input: test_batch,
                    self.model.noise_tensor: noise,
                    self.model.is_training_gen: False,
                    self.model.is_training_dis: False,
                    self.model.is_training_enc_g: False,
                    self.model.is_training_enc_r: False,
                    self.model.feature_match1 : feature_match1,
                    self.model.feature_match2 : feature_match2,
                }
                scores_im1 += self.sess.run(self.model.img_score_l1, feed_dict=feed_dict).tolist()
                scores_im2 += self.sess.run(self.model.img_score_l2, feed_dict=feed_dict).tolist()
                scores_comb_im += self.sess.run(self.model.score_comb_im, feed_dict=feed_dict).tolist()
                scores_comb_z += self.sess.run(self.model.score_comb_z, feed_dict=feed_dict).tolist()
                scores_final_1 += self.sess.run(self.model.final_score_1, feed_dict=feed_dict).tolist()
                scores_final_2 += self.sess.run(self.model.final_score_2, feed_dict=feed_dict).tolist()
                scores_final_3 += self.sess.run(self.model.final_score_3, feed_dict=feed_dict).tolist()
                summaries += self.sess.run([self.model.sum_op_im_test], feed_dict=feed_dict)
                if self.config.trainer.enable_disc_xx:
                    scores_final_4 += self.sess.run(
                        self.model.final_score_4, feed_dict=feed_dict
                    ).tolist()
                if self.config.trainer.enable_disc_zz:
                    # scores_final_5 += self.sess.run(
                    #     self.model.final_score_5, feed_dict=feed_dict
                    # ).tolist()
                    scores_final_6 += self.sess.run(
                        self.model.final_score_6, feed_dict=feed_dict
                    ).tolist()
                inference_time.append(time() - test_batch_begin)
                true_labels += test_labels.tolist()
            self.summarizer.add_tensorboard(step=cur_epoch, summaries=summaries, summarizer="test")
            scores_im1 = np.asarray(scores_im1)
            scores_im2 = np.asarray(scores_im2)
            scores_comb_im = np.asarray(scores_comb_im)
            scores_comb_z = np.asarray(scores_comb_z)
            scores_final_1 = np.asarray(scores_final_1)
            scores_final_2 = np.asarray(scores_final_2)
            scores_final_3 = np.asarray(scores_final_3)
            if self.config.trainer.enable_disc_xx:
                scores_final_4 = np.asarray(scores_final_4)
            if self.config.trainer.enable_disc_zz:
                # scores_final_5 = np.asarray(scores_final_5)
                scores_final_6 = np.asarray(scores_final_6)
            true_labels = np.asarray(true_labels)
            inference_time = np.mean(inference_time)
            self.logger.info("Testing: Mean inference time is {:4f}".format(inference_time))
            step = self.sess.run(self.model.global_step_tensor)
            percentiles = np.asarray(self.config.trainer.percentiles)
            postfix = "_2_{}".format(str(f))
            save_results(
                self.config.log.result_dir,
                scores_im1,
                true_labels,
                self.config.model.name,
                self.config.data_loader.dataset_name,
                "im1{}".format(postfix),
                "paper",
                self.config.trainer.label,
                self.config.data_loader.random_seed,
                self.logger,
                step,
                percentile=percentiles,
                postfix=postfix
            )
            save_results(
                self.config.log.result_dir,
                scores_im2,
                true_labels,
                self.config.model.name,
                self.config.data_loader.dataset_name,
                "im2{}".format(postfix),
                "paper",
                self.config.trainer.label,
                self.config.data_loader.random_seed,
                self.logger,
                step,
                percentile=percentiles,
                postfix=postfix
            )
            save_results(
                self.config.log.result_dir,
                scores_comb_im,
                true_labels,
                self.config.model.name,
                self.config.data_loader.dataset_name,
                "comb_im{}".format(postfix),
                "paper",
                self.config.trainer.label,
                self.config.data_loader.random_seed,
                self.logger,
                step,
                percentile=percentiles,
                postfix=postfix
            )
            save_results(
                self.config.log.result_dir,
                scores_comb_z,
                true_labels,
                self.config.model.name,
                self.config.data_loader.dataset_name,
                "comb_z{}".format(postfix),
                "paper",
                self.config.trainer.label,
                self.config.data_loader.random_seed,
                self.logger,
                step,
                percentile=percentiles,
                postfix=postfix
            )
            save_results(
                self.config.log.result_dir,
                scores_final_1,
                true_labels,
                self.config.model.name,
                self.config.data_loader.dataset_name,
                "final_1{}".format(postfix),
                "paper",
                self.config.trainer.label,
                self.config.data_loader.random_seed,
                self.logger,
                step,
                percentile=percentiles,
                postfix=postfix
            )
            save_results(
                self.config.log.result_dir,
                scores_final_2,
                true_labels,
                self.config.model.name,
                self.config.data_loader.dataset_name,
                "final_2{}".format(postfix),
                "paper",
                self.config.trainer.label,
                self.config.data_loader.random_seed,
                self.logger,
                step,
                percentile=percentiles,
                postfix=postfix
            )
            save_results(
                self.config.log.result_dir,
                scores_final_3,
                true_labels,
                self.config.model.name,
                self.config.data_loader.dataset_name,
                "final_3{}".format(postfix),
                "paper",
                self.config.trainer.label,
                self.config.data_loader.random_seed,
                self.logger,
                step,
                percentile=percentiles,
                postfix=postfix
            )
            if self.config.trainer.enable_disc_xx:

                save_results(
                    self.config.log.result_dir,
                    scores_final_4,
                    true_labels,
                    self.config.model.name,
                    self.config.data_loader.dataset_name,
                    "final_4",
                    "paper",
                    self.config.trainer.label,
                    self.config.data_loader.random_seed,
                    self.logger,
                    step,
                    percentile=percentiles,
                )
            if self.config.trainer.enable_disc_zz:
                # save_results(
                #     self.config.log.result_dir,
                #     scores_final_5,
                #     true_labels,
                #     self.config.model.name,
                #     self.config.data_loader.dataset_name,
                #     "final_5",
                #     "paper",
                #     self.config.trainer.label,
                #     self.config.data_loader.random_seed,
                #     self.logger,
                #     step,
                #     percentile=percentiles,
                # )
                save_results(
                    self.config.log.result_dir,
                    scores_final_6,
                    true_labels,
                    self.config.model.name,
                    self.config.data_loader.dataset_name,
                    "final_6",
                    "paper",
                    self.config.trainer.label,
                    self.config.data_loader.random_seed,
                    self.logger,
                    step,
                    percentile=percentiles,
                )
Пример #4
0
 def test_epoch(self):
     self.logger.warn("Testing evaluation...")
     scores_1 = []
     scores_2 = []
     inference_time = []
     true_labels = []
     summaries = []
     # Create the scores
     test_loop = tqdm(range(self.config.data_loader.num_iter_per_test))
     cur_epoch = self.model.cur_epoch_tensor.eval(self.sess)
     for _ in test_loop:
         test_batch_begin = time()
         test_batch, test_labels = self.sess.run(
             [self.data.test_image, self.data.test_label])
         test_loop.refresh()  # to show immediately the update
         sleep(0.01)
         noise = np.random.normal(
             loc=0.0,
             scale=1.0,
             size=[self.config.data_loader.test_batch, self.noise_dim])
         feed_dict = {
             self.model.image_input: test_batch,
             self.model.noise_tensor: noise,
             self.model.is_training: False,
         }
         scores_1 += self.sess.run(self.model.list_scores_1,
                                   feed_dict=feed_dict).tolist()
         scores_2 += self.sess.run(self.model.list_scores_2,
                                   feed_dict=feed_dict).tolist()
         summaries += self.sess.run([self.model.sum_op_im_test],
                                    feed_dict=feed_dict)
         inference_time.append(time() - test_batch_begin)
         true_labels += test_labels.tolist()
     # Since the higher anomaly score indicates the anomalous one, and we inverted the labels to show that
     # normal images are 0 meaning that contains no anomaly and anomalous images are 1 meaning that it contains
     # an anomalous region, we first scale the scores and then invert them to match the scores
     scores_1 = np.asarray(scores_1)
     scores_2 = np.asarray(scores_2)
     true_labels = np.asarray(true_labels)
     inference_time = np.mean(inference_time)
     self.summarizer.add_tensorboard(step=cur_epoch,
                                     summaries=summaries,
                                     summarizer="test")
     self.logger.info(
         "Testing: Mean inference time is {:4f}".format(inference_time))
     step = self.sess.run(self.model.global_step_tensor)
     percentiles = np.asarray(self.config.trainer.percentiles)
     save_results(
         self.config.log.result_dir,
         scores_1,
         true_labels,
         self.config.model.name,
         self.config.data_loader.dataset_name,
         "fm_1",
         "paper",
         self.config.trainer.label,
         self.config.data_loader.random_seed,
         self.logger,
         step,
         percentile=percentiles,
     )
     save_results(
         self.config.log.result_dir,
         scores_2,
         true_labels,
         self.config.model.name,
         self.config.data_loader.dataset_name,
         "fm_2",
         "paper",
         self.config.trainer.label,
         self.config.data_loader.random_seed,
         self.logger,
         step,
         percentile=percentiles,
     )
Пример #5
0
def train_and_test(nb_epochs, weight, method, degree, random_seed):
    """ Runs the Bigan on the KDD dataset

    Note:
        Saves summaries on tensorboard. To display them, please use cmd line
        tensorboard --logdir=model.training_logdir() --port=number
    Args:
        nb_epochs (int): number of epochs
        weight (float, optional): weight for the anomaly score composition
        method (str, optional): 'fm' for ``Feature Matching`` or "cross-e"
                                     for ``cross entropy``, "efm" etc.
        anomalous_label (int): int in range 0 to 10, is the class/digit
                                which is considered outlier
    """
    logger = logging.getLogger("BiGAN.train.kdd.{}".format(method))

    # Placeholders
    input_pl = tf.placeholder(tf.float32,
                              shape=data.get_shape_input(),
                              name="input")
    is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl')
    learning_rate = tf.placeholder(tf.float32, shape=(), name="lr_pl")

    # Data
    trainx, trainy = data.get_train()
    trainx_copy = trainx.copy()
    testx, testy = data.get_test()
    trainx_org, trainy_org = data.get_train_org()

    print('samples:', trainx.shape, testx.shape, trainx_org.shape)

    # Parameters
    starting_lr = network.learning_rate
    batch_size = network.batch_size
    latent_dim = network.latent_dim
    ema_decay = 0.9999

    rng = np.random.RandomState(RANDOM_SEED)
    nr_batches_train = int(trainx.shape[0] / batch_size)
    nr_batches_test = int(testx.shape[0] / batch_size)
    nr_batches_train_org = int(trainx_org.shape[0] / batch_size)

    logger.info('Building training graph...')

    logger.warn("The BiGAN is training with the following parameters:")
    display_parameters(batch_size, starting_lr, ema_decay, weight, method,
                       degree)

    gen = network.decoder
    enc = network.encoder
    dis = network.discriminator

    with tf.variable_scope('encoder_model'):
        z_gen = enc(input_pl, is_training=is_training_pl)

    with tf.variable_scope('generator_model'):
        z = tf.random_normal([batch_size, latent_dim])
        x_gen = gen(z, is_training=is_training_pl)

    with tf.variable_scope('discriminator_model'):
        l_encoder, inter_layer_inp = dis(z_gen,
                                         input_pl,
                                         is_training=is_training_pl)
        l_generator, inter_layer_rct = dis(z,
                                           x_gen,
                                           is_training=is_training_pl,
                                           reuse=True)

    with tf.name_scope('loss_functions'):
        # discriminator
        loss_dis_enc = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.ones_like(l_encoder), logits=l_encoder))
        loss_dis_gen = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.zeros_like(l_generator), logits=l_generator))
        loss_discriminator = loss_dis_gen + loss_dis_enc
        # generator
        loss_generator = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.ones_like(l_generator), logits=l_generator))
        # encoder
        loss_encoder = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.zeros_like(l_encoder), logits=l_encoder))

    with tf.name_scope('optimizers'):
        # control op dependencies for batch norm and trainable variables
        tvars = tf.trainable_variables()
        dvars = [var for var in tvars if 'discriminator_model' in var.name]
        gvars = [var for var in tvars if 'generator_model' in var.name]
        evars = [var for var in tvars if 'encoder_model' in var.name]

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        update_ops_gen = [
            x for x in update_ops if ('generator_model' in x.name)
        ]
        update_ops_enc = [x for x in update_ops if ('encoder_model' in x.name)]
        update_ops_dis = [
            x for x in update_ops if ('discriminator_model' in x.name)
        ]

        optimizer_dis = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                               beta1=0.5,
                                               name='dis_optimizer')
        optimizer_gen = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                               beta1=0.5,
                                               name='gen_optimizer')
        optimizer_enc = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                               beta1=0.5,
                                               name='enc_optimizer')

        with tf.control_dependencies(update_ops_gen):
            gen_op = optimizer_gen.minimize(loss_generator, var_list=gvars)
        with tf.control_dependencies(update_ops_enc):
            enc_op = optimizer_enc.minimize(loss_encoder, var_list=evars)
        with tf.control_dependencies(update_ops_dis):
            dis_op = optimizer_dis.minimize(loss_discriminator, var_list=dvars)

        # Exponential Moving Average for estimation
        dis_ema = tf.train.ExponentialMovingAverage(decay=ema_decay)
        maintain_averages_op_dis = dis_ema.apply(dvars)

        with tf.control_dependencies([dis_op]):
            train_dis_op = tf.group(maintain_averages_op_dis)

        gen_ema = tf.train.ExponentialMovingAverage(decay=ema_decay)
        maintain_averages_op_gen = gen_ema.apply(gvars)

        with tf.control_dependencies([gen_op]):
            train_gen_op = tf.group(maintain_averages_op_gen)

        enc_ema = tf.train.ExponentialMovingAverage(decay=ema_decay)
        maintain_averages_op_enc = enc_ema.apply(evars)

        with tf.control_dependencies([enc_op]):
            train_enc_op = tf.group(maintain_averages_op_enc)

    with tf.name_scope('summary'):
        with tf.name_scope('dis_summary'):
            tf.summary.scalar('loss_discriminator', loss_discriminator,
                              ['dis'])
            tf.summary.scalar('loss_dis_encoder', loss_dis_enc, ['dis'])
            tf.summary.scalar('loss_dis_gen', loss_dis_gen, ['dis'])

        with tf.name_scope('gen_summary'):
            tf.summary.scalar('loss_generator', loss_generator, ['gen'])
            tf.summary.scalar('loss_encoder', loss_encoder, ['gen'])

        sum_op_dis = tf.summary.merge_all('dis')
        sum_op_gen = tf.summary.merge_all('gen')

    logger.info('Building testing graph...')

    with tf.variable_scope('encoder_model'):
        z_gen_ema = enc(input_pl,
                        is_training=is_training_pl,
                        getter=get_getter(enc_ema),
                        reuse=True)

    with tf.variable_scope('generator_model'):
        reconstruct_ema = gen(z_gen_ema,
                              is_training=is_training_pl,
                              getter=get_getter(gen_ema),
                              reuse=True)

    with tf.variable_scope('discriminator_model'):
        l_encoder_ema, inter_layer_inp_ema = dis(z_gen_ema,
                                                 input_pl,
                                                 is_training=is_training_pl,
                                                 getter=get_getter(dis_ema),
                                                 reuse=True)
        l_generator_ema, inter_layer_rct_ema = dis(z_gen_ema,
                                                   reconstruct_ema,
                                                   is_training=is_training_pl,
                                                   getter=get_getter(dis_ema),
                                                   reuse=True)
    with tf.name_scope('Testing'):
        with tf.variable_scope('Reconstruction_loss'):
            delta = input_pl - reconstruct_ema
            delta_flat = tf.contrib.layers.flatten(delta)
            gen_score = tf.norm(delta_flat,
                                ord=degree,
                                axis=1,
                                keep_dims=False,
                                name='epsilon')

        with tf.variable_scope('Discriminator_loss'):
            if method == "cross-e":
                dis_score = tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.ones_like(l_generator_ema),
                    logits=l_generator_ema)

            elif method == "fm":
                fm = inter_layer_inp_ema - inter_layer_rct_ema
                fm = tf.contrib.layers.flatten(fm)
                dis_score = tf.norm(fm,
                                    ord=degree,
                                    axis=1,
                                    keep_dims=False,
                                    name='d_loss')

            dis_score = tf.squeeze(dis_score)

        with tf.variable_scope('Score'):
            list_scores = (1 - weight) * gen_score + weight * dis_score

    logdir = create_logdir(weight, method, random_seed)

    sv = tf.train.Supervisor(logdir=logdir,
                             save_summaries_secs=None,
                             save_model_secs=120)

    logger.info('Start training...')
    with sv.managed_session() as sess:

        logger.info('Initialization done')
        writer = tf.summary.FileWriter(logdir, sess.graph)
        train_batch = 0
        epoch = 0

        while not sv.should_stop() and epoch < nb_epochs:

            lr = starting_lr
            begin = time.time()

            # construct randomly permuted minibatches
            trainx = trainx[rng.permutation(
                trainx.shape[0])]  # shuffling dataset
            trainx_copy = trainx_copy[rng.permutation(trainx.shape[0])]
            train_loss_dis, train_loss_gen, train_loss_enc = [0, 0, 0]

            # training
            for t in range(nr_batches_train):

                display_progression_epoch(t, nr_batches_train)
                ran_from = t * batch_size
                ran_to = (t + 1) * batch_size

                # train discriminator
                feed_dict = {
                    input_pl: trainx[ran_from:ran_to],
                    is_training_pl: True,
                    learning_rate: lr
                }

                _, ld, sm = sess.run(
                    [train_dis_op, loss_discriminator, sum_op_dis],
                    feed_dict=feed_dict)
                train_loss_dis += ld
                writer.add_summary(sm, train_batch)

                # train generator and encoder
                feed_dict = {
                    input_pl: trainx_copy[ran_from:ran_to],
                    is_training_pl: True,
                    learning_rate: lr
                }
                _, _, le, lg, sm = sess.run([
                    train_gen_op, train_enc_op, loss_encoder, loss_generator,
                    sum_op_gen
                ],
                                            feed_dict=feed_dict)
                train_loss_gen += lg
                train_loss_enc += le
                writer.add_summary(sm, train_batch)

                train_batch += 1

            train_loss_gen /= nr_batches_train
            train_loss_enc /= nr_batches_train
            train_loss_dis /= nr_batches_train

            logger.info('Epoch terminated')
            print(
                "Epoch %d | time = %ds | loss gen = %.4f | loss enc = %.4f | loss dis = %.4f "
                % (epoch, time.time() - begin, train_loss_gen, train_loss_enc,
                   train_loss_dis))

            epoch += 1

        logger.warn('Training evaluation...')

        inds = rng.permutation(trainx_org.shape[0])
        trainx_org = trainx_org[inds]  # shuffling  dataset
        trainy_org = trainy_org[inds]  # shuffling  dataset
        scores = []
        inference_time = []

        # Create scores
        for t in range(nr_batches_train_org):

            # construct randomly permuted minibatches
            ran_from = t * batch_size
            ran_to = (t + 1) * batch_size
            begin_val_batch = time.time()

            feed_dict = {
                input_pl: trainx_org[ran_from:ran_to],
                is_training_pl: False
            }

            scores += sess.run(list_scores, feed_dict=feed_dict).tolist()
            inference_time.append(time.time() - begin_val_batch)

        logger.info('Testing : mean inference time is %.4f' %
                    (np.mean(inference_time)))

        ran_from = nr_batches_train_org * batch_size
        ran_to = (nr_batches_train_org + 1) * batch_size
        size = trainx_org[ran_from:ran_to].shape[0]
        fill = np.ones([batch_size - size, 70])

        batch = np.concatenate([trainx_org[ran_from:ran_to], fill], axis=0)
        feed_dict = {input_pl: batch, is_training_pl: False}

        batch_score = sess.run(list_scores, feed_dict=feed_dict).tolist()

        scores += batch_score[:size]

        # Highest 80% are anomalous
        per = np.percentile(scores, 80)

        y_pred_org = scores.copy()
        y_pred_org = np.array(y_pred_org)

        inds = (y_pred_org < per)
        inds_comp = (y_pred_org >= per)

        y_pred_org[inds] = 0
        y_pred_org[inds_comp] = 1

        roc_auc = do_roc(scores, trainy_org, "train_", "Results/", True)
        prc_auc = do_prc(scores, trainy_org, "train_", "Results/", True)
        #prg_auc = do_prg(scores, trainy_org, "train_", "Results/", True)

        precision, recall, f1, _ = precision_recall_fscore_support(
            trainy_org, y_pred_org, average='binary')

        print("Testing : Prec = %.4f | Rec = %.4f | F1 = %.4f " %
              (precision, recall, f1))

        save_results(scores,
                     trainy_org,
                     'bigan',
                     'sia_train_air_4',
                     'fm',
                     '0.5',
                     '1',
                     2018,
                     'outlier',
                     0.1,
                     step=-1)

        logger.warn('Testing evaluation...')

        inds = rng.permutation(testx.shape[0])
        testx = testx[inds]  # shuffling  dataset
        testy = testy[inds]  # shuffling  dataset
        scores = []
        inference_time = []

        # Create scores
        for t in range(nr_batches_test):

            # construct randomly permuted minibatches
            ran_from = t * batch_size
            ran_to = (t + 1) * batch_size
            begin_val_batch = time.time()

            feed_dict = {
                input_pl: testx[ran_from:ran_to],
                is_training_pl: False
            }

            scores += sess.run(list_scores, feed_dict=feed_dict).tolist()
            inference_time.append(time.time() - begin_val_batch)

        logger.info('Testing : mean inference time is %.4f' %
                    (np.mean(inference_time)))

        ran_from = nr_batches_test * batch_size
        ran_to = (nr_batches_test + 1) * batch_size
        size = testx[ran_from:ran_to].shape[0]
        fill = np.ones([batch_size - size, 70])

        batch = np.concatenate([testx[ran_from:ran_to], fill], axis=0)
        feed_dict = {input_pl: batch, is_training_pl: False}

        batch_score = sess.run(list_scores, feed_dict=feed_dict).tolist()

        scores += batch_score[:size]

        # Highest 80% are anomalous
        per = np.percentile(scores, 80)

        y_pred = scores.copy()
        y_pred = np.array(y_pred)

        inds = (y_pred < per)
        inds_comp = (y_pred >= per)

        y_pred[inds] = 0
        y_pred[inds_comp] = 1

        roc_auc = do_roc(scores, testy, "test_", "Results/", True)
        prc_auc = do_prc(scores, testy, "test_", "Results/", True)
        #prg_auc = do_prg(scores, testy, "test_", "Results/", True)

        precision, recall, f1, _ = precision_recall_fscore_support(
            testy, y_pred, average='binary')

        print("Testing : Prec = %.4f | Rec = %.4f | F1 = %.4f " %
              (precision, recall, f1))

        # SIA metrics calculation
        TN_train, FN_train, FP_train, TP_train = my_confusion_matrix(
            trainy_org, y_pred_org)
        TN_test, FN_test, FP_test, TP_test = my_confusion_matrix(testy, y_pred)

        overall_train, average_train, sens_train, spec_train, ppr_train = derive_metric(
            TN_train, FN_train, FP_train, TP_train)
        overall_test, average_test, sens_test, spec_test, ppr_test = derive_metric(
            TN_test, FN_test, FP_test, TP_test)

        print('Train Metrics:', overall_train, average_train, sens_train,
              spec_train, ppr_train)
        print('Test Metrics:', overall_test, average_test, sens_test,
              spec_test, ppr_test)

        save_results(scores,
                     testy,
                     'bigan',
                     'sia_test_air_4',
                     'fm',
                     '0.5',
                     '1',
                     2018,
                     'outlier',
                     0.1,
                     step=-1)
Пример #6
0
def train_and_test(dataset, nb_epochs, degree, random_seed, label,
                   allow_zz, enable_sm, score_method,
                   enable_early_stop, do_spectral_norm):
    """ Runs the AliCE on the specified dataset

    Note:
        Saves summaries on tensorboard. To display them, please use cmd line
        tensorboard --logdir=model.training_logdir() --port=number
    Args:
        dataset (str): name of the dataset
        nb_epochs (int): number of epochs
        degree (int): degree of the norm in the feature matching
        random_seed (int): trying different seeds for averaging the results
        label (int): label which is normal for image experiments
        allow_zz (bool): allow the d_zz discriminator or not for ablation study
        enable_sm (bool): allow TF summaries for monitoring the training
        score_method (str): which metric to use for the ablation study
        enable_early_stop (bool): allow early stopping for determining the number of epochs
        do_spectral_norm (bool): allow spectral norm or not for ablation study
    """
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    logger = logging.getLogger("ALAD.run.{}.{}".format(
        dataset, label))

    # Import model and data
    network = importlib.import_module('alad.{}_utilities'.format(dataset))
    data = importlib.import_module("data.{}".format(dataset))

    # Parameters
    starting_lr = network.learning_rate
    batch_size = network.batch_size
    latent_dim = network.latent_dim
    ema_decay = 0.999

    global_step = tf.Variable(0, name='global_step', trainable=False)

    # Placeholders
    x_pl = tf.placeholder(tf.float32, shape=data.get_shape_input(),
                          name="input_x")
    z_pl = tf.placeholder(tf.float32, shape=[None, latent_dim],
                          name="input_z")
    is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl')
    learning_rate = tf.placeholder(tf.float32, shape=(), name="lr_pl")

    # Data
    # label 部分都已经转换为 0 normal数据  / 1 abnormal数据
    logger.info('Data loading...')
    trainx, trainy = data.get_train(label)
    if enable_early_stop: validx, validy = data.get_valid(label)
    trainx_copy = trainx.copy()
    testx, testy = data.get_test(label)

    rng = np.random.RandomState(random_seed)
    nr_batches_train = int(trainx.shape[0] / batch_size)
    nr_batches_test = int(testx.shape[0] / batch_size)

    logger.info('Building graph...')

    logger.warn("ALAD is training with the following parameters:")
    display_parameters(batch_size, starting_lr, ema_decay, degree, label,
                       allow_zz, score_method, do_spectral_norm)

    gen = network.decoder
    enc = network.encoder
    dis_xz = network.discriminator_xz
    dis_xx = network.discriminator_xx
    dis_zz = network.discriminator_zz

    with tf.variable_scope('encoder_model'):
        z_gen = enc(x_pl, is_training=is_training_pl,
                    do_spectral_norm=do_spectral_norm)

    with tf.variable_scope('generator_model'):
        x_gen = gen(z_pl, is_training=is_training_pl)
        rec_x = gen(z_gen, is_training=is_training_pl, reuse=True)

    with tf.variable_scope('encoder_model'):
        rec_z = enc(x_gen, is_training=is_training_pl, reuse=True,
                    do_spectral_norm=do_spectral_norm)

    with tf.variable_scope('discriminator_model_xz'):
        l_encoder, inter_layer_inp_xz = dis_xz(x_pl, z_gen,
                                            is_training=is_training_pl,
                    do_spectral_norm=do_spectral_norm)
        l_generator, inter_layer_rct_xz = dis_xz(x_gen, z_pl,
                                              is_training=is_training_pl,
                                              reuse=True,
                    do_spectral_norm=do_spectral_norm)

    with tf.variable_scope('discriminator_model_xx'):
        x_logit_real, inter_layer_inp_xx = dis_xx(x_pl, x_pl,
                                                  is_training=is_training_pl,
                    do_spectral_norm=do_spectral_norm)
        x_logit_fake, inter_layer_rct_xx = dis_xx(x_pl, rec_x, is_training=is_training_pl,
                              reuse=True, do_spectral_norm=do_spectral_norm)

    with tf.variable_scope('discriminator_model_zz'):
        z_logit_real, _ = dis_zz(z_pl, z_pl, is_training=is_training_pl,
                                 do_spectral_norm=do_spectral_norm)
        z_logit_fake, _ = dis_zz(z_pl, rec_z, is_training=is_training_pl,
                              reuse=True, do_spectral_norm=do_spectral_norm)

    with tf.name_scope('loss_functions'):

        # discriminator xz
        loss_dis_enc = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.ones_like(l_encoder),logits=l_encoder))
        loss_dis_gen = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.zeros_like(l_generator),logits=l_generator))
        dis_loss_xz = loss_dis_gen + loss_dis_enc

        # discriminator xx
        x_real_dis = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=x_logit_real, labels=tf.ones_like(x_logit_real))
        x_fake_dis = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=x_logit_fake, labels=tf.zeros_like(x_logit_fake))
        dis_loss_xx = tf.reduce_mean(x_real_dis + x_fake_dis)

        # discriminator zz
        z_real_dis = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=z_logit_real, labels=tf.ones_like(z_logit_real))
        z_fake_dis = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=z_logit_fake, labels=tf.zeros_like(z_logit_fake))
        dis_loss_zz = tf.reduce_mean(z_real_dis + z_fake_dis)

        loss_discriminator = dis_loss_xz + dis_loss_xx + dis_loss_zz if \
            allow_zz else dis_loss_xz + dis_loss_xx

        # generator and encoder
        gen_loss_xz = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.ones_like(l_generator),logits=l_generator))
        enc_loss_xz = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.zeros_like(l_encoder), logits=l_encoder))
        x_real_gen = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=x_logit_real, labels=tf.zeros_like(x_logit_real))
        x_fake_gen = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=x_logit_fake, labels=tf.ones_like(x_logit_fake))
        z_real_gen = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=z_logit_real, labels=tf.zeros_like(z_logit_real))
        z_fake_gen = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=z_logit_fake, labels=tf.ones_like(z_logit_fake))

        cost_x = tf.reduce_mean(x_real_gen + x_fake_gen)
        cost_z = tf.reduce_mean(z_real_gen + z_fake_gen)

        cycle_consistency_loss = cost_x + cost_z if allow_zz else cost_x
        loss_generator = gen_loss_xz + cycle_consistency_loss
        loss_encoder = enc_loss_xz + cycle_consistency_loss

    with tf.name_scope('optimizers'):

        # control op dependencies for batch norm and trainable variables
        tvars = tf.trainable_variables()
        dxzvars = [var for var in tvars if 'discriminator_model_xz' in var.name]
        dxxvars = [var for var in tvars if 'discriminator_model_xx' in var.name]
        dzzvars = [var for var in tvars if 'discriminator_model_zz' in var.name]
        gvars = [var for var in tvars if 'generator_model' in var.name]
        evars = [var for var in tvars if 'encoder_model' in var.name]

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        update_ops_gen = [x for x in update_ops if ('generator_model' in x.name)]
        update_ops_enc = [x for x in update_ops if ('encoder_model' in x.name)]
        update_ops_dis_xz = [x for x in update_ops if
                             ('discriminator_model_xz' in x.name)]
        update_ops_dis_xx = [x for x in update_ops if
                             ('discriminator_model_xx' in x.name)]
        update_ops_dis_zz = [x for x in update_ops if
                             ('discriminator_model_zz' in x.name)]

        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                                  beta1=0.5)

        with tf.control_dependencies(update_ops_gen):
            gen_op = optimizer.minimize(loss_generator, var_list=gvars,
                                            global_step=global_step)
        with tf.control_dependencies(update_ops_enc):
            enc_op = optimizer.minimize(loss_encoder, var_list=evars)

        with tf.control_dependencies(update_ops_dis_xz):
            dis_op_xz = optimizer.minimize(dis_loss_xz, var_list=dxzvars)

        with tf.control_dependencies(update_ops_dis_xx):
            dis_op_xx = optimizer.minimize(dis_loss_xx, var_list=dxxvars)

        with tf.control_dependencies(update_ops_dis_zz):
            dis_op_zz = optimizer.minimize(dis_loss_zz, var_list=dzzvars)

        # Exponential Moving Average for inference
        def train_op_with_ema_dependency(vars, op):
            ema = tf.train.ExponentialMovingAverage(decay=ema_decay)
            maintain_averages_op = ema.apply(vars)
            with tf.control_dependencies([op]):
                train_op = tf.group(maintain_averages_op)
            return train_op, ema

        train_gen_op, gen_ema = train_op_with_ema_dependency(gvars, gen_op)
        train_enc_op, enc_ema = train_op_with_ema_dependency(evars, enc_op)
        train_dis_op_xz, xz_ema = train_op_with_ema_dependency(dxzvars,
                                                               dis_op_xz)
        train_dis_op_xx, xx_ema = train_op_with_ema_dependency(dxxvars,
                                                               dis_op_xx)
        train_dis_op_zz, zz_ema = train_op_with_ema_dependency(dzzvars,
                                                               dis_op_zz)

    with tf.variable_scope('encoder_model'):
        z_gen_ema = enc(x_pl, is_training=is_training_pl,
                        getter=get_getter(enc_ema), reuse=True,
                        do_spectral_norm=do_spectral_norm)

    with tf.variable_scope('generator_model'):
        rec_x_ema = gen(z_gen_ema, is_training=is_training_pl,
                              getter=get_getter(gen_ema), reuse=True)
        x_gen_ema = gen(z_pl, is_training=is_training_pl,
                              getter=get_getter(gen_ema), reuse=True)

    with tf.variable_scope('discriminator_model_xx'):
        l_encoder_emaxx, inter_layer_inp_emaxx = dis_xx(x_pl, x_pl,
                                                    is_training=is_training_pl,
                                                    getter=get_getter(xx_ema),
                                                    reuse=True,
                    do_spectral_norm=do_spectral_norm)

        l_generator_emaxx, inter_layer_rct_emaxx = dis_xx(x_pl, rec_x_ema,
                                                      is_training=is_training_pl,
                                                      getter=get_getter(
                                                          xx_ema),
                                                      reuse=True,
                    do_spectral_norm=do_spectral_norm)

    with tf.name_scope('Testing'):

        with tf.variable_scope('Scores'):


            score_ch = tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.ones_like(l_generator_emaxx),
                    logits=l_generator_emaxx)
            score_ch = tf.squeeze(score_ch)

            rec = x_pl - rec_x_ema
            rec = tf.contrib.layers.flatten(rec)
            score_l1 = tf.norm(rec, ord=1, axis=1,
                            keep_dims=False, name='d_loss')
            score_l1 = tf.squeeze(score_l1)

            rec = x_pl - rec_x_ema
            rec = tf.contrib.layers.flatten(rec)
            score_l2 = tf.norm(rec, ord=2, axis=1,
                            keep_dims=False, name='d_loss')
            score_l2 = tf.squeeze(score_l2)

            inter_layer_inp, inter_layer_rct = inter_layer_inp_emaxx, \
                                               inter_layer_rct_emaxx
            fm = inter_layer_inp - inter_layer_rct
            fm = tf.contrib.layers.flatten(fm)
            score_fm = tf.norm(fm, ord=degree, axis=1,
                             keep_dims=False, name='d_loss')
            score_fm = tf.squeeze(score_fm)

    if enable_early_stop:
        rec_error_valid = tf.reduce_mean(score_fm)

    if enable_sm:

        with tf.name_scope('summary'):
            with tf.name_scope('dis_summary'):
                tf.summary.scalar('loss_discriminator', loss_discriminator, ['dis'])
                tf.summary.scalar('loss_dis_encoder', loss_dis_enc, ['dis'])
                tf.summary.scalar('loss_dis_gen', loss_dis_gen, ['dis'])
                tf.summary.scalar('loss_dis_xz', dis_loss_xz, ['dis'])
                tf.summary.scalar('loss_dis_xx', dis_loss_xx, ['dis'])
                if allow_zz:
                    tf.summary.scalar('loss_dis_zz', dis_loss_zz, ['dis'])

            with tf.name_scope('gen_summary'):
                tf.summary.scalar('loss_generator', loss_generator, ['gen'])
                tf.summary.scalar('loss_encoder', loss_encoder, ['gen'])
                tf.summary.scalar('loss_encgen_dxx', cost_x, ['gen'])
                if allow_zz:
                    tf.summary.scalar('loss_encgen_dzz', cost_z, ['gen'])

            if enable_early_stop:
                with tf.name_scope('validation_summary'):
                   tf.summary.scalar('valid', rec_error_valid, ['v'])

            with tf.name_scope('img_summary'):
                heatmap_pl_latent = tf.placeholder(tf.float32,
                                                   shape=(1, 480, 640, 3),
                                                   name="heatmap_pl_latent")
                sum_op_latent = tf.summary.image('heatmap_latent', heatmap_pl_latent)

            if dataset in IMAGES_DATASETS:
                with tf.name_scope('image_summary'):
                    tf.summary.image('reconstruct', rec_x, 8, ['image'])
                    tf.summary.image('input_images', x_pl, 8, ['image'])

            else:
                heatmap_pl_rec = tf.placeholder(tf.float32, shape=(1, 480, 640, 3),
                                            name="heatmap_pl_rec")
                with tf.name_scope('image_summary'):
                    tf.summary.image('heatmap_rec', heatmap_pl_rec, 1, ['image'])

            sum_op_dis = tf.summary.merge_all('dis')
            sum_op_gen = tf.summary.merge_all('gen')
            sum_op = tf.summary.merge([sum_op_dis, sum_op_gen])
            sum_op_im = tf.summary.merge_all('image')
            sum_op_valid = tf.summary.merge_all('v')

    logdir = create_logdir(dataset, label, random_seed, allow_zz, score_method,
                           do_spectral_norm)

    saver = tf.train.Saver(max_to_keep=2)
    save_model_secs = None if enable_early_stop else 20
    sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=None, saver=saver, save_model_secs=save_model_secs) 

    logger.info('Start training...')
    with sv.managed_session(config=config) as sess:

        step = sess.run(global_step)
        logger.info('Initialization done at step {}'.format(step/nr_batches_train))
        writer = tf.summary.FileWriter(logdir, sess.graph)
        train_batch = 0
        epoch = 0
        best_valid_loss = 0
        request_stop = False

        while not sv.should_stop() and epoch < nb_epochs:

            lr = starting_lr
            begin = time.time()

             # construct randomly permuted minibatches
            trainx = trainx[rng.permutation(trainx.shape[0])]  # shuffling dataset
            trainx_copy = trainx_copy[rng.permutation(trainx.shape[0])]
            train_loss_dis_xz, train_loss_dis_xx,  train_loss_dis_zz, \
            train_loss_dis, train_loss_gen, train_loss_enc = [0, 0, 0, 0, 0, 0]

            # Training
            for t in range(nr_batches_train):

                display_progression_epoch(t, nr_batches_train)
                ran_from = t * batch_size
                ran_to = (t + 1) * batch_size

                # train discriminator
                feed_dict = {x_pl: trainx[ran_from:ran_to],
                             z_pl: np.random.normal(size=[batch_size, latent_dim]),
                             is_training_pl: True,
                             learning_rate:lr}

                _, _, _, ld, ldxz, ldxx, ldzz, step = sess.run([train_dis_op_xz,
                                                              train_dis_op_xx,
                                                              train_dis_op_zz,
                                                              loss_discriminator,
                                                              dis_loss_xz,
                                                              dis_loss_xx,
                                                              dis_loss_zz,
                                                              global_step],
                                                             feed_dict=feed_dict)
                train_loss_dis += ld
                train_loss_dis_xz += ldxz
                train_loss_dis_xx += ldxx
                train_loss_dis_zz += ldzz

                # train generator and encoder
                feed_dict = {x_pl: trainx_copy[ran_from:ran_to],
                             z_pl: np.random.normal(size=[batch_size, latent_dim]),
                             is_training_pl: True,
                             learning_rate:lr}
                _,_, le, lg = sess.run([train_gen_op,
                                            train_enc_op,
                                            loss_encoder,
                                            loss_generator],
                                           feed_dict=feed_dict)
                train_loss_gen += lg
                train_loss_enc += le

                if enable_sm:
                    sm = sess.run(sum_op, feed_dict=feed_dict)
                    writer.add_summary(sm, step)

                    if t % FREQ_PRINT == 0 and dataset in IMAGES_DATASETS:  # inspect reconstruction
                        t = np.random.randint(0, trainx.shape[0]-batch_size)
                        ran_from = t
                        ran_to = t + batch_size
                        feed_dict = {x_pl: trainx[ran_from:ran_to],
                            z_pl: np.random.normal(
                                size=[batch_size, latent_dim]),
                            is_training_pl: False}
                        sm = sess.run(sum_op_im, feed_dict=feed_dict)
                        writer.add_summary(sm, step)#train_batch)

                train_batch += 1

            train_loss_gen /= nr_batches_train
            train_loss_enc /= nr_batches_train
            train_loss_dis /= nr_batches_train
            train_loss_dis_xz /= nr_batches_train
            train_loss_dis_xx /= nr_batches_train
            train_loss_dis_zz /= nr_batches_train

            logger.info('Epoch terminated')
            if allow_zz:
                print("Epoch %d | time = %ds | loss gen = %.4f | loss enc = %.4f | "
                      "loss dis = %.4f | loss dis xz = %.4f | loss dis xx = %.4f | "
                      "loss dis zz = %.4f"
                      % (epoch, time.time() - begin, train_loss_gen,
                         train_loss_enc, train_loss_dis, train_loss_dis_xz,
                         train_loss_dis_xx, train_loss_dis_zz))
            else:
                print("Epoch %d | time = %ds | loss gen = %.4f | loss enc = %.4f | "
                      "loss dis = %.4f | loss dis xz = %.4f | loss dis xx = %.4f | "
                      % (epoch, time.time() - begin, train_loss_gen,
                         train_loss_enc, train_loss_dis, train_loss_dis_xz,
                         train_loss_dis_xx))

            ##EARLY STOPPING
            if (epoch + 1) % FREQ_EV == 0 and enable_early_stop:

                valid_loss = 0
                feed_dict = {x_pl: validx,
                             z_pl: np.random.normal(size=[validx.shape[0], latent_dim]),
                             is_training_pl: False}
                vl, lat = sess.run([rec_error_valid, rec_z], feed_dict=feed_dict)
                valid_loss += vl

                if enable_sm:
                    sm = sess.run(sum_op_valid, feed_dict=feed_dict)
                    writer.add_summary(sm, step)  # train_batch)

                logger.info('Validation: valid loss {:.4f}'.format(valid_loss))

                if (valid_loss < best_valid_loss or epoch == FREQ_EV-1):
                    best_valid_loss = valid_loss
                    logger.info("Best model - valid loss = {:.4f} - saving...".format(best_valid_loss))
                    sv.saver.save(sess, logdir+'/model.ckpt', global_step=step)
                    nb_without_improvements = 0
                else:
                    nb_without_improvements += FREQ_EV

                if nb_without_improvements > PATIENCE:
                    sv.request_stop()
                    logger.warning(
                      "Early stopping at epoch {} with weights from epoch {}".format(
                          epoch, epoch - nb_without_improvements))

            epoch += 1

        sv.saver.save(sess, logdir+'/model.ckpt', global_step=step)

        logger.warn('Testing evaluation...')

        scores_ch = []
        scores_l1 = []
        scores_l2 = []
        scores_fm = []
        inference_time = []

        # Create scores
        for t in range(nr_batches_test):

            # construct randomly permuted minibatches
            ran_from = t * batch_size
            ran_to = (t + 1) * batch_size
            begin_test_time_batch = time.time()

            feed_dict = {x_pl: testx[ran_from:ran_to],
                         z_pl: np.random.normal(size=[batch_size, latent_dim]),
                         is_training_pl:False}

            scores_ch += sess.run(score_ch, feed_dict=feed_dict).tolist()
            scores_l1 += sess.run(score_l1, feed_dict=feed_dict).tolist()
            scores_l2 += sess.run(score_l2, feed_dict=feed_dict).tolist()
            scores_fm += sess.run(score_fm, feed_dict=feed_dict).tolist()
            inference_time.append(time.time() - begin_test_time_batch)


        inference_time = np.mean(inference_time)
        logger.info('Testing : mean inference time is %.4f' % (inference_time))

        if testx.shape[0] % batch_size != 0:

            batch, size = batch_fill(testx, batch_size)
            feed_dict = {x_pl: batch,
                         z_pl: np.random.normal(size=[batch_size, latent_dim]),
                         is_training_pl: False}

            bscores_ch = sess.run(score_ch,feed_dict=feed_dict).tolist()
            bscores_l1 = sess.run(score_l1,feed_dict=feed_dict).tolist()
            bscores_l2 = sess.run(score_l2,feed_dict=feed_dict).tolist()
            bscores_fm = sess.run(score_fm,feed_dict=feed_dict).tolist()


            scores_ch += bscores_ch[:size]
            scores_l1 += bscores_l1[:size]
            scores_l2 += bscores_l2[:size]
            scores_fm += bscores_fm[:size]

        model = 'alad_sn{}_dzz{}'.format(do_spectral_norm, allow_zz)
        save_results(scores_ch, testy, model, dataset, 'ch',
                     'dzzenabled{}'.format(allow_zz), label, random_seed, step)
        save_results(scores_l1, testy, model, dataset, 'l1',
                     'dzzenabled{}'.format(allow_zz), label, random_seed, step)
        save_results(scores_l2, testy, model, dataset, 'l2',
                     'dzzenabled{}'.format(allow_zz), label, random_seed, step)
        save_results(scores_fm, testy, model, dataset, 'fm',
                     'dzzenabled{}'.format(allow_zz), label, random_seed,  step)
Пример #7
0
 def test_epoch(self):
     # Evaluation for the testing
     self.logger.info("Testing evaluation...")
     scores_ch = []
     scores_l1 = []
     scores_l2 = []
     scores_fm = []
     inference_time = []
     true_labels = []
     summaries = []
     # Create the scores
     test_loop = tqdm(range(self.config.data_loader.num_iter_per_test))
     cur_epoch = self.model.cur_epoch_tensor.eval(self.sess)
     for _ in test_loop:
         test_batch_begin = time()
         test_batch, test_labels = self.sess.run(
             [self.data.test_image, self.data.test_label])
         test_loop.refresh()  # to show immediately the update
         sleep(0.01)
         noise = np.random.normal(
             loc=0.0,
             scale=1.0,
             size=[self.config.data_loader.test_batch, self.noise_dim])
         feed_dict = {
             self.model.image_tensor: test_batch,
             self.model.noise_tensor: noise,
             self.model.is_training: False,
         }
         scores_ch += self.sess.run(self.model.score_ch,
                                    feed_dict=feed_dict).tolist()
         scores_l1 += self.sess.run(self.model.score_l1,
                                    feed_dict=feed_dict).tolist()
         scores_l2 += self.sess.run(self.model.score_l2,
                                    feed_dict=feed_dict).tolist()
         scores_fm += self.sess.run(self.model.score_fm,
                                    feed_dict=feed_dict).tolist()
         summaries += self.sess.run([self.model.sum_op_im_test],
                                    feed_dict=feed_dict)
         inference_time.append(time() - test_batch_begin)
         true_labels += test_labels.tolist()
     scores_ch = np.asarray(scores_ch)
     scores_l1 = np.asarray(scores_l1)
     scores_l2 = np.asarray(scores_l2)
     scores_fm = np.asarray(scores_fm)
     true_labels = np.asarray(true_labels)
     inference_time = np.mean(inference_time)
     self.summarizer.add_tensorboard(step=cur_epoch,
                                     summaries=summaries,
                                     summarizer="test")
     self.logger.info(
         "Testing: Mean inference time is {:4f}".format(inference_time))
     # TODO BATCH FILL ?
     model = "alad_sn{}_dzz{}".format(self.config.trainer.do_spectral_norm,
                                      self.config.trainer.allow_zz)
     random_seed = 42
     label = 0
     step = self.sess.run(self.model.global_step_tensor)
     percentiles = np.asarray(self.config.trainer.percentiles)
     save_results(
         self.config.log.result_dir,
         scores_ch,
         true_labels,
         model,
         self.config.data_loader.dataset_name,
         "ch",
         "dzzenabled{}".format(self.config.trainer.allow_zz),
         label,
         self.config.data_loader.random_seed,
         self.logger,
         step,
         percentile=percentiles,
     )
     save_results(
         self.config.log.result_dir,
         scores_l1,
         true_labels,
         model,
         self.config.data_loader.dataset_name,
         "l1",
         "dzzenabled{}".format(self.config.trainer.allow_zz),
         label,
         self.config.data_loader.random_seed,
         self.logger,
         step,
         percentile=percentiles,
     )
     save_results(
         self.config.log.result_dir,
         scores_l2,
         true_labels,
         model,
         self.config.data_loader.dataset_name,
         "l2",
         "dzzenabled{}".format(self.config.trainer.allow_zz),
         label,
         self.config.data_loader.random_seed,
         self.logger,
         step,
         percentile=percentiles,
     )
     save_results(
         self.config.log.result_dir,
         scores_fm,
         true_labels,
         model,
         self.config.data_loader.dataset_name,
         "fm",
         "dzzenabled{}".format(self.config.trainer.allow_zz),
         label,
         self.config.data_loader.random_seed,
         self.logger,
         step,
         percentile=percentiles,
     )
Пример #8
0
def train_and_test(dataset, nb_epochs, random_seed, label):
    """ Runs DSEBM on available datasets 

    Note:
        Saves summaries on tensorboard. To display them, please use cmd line
        tensorboard --logdir=model.training_logdir() --port=number
    Args:
        dataset (string): dataset to run the model on 
        nb_epochs (int): number of epochs
        random_seed (int): trying different seeds for averaging the results
        label (int): label which is normal for image experiments
        anomaly_type (string): "novelty" for 100% normal samples in the training set
                               "outlier" for a contamined training set 
        anomaly_proportion (float): if "outlier", anomaly proportion in the training set
    """
    logger = logging.getLogger("DSEBM.run.{}.{}".format(dataset, label))

    # Import model and data
    network = importlib.import_module('dsebm.{}_utilities'.format(dataset))
    data = importlib.import_module("data.{}".format(dataset))

    # Parameters
    starting_lr = network.learning_rate
    batch_size = network.batch_size

    # Placeholders
    x_pl = tf.placeholder(tf.float32,
                          shape=data.get_shape_input(),
                          name="input")
    is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl')
    learning_rate = tf.placeholder(tf.float32, shape=(), name="lr_pl")

    #test
    y_true = tf.placeholder(tf.int32, shape=[None], name="y_true")

    logger.info('Building training graph...')
    logger.warn("The DSEBM is training with the following parameters:")
    display_parameters(batch_size, starting_lr, label)

    net = network.network

    global_step = tf.train.get_or_create_global_step()

    noise = tf.random_normal(shape=tf.shape(x_pl),
                             mean=0.0,
                             stddev=1.,
                             dtype=tf.float32)
    x_noise = x_pl + noise

    with tf.variable_scope('network'):
        b_prime_shape = list(data.get_shape_input())
        b_prime_shape[0] = batch_size
        b_prime = tf.get_variable(name='b_prime',
                                  shape=b_prime_shape)  #tf.shape(x_pl))
        net_out = net(x_pl, is_training=is_training_pl)
        net_out_noise = net(x_noise, is_training=is_training_pl, reuse=True)

    with tf.name_scope('energies'):
        energy = 0.5 * tf.reduce_sum(tf.square(x_pl - b_prime)) \
                 - tf.reduce_sum(net_out)

        energy_noise = 0.5 * tf.reduce_sum(tf.square(x_noise - b_prime)) \
                       - tf.reduce_sum(net_out_noise)

    with tf.name_scope('reconstructions'):
        # reconstruction
        grad = tf.gradients(energy, x_pl)
        fx = x_pl - tf.gradients(energy, x_pl)
        fx = tf.squeeze(fx, axis=0)
        fx_noise = x_noise - tf.gradients(energy_noise, x_noise)

    with tf.name_scope("loss_function"):
        # DSEBM for images
        if len(data.get_shape_input()) == 4:
            loss = tf.reduce_mean(
                tf.reduce_sum(tf.square(x_pl - fx_noise), axis=[1, 2, 3]))
        # DSEBM for tabular data
        else:
            loss = tf.reduce_mean(tf.square(x_pl - fx_noise))

    with tf.name_scope('optimizers'):
        # control op dependencies for batch norm and trainable variables
        tvars = tf.trainable_variables()
        netvars = [var for var in tvars if 'network' in var.name]

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        update_ops_net = [x for x in update_ops if ('network' in x.name)]

        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                           name='optimizer')

        with tf.control_dependencies(update_ops_net):
            train_op = optimizer.minimize(loss,
                                          var_list=netvars,
                                          global_step=global_step)

    with tf.variable_scope('Scores'):

        with tf.name_scope('Energy_score'):
            flat = tf.layers.flatten(x_pl - b_prime)
            if len(data.get_shape_input()) == 4:
                list_scores_energy = 0.5 * tf.reduce_sum(tf.square(flat), axis=1) \
                                           - tf.reduce_sum(net_out, axis=[1, 2, 3])
            else:
                list_scores_energy = 0.5 * tf.reduce_sum(tf.square(flat), axis=1) \
                           - tf.reduce_sum(net_out, axis=1)
        with tf.name_scope('Reconstruction_score'):
            delta = x_pl - fx
            delta_flat = tf.layers.flatten(delta)
            list_scores_reconstruction = tf.norm(delta_flat,
                                                 ord=2,
                                                 axis=1,
                                                 keepdims=False,
                                                 name='reconstruction')

    #   with tf.name_scope('predictions'):
    #       # Highest 20% are anomalous
    #       if dataset=="kdd":
    #           per = tf.contrib.distributions.percentile(list_scores_energy, 80)
    #       else:
    #           per = tf.contrib.distributions.percentile(list_scores_energy, 95)
    #       y_pred = tf.greater_equal(list_scores_energy, per)
    #
    #       #y_test_true = tf.cast(y_test_true, tf.float32)
    #       cm = tf.confusion_matrix(y_true, y_pred, num_classes=2)
    #       recall = cm[1,1]/(cm[1,0]+cm[1,1])
    #       precision = cm[1,1]/(cm[0,1]+cm[1,1])
    #       f1 = 2*precision*recall/(precision + recall)

    with tf.name_scope('training_summary'):

        tf.summary.scalar('score_matching_loss', loss, ['net'])
        tf.summary.scalar('energy', energy, ['net'])

        if dataset in IMAGES_DATASETS:
            with tf.name_scope('image_summary'):
                tf.summary.image('reconstruct', fx, 6, ['image'])
                tf.summary.image('input_images', x_pl, 6, ['image'])
                sum_op_im = tf.summary.merge_all('image')

        sum_op_net = tf.summary.merge_all('net')

    logdir = create_logdir(dataset, label, random_seed)

    sv = tf.train.Supervisor(logdir=logdir + "/train",
                             save_summaries_secs=None,
                             save_model_secs=None)

    # Data
    logger.info('Data loading...')
    trainx, trainy = data.get_train(label)
    trainx_copy = trainx.copy()
    if dataset in IMAGES_DATASETS: validx, validy = data.get_valid(label)
    testx, testy = data.get_test(label)

    rng = np.random.RandomState(RANDOM_SEED)
    nr_batches_train = int(trainx.shape[0] / batch_size)
    if dataset in IMAGES_DATASETS:
        nr_batches_valid = int(validx.shape[0] / batch_size)
    nr_batches_test = int(testx.shape[0] / batch_size)

    logger.info("Train: {} samples in {} batches".format(
        trainx.shape[0], nr_batches_train))
    if dataset in IMAGES_DATASETS:
        logger.info("Valid: {} samples in {} batches".format(
            validx.shape[0], nr_batches_valid))
    logger.info("Test:  {} samples in {} batches".format(
        testx.shape[0], nr_batches_test))

    logger.info('Start training...')
    with sv.managed_session() as sess:
        logger.info('Initialization done')

        train_writer = tf.summary.FileWriter(logdir + "/train", sess.graph)
        valid_writer = tf.summary.FileWriter(logdir + "/valid", sess.graph)

        train_batch = 0
        epoch = 0
        best_valid_loss = 0
        train_losses = [0] * STRIP_EV

        while not sv.should_stop() and epoch < nb_epochs:
            lr = starting_lr

            begin = time.time()
            trainx = trainx[rng.permutation(
                trainx.shape[0])]  # shuffling unl dataset
            trainx_copy = trainx_copy[rng.permutation(trainx.shape[0])]

            losses, energies = [0, 0]
            # training
            for t in range(nr_batches_train):
                display_progression_epoch(t, nr_batches_train)

                # construct randomly permuted minibatches
                ran_from = t * batch_size
                ran_to = (t + 1) * batch_size

                # train the net
                feed_dict = {
                    x_pl: trainx[ran_from:ran_to],
                    is_training_pl: True,
                    learning_rate: lr
                }
                _, ld, en, sm, step = sess.run(
                    [train_op, loss, energy, sum_op_net, global_step],
                    feed_dict=feed_dict)
                losses += ld
                energies += en
                train_writer.add_summary(sm, step)

                if t % FREQ_PRINT == 0 and dataset in IMAGES_DATASETS:  # inspect reconstruction
                    t = np.random.randint(0, 40)
                    ran_from = t
                    ran_to = t + batch_size
                    sm = sess.run(sum_op_im,
                                  feed_dict={
                                      x_pl: trainx[ran_from:ran_to],
                                      is_training_pl: False
                                  })
                    train_writer.add_summary(sm, step)

                train_batch += 1

            losses /= nr_batches_train
            energies /= nr_batches_train
            # Remembering loss for early stopping
            train_losses[epoch % STRIP_EV] = losses

            logger.info('Epoch terminated')
            print("Epoch %d | time = %ds | loss = %.4f | energy = %.4f " %
                  (epoch, time.time() - begin, losses, energies))

            if (epoch + 1) % FREQ_SNAP == 0 and dataset in IMAGES_DATASETS:

                print("Take a snap of the reconstructions...")
                x = trainx[:batch_size]
                feed_dict = {x_pl: x, is_training_pl: False}

                rect_x = sess.run(fx, feed_dict=feed_dict)
                nama_e_wa = "dsebm/reconstructions/{}/{}/" \
                            "{}_epoch{}".format(dataset,
                                                label,
                                                random_seed, epoch)
                nb_imgs = 50
                save_grid_plot(x[:nb_imgs], rect_x[:nb_imgs], nama_e_wa,
                               nb_imgs)

            if (epoch + 1) % FREQ_EV == 0 and dataset in IMAGES_DATASETS:
                logger.info("Validation")
                inds = rng.permutation(validx.shape[0])
                validx = validx[inds]  # shuffling  dataset
                validy = validy[inds]  # shuffling  dataset
                valid_loss = 0
                for t in range(nr_batches_valid):
                    display_progression_epoch(t, nr_batches_valid)

                    # construct randomly permuted minibatches
                    ran_from = t * batch_size
                    ran_to = (t + 1) * batch_size

                    # train the net
                    feed_dict = {
                        x_pl: validx[ran_from:ran_to],
                        y_true: validy[ran_from:ran_to],
                        is_training_pl: False
                    }

                    vl, sm, step = sess.run([loss, sum_op_net, global_step],
                                            feed_dict=feed_dict)
                    valid_writer.add_summary(sm, step + t)  #train_batch)
                    valid_loss += vl

                valid_loss /= nr_batches_valid

                # train the net

                logger.info("Validation loss at step " + str(step) + ":" +
                            str(valid_loss))
                ##EARLY STOPPING
                #UPDATE WEIGHTS
                if valid_loss < best_valid_loss or epoch == FREQ_EV - 1:
                    best_valid_loss = valid_loss
                    logger.info("Best model - loss={} - saving...".format(
                        best_valid_loss))
                    sv.saver.save(sess,
                                  logdir + '/train/model.ckpt',
                                  global_step=step)
                    nb_without_improvements = 0
                else:
                    nb_without_improvements += FREQ_EV

                if nb_without_improvements > PATIENCE:
                    sv.request_stop()
                    logger.warning(
                        "Early stopping at epoch {} with weights from epoch {}"
                        .format(epoch, epoch - nb_without_improvements))
            epoch += 1

        logger.warn('Testing evaluation...')

        step = sess.run(global_step)
        scores_e = []
        scores_r = []
        inference_time = []

        # Create scores
        for t in range(nr_batches_test):
            # construct randomly permuted minibatches
            ran_from = t * batch_size
            ran_to = (t + 1) * batch_size
            begin_val_batch = time.time()

            feed_dict = {x_pl: testx[ran_from:ran_to], is_training_pl: False}

            scores_e += sess.run(list_scores_energy,
                                 feed_dict=feed_dict).tolist()

            scores_r += sess.run(list_scores_reconstruction,
                                 feed_dict=feed_dict).tolist()
            inference_time.append(time.time() - begin_val_batch)

        logger.info('Testing : mean inference time is %.4f' %
                    (np.mean(inference_time)))

        if testx.shape[0] % batch_size != 0:
            batch, size = batch_fill(testx, batch_size)
            feed_dict = {x_pl: batch, is_training_pl: False}
            batch_score_e = sess.run(list_scores_energy,
                                     feed_dict=feed_dict).tolist()
            batch_score_r = sess.run(list_scores_reconstruction,
                                     feed_dict=feed_dict).tolist()
            scores_e += batch_score_e[:size]
            scores_r += batch_score_r[:size]

        save_results(scores_e, testy, 'dsebm', dataset, 'energy', "test",
                     label, random_seed, step)
        save_results(scores_r, testy, 'dsebm', dataset, 'reconstruction',
                     "test", label, random_seed, step)
 def test_epoch(self):
     self.logger.warn("Testing evaluation...")
     scores_im1 = []
     scores_im2 = []
     scores_comb = []
     scores_mask1 = []
     scores_mask2 = []
     scores_pipe = []
     scores_pipe_2 = []
     inference_time = []
     true_labels = []
     # Create the scores
     test_loop = tqdm(range(self.config.data_loader.num_iter_per_test))
     for _ in test_loop:
         test_batch_begin = time()
         test_batch, test_labels = self.sess.run(
             [self.data.test_image, self.data.test_label])
         test_loop.refresh()  # to show immediately the update
         sleep(0.01)
         noise = np.random.normal(
             loc=0.0,
             scale=1.0,
             size=[self.config.data_loader.test_batch, self.noise_dim])
         feed_dict = {
             self.model.image_input: test_batch,
             self.model.noise_tensor: noise,
             self.model.is_training_gen: False,
             self.model.is_training_dis: False,
             self.model.is_training_enc_g: False,
             self.model.is_training_enc_r: False,
         }
         scores_im1 += self.sess.run(self.model.img_score_l1,
                                     feed_dict=feed_dict).tolist()
         scores_im2 += self.sess.run(self.model.img_score_l2,
                                     feed_dict=feed_dict).tolist()
         scores_comb += self.sess.run(self.model.score_comb,
                                      feed_dict=feed_dict).tolist()
         scores_mask1 += self.sess.run(self.model.mask_score_1,
                                       feed_dict=feed_dict).tolist()
         scores_mask2 += self.sess.run(self.model.mask_score_2,
                                       feed_dict=feed_dict).tolist()
         scores_pipe += self.sess.run(self.model.pipe_score,
                                      feed_dict=feed_dict).tolist()
         scores_pipe_2 += self.sess.run(self.model.pipe_score_2,
                                        feed_dict=feed_dict).tolist()
         if self.config.trainer.enable_disc_xx:
             # scores_final_3 += self.sess.run(
             #     self.model.final_score_3, feed_dict=feed_dict
             # ).tolist()
             scores_final_4 += self.sess.run(self.model.final_score_4,
                                             feed_dict=feed_dict).tolist()
         if self.config.trainer.enable_disc_zz:
             # scores_final_5 += self.sess.run(
             #     self.model.final_score_5, feed_dict=feed_dict
             # ).tolist()
             scores_final_6 += self.sess.run(self.model.final_score_6,
                                             feed_dict=feed_dict).tolist()
         inference_time.append(time() - test_batch_begin)
         true_labels += test_labels.tolist()
     scores_im1 = np.asarray(scores_im1)
     scores_im2 = np.asarray(scores_im2)
     scores_comb = np.asarray(scores_comb)
     scores_pipe = np.asarray(scores_pipe)
     scores_pipe_2 = np.asarray(scores_pipe_2)
     scores_mask1 = np.asarray(scores_mask1)
     scores_mask2 = np.asarray(scores_mask2)
     if self.config.trainer.enable_disc_xx:
         #scores_final_3 = np.asarray(scores_final_3)
         scores_final_4 = np.asarray(scores_final_4)
     if self.config.trainer.enable_disc_zz:
         #scores_final_5 = np.asarray(scores_final_5)
         scores_final_6 = np.asarray(scores_final_6)
     true_labels = np.asarray(true_labels)
     inference_time = np.mean(inference_time)
     self.logger.info(
         "Testing: Mean inference time is {:4f}".format(inference_time))
     step = self.sess.run(self.model.global_step_tensor)
     percentiles = np.asarray(self.config.trainer.percentiles)
     save_results(
         self.config.log.result_dir,
         scores_im1,
         true_labels,
         self.config.model.name,
         self.config.data_loader.dataset_name,
         "im1",
         "paper",
         self.config.trainer.label,
         self.config.data_loader.random_seed,
         self.logger,
         step,
         percentile=percentiles,
     )
     save_results(
         self.config.log.result_dir,
         scores_im2,
         true_labels,
         self.config.model.name,
         self.config.data_loader.dataset_name,
         "im2",
         "paper",
         self.config.trainer.label,
         self.config.data_loader.random_seed,
         self.logger,
         step,
         percentile=percentiles,
     )
     save_results(
         self.config.log.result_dir,
         scores_comb,
         true_labels,
         self.config.model.name,
         self.config.data_loader.dataset_name,
         "comb",
         "paper",
         self.config.trainer.label,
         self.config.data_loader.random_seed,
         self.logger,
         step,
         percentile=percentiles,
     )
     save_results(
         self.config.log.result_dir,
         scores_mask1,
         true_labels,
         self.config.model.name,
         self.config.data_loader.dataset_name,
         "mask_1",
         "paper",
         self.config.trainer.label,
         self.config.data_loader.random_seed,
         self.logger,
         step,
         percentile=percentiles,
     )
     save_results(
         self.config.log.result_dir,
         scores_mask2,
         true_labels,
         self.config.model.name,
         self.config.data_loader.dataset_name,
         "mask_2",
         "paper",
         self.config.trainer.label,
         self.config.data_loader.random_seed,
         self.logger,
         step,
         percentile=percentiles,
     )
     save_results(
         self.config.log.result_dir,
         scores_pipe,
         true_labels,
         self.config.model.name,
         self.config.data_loader.dataset_name,
         "scores_pipe_1",
         "paper",
         self.config.trainer.label,
         self.config.data_loader.random_seed,
         self.logger,
         step,
         percentile=percentiles,
     )
     save_results(
         self.config.log.result_dir,
         scores_pipe_2,
         true_labels,
         self.config.model.name,
         self.config.data_loader.dataset_name,
         "scores_pipe_2",
         "paper",
         self.config.trainer.label,
         self.config.data_loader.random_seed,
         self.logger,
         step,
         percentile=percentiles,
     )
Пример #10
0
def train_and_test(dataset, nb_epochs, K, l1, l2, label,
                   random_seed):

    """ Runs the DAGMM on the specified dataset

    Note:
        Saves summaries on tensorboard. To display them, please use cmd line
        tensorboard --logdir=model.training_logdir() --port=number
    Args:
        nb_epochs (int): number of epochs
        weight (float, optional): weight for the anomaly score composition
        anomalous_label (int): int in range 0 to 10, is the class/digit
                                which is considered outlier
    """
    logger = logging.getLogger("DAGMM.train.{}.{}".format(dataset,label))

    # Import model and data
    model = importlib.import_module('dagmm.{}_utilities'.format(dataset))
    data = importlib.import_module("data.{}".format(dataset))

    # Parameters
    starting_lr = model.params["learning_rate"]
    batch_size = model.params["batch_size"]
    if l1==-1: l1 = model.params["l1"]
    if l2==-1: l2 = model.params["l2"]
    if K==-1: K = model.params["K"]

    # Placeholders

    x_pl = tf.placeholder(tf.float32, data.get_shape_input())
    is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl')
    learning_rate = tf.placeholder(tf.float32, shape=(), name="lr_pl")

    logger.info('Building training graph...')

    logger.warning("The DAGMM is training with the following parameters:")
    display_parameters(batch_size, starting_lr, l1, l2,
                       label)


    global_step = tf.Variable(0, name='global_step', trainable=False)

    enc = model.encoder
    dec = model.decoder
    feat_ex = model.feature_extractor
    est = model.estimator

    #feature extraction for images
    if model.params["is_image"] and not METHOD=="pca":
         x_features = image_features.extract_features(x_pl)
    else:
         x_features = x_pl
    n_features = x_features.shape[1]

    with tf.variable_scope('encoder_model'):
        z_c = enc(x_features, is_training=is_training_pl)  
    
    with tf.variable_scope('decoder_model'):
        x_rec = dec(z_c, n_features, is_training=is_training_pl)

    with tf.variable_scope('feature_extractor_model'):
        x_flat = tf.layers.flatten(x_features)
        x_rec_flat = tf.layers.flatten(x_rec)
        z_r = feat_ex(x_flat, x_rec_flat)

    z = tf.concat([z_c, z_r], axis=1)

    with tf.variable_scope('estimator_model'):
        gamma = est(z, K, is_training=is_training_pl)

    with tf.variable_scope('gmm'):
        energy, penalty = gmm.compute_energy_and_penalty(z, gamma, is_training_pl)

    with tf.name_scope('loss_functions'):
        # reconstruction error
        rec_error = reconstruction_error(x_flat, x_rec_flat)
        loss_rec = tf.reduce_mean(rec_error)

        # probabilities to observe
        loss_energy = tf.reduce_mean(energy)
 
        # full loss
        full_loss = loss_rec + l1*loss_energy + l2*penalty



    with tf.name_scope('optimizer'):
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                           beta1=0.5, name='dis_optimizer')

        train_op = optimizer.minimize(full_loss, global_step=global_step)

        with tf.name_scope('predictions'):
            # Highest 20% are anomalous
            if dataset=="kdd":
                per = tf.contrib.distributions.percentile(energy, 80)
            else:
                per = tf.contrib.distributions.percentile(energy, 80)
            y_pred = tf.greater_equal(energy, per)
           
    with tf.name_scope('summary'):
        with tf.name_scope('loss_summary'):
            tf.summary.scalar('loss_rec', loss_rec, ['loss'])
            tf.summary.scalar('mean_energy', loss_energy, ['loss'])
            tf.summary.scalar('penalty', penalty, ['loss'])
            tf.summary.scalar('full_loss', full_loss, ['loss'])

        sum_op_loss = tf.summary.merge_all('loss')

    # Data
    logger.info('Data loading...')

    trainx, trainy = data.get_train(label)
    trainx_copy = trainx.copy()
    testx, testy = data.get_test(label)
    
    if model.params["is_image"] and METHOD=="pca":
       logger.info('PCA...')
       trainx = trainx.reshape([trainx.shape[0], -1])
       testx = testx.reshape([testx.shape[0], -1])
       trainx, testx = image_features.pca(trainx, testx, 20)
       logger.info('Done')

    rng = np.random.RandomState(RANDOM_SEED)
    nr_batches_train = int(trainx.shape[0] / batch_size)
    nr_batches_test = int(testx.shape[0] / batch_size)
    

    logdir = create_logdir(dataset, K, l1, l2, label, random_seed)

    sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=None,
                             save_model_secs=10)

    logger.info('Start training...')
    with sv.managed_session() as sess:

        logger.info('Initialization done')
        writer = tf.summary.FileWriter(logdir, sess.graph)
        train_batch = 0
        epoch = 0

        while not sv.should_stop() and epoch < nb_epochs:

            lr = starting_lr
            begin = time.time()

             # construct randomly permuted minibatches
            trainx = trainx[rng.permutation(trainx.shape[0])]  # shuffling dataset
            trainx_copy = trainx_copy[rng.permutation(trainx.shape[0])]
            train_loss_rec, train_loss = [0, 0]

            # training
            for t in range(nr_batches_train):

                display_progression_epoch(t, nr_batches_train)
                ran_from = t * batch_size
                ran_to = ran_from + batch_size
                feed_dict = {x_pl: trainx[ran_from:ran_to], 
                             is_training_pl: True,
                             learning_rate:lr}

                _, lrec, loss, sm, step = sess.run([train_op,
                                              loss_rec,
                                              full_loss,
                                              sum_op_loss,
                                              global_step],
                                              feed_dict=feed_dict)
                train_loss_rec += lrec
                train_loss += loss
                writer.add_summary(sm, step)#train_batch)

                if np.isnan(loss):
                    logger.info("Loss is nan - Stopping")
                    break

                train_batch += 1
       
            if np.isnan(loss):
                logger.info("Loss is nan - Stopping")
                break

            train_loss_rec /= nr_batches_train
            train_loss /= nr_batches_train

            logger.info('Epoch terminated')
            print("Epoch %d | time = %ds | loss rec = %.4f "
                  "| loss = %.4f"
                  % (epoch, time.time() - begin, train_loss_rec,
                     train_loss))

            epoch += 1


        logger.warning('Testing evaluation...')
         
        inds = rng.permutation(testx.shape[0])
    
        ##TESTING PER BATCHS
        inference_time = [] 
        scores = []
        for t in range(nr_batches_test+1):
            ran_from = t * batch_size
            ran_to = min(ran_from + batch_size, testx.shape[0])
            feed_dict = {x_pl: testx[ran_from:ran_to], 
                         is_training_pl: False}
            begin_val = time.time()
            if l1>0:
                scoresb, step = sess.run([energy, global_step], feed_dict=feed_dict)
            else:
                scoresb, step = sess.run([rec_error, global_step], feed_dict=feed_dict)
            scores.append(scoresb)
            inference_time.append(time.time() - begin_val)
        scores = np.concatenate(scores, axis = 0)

        logger.warning('Testing : inference time is %.4f' % (
            np.mean(inference_time)))

        #scores = np.array(scores)
        save_results(scores, testy, 'dagmm/K{}'.format(K), dataset, None, str(l1)+"_"+str(l2), label,
                     random_seed, step)
Пример #11
0
    def train_epoch(self):
        begin = time()
        # Attach the epoch loop to a variable
        loop = tqdm(range(self.config.data_loader.num_iter_per_epoch))
        # Define the lists for summaries and losses
        gen_losses = []
        disc_losses = []
        disc_xz_losses = []
        disc_xx_losses = []
        disc_zz_losses = []
        summaries = []
        # Get the current epoch counter
        cur_epoch = self.model.cur_epoch_tensor.eval(self.sess)
        image = self.data.image
        for _ in loop:
            loop.set_description("Epoch:{}".format(cur_epoch + 1))
            loop.refresh()  # to show immediately the update
            sleep(0.01)
            lg, ld, ldxz, ldxx, ldzz, sum_g, sum_d = self.train_step(
                image, cur_epoch)
            gen_losses.append(lg)
            disc_losses.append(ld)
            disc_xz_losses.append(ldxz)
            disc_xx_losses.append(ldxx)
            disc_zz_losses.append(ldzz)
            summaries.append(sum_g)
            summaries.append(sum_d)
        self.logger.info("Epoch {} terminated".format(cur_epoch))
        self.summarizer.add_tensorboard(step=cur_epoch, summaries=summaries)
        # Check for reconstruction
        if cur_epoch % self.config.log.frequency_test == 0:
            image_eval = self.sess.run(image)
            feed_dict = {
                self.model.image_input: image_eval,
                self.model.is_training: False
            }
            reconstruction = self.sess.run(self.model.sum_op_im,
                                           feed_dict=feed_dict)
            self.summarizer.add_tensorboard(step=cur_epoch,
                                            summaries=[reconstruction])
        # Get the means of the loss values to display
        gl_m = np.mean(gen_losses)
        dl_m = np.mean(disc_losses)
        dlxz_m = np.mean(disc_xz_losses)
        dlxx_m = np.mean(disc_xx_losses)
        dlzz_m = np.mean(disc_zz_losses)
        if self.config.trainer.allow_zz:
            self.logger.info(
                "Epoch {} | time = {} | loss gen = {:4f} |"
                "loss dis = {:4f} | loss dis xz = {:4f} | loss dis xx = {:4f} | "
                "loss dis zz = {:4f}".format(cur_epoch,
                                             time() - begin, gl_m, dl_m,
                                             dlxz_m, dlxx_m, dlzz_m))
        else:
            self.logger.info(
                "Epoch {} | time = {} | loss gen = {:4f} | "
                "loss dis = {:4f} | loss dis xz = {:4f} | loss dis xx = {:4f} | "
                .format(cur_epoch,
                        time() - begin, gl_m, dl_m, dlxz_m, dlxx_m))
        # Save the model state
        # self.model.save(self.sess)

        if (
                cur_epoch + 1
        ) % self.config.trainer.frequency_eval == 0 and self.config.trainer.enable_early_stop:
            valid_loss = 0
            image_valid = self.sess.run(self.data.valid_image)

            feed_dict = {
                self.model.image_input: image_valid,
                self.model.is_training: False
            }
            vl = self.sess.run([self.model.rec_error_valid],
                               feed_dict=feed_dict)
            valid_loss += vl[0]
            if self.config.log.enable_summary:
                sm = self.sess.run(self.model.sum_op_valid,
                                   feed_dict=feed_dict)
                self.summarizer.add_tensorboard(step=cur_epoch,
                                                summaries=[sm],
                                                summarizer="valid")

            self.logger.info(
                "Validation: valid loss {:.4f}".format(valid_loss))
            if (valid_loss < self.best_valid_loss
                    or cur_epoch == self.config.trainer.frequency_eval - 1):
                self.best_valid_loss = valid_loss
                self.logger.info(
                    "Best model - valid loss = {:.4f} - saving...".format(
                        self.best_valid_loss))
                # Save the model state
                self.model.save(self.sess)
                self.nb_without_improvements = 0
            else:
                self.nb_without_improvements += self.config.trainer.frequency_eval
            if self.nb_without_improvements > self.config.trainer.patience:
                self.patience_lost = True
                self.logger.warning(
                    "Early stopping at epoch {} with weights from epoch {}".
                    format(cur_epoch,
                           cur_epoch - self.nb_without_improvements))

        self.logger.warn("Testing evaluation...")
        scores = []
        inference_time = []
        true_labels = []
        # Create the scores
        test_loop = tqdm(range(self.config.data_loader.num_iter_per_test))
        for _ in test_loop:
            test_batch_begin = time()
            test_batch, test_labels = self.sess.run(
                [self.data.test_image, self.data.test_label])
            test_loop.refresh()  # to show immediately the update
            sleep(0.01)
            feed_dict = {
                self.model.image_input: test_batch,
                self.model.is_training: False
            }
            scores += self.sess.run(self.model.score,
                                    feed_dict=feed_dict).tolist()
            inference_time.append(time() - test_batch_begin)
            true_labels += test_labels.tolist()
        true_labels = np.asarray(true_labels)
        inference_time = np.mean(inference_time)
        self.logger.info(
            "Testing: Mean inference time is {:4f}".format(inference_time))
        scores = np.asarray(scores)
        scores_scaled = (scores - min(scores)) / (max(scores) - min(scores))
        step = self.sess.run(self.model.global_step_tensor)
        save_results(
            self.config.log.result_dir,
            scores_scaled,
            true_labels,
            self.config.model.name,
            self.config.data_loader.dataset_name,
            "fm",
            "paper",
            self.config.trainer.label,
            self.config.data_loader.random_seed,
            self.logger,
            step,
        )
Пример #12
0
    def test_epoch(self):
        # Evaluation for the testing
        self.logger.info("Testing evaluation...")
        rect_x, rec_error,rec_error2, latent, scores_1, scores_2 = [],[], [], [], [], []
        inference_time = []
        true_labels = []
        summaries = []
        # Create the scores
        test_loop = tqdm(range(self.config.data_loader.num_iter_per_test))
        cur_epoch = self.model.cur_epoch_tensor.eval(self.sess)
        for _ in test_loop:
            begin_val_batch = time()
            test_batch, test_labels = self.sess.run(
                [self.data.test_image, self.data.test_label])
            test_loop.refresh()  # to show immediately the update
            sleep(0.01)
            noise = np.random.normal(
                loc=0.0,
                scale=1.0,
                size=[self.config.data_loader.test_batch, self.noise_dim])
            feed_dict = {
                self.model.image_input: test_batch,
                self.model.noise_tensor: noise,
                self.model.is_training: False,
            }
            for _ in range(self.config.trainer.steps_number):
                _ = self.sess.run(self.model.invert_op, feed_dict=feed_dict)

            brect_x, brec_error, brec_error2, bscores_1, bscores_2, blatent = self.sess.run(
                [
                    self.model.rec_gen_ema,
                    self.model.reconstruction_score_1,
                    self.model.reconstruction_score_2,
                    self.model.loss_invert_1,
                    self.model.loss_invert_2,
                    self.model.z_optim,
                ],
                feed_dict=feed_dict,
            )
            rect_x.append(brect_x)
            rec_error.append(brec_error)
            rec_error2.append(brec_error2)
            scores_1.append(bscores_1)
            scores_2.append(bscores_2)
            latent.append(blatent)
            self.sess.run(self.model.reinit_test_graph_op)
            inference_time.append(time() - begin_val_batch)
            true_labels += test_labels.tolist()
            summaries += self.sess.run([self.model.sum_op_im_test],
                                       feed_dict=feed_dict)
        true_labels = np.asarray(true_labels)
        inference_time = np.mean(inference_time)
        self.summarizer.add_tensorboard(step=cur_epoch,
                                        summaries=summaries,
                                        summarizer="test")
        self.logger.info(
            "Testing: Mean inference time is {:4f}".format(inference_time))
        rect_x = np.concatenate(rect_x, axis=0)
        rec_error = np.concatenate(rec_error, axis=0)
        rec_error2 = np.concatenate(rec_error2, axis=0)
        scores_1 = np.concatenate(scores_1, axis=0)
        scores_2 = np.concatenate(scores_2, axis=0)
        latent = np.concatenate(latent, axis=0)
        step = self.sess.run(self.model.global_step_tensor)
        percentiles = np.asarray(self.config.trainer.percentiles)
        save_results(
            self.config.log.result_dir,
            scores_1,
            true_labels,
            self.config.model.name,
            self.config.data_loader.dataset_name,
            "fm_1",
            "paper",
            self.config.trainer.label,
            self.config.data_loader.random_seed,
            self.logger,
            step,
            percentile=percentiles,
        )
        save_results(
            self.config.log.result_dir,
            scores_2,
            true_labels,
            self.config.model.name,
            self.config.data_loader.dataset_name,
            "fm_2",
            "paper",
            self.config.trainer.label,
            self.config.data_loader.random_seed,
            self.logger,
            step,
            percentile=percentiles,
        )