Example #1
0
def train_and_test(dataset, nb_epochs, weight, degree, random_seed, label,
                   enable_sm, score_method, enable_early_stop):
    """ Runs the AnoGAN on the specified dataset

    Note:
        Saves summaries on tensorboard. To display them, please use cmd line
        tensorboard --logdir=model.training_logdir() --port=number
    Args:
        dataset (str): name of the dataset
        nb_epochs (int): number of epochs
        weight (float): weight in the inverting loss function
        degree (int): degree of the norm in the feature matching
        random_seed (int): trying different seeds for averaging the results
        label (int): label which is normal for image experiments
        enable_sm (bool): allow TF summaries for monitoring the training
        score_method (str): which metric to use for the ablation study
        enable_early_stop (bool): allow early stopping for determining the number of epochs
    """

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    logger = logging.getLogger("AnoGAN.run.{}.{}".format(dataset, label))

    # Import model and data
    network = importlib.import_module('anogan.{}_utilities'.format(dataset))
    data = importlib.import_module("data.{}".format(dataset))

    # Parameters
    starting_lr = network.learning_rate
    batch_size = network.batch_size
    latent_dim = network.latent_dim
    ema_decay = 0.999

    # Data
    logger.info('Data loading...')
    trainx, trainy = data.get_train(label)
    if enable_early_stop:
        validx, validy = data.get_valid(label)
        nr_batches_valid = int(validx.shape[0] / batch_size)
    trainx_copy = trainx.copy()
    testx, testy = data.get_test(label)

    global_step = tf.Variable(0, name='global_step', trainable=False)
    # Placeholders
    x_pl = tf.placeholder(tf.float32,
                          shape=(None, trainx.shape[1]),
                          name="input")
    is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl')
    learning_rate = tf.placeholder(tf.float32, shape=(), name="lr_pl")

    rng = np.random.RandomState(random_seed)
    nr_batches_train = int(trainx.shape[0] / batch_size)
    nr_batches_test = int(testx.shape[0] / batch_size)

    logger.info('Building graph...')
    logger.warn("The GAN is training with the following parameters:")
    display_parameters(batch_size, starting_lr, ema_decay, weight, degree,
                       label, trainx.shape[1])

    gen = network.generator
    dis = network.discriminator

    # Sample noise from random normal distribution
    random_z = tf.random_normal([batch_size, latent_dim],
                                mean=0.0,
                                stddev=1.0,
                                name='random_z')
    # Generate images with generator
    x_gen = gen(x_pl, random_z, is_training=is_training_pl)

    real_d, inter_layer_real = dis(x_pl, is_training=is_training_pl)
    fake_d, inter_layer_fake = dis(x_gen,
                                   is_training=is_training_pl,
                                   reuse=True)

    with tf.name_scope('loss_functions'):
        # Calculate seperate losses for discriminator with real and fake images
        real_discriminator_loss = tf.losses.sigmoid_cross_entropy(
            tf.ones_like(real_d), real_d, scope='real_discriminator_loss')
        fake_discriminator_loss = tf.losses.sigmoid_cross_entropy(
            tf.zeros_like(fake_d), fake_d, scope='fake_discriminator_loss')
        # Add discriminator losses
        loss_discriminator = real_discriminator_loss + fake_discriminator_loss
        # Calculate loss for generator by flipping label on discriminator output
        loss_generator = tf.losses.sigmoid_cross_entropy(
            tf.ones_like(fake_d), fake_d, scope='generator_loss')

    with tf.name_scope('optimizers'):
        # control op dependencies for batch norm and trainable variables
        dvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                  scope='discriminator')
        gvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                  scope='generator')

        update_ops_gen = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           scope='generator')
        update_ops_dis = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           scope='discriminator')

        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                           beta1=0.5)

        with tf.control_dependencies(update_ops_gen):
            gen_op = optimizer.minimize(loss_generator,
                                        var_list=gvars,
                                        global_step=global_step)

        with tf.control_dependencies(update_ops_dis):
            dis_op = optimizer.minimize(loss_discriminator, var_list=dvars)

        # Exponential Moving Average for inference
        def train_op_with_ema_dependency(vars, op):
            ema = tf.train.ExponentialMovingAverage(decay=ema_decay)
            maintain_averages_op = ema.apply(vars)
            with tf.control_dependencies([op]):
                train_op = tf.group(maintain_averages_op)
            return train_op, ema

        train_gen_op, gen_ema = train_op_with_ema_dependency(gvars, gen_op)
        train_dis_op, dis_ema = train_op_with_ema_dependency(dvars, dis_op)

    ### Testing ###
    with tf.variable_scope("latent_variable"):
        z_optim = tf.get_variable(
            name='z_optim',
            shape=[batch_size, latent_dim],
            initializer=tf.truncated_normal_initializer())
        reinit_z = z_optim.initializer

    # EMA
    x_gen_ema = gen(x_pl,
                    random_z,
                    is_training=is_training_pl,
                    getter=get_getter(gen_ema),
                    reuse=True)
    rec_x_ema = gen(x_pl,
                    z_optim,
                    is_training=is_training_pl,
                    getter=get_getter(gen_ema),
                    reuse=True)
    # Pass real and fake images into discriminator separately
    real_d_ema, inter_layer_real_ema = dis(x_pl,
                                           is_training=is_training_pl,
                                           getter=get_getter(gen_ema),
                                           reuse=True)
    fake_d_ema, inter_layer_fake_ema = dis(rec_x_ema,
                                           is_training=is_training_pl,
                                           getter=get_getter(gen_ema),
                                           reuse=True)

    with tf.name_scope('Testing'):
        with tf.variable_scope('Reconstruction_loss'):
            delta = x_pl - rec_x_ema
            delta_flat = tf.contrib.layers.flatten(delta)
            reconstruction_score = tf.norm(delta_flat,
                                           ord=degree,
                                           axis=1,
                                           keep_dims=False,
                                           name='epsilon')

        with tf.variable_scope('Discriminator_scores'):

            if score_method == 'cross-e':
                dis_score = tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.ones_like(fake_d_ema), logits=fake_d_ema)

            else:
                fm = inter_layer_real_ema - inter_layer_fake_ema
                fm = tf.contrib.layers.flatten(fm)
                dis_score = tf.norm(fm,
                                    ord=degree,
                                    axis=1,
                                    keep_dims=False,
                                    name='d_loss')

            dis_score = tf.squeeze(dis_score)

        with tf.variable_scope('Score'):
            loss_invert = weight * reconstruction_score \
                                  + (1 - weight) * dis_score

    rec_error_valid = tf.reduce_mean(loss_invert)

    with tf.variable_scope("Test_learning_rate"):
        step_lr = tf.Variable(0, trainable=False)
        learning_rate_invert = 0.001
        reinit_lr = tf.variables_initializer(
            tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                              scope="Test_learning_rate"))

    with tf.name_scope('Test_optimizer'):
        invert_op = tf.train.AdamOptimizer(learning_rate_invert).\
            minimize(loss_invert,global_step=step_lr, var_list=[z_optim],
                     name='optimizer')
        reinit_optim = tf.variables_initializer(
            tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                              scope='Test_optimizer'))

    reinit_test_graph_op = [reinit_z, reinit_lr, reinit_optim]

    with tf.name_scope("Scores"):
        list_scores = loss_invert

    if enable_sm:
        with tf.name_scope('training_summary'):
            with tf.name_scope('dis_summary'):
                tf.summary.scalar('real_discriminator_loss',
                                  real_discriminator_loss, ['dis'])
                tf.summary.scalar('fake_discriminator_loss',
                                  fake_discriminator_loss, ['dis'])
                tf.summary.scalar('discriminator_loss', loss_discriminator,
                                  ['dis'])

            with tf.name_scope('gen_summary'):
                tf.summary.scalar('loss_generator', loss_generator, ['gen'])

            with tf.name_scope('img_summary'):
                heatmap_pl_latent = tf.placeholder(tf.float32,
                                                   shape=(1, 480, 640, 3),
                                                   name="heatmap_pl_latent")
                sum_op_latent = tf.summary.image('heatmap_latent',
                                                 heatmap_pl_latent)

            with tf.name_scope('validation_summary'):
                tf.summary.scalar('valid', rec_error_valid, ['v'])

            if dataset in IMAGES_DATASETS:
                with tf.name_scope('image_summary'):
                    tf.summary.image('reconstruct', x_gen, 8, ['image'])
                    tf.summary.image('input_images', x_pl, 8, ['image'])

            else:
                heatmap_pl_rec = tf.placeholder(tf.float32,
                                                shape=(1, 480, 640, 3),
                                                name="heatmap_pl_rec")
                with tf.name_scope('image_summary'):
                    tf.summary.image('heatmap_rec', heatmap_pl_rec, 1,
                                     ['image'])

            sum_op_dis = tf.summary.merge_all('dis')
            sum_op_gen = tf.summary.merge_all('gen')
            sum_op = tf.summary.merge([sum_op_dis, sum_op_gen])
            sum_op_im = tf.summary.merge_all('image')
            sum_op_valid = tf.summary.merge_all('v')

    logdir = create_logdir(dataset, weight, label, random_seed)

    sv = tf.train.Supervisor(logdir=logdir,
                             save_summaries_secs=None,
                             save_model_secs=None)

    logger.info('Start training...')
    with sv.managed_session(config=config) as sess:

        logger.info('Initialization done')

        writer = tf.summary.FileWriter(logdir, sess.graph)

        train_batch = 0
        epoch = 0
        best_valid_loss = 0

        while not sv.should_stop() and epoch < nb_epochs:

            lr = starting_lr
            begin = time.time()

            trainx = trainx[rng.permutation(
                trainx.shape[0])]  # shuffling unl dataset
            trainx_copy = trainx_copy[rng.permutation(trainx.shape[0])]

            train_loss_dis, train_loss_gen = [0, 0]
            # training
            for t in range(nr_batches_train):
                display_progression_epoch(t, nr_batches_train)

                # construct randomly permuted minibatches
                ran_from = t * batch_size
                ran_to = (t + 1) * batch_size

                # train discriminator
                feed_dict = {
                    x_pl: trainx[ran_from:ran_to],
                    is_training_pl: True,
                    learning_rate: lr
                }
                _, ld, step = sess.run(
                    [train_dis_op, loss_discriminator, global_step],
                    feed_dict=feed_dict)
                train_loss_dis += ld

                # train generator
                feed_dict = {
                    x_pl: trainx_copy[ran_from:ran_to],
                    is_training_pl: True,
                    learning_rate: lr
                }
                _, lg = sess.run([train_gen_op, loss_generator],
                                 feed_dict=feed_dict)
                train_loss_gen += lg

                if enable_sm:
                    sm = sess.run(sum_op, feed_dict=feed_dict)
                    writer.add_summary(sm, step)

                    if t % FREQ_PRINT == 0:  # inspect reconstruction
                        # t = np.random.randint(0,400)
                        # ran_from = t
                        # ran_to = t + batch_size
                        # sm = sess.run(sum_op_im, feed_dict={x_pl: trainx[ran_from:ran_to],is_training_pl: False})
                        # writer.add_summary(sm, train_batch)

                        # data = sess.run(z_gen, feed_dict={
                        #     x_pl: trainx[ran_from:ran_to],
                        #     is_training_pl: False})
                        # data = np.expand_dims(heatmap(data), axis=0)
                        # sml = sess.run(sum_op_latent, feed_dict={
                        #     heatmap_pl_latent: data,
                        #     is_training_pl: False})
                        #
                        # writer.add_summary(sml, train_batch)

                        if dataset in IMAGES_DATASETS:
                            sm = sess.run(sum_op_im,
                                          feed_dict={
                                              x_pl: trainx[ran_from:ran_to],
                                              is_training_pl: False
                                          })
                            #
                            # else:
                            #     data = sess.run(z_gen, feed_dict={
                            #         x_pl: trainx[ran_from:ran_to],
                            #         z_pl: np.random.normal(
                            #             size=[batch_size, latent_dim]),
                            #         is_training_pl: False})
                            #     data = np.expand_dims(heatmap(data), axis=0)
                            #     sm = sess.run(sum_op_im, feed_dict={
                            #         heatmap_pl_rec: data,
                            #         is_training_pl: False})
                            writer.add_summary(sm, step)  #train_batch)
                train_batch += 1

            train_loss_gen /= nr_batches_train
            train_loss_dis /= nr_batches_train

            # logger.info('Epoch terminated')
            # print("Epoch %d | time = %ds | loss gen = %.4f | loss dis = %.4f "
            #       % (epoch, time.time() - begin, train_loss_gen, train_loss_dis))

            ##EARLY STOPPING
            if (epoch + 1) % FREQ_EV == 0 and enable_early_stop:
                logger.info('Validation...')

                inds = rng.permutation(validx.shape[0])
                validx = validx[inds]  # shuffling  dataset
                validy = validy[inds]  # shuffling  dataset

                valid_loss = 0

                # Create scores
                for t in range(nr_batches_valid):
                    # construct randomly permuted minibatches
                    display_progression_epoch(t, nr_batches_valid)
                    ran_from = t * batch_size
                    ran_to = (t + 1) * batch_size

                    feed_dict = {
                        x_pl: validx[ran_from:ran_to],
                        is_training_pl: False
                    }
                    for _ in range(STEPS_NUMBER):
                        _ = sess.run(invert_op, feed_dict=feed_dict)
                    vl = sess.run(rec_error_valid, feed_dict=feed_dict)
                    valid_loss += vl
                    sess.run(reinit_test_graph_op)

                valid_loss /= nr_batches_valid
                sess.run(reinit_test_graph_op)

                if enable_sm:
                    sm = sess.run(sum_op_valid, feed_dict=feed_dict)
                    writer.add_summary(sm, step)  # train_batch)

                logger.info('Validation: valid loss {:.4f}'.format(valid_loss))

                if valid_loss < best_valid_loss or epoch == FREQ_EV - 1:

                    best_valid_loss = valid_loss
                    logger.info(
                        "Best model - valid loss = {:.4f} - saving...".format(
                            best_valid_loss))
                    sv.saver.save(sess,
                                  logdir + '/model.ckpt',
                                  global_step=step)
                    nb_without_improvements = 0
                else:
                    nb_without_improvements += FREQ_EV

                if nb_without_improvements > PATIENCE:
                    sv.request_stop()
                    logger.warning(
                        "Early stopping at epoch {} with weights from epoch {}"
                        .format(epoch, epoch - nb_without_improvements))

            epoch += 1

        logger.warn('Testing evaluation...')
        step = sess.run(global_step)
        sv.saver.save(sess, logdir + '/model.ckpt', global_step=step)

        rect_x, rec_error, latent, scores = [], [], [], []
        inference_time = []

        # Create scores
        for t in range(nr_batches_test):
            # construct randomly permuted minibatches
            display_progression_epoch(t, nr_batches_test)
            ran_from = t * batch_size
            ran_to = (t + 1) * batch_size
            begin_val_batch = time.time()

            feed_dict = {x_pl: testx[ran_from:ran_to], is_training_pl: False}

            for _ in range(STEPS_NUMBER):
                _ = sess.run(invert_op, feed_dict=feed_dict)

            brect_x, brec_error, bscores, blatent = sess.run(
                [rec_x_ema, reconstruction_score, loss_invert, z_optim],
                feed_dict=feed_dict)
            rect_x.append(brect_x)
            rec_error.append(brec_error)
            scores.append(bscores)
            latent.append(blatent)
            sess.run(reinit_test_graph_op)

            inference_time.append(time.time() - begin_val_batch)

        logger.info('Testing : mean inference time is %.4f' %
                    (np.mean(inference_time)))

        if testx.shape[0] % batch_size != 0:
            batch, size = batch_fill(testx, batch_size)
            feed_dict = {x_pl: batch, is_training_pl: False}
            for _ in range(STEPS_NUMBER):
                _ = sess.run(invert_op, feed_dict=feed_dict)
            brect_x, brec_error, bscores, blatent = sess.run(
                [rec_x_ema, reconstruction_score, loss_invert, z_optim],
                feed_dict=feed_dict)
            rect_x.append(brect_x[:size])
            rec_error.append(brec_error[:size])
            scores.append(bscores[:size])
            latent.append(blatent[:size])
            sess.run(reinit_test_graph_op)

        rect_x = np.concatenate(rect_x, axis=0)
        rec_error = np.concatenate(rec_error, axis=0)
        scores = np.concatenate(scores, axis=0)
        latent = np.concatenate(latent, axis=0)
        save_results(scores, testy, 'anogan', dataset, score_method, weight,
                     label, random_seed)
Example #2
0
def train_and_test(dataset, nb_epochs, degree, random_seed, label,
                   allow_zz, enable_sm, score_method,
                   enable_early_stop, do_spectral_norm):
    """ Runs the AliCE on the specified dataset

    Note:
        Saves summaries on tensorboard. To display them, please use cmd line
        tensorboard --logdir=model.training_logdir() --port=number
    Args:
        dataset (str): name of the dataset
        nb_epochs (int): number of epochs
        degree (int): degree of the norm in the feature matching
        random_seed (int): trying different seeds for averaging the results
        label (int): label which is normal for image experiments
        allow_zz (bool): allow the d_zz discriminator or not for ablation study
        enable_sm (bool): allow TF summaries for monitoring the training
        score_method (str): which metric to use for the ablation study
        enable_early_stop (bool): allow early stopping for determining the number of epochs
        do_spectral_norm (bool): allow spectral norm or not for ablation study
    """
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    logger = logging.getLogger("ALAD.run.{}.{}".format(
        dataset, label))

    # Import model and data
    network = importlib.import_module('alad.{}_utilities'.format(dataset))
    data = importlib.import_module("data.{}".format(dataset))

    # Parameters
    starting_lr = network.learning_rate
    batch_size = network.batch_size
    latent_dim = network.latent_dim
    ema_decay = 0.999

    global_step = tf.Variable(0, name='global_step', trainable=False)

    # Placeholders
    x_pl = tf.placeholder(tf.float32, shape=data.get_shape_input(),
                          name="input_x")
    z_pl = tf.placeholder(tf.float32, shape=[None, latent_dim],
                          name="input_z")
    is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl')
    learning_rate = tf.placeholder(tf.float32, shape=(), name="lr_pl")

    # Data
    # label 部分都已经转换为 0 normal数据  / 1 abnormal数据
    logger.info('Data loading...')
    trainx, trainy = data.get_train(label)
    if enable_early_stop: validx, validy = data.get_valid(label)
    trainx_copy = trainx.copy()
    testx, testy = data.get_test(label)

    rng = np.random.RandomState(random_seed)
    nr_batches_train = int(trainx.shape[0] / batch_size)
    nr_batches_test = int(testx.shape[0] / batch_size)

    logger.info('Building graph...')

    logger.warn("ALAD is training with the following parameters:")
    display_parameters(batch_size, starting_lr, ema_decay, degree, label,
                       allow_zz, score_method, do_spectral_norm)

    gen = network.decoder
    enc = network.encoder
    dis_xz = network.discriminator_xz
    dis_xx = network.discriminator_xx
    dis_zz = network.discriminator_zz

    with tf.variable_scope('encoder_model'):
        z_gen = enc(x_pl, is_training=is_training_pl,
                    do_spectral_norm=do_spectral_norm)

    with tf.variable_scope('generator_model'):
        x_gen = gen(z_pl, is_training=is_training_pl)
        rec_x = gen(z_gen, is_training=is_training_pl, reuse=True)

    with tf.variable_scope('encoder_model'):
        rec_z = enc(x_gen, is_training=is_training_pl, reuse=True,
                    do_spectral_norm=do_spectral_norm)

    with tf.variable_scope('discriminator_model_xz'):
        l_encoder, inter_layer_inp_xz = dis_xz(x_pl, z_gen,
                                            is_training=is_training_pl,
                    do_spectral_norm=do_spectral_norm)
        l_generator, inter_layer_rct_xz = dis_xz(x_gen, z_pl,
                                              is_training=is_training_pl,
                                              reuse=True,
                    do_spectral_norm=do_spectral_norm)

    with tf.variable_scope('discriminator_model_xx'):
        x_logit_real, inter_layer_inp_xx = dis_xx(x_pl, x_pl,
                                                  is_training=is_training_pl,
                    do_spectral_norm=do_spectral_norm)
        x_logit_fake, inter_layer_rct_xx = dis_xx(x_pl, rec_x, is_training=is_training_pl,
                              reuse=True, do_spectral_norm=do_spectral_norm)

    with tf.variable_scope('discriminator_model_zz'):
        z_logit_real, _ = dis_zz(z_pl, z_pl, is_training=is_training_pl,
                                 do_spectral_norm=do_spectral_norm)
        z_logit_fake, _ = dis_zz(z_pl, rec_z, is_training=is_training_pl,
                              reuse=True, do_spectral_norm=do_spectral_norm)

    with tf.name_scope('loss_functions'):

        # discriminator xz
        loss_dis_enc = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.ones_like(l_encoder),logits=l_encoder))
        loss_dis_gen = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.zeros_like(l_generator),logits=l_generator))
        dis_loss_xz = loss_dis_gen + loss_dis_enc

        # discriminator xx
        x_real_dis = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=x_logit_real, labels=tf.ones_like(x_logit_real))
        x_fake_dis = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=x_logit_fake, labels=tf.zeros_like(x_logit_fake))
        dis_loss_xx = tf.reduce_mean(x_real_dis + x_fake_dis)

        # discriminator zz
        z_real_dis = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=z_logit_real, labels=tf.ones_like(z_logit_real))
        z_fake_dis = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=z_logit_fake, labels=tf.zeros_like(z_logit_fake))
        dis_loss_zz = tf.reduce_mean(z_real_dis + z_fake_dis)

        loss_discriminator = dis_loss_xz + dis_loss_xx + dis_loss_zz if \
            allow_zz else dis_loss_xz + dis_loss_xx

        # generator and encoder
        gen_loss_xz = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.ones_like(l_generator),logits=l_generator))
        enc_loss_xz = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.zeros_like(l_encoder), logits=l_encoder))
        x_real_gen = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=x_logit_real, labels=tf.zeros_like(x_logit_real))
        x_fake_gen = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=x_logit_fake, labels=tf.ones_like(x_logit_fake))
        z_real_gen = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=z_logit_real, labels=tf.zeros_like(z_logit_real))
        z_fake_gen = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=z_logit_fake, labels=tf.ones_like(z_logit_fake))

        cost_x = tf.reduce_mean(x_real_gen + x_fake_gen)
        cost_z = tf.reduce_mean(z_real_gen + z_fake_gen)

        cycle_consistency_loss = cost_x + cost_z if allow_zz else cost_x
        loss_generator = gen_loss_xz + cycle_consistency_loss
        loss_encoder = enc_loss_xz + cycle_consistency_loss

    with tf.name_scope('optimizers'):

        # control op dependencies for batch norm and trainable variables
        tvars = tf.trainable_variables()
        dxzvars = [var for var in tvars if 'discriminator_model_xz' in var.name]
        dxxvars = [var for var in tvars if 'discriminator_model_xx' in var.name]
        dzzvars = [var for var in tvars if 'discriminator_model_zz' in var.name]
        gvars = [var for var in tvars if 'generator_model' in var.name]
        evars = [var for var in tvars if 'encoder_model' in var.name]

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        update_ops_gen = [x for x in update_ops if ('generator_model' in x.name)]
        update_ops_enc = [x for x in update_ops if ('encoder_model' in x.name)]
        update_ops_dis_xz = [x for x in update_ops if
                             ('discriminator_model_xz' in x.name)]
        update_ops_dis_xx = [x for x in update_ops if
                             ('discriminator_model_xx' in x.name)]
        update_ops_dis_zz = [x for x in update_ops if
                             ('discriminator_model_zz' in x.name)]

        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                                  beta1=0.5)

        with tf.control_dependencies(update_ops_gen):
            gen_op = optimizer.minimize(loss_generator, var_list=gvars,
                                            global_step=global_step)
        with tf.control_dependencies(update_ops_enc):
            enc_op = optimizer.minimize(loss_encoder, var_list=evars)

        with tf.control_dependencies(update_ops_dis_xz):
            dis_op_xz = optimizer.minimize(dis_loss_xz, var_list=dxzvars)

        with tf.control_dependencies(update_ops_dis_xx):
            dis_op_xx = optimizer.minimize(dis_loss_xx, var_list=dxxvars)

        with tf.control_dependencies(update_ops_dis_zz):
            dis_op_zz = optimizer.minimize(dis_loss_zz, var_list=dzzvars)

        # Exponential Moving Average for inference
        def train_op_with_ema_dependency(vars, op):
            ema = tf.train.ExponentialMovingAverage(decay=ema_decay)
            maintain_averages_op = ema.apply(vars)
            with tf.control_dependencies([op]):
                train_op = tf.group(maintain_averages_op)
            return train_op, ema

        train_gen_op, gen_ema = train_op_with_ema_dependency(gvars, gen_op)
        train_enc_op, enc_ema = train_op_with_ema_dependency(evars, enc_op)
        train_dis_op_xz, xz_ema = train_op_with_ema_dependency(dxzvars,
                                                               dis_op_xz)
        train_dis_op_xx, xx_ema = train_op_with_ema_dependency(dxxvars,
                                                               dis_op_xx)
        train_dis_op_zz, zz_ema = train_op_with_ema_dependency(dzzvars,
                                                               dis_op_zz)

    with tf.variable_scope('encoder_model'):
        z_gen_ema = enc(x_pl, is_training=is_training_pl,
                        getter=get_getter(enc_ema), reuse=True,
                        do_spectral_norm=do_spectral_norm)

    with tf.variable_scope('generator_model'):
        rec_x_ema = gen(z_gen_ema, is_training=is_training_pl,
                              getter=get_getter(gen_ema), reuse=True)
        x_gen_ema = gen(z_pl, is_training=is_training_pl,
                              getter=get_getter(gen_ema), reuse=True)

    with tf.variable_scope('discriminator_model_xx'):
        l_encoder_emaxx, inter_layer_inp_emaxx = dis_xx(x_pl, x_pl,
                                                    is_training=is_training_pl,
                                                    getter=get_getter(xx_ema),
                                                    reuse=True,
                    do_spectral_norm=do_spectral_norm)

        l_generator_emaxx, inter_layer_rct_emaxx = dis_xx(x_pl, rec_x_ema,
                                                      is_training=is_training_pl,
                                                      getter=get_getter(
                                                          xx_ema),
                                                      reuse=True,
                    do_spectral_norm=do_spectral_norm)

    with tf.name_scope('Testing'):

        with tf.variable_scope('Scores'):


            score_ch = tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.ones_like(l_generator_emaxx),
                    logits=l_generator_emaxx)
            score_ch = tf.squeeze(score_ch)

            rec = x_pl - rec_x_ema
            rec = tf.contrib.layers.flatten(rec)
            score_l1 = tf.norm(rec, ord=1, axis=1,
                            keep_dims=False, name='d_loss')
            score_l1 = tf.squeeze(score_l1)

            rec = x_pl - rec_x_ema
            rec = tf.contrib.layers.flatten(rec)
            score_l2 = tf.norm(rec, ord=2, axis=1,
                            keep_dims=False, name='d_loss')
            score_l2 = tf.squeeze(score_l2)

            inter_layer_inp, inter_layer_rct = inter_layer_inp_emaxx, \
                                               inter_layer_rct_emaxx
            fm = inter_layer_inp - inter_layer_rct
            fm = tf.contrib.layers.flatten(fm)
            score_fm = tf.norm(fm, ord=degree, axis=1,
                             keep_dims=False, name='d_loss')
            score_fm = tf.squeeze(score_fm)

    if enable_early_stop:
        rec_error_valid = tf.reduce_mean(score_fm)

    if enable_sm:

        with tf.name_scope('summary'):
            with tf.name_scope('dis_summary'):
                tf.summary.scalar('loss_discriminator', loss_discriminator, ['dis'])
                tf.summary.scalar('loss_dis_encoder', loss_dis_enc, ['dis'])
                tf.summary.scalar('loss_dis_gen', loss_dis_gen, ['dis'])
                tf.summary.scalar('loss_dis_xz', dis_loss_xz, ['dis'])
                tf.summary.scalar('loss_dis_xx', dis_loss_xx, ['dis'])
                if allow_zz:
                    tf.summary.scalar('loss_dis_zz', dis_loss_zz, ['dis'])

            with tf.name_scope('gen_summary'):
                tf.summary.scalar('loss_generator', loss_generator, ['gen'])
                tf.summary.scalar('loss_encoder', loss_encoder, ['gen'])
                tf.summary.scalar('loss_encgen_dxx', cost_x, ['gen'])
                if allow_zz:
                    tf.summary.scalar('loss_encgen_dzz', cost_z, ['gen'])

            if enable_early_stop:
                with tf.name_scope('validation_summary'):
                   tf.summary.scalar('valid', rec_error_valid, ['v'])

            with tf.name_scope('img_summary'):
                heatmap_pl_latent = tf.placeholder(tf.float32,
                                                   shape=(1, 480, 640, 3),
                                                   name="heatmap_pl_latent")
                sum_op_latent = tf.summary.image('heatmap_latent', heatmap_pl_latent)

            if dataset in IMAGES_DATASETS:
                with tf.name_scope('image_summary'):
                    tf.summary.image('reconstruct', rec_x, 8, ['image'])
                    tf.summary.image('input_images', x_pl, 8, ['image'])

            else:
                heatmap_pl_rec = tf.placeholder(tf.float32, shape=(1, 480, 640, 3),
                                            name="heatmap_pl_rec")
                with tf.name_scope('image_summary'):
                    tf.summary.image('heatmap_rec', heatmap_pl_rec, 1, ['image'])

            sum_op_dis = tf.summary.merge_all('dis')
            sum_op_gen = tf.summary.merge_all('gen')
            sum_op = tf.summary.merge([sum_op_dis, sum_op_gen])
            sum_op_im = tf.summary.merge_all('image')
            sum_op_valid = tf.summary.merge_all('v')

    logdir = create_logdir(dataset, label, random_seed, allow_zz, score_method,
                           do_spectral_norm)

    saver = tf.train.Saver(max_to_keep=2)
    save_model_secs = None if enable_early_stop else 20
    sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=None, saver=saver, save_model_secs=save_model_secs) 

    logger.info('Start training...')
    with sv.managed_session(config=config) as sess:

        step = sess.run(global_step)
        logger.info('Initialization done at step {}'.format(step/nr_batches_train))
        writer = tf.summary.FileWriter(logdir, sess.graph)
        train_batch = 0
        epoch = 0
        best_valid_loss = 0
        request_stop = False

        while not sv.should_stop() and epoch < nb_epochs:

            lr = starting_lr
            begin = time.time()

             # construct randomly permuted minibatches
            trainx = trainx[rng.permutation(trainx.shape[0])]  # shuffling dataset
            trainx_copy = trainx_copy[rng.permutation(trainx.shape[0])]
            train_loss_dis_xz, train_loss_dis_xx,  train_loss_dis_zz, \
            train_loss_dis, train_loss_gen, train_loss_enc = [0, 0, 0, 0, 0, 0]

            # Training
            for t in range(nr_batches_train):

                display_progression_epoch(t, nr_batches_train)
                ran_from = t * batch_size
                ran_to = (t + 1) * batch_size

                # train discriminator
                feed_dict = {x_pl: trainx[ran_from:ran_to],
                             z_pl: np.random.normal(size=[batch_size, latent_dim]),
                             is_training_pl: True,
                             learning_rate:lr}

                _, _, _, ld, ldxz, ldxx, ldzz, step = sess.run([train_dis_op_xz,
                                                              train_dis_op_xx,
                                                              train_dis_op_zz,
                                                              loss_discriminator,
                                                              dis_loss_xz,
                                                              dis_loss_xx,
                                                              dis_loss_zz,
                                                              global_step],
                                                             feed_dict=feed_dict)
                train_loss_dis += ld
                train_loss_dis_xz += ldxz
                train_loss_dis_xx += ldxx
                train_loss_dis_zz += ldzz

                # train generator and encoder
                feed_dict = {x_pl: trainx_copy[ran_from:ran_to],
                             z_pl: np.random.normal(size=[batch_size, latent_dim]),
                             is_training_pl: True,
                             learning_rate:lr}
                _,_, le, lg = sess.run([train_gen_op,
                                            train_enc_op,
                                            loss_encoder,
                                            loss_generator],
                                           feed_dict=feed_dict)
                train_loss_gen += lg
                train_loss_enc += le

                if enable_sm:
                    sm = sess.run(sum_op, feed_dict=feed_dict)
                    writer.add_summary(sm, step)

                    if t % FREQ_PRINT == 0 and dataset in IMAGES_DATASETS:  # inspect reconstruction
                        t = np.random.randint(0, trainx.shape[0]-batch_size)
                        ran_from = t
                        ran_to = t + batch_size
                        feed_dict = {x_pl: trainx[ran_from:ran_to],
                            z_pl: np.random.normal(
                                size=[batch_size, latent_dim]),
                            is_training_pl: False}
                        sm = sess.run(sum_op_im, feed_dict=feed_dict)
                        writer.add_summary(sm, step)#train_batch)

                train_batch += 1

            train_loss_gen /= nr_batches_train
            train_loss_enc /= nr_batches_train
            train_loss_dis /= nr_batches_train
            train_loss_dis_xz /= nr_batches_train
            train_loss_dis_xx /= nr_batches_train
            train_loss_dis_zz /= nr_batches_train

            logger.info('Epoch terminated')
            if allow_zz:
                print("Epoch %d | time = %ds | loss gen = %.4f | loss enc = %.4f | "
                      "loss dis = %.4f | loss dis xz = %.4f | loss dis xx = %.4f | "
                      "loss dis zz = %.4f"
                      % (epoch, time.time() - begin, train_loss_gen,
                         train_loss_enc, train_loss_dis, train_loss_dis_xz,
                         train_loss_dis_xx, train_loss_dis_zz))
            else:
                print("Epoch %d | time = %ds | loss gen = %.4f | loss enc = %.4f | "
                      "loss dis = %.4f | loss dis xz = %.4f | loss dis xx = %.4f | "
                      % (epoch, time.time() - begin, train_loss_gen,
                         train_loss_enc, train_loss_dis, train_loss_dis_xz,
                         train_loss_dis_xx))

            ##EARLY STOPPING
            if (epoch + 1) % FREQ_EV == 0 and enable_early_stop:

                valid_loss = 0
                feed_dict = {x_pl: validx,
                             z_pl: np.random.normal(size=[validx.shape[0], latent_dim]),
                             is_training_pl: False}
                vl, lat = sess.run([rec_error_valid, rec_z], feed_dict=feed_dict)
                valid_loss += vl

                if enable_sm:
                    sm = sess.run(sum_op_valid, feed_dict=feed_dict)
                    writer.add_summary(sm, step)  # train_batch)

                logger.info('Validation: valid loss {:.4f}'.format(valid_loss))

                if (valid_loss < best_valid_loss or epoch == FREQ_EV-1):
                    best_valid_loss = valid_loss
                    logger.info("Best model - valid loss = {:.4f} - saving...".format(best_valid_loss))
                    sv.saver.save(sess, logdir+'/model.ckpt', global_step=step)
                    nb_without_improvements = 0
                else:
                    nb_without_improvements += FREQ_EV

                if nb_without_improvements > PATIENCE:
                    sv.request_stop()
                    logger.warning(
                      "Early stopping at epoch {} with weights from epoch {}".format(
                          epoch, epoch - nb_without_improvements))

            epoch += 1

        sv.saver.save(sess, logdir+'/model.ckpt', global_step=step)

        logger.warn('Testing evaluation...')

        scores_ch = []
        scores_l1 = []
        scores_l2 = []
        scores_fm = []
        inference_time = []

        # Create scores
        for t in range(nr_batches_test):

            # construct randomly permuted minibatches
            ran_from = t * batch_size
            ran_to = (t + 1) * batch_size
            begin_test_time_batch = time.time()

            feed_dict = {x_pl: testx[ran_from:ran_to],
                         z_pl: np.random.normal(size=[batch_size, latent_dim]),
                         is_training_pl:False}

            scores_ch += sess.run(score_ch, feed_dict=feed_dict).tolist()
            scores_l1 += sess.run(score_l1, feed_dict=feed_dict).tolist()
            scores_l2 += sess.run(score_l2, feed_dict=feed_dict).tolist()
            scores_fm += sess.run(score_fm, feed_dict=feed_dict).tolist()
            inference_time.append(time.time() - begin_test_time_batch)


        inference_time = np.mean(inference_time)
        logger.info('Testing : mean inference time is %.4f' % (inference_time))

        if testx.shape[0] % batch_size != 0:

            batch, size = batch_fill(testx, batch_size)
            feed_dict = {x_pl: batch,
                         z_pl: np.random.normal(size=[batch_size, latent_dim]),
                         is_training_pl: False}

            bscores_ch = sess.run(score_ch,feed_dict=feed_dict).tolist()
            bscores_l1 = sess.run(score_l1,feed_dict=feed_dict).tolist()
            bscores_l2 = sess.run(score_l2,feed_dict=feed_dict).tolist()
            bscores_fm = sess.run(score_fm,feed_dict=feed_dict).tolist()


            scores_ch += bscores_ch[:size]
            scores_l1 += bscores_l1[:size]
            scores_l2 += bscores_l2[:size]
            scores_fm += bscores_fm[:size]

        model = 'alad_sn{}_dzz{}'.format(do_spectral_norm, allow_zz)
        save_results(scores_ch, testy, model, dataset, 'ch',
                     'dzzenabled{}'.format(allow_zz), label, random_seed, step)
        save_results(scores_l1, testy, model, dataset, 'l1',
                     'dzzenabled{}'.format(allow_zz), label, random_seed, step)
        save_results(scores_l2, testy, model, dataset, 'l2',
                     'dzzenabled{}'.format(allow_zz), label, random_seed, step)
        save_results(scores_fm, testy, model, dataset, 'fm',
                     'dzzenabled{}'.format(allow_zz), label, random_seed,  step)
Example #3
0
def test(dataset, nb_epochs, degree, random_seed, label,
                   allow_zz, enable_sm, score_method,
                   enable_early_stop, do_spectral_norm):
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    logger = logging.getLogger("ALAD.run.{}.{}".format(
    dataset, label))

    # Import model and data
    network = importlib.import_module('alad.{}_utilities'.format(dataset))
    data = importlib.import_module("data.{}".format(dataset))

    # Parameters
    starting_lr = network.learning_rate
    batch_size = network.batch_size
    latent_dim = network.latent_dim
    ema_decay = 0.999

    global_step = tf.Variable(0, name='global_step', trainable=False)

    # Placeholders
    x_pl = tf.placeholder(tf.float32, shape=data.get_shape_input(),
                          name="input_x")
    z_pl = tf.placeholder(tf.float32, shape=[None, latent_dim],
                          name="input_z")
    is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl')
    learning_rate = tf.placeholder(tf.float32, shape=(), name="lr_pl")

    # Data
    logger.info('Data loading...')
  
    #testx, testy = data.get_test(label)

    df = pd.DataFrame(columns = ['Flow ID', 'Src IP', 'Src Port', 'Dst IP', 'Dst Port', 'Protocol', 'Timestamp', 'Flow Duration', 'Tot Fwd Pkts', 'Tot Bwd Pkts', 'TotLen Fwd Pkts', 'TotLen Bwd Pkts', 
        'Fwd Pkt Len Max', 'Fwd Pkt Len Min', 'Fwd Pkt Len Mean', 'Fwd Pkt Len Std', 'Bwd Pkt Len Max', 'Bwd Pkt Len Min', 'Bwd Pkt Len Mean', 'Bwd Pkt Len Std', 'Flow Byts/s', 'Flow Pkts/s', 'Flow IAT Mean', 
        'Flow IAT Std', 'Flow IAT Max', 'Flow IAT Min', 'Fwd IAT Tot', 'Fwd IAT Mean', 'Fwd IAT Std', 'Fwd IAT Max', 'Fwd IAT Min', 'Bwd IAT Tot', 'Bwd IAT Mean', 'Bwd IAT Std', 'Bwd IAT Max', 'Bwd IAT Min', 'Fwd PSH Flags', 
        'Bwd PSH Flags', 'Fwd URG Flags', 'Bwd URG Flags', 'Fwd Header Len', 'Bwd Header Len', 'Fwd Pkts/s', 'Bwd Pkts/s', 'Pkt Len Min', 'Pkt Len Max', 'Pkt Len Mean', 'Pkt Len Std', 'Pkt Len Var', 'FIN Flag Cnt', 'SYN Flag Cnt', 
        'RST Flag Cnt', 'PSH Flag Cnt', 'ACK Flag Cnt', 'URG Flag Cnt', 'CWE Flag Count', 'ECE Flag Cnt', 'Down/Up Ratio', 'Pkt Size Avg', 'Fwd Seg Size Avg', 'Bwd Seg Size Avg', 'Fwd Byts/b Avg', 'Fwd Pkts/b Avg', 
        'Fwd Blk Rate Avg', 'Bwd Byts/b Avg', 'Bwd Pkts/b Avg', 'Bwd Blk Rate Avg', 'Subflow Fwd Pkts', 'Subflow Fwd Byts', 'Subflow Bwd Pkts', 'Subflow Bwd Byts', 'Init Fwd Win Byts', 'Init Bwd Win Byts', 
        'Fwd Act Data Pkts', 'Fwd Seg Size Min', 'Active Mean', 'Active Std', 'Active Max', 'Active Min', 'Idle Mean', 'Idle Std', 'Idle Max', 'Idle Min', 'Label'])


    filenames = os.listdir("../package_file_temp_out/")
    os.chdir('../package_file_temp_out/')
    for x in filenames:
        df_temp = pd.read_csv(x)
        l = [df]
        l.append(df_temp)
        df = pd.concat(l, axis=0, ignore_index=True)
        #df.append(df_temp, ignore_index=True)

    os.chdir('../')
    result = []
    for x in df.columns:
        if x != 'Label':
            result.append(x)
    
    record = df.as_matrix(result)

    df.drop('Flow ID', axis=1, inplace=True)
    df.drop('Src IP', axis=1, inplace=True)
    df.drop('Dst IP', axis=1, inplace=True)
    df.drop('Timestamp', axis=1, inplace=True)
    df.drop('Flow Byts/s', axis=1, inplace=True)
    df.drop('Flow Pkts/s', axis=1, inplace=True)

    element = ['Flow ID', 'Src IP', 'Dst IP', 'Timestamp', 'Flow Byts/s', 'Flow Pkts/s']
    for x in element:
        result.remove(x)
    testx = df.as_matrix(result).astype(np.float32)

    rng = np.random.RandomState(random_seed)
    nr_batches_test = int(testx.shape[0] / batch_size)

    logger.info('Building graph...')

    gen = network.decoder
    enc = network.encoder
    dis_xz = network.discriminator_xz
    dis_xx = network.discriminator_xx
    dis_zz = network.discriminator_zz

    with tf.variable_scope('encoder_model'):
        z_gen = enc(x_pl, is_training=is_training_pl,
                    do_spectral_norm=do_spectral_norm)

    with tf.variable_scope('generator_model'):
        x_gen = gen(z_pl, is_training=is_training_pl)
        rec_x = gen(z_gen, is_training=is_training_pl, reuse=True)

    with tf.variable_scope('encoder_model'):
        rec_z = enc(x_gen, is_training=is_training_pl, reuse=True,
                    do_spectral_norm=do_spectral_norm)

    with tf.variable_scope('discriminator_model_xz'):
        l_encoder, inter_layer_inp_xz = dis_xz(x_pl, z_gen,
                                            is_training=is_training_pl,
                    do_spectral_norm=do_spectral_norm)
        l_generator, inter_layer_rct_xz = dis_xz(x_gen, z_pl,
                                              is_training=is_training_pl,
                                              reuse=True,
                    do_spectral_norm=do_spectral_norm)

    with tf.variable_scope('discriminator_model_xx'):
        x_logit_real, inter_layer_inp_xx = dis_xx(x_pl, x_pl,
                                                  is_training=is_training_pl,
                    do_spectral_norm=do_spectral_norm)
        x_logit_fake, inter_layer_rct_xx = dis_xx(x_pl, rec_x, is_training=is_training_pl,
                              reuse=True, do_spectral_norm=do_spectral_norm)

    with tf.variable_scope('discriminator_model_zz'):
        z_logit_real, _ = dis_zz(z_pl, z_pl, is_training=is_training_pl,
                                 do_spectral_norm=do_spectral_norm)
        z_logit_fake, _ = dis_zz(z_pl, rec_z, is_training=is_training_pl,
                              reuse=True, do_spectral_norm=do_spectral_norm)

    with tf.name_scope('loss_functions'):

        # discriminator xz
        loss_dis_enc = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.ones_like(l_encoder),logits=l_encoder))
        loss_dis_gen = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.zeros_like(l_generator),logits=l_generator))
        dis_loss_xz = loss_dis_gen + loss_dis_enc

        # discriminator xx
        x_real_dis = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=x_logit_real, labels=tf.ones_like(x_logit_real))
        x_fake_dis = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=x_logit_fake, labels=tf.zeros_like(x_logit_fake))
        dis_loss_xx = tf.reduce_mean(x_real_dis + x_fake_dis)

        # discriminator zz
        z_real_dis = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=z_logit_real, labels=tf.ones_like(z_logit_real))
        z_fake_dis = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=z_logit_fake, labels=tf.zeros_like(z_logit_fake))
        dis_loss_zz = tf.reduce_mean(z_real_dis + z_fake_dis)

        loss_discriminator = dis_loss_xz + dis_loss_xx + dis_loss_zz if \
            allow_zz else dis_loss_xz + dis_loss_xx

        # generator and encoder
        gen_loss_xz = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.ones_like(l_generator),logits=l_generator))
        enc_loss_xz = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.zeros_like(l_encoder), logits=l_encoder))
        x_real_gen = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=x_logit_real, labels=tf.zeros_like(x_logit_real))
        x_fake_gen = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=x_logit_fake, labels=tf.ones_like(x_logit_fake))
        z_real_gen = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=z_logit_real, labels=tf.zeros_like(z_logit_real))
        z_fake_gen = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=z_logit_fake, labels=tf.ones_like(z_logit_fake))

        cost_x = tf.reduce_mean(x_real_gen + x_fake_gen)
        cost_z = tf.reduce_mean(z_real_gen + z_fake_gen)

        cycle_consistency_loss = cost_x + cost_z if allow_zz else cost_x
        loss_generator = gen_loss_xz + cycle_consistency_loss
        loss_encoder = enc_loss_xz + cycle_consistency_loss

    with tf.name_scope('optimizers'):

        # control op dependencies for batch norm and trainable variables
        tvars = tf.trainable_variables()
        dxzvars = [var for var in tvars if 'discriminator_model_xz' in var.name]
        dxxvars = [var for var in tvars if 'discriminator_model_xx' in var.name]
        dzzvars = [var for var in tvars if 'discriminator_model_zz' in var.name]
        gvars = [var for var in tvars if 'generator_model' in var.name]
        evars = [var for var in tvars if 'encoder_model' in var.name]

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        update_ops_gen = [x for x in update_ops if ('generator_model' in x.name)]
        update_ops_enc = [x for x in update_ops if ('encoder_model' in x.name)]
        update_ops_dis_xz = [x for x in update_ops if
                             ('discriminator_model_xz' in x.name)]
        update_ops_dis_xx = [x for x in update_ops if
                             ('discriminator_model_xx' in x.name)]
        update_ops_dis_zz = [x for x in update_ops if
                             ('discriminator_model_zz' in x.name)]

        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                                  beta1=0.5)

        with tf.control_dependencies(update_ops_gen):
            gen_op = optimizer.minimize(loss_generator, var_list=gvars,
                                            global_step=global_step)
        with tf.control_dependencies(update_ops_enc):
            enc_op = optimizer.minimize(loss_encoder, var_list=evars)

        with tf.control_dependencies(update_ops_dis_xz):
            dis_op_xz = optimizer.minimize(dis_loss_xz, var_list=dxzvars)

        with tf.control_dependencies(update_ops_dis_xx):
            dis_op_xx = optimizer.minimize(dis_loss_xx, var_list=dxxvars)

        with tf.control_dependencies(update_ops_dis_zz):
            dis_op_zz = optimizer.minimize(dis_loss_zz, var_list=dzzvars)

        # Exponential Moving Average for inference
        def train_op_with_ema_dependency(vars, op):
            ema = tf.train.ExponentialMovingAverage(decay=ema_decay)
            maintain_averages_op = ema.apply(vars)
            with tf.control_dependencies([op]):
                train_op = tf.group(maintain_averages_op)
            return train_op, ema

        train_gen_op, gen_ema = train_op_with_ema_dependency(gvars, gen_op)
        train_enc_op, enc_ema = train_op_with_ema_dependency(evars, enc_op)
        train_dis_op_xz, xz_ema = train_op_with_ema_dependency(dxzvars,
                                                               dis_op_xz)
        train_dis_op_xx, xx_ema = train_op_with_ema_dependency(dxxvars,
                                                               dis_op_xx)
        train_dis_op_zz, zz_ema = train_op_with_ema_dependency(dzzvars,
                                                               dis_op_zz)

    with tf.variable_scope('encoder_model'):
        z_gen_ema = enc(x_pl, is_training=is_training_pl,
                        getter=get_getter(enc_ema), reuse=True,
                        do_spectral_norm=do_spectral_norm)

    with tf.variable_scope('generator_model'):
        rec_x_ema = gen(z_gen_ema, is_training=is_training_pl,
                              getter=get_getter(gen_ema), reuse=True)
        x_gen_ema = gen(z_pl, is_training=is_training_pl,
                              getter=get_getter(gen_ema), reuse=True)

    with tf.variable_scope('discriminator_model_xx'):
        l_encoder_emaxx, inter_layer_inp_emaxx = dis_xx(x_pl, x_pl,
                                                    is_training=is_training_pl,
                                                    getter=get_getter(xx_ema),
                                                    reuse=True,
                    do_spectral_norm=do_spectral_norm)

        l_generator_emaxx, inter_layer_rct_emaxx = dis_xx(x_pl, rec_x_ema,
                                                      is_training=is_training_pl,
                                                      getter=get_getter(
                                                          xx_ema),
                                                      reuse=True,
                    do_spectral_norm=do_spectral_norm)

    with tf.name_scope('Testing'):

        with tf.variable_scope('Scores'):


            score_ch = tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.ones_like(l_generator_emaxx),
                    logits=l_generator_emaxx)
            score_ch = tf.squeeze(score_ch)

            rec = x_pl - rec_x_ema
            rec = tf.contrib.layers.flatten(rec)
            score_l1 = tf.norm(rec, ord=1, axis=1,
                            keep_dims=False, name='d_loss')
            score_l1 = tf.squeeze(score_l1)

            rec = x_pl - rec_x_ema
            rec = tf.contrib.layers.flatten(rec)
            score_l2 = tf.norm(rec, ord=2, axis=1,
                            keep_dims=False, name='d_loss')
            score_l2 = tf.squeeze(score_l2)

            inter_layer_inp, inter_layer_rct = inter_layer_inp_emaxx, \
                                               inter_layer_rct_emaxx
            fm = inter_layer_inp - inter_layer_rct
            fm = tf.contrib.layers.flatten(fm)
            score_fm = tf.norm(fm, ord=degree, axis=1,
                             keep_dims=False, name='d_loss')
            score_fm = tf.squeeze(score_fm)

    if enable_early_stop:
        rec_error_valid = tf.reduce_mean(score_fm)

    if enable_sm:

        with tf.name_scope('summary'):
            with tf.name_scope('dis_summary'):
                tf.summary.scalar('loss_discriminator', loss_discriminator, ['dis'])
                tf.summary.scalar('loss_dis_encoder', loss_dis_enc, ['dis'])
                tf.summary.scalar('loss_dis_gen', loss_dis_gen, ['dis'])
                tf.summary.scalar('loss_dis_xz', dis_loss_xz, ['dis'])
                tf.summary.scalar('loss_dis_xx', dis_loss_xx, ['dis'])
                if allow_zz:
                    tf.summary.scalar('loss_dis_zz', dis_loss_zz, ['dis'])

            with tf.name_scope('gen_summary'):
                tf.summary.scalar('loss_generator', loss_generator, ['gen'])
                tf.summary.scalar('loss_encoder', loss_encoder, ['gen'])
                tf.summary.scalar('loss_encgen_dxx', cost_x, ['gen'])
                if allow_zz:
                    tf.summary.scalar('loss_encgen_dzz', cost_z, ['gen'])

            if enable_early_stop:
                with tf.name_scope('validation_summary'):
                   tf.summary.scalar('valid', rec_error_valid, ['v'])

            with tf.name_scope('img_summary'):
                heatmap_pl_latent = tf.placeholder(tf.float32,
                                                   shape=(1, 480, 640, 3),
                                                   name="heatmap_pl_latent")
                sum_op_latent = tf.summary.image('heatmap_latent', heatmap_pl_latent)

            if dataset in IMAGES_DATASETS:
                with tf.name_scope('image_summary'):
                    tf.summary.image('reconstruct', rec_x, 8, ['image'])
                    tf.summary.image('input_images', x_pl, 8, ['image'])

            else:
                heatmap_pl_rec = tf.placeholder(tf.float32, shape=(1, 480, 640, 3),
                                            name="heatmap_pl_rec")
                with tf.name_scope('image_summary'):
                    tf.summary.image('heatmap_rec', heatmap_pl_rec, 1, ['image'])

            sum_op_dis = tf.summary.merge_all('dis')
            sum_op_gen = tf.summary.merge_all('gen')
            sum_op = tf.summary.merge([sum_op_dis, sum_op_gen])
            sum_op_im = tf.summary.merge_all('image')
            sum_op_valid = tf.summary.merge_all('v')

    logdir = "ALAD/train_logs/cicids2017/alad_snTrue_dzzTrue/dzzenabledTrue/fm/label0/rd43"
    saver = tf.train.Saver(max_to_keep=1)
    save_model_secs = None if enable_early_stop else 20
    sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=None, saver=saver, save_model_secs=save_model_secs) 


    with sv.managed_session(config=config) as sess:

        ckpt = tf.train.get_checkpoint_state(logdir)
        if ckpt and ckpt.model_checkpoint_path:
            step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
            logger.info(ckpt.model_checkpoint_path)
            file = os.path.basename(ckpt.model_checkpoint_path)
            checkpoint_path = os.path.join(logdir, file)
            saver.restore(sess, checkpoint_path)


        scores_ch = []
        scores_l1 = []
        scores_l2 = []
        scores_fm = []
        inference_time = []

        # Create scores
        for t in range(nr_batches_test):

            # construct randomly permuted minibatches
            ran_from = t * batch_size
            ran_to = (t + 1) * batch_size
            begin_test_time_batch = time.time()

            feed_dict = {x_pl: testx[ran_from:ran_to],
                         z_pl: np.random.normal(size=[batch_size, latent_dim]),
                         is_training_pl:False}

            scores_ch += sess.run(score_ch, feed_dict=feed_dict).tolist()
            scores_l1 += sess.run(score_l1, feed_dict=feed_dict).tolist()
            scores_l2 += sess.run(score_l2, feed_dict=feed_dict).tolist()
            scores_fm += sess.run(score_fm, feed_dict=feed_dict).tolist()
            inference_time.append(time.time() - begin_test_time_batch)


        inference_time = np.mean(inference_time)
        logger.info('Testing : mean inference time is %.4f' % (inference_time))

        if testx.shape[0] % batch_size != 0:

            batch, size = batch_fill(testx, batch_size)
            feed_dict = {x_pl: batch,
                         z_pl: np.random.normal(size=[batch_size, latent_dim]),
                         is_training_pl: False}

            bscores_ch = sess.run(score_ch,feed_dict=feed_dict).tolist()
            bscores_l1 = sess.run(score_l1,feed_dict=feed_dict).tolist()
            bscores_l2 = sess.run(score_l2,feed_dict=feed_dict).tolist()
            bscores_fm = sess.run(score_fm,feed_dict=feed_dict).tolist()


            scores_ch += bscores_ch[:size]
            scores_l1 += bscores_l1[:size]
            scores_l2 += bscores_l2[:size]
            scores_fm += bscores_fm[:size]

        model = 'alad_sn{}_dzz{}'.format(do_spectral_norm, allow_zz)
        save_results(scores_ch, record, model, dataset, 'ch',
                     'dzzenabled{}'.format(allow_zz), label, random_seed, int(step), False)
        save_results(scores_l1, record, model, dataset, 'l1',
                     'dzzenabled{}'.format(allow_zz), label, random_seed, int(step), False)
        save_results(scores_l2, record, model, dataset, 'l2',
                     'dzzenabled{}'.format(allow_zz), label, random_seed, int(step), False)
        save_results(scores_fm, record, model, dataset, 'fm',
                     'dzzenabled{}'.format(allow_zz), label, random_seed, int(step), False)
Example #4
0
def train_and_test(dataset, nb_epochs, random_seed, label):
    """ Runs DSEBM on available datasets 

    Note:
        Saves summaries on tensorboard. To display them, please use cmd line
        tensorboard --logdir=model.training_logdir() --port=number
    Args:
        dataset (string): dataset to run the model on 
        nb_epochs (int): number of epochs
        random_seed (int): trying different seeds for averaging the results
        label (int): label which is normal for image experiments
        anomaly_type (string): "novelty" for 100% normal samples in the training set
                               "outlier" for a contamined training set 
        anomaly_proportion (float): if "outlier", anomaly proportion in the training set
    """
    logger = logging.getLogger("DSEBM.run.{}.{}".format(dataset, label))

    # Import model and data
    network = importlib.import_module('dsebm.{}_utilities'.format(dataset))
    data = importlib.import_module("data.{}".format(dataset))

    # Parameters
    starting_lr = network.learning_rate
    batch_size = network.batch_size

    # Placeholders
    x_pl = tf.placeholder(tf.float32,
                          shape=data.get_shape_input(),
                          name="input")
    is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl')
    learning_rate = tf.placeholder(tf.float32, shape=(), name="lr_pl")

    #test
    y_true = tf.placeholder(tf.int32, shape=[None], name="y_true")

    logger.info('Building training graph...')
    logger.warn("The DSEBM is training with the following parameters:")
    display_parameters(batch_size, starting_lr, label)

    net = network.network

    global_step = tf.train.get_or_create_global_step()

    noise = tf.random_normal(shape=tf.shape(x_pl),
                             mean=0.0,
                             stddev=1.,
                             dtype=tf.float32)
    x_noise = x_pl + noise

    with tf.variable_scope('network'):
        b_prime_shape = list(data.get_shape_input())
        b_prime_shape[0] = batch_size
        b_prime = tf.get_variable(name='b_prime',
                                  shape=b_prime_shape)  #tf.shape(x_pl))
        net_out = net(x_pl, is_training=is_training_pl)
        net_out_noise = net(x_noise, is_training=is_training_pl, reuse=True)

    with tf.name_scope('energies'):
        energy = 0.5 * tf.reduce_sum(tf.square(x_pl - b_prime)) \
                 - tf.reduce_sum(net_out)

        energy_noise = 0.5 * tf.reduce_sum(tf.square(x_noise - b_prime)) \
                       - tf.reduce_sum(net_out_noise)

    with tf.name_scope('reconstructions'):
        # reconstruction
        grad = tf.gradients(energy, x_pl)
        fx = x_pl - tf.gradients(energy, x_pl)
        fx = tf.squeeze(fx, axis=0)
        fx_noise = x_noise - tf.gradients(energy_noise, x_noise)

    with tf.name_scope("loss_function"):
        # DSEBM for images
        if len(data.get_shape_input()) == 4:
            loss = tf.reduce_mean(
                tf.reduce_sum(tf.square(x_pl - fx_noise), axis=[1, 2, 3]))
        # DSEBM for tabular data
        else:
            loss = tf.reduce_mean(tf.square(x_pl - fx_noise))

    with tf.name_scope('optimizers'):
        # control op dependencies for batch norm and trainable variables
        tvars = tf.trainable_variables()
        netvars = [var for var in tvars if 'network' in var.name]

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        update_ops_net = [x for x in update_ops if ('network' in x.name)]

        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                           name='optimizer')

        with tf.control_dependencies(update_ops_net):
            train_op = optimizer.minimize(loss,
                                          var_list=netvars,
                                          global_step=global_step)

    with tf.variable_scope('Scores'):

        with tf.name_scope('Energy_score'):
            flat = tf.layers.flatten(x_pl - b_prime)
            if len(data.get_shape_input()) == 4:
                list_scores_energy = 0.5 * tf.reduce_sum(tf.square(flat), axis=1) \
                                           - tf.reduce_sum(net_out, axis=[1, 2, 3])
            else:
                list_scores_energy = 0.5 * tf.reduce_sum(tf.square(flat), axis=1) \
                           - tf.reduce_sum(net_out, axis=1)
        with tf.name_scope('Reconstruction_score'):
            delta = x_pl - fx
            delta_flat = tf.layers.flatten(delta)
            list_scores_reconstruction = tf.norm(delta_flat,
                                                 ord=2,
                                                 axis=1,
                                                 keepdims=False,
                                                 name='reconstruction')

    #   with tf.name_scope('predictions'):
    #       # Highest 20% are anomalous
    #       if dataset=="kdd":
    #           per = tf.contrib.distributions.percentile(list_scores_energy, 80)
    #       else:
    #           per = tf.contrib.distributions.percentile(list_scores_energy, 95)
    #       y_pred = tf.greater_equal(list_scores_energy, per)
    #
    #       #y_test_true = tf.cast(y_test_true, tf.float32)
    #       cm = tf.confusion_matrix(y_true, y_pred, num_classes=2)
    #       recall = cm[1,1]/(cm[1,0]+cm[1,1])
    #       precision = cm[1,1]/(cm[0,1]+cm[1,1])
    #       f1 = 2*precision*recall/(precision + recall)

    with tf.name_scope('training_summary'):

        tf.summary.scalar('score_matching_loss', loss, ['net'])
        tf.summary.scalar('energy', energy, ['net'])

        if dataset in IMAGES_DATASETS:
            with tf.name_scope('image_summary'):
                tf.summary.image('reconstruct', fx, 6, ['image'])
                tf.summary.image('input_images', x_pl, 6, ['image'])
                sum_op_im = tf.summary.merge_all('image')

        sum_op_net = tf.summary.merge_all('net')

    logdir = create_logdir(dataset, label, random_seed)

    sv = tf.train.Supervisor(logdir=logdir + "/train",
                             save_summaries_secs=None,
                             save_model_secs=None)

    # Data
    logger.info('Data loading...')
    trainx, trainy = data.get_train(label)
    trainx_copy = trainx.copy()
    if dataset in IMAGES_DATASETS: validx, validy = data.get_valid(label)
    testx, testy = data.get_test(label)

    rng = np.random.RandomState(RANDOM_SEED)
    nr_batches_train = int(trainx.shape[0] / batch_size)
    if dataset in IMAGES_DATASETS:
        nr_batches_valid = int(validx.shape[0] / batch_size)
    nr_batches_test = int(testx.shape[0] / batch_size)

    logger.info("Train: {} samples in {} batches".format(
        trainx.shape[0], nr_batches_train))
    if dataset in IMAGES_DATASETS:
        logger.info("Valid: {} samples in {} batches".format(
            validx.shape[0], nr_batches_valid))
    logger.info("Test:  {} samples in {} batches".format(
        testx.shape[0], nr_batches_test))

    logger.info('Start training...')
    with sv.managed_session() as sess:
        logger.info('Initialization done')

        train_writer = tf.summary.FileWriter(logdir + "/train", sess.graph)
        valid_writer = tf.summary.FileWriter(logdir + "/valid", sess.graph)

        train_batch = 0
        epoch = 0
        best_valid_loss = 0
        train_losses = [0] * STRIP_EV

        while not sv.should_stop() and epoch < nb_epochs:
            lr = starting_lr

            begin = time.time()
            trainx = trainx[rng.permutation(
                trainx.shape[0])]  # shuffling unl dataset
            trainx_copy = trainx_copy[rng.permutation(trainx.shape[0])]

            losses, energies = [0, 0]
            # training
            for t in range(nr_batches_train):
                display_progression_epoch(t, nr_batches_train)

                # construct randomly permuted minibatches
                ran_from = t * batch_size
                ran_to = (t + 1) * batch_size

                # train the net
                feed_dict = {
                    x_pl: trainx[ran_from:ran_to],
                    is_training_pl: True,
                    learning_rate: lr
                }
                _, ld, en, sm, step = sess.run(
                    [train_op, loss, energy, sum_op_net, global_step],
                    feed_dict=feed_dict)
                losses += ld
                energies += en
                train_writer.add_summary(sm, step)

                if t % FREQ_PRINT == 0 and dataset in IMAGES_DATASETS:  # inspect reconstruction
                    t = np.random.randint(0, 40)
                    ran_from = t
                    ran_to = t + batch_size
                    sm = sess.run(sum_op_im,
                                  feed_dict={
                                      x_pl: trainx[ran_from:ran_to],
                                      is_training_pl: False
                                  })
                    train_writer.add_summary(sm, step)

                train_batch += 1

            losses /= nr_batches_train
            energies /= nr_batches_train
            # Remembering loss for early stopping
            train_losses[epoch % STRIP_EV] = losses

            logger.info('Epoch terminated')
            print("Epoch %d | time = %ds | loss = %.4f | energy = %.4f " %
                  (epoch, time.time() - begin, losses, energies))

            if (epoch + 1) % FREQ_SNAP == 0 and dataset in IMAGES_DATASETS:

                print("Take a snap of the reconstructions...")
                x = trainx[:batch_size]
                feed_dict = {x_pl: x, is_training_pl: False}

                rect_x = sess.run(fx, feed_dict=feed_dict)
                nama_e_wa = "dsebm/reconstructions/{}/{}/" \
                            "{}_epoch{}".format(dataset,
                                                label,
                                                random_seed, epoch)
                nb_imgs = 50
                save_grid_plot(x[:nb_imgs], rect_x[:nb_imgs], nama_e_wa,
                               nb_imgs)

            if (epoch + 1) % FREQ_EV == 0 and dataset in IMAGES_DATASETS:
                logger.info("Validation")
                inds = rng.permutation(validx.shape[0])
                validx = validx[inds]  # shuffling  dataset
                validy = validy[inds]  # shuffling  dataset
                valid_loss = 0
                for t in range(nr_batches_valid):
                    display_progression_epoch(t, nr_batches_valid)

                    # construct randomly permuted minibatches
                    ran_from = t * batch_size
                    ran_to = (t + 1) * batch_size

                    # train the net
                    feed_dict = {
                        x_pl: validx[ran_from:ran_to],
                        y_true: validy[ran_from:ran_to],
                        is_training_pl: False
                    }

                    vl, sm, step = sess.run([loss, sum_op_net, global_step],
                                            feed_dict=feed_dict)
                    valid_writer.add_summary(sm, step + t)  #train_batch)
                    valid_loss += vl

                valid_loss /= nr_batches_valid

                # train the net

                logger.info("Validation loss at step " + str(step) + ":" +
                            str(valid_loss))
                ##EARLY STOPPING
                #UPDATE WEIGHTS
                if valid_loss < best_valid_loss or epoch == FREQ_EV - 1:
                    best_valid_loss = valid_loss
                    logger.info("Best model - loss={} - saving...".format(
                        best_valid_loss))
                    sv.saver.save(sess,
                                  logdir + '/train/model.ckpt',
                                  global_step=step)
                    nb_without_improvements = 0
                else:
                    nb_without_improvements += FREQ_EV

                if nb_without_improvements > PATIENCE:
                    sv.request_stop()
                    logger.warning(
                        "Early stopping at epoch {} with weights from epoch {}"
                        .format(epoch, epoch - nb_without_improvements))
            epoch += 1

        logger.warn('Testing evaluation...')

        step = sess.run(global_step)
        scores_e = []
        scores_r = []
        inference_time = []

        # Create scores
        for t in range(nr_batches_test):
            # construct randomly permuted minibatches
            ran_from = t * batch_size
            ran_to = (t + 1) * batch_size
            begin_val_batch = time.time()

            feed_dict = {x_pl: testx[ran_from:ran_to], is_training_pl: False}

            scores_e += sess.run(list_scores_energy,
                                 feed_dict=feed_dict).tolist()

            scores_r += sess.run(list_scores_reconstruction,
                                 feed_dict=feed_dict).tolist()
            inference_time.append(time.time() - begin_val_batch)

        logger.info('Testing : mean inference time is %.4f' %
                    (np.mean(inference_time)))

        if testx.shape[0] % batch_size != 0:
            batch, size = batch_fill(testx, batch_size)
            feed_dict = {x_pl: batch, is_training_pl: False}
            batch_score_e = sess.run(list_scores_energy,
                                     feed_dict=feed_dict).tolist()
            batch_score_r = sess.run(list_scores_reconstruction,
                                     feed_dict=feed_dict).tolist()
            scores_e += batch_score_e[:size]
            scores_r += batch_score_r[:size]

        save_results(scores_e, testy, 'dsebm', dataset, 'energy', "test",
                     label, random_seed, step)
        save_results(scores_r, testy, 'dsebm', dataset, 'reconstruction',
                     "test", label, random_seed, step)