def train_and_test(dataset, nb_epochs, weight, degree, random_seed, label, enable_sm, score_method, enable_early_stop): """ Runs the AnoGAN on the specified dataset Note: Saves summaries on tensorboard. To display them, please use cmd line tensorboard --logdir=model.training_logdir() --port=number Args: dataset (str): name of the dataset nb_epochs (int): number of epochs weight (float): weight in the inverting loss function degree (int): degree of the norm in the feature matching random_seed (int): trying different seeds for averaging the results label (int): label which is normal for image experiments enable_sm (bool): allow TF summaries for monitoring the training score_method (str): which metric to use for the ablation study enable_early_stop (bool): allow early stopping for determining the number of epochs """ config = tf.ConfigProto() config.gpu_options.allow_growth = True logger = logging.getLogger("AnoGAN.run.{}.{}".format(dataset, label)) # Import model and data network = importlib.import_module('anogan.{}_utilities'.format(dataset)) data = importlib.import_module("data.{}".format(dataset)) # Parameters starting_lr = network.learning_rate batch_size = network.batch_size latent_dim = network.latent_dim ema_decay = 0.999 # Data logger.info('Data loading...') trainx, trainy = data.get_train(label) if enable_early_stop: validx, validy = data.get_valid(label) nr_batches_valid = int(validx.shape[0] / batch_size) trainx_copy = trainx.copy() testx, testy = data.get_test(label) global_step = tf.Variable(0, name='global_step', trainable=False) # Placeholders x_pl = tf.placeholder(tf.float32, shape=(None, trainx.shape[1]), name="input") is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl') learning_rate = tf.placeholder(tf.float32, shape=(), name="lr_pl") rng = np.random.RandomState(random_seed) nr_batches_train = int(trainx.shape[0] / batch_size) nr_batches_test = int(testx.shape[0] / batch_size) logger.info('Building graph...') logger.warn("The GAN is training with the following parameters:") display_parameters(batch_size, starting_lr, ema_decay, weight, degree, label, trainx.shape[1]) gen = network.generator dis = network.discriminator # Sample noise from random normal distribution random_z = tf.random_normal([batch_size, latent_dim], mean=0.0, stddev=1.0, name='random_z') # Generate images with generator x_gen = gen(x_pl, random_z, is_training=is_training_pl) real_d, inter_layer_real = dis(x_pl, is_training=is_training_pl) fake_d, inter_layer_fake = dis(x_gen, is_training=is_training_pl, reuse=True) with tf.name_scope('loss_functions'): # Calculate seperate losses for discriminator with real and fake images real_discriminator_loss = tf.losses.sigmoid_cross_entropy( tf.ones_like(real_d), real_d, scope='real_discriminator_loss') fake_discriminator_loss = tf.losses.sigmoid_cross_entropy( tf.zeros_like(fake_d), fake_d, scope='fake_discriminator_loss') # Add discriminator losses loss_discriminator = real_discriminator_loss + fake_discriminator_loss # Calculate loss for generator by flipping label on discriminator output loss_generator = tf.losses.sigmoid_cross_entropy( tf.ones_like(fake_d), fake_d, scope='generator_loss') with tf.name_scope('optimizers'): # control op dependencies for batch norm and trainable variables dvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') gvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') update_ops_gen = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='generator') update_ops_dis = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='discriminator') optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.5) with tf.control_dependencies(update_ops_gen): gen_op = optimizer.minimize(loss_generator, var_list=gvars, global_step=global_step) with tf.control_dependencies(update_ops_dis): dis_op = optimizer.minimize(loss_discriminator, var_list=dvars) # Exponential Moving Average for inference def train_op_with_ema_dependency(vars, op): ema = tf.train.ExponentialMovingAverage(decay=ema_decay) maintain_averages_op = ema.apply(vars) with tf.control_dependencies([op]): train_op = tf.group(maintain_averages_op) return train_op, ema train_gen_op, gen_ema = train_op_with_ema_dependency(gvars, gen_op) train_dis_op, dis_ema = train_op_with_ema_dependency(dvars, dis_op) ### Testing ### with tf.variable_scope("latent_variable"): z_optim = tf.get_variable( name='z_optim', shape=[batch_size, latent_dim], initializer=tf.truncated_normal_initializer()) reinit_z = z_optim.initializer # EMA x_gen_ema = gen(x_pl, random_z, is_training=is_training_pl, getter=get_getter(gen_ema), reuse=True) rec_x_ema = gen(x_pl, z_optim, is_training=is_training_pl, getter=get_getter(gen_ema), reuse=True) # Pass real and fake images into discriminator separately real_d_ema, inter_layer_real_ema = dis(x_pl, is_training=is_training_pl, getter=get_getter(gen_ema), reuse=True) fake_d_ema, inter_layer_fake_ema = dis(rec_x_ema, is_training=is_training_pl, getter=get_getter(gen_ema), reuse=True) with tf.name_scope('Testing'): with tf.variable_scope('Reconstruction_loss'): delta = x_pl - rec_x_ema delta_flat = tf.contrib.layers.flatten(delta) reconstruction_score = tf.norm(delta_flat, ord=degree, axis=1, keep_dims=False, name='epsilon') with tf.variable_scope('Discriminator_scores'): if score_method == 'cross-e': dis_score = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.ones_like(fake_d_ema), logits=fake_d_ema) else: fm = inter_layer_real_ema - inter_layer_fake_ema fm = tf.contrib.layers.flatten(fm) dis_score = tf.norm(fm, ord=degree, axis=1, keep_dims=False, name='d_loss') dis_score = tf.squeeze(dis_score) with tf.variable_scope('Score'): loss_invert = weight * reconstruction_score \ + (1 - weight) * dis_score rec_error_valid = tf.reduce_mean(loss_invert) with tf.variable_scope("Test_learning_rate"): step_lr = tf.Variable(0, trainable=False) learning_rate_invert = 0.001 reinit_lr = tf.variables_initializer( tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="Test_learning_rate")) with tf.name_scope('Test_optimizer'): invert_op = tf.train.AdamOptimizer(learning_rate_invert).\ minimize(loss_invert,global_step=step_lr, var_list=[z_optim], name='optimizer') reinit_optim = tf.variables_initializer( tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Test_optimizer')) reinit_test_graph_op = [reinit_z, reinit_lr, reinit_optim] with tf.name_scope("Scores"): list_scores = loss_invert if enable_sm: with tf.name_scope('training_summary'): with tf.name_scope('dis_summary'): tf.summary.scalar('real_discriminator_loss', real_discriminator_loss, ['dis']) tf.summary.scalar('fake_discriminator_loss', fake_discriminator_loss, ['dis']) tf.summary.scalar('discriminator_loss', loss_discriminator, ['dis']) with tf.name_scope('gen_summary'): tf.summary.scalar('loss_generator', loss_generator, ['gen']) with tf.name_scope('img_summary'): heatmap_pl_latent = tf.placeholder(tf.float32, shape=(1, 480, 640, 3), name="heatmap_pl_latent") sum_op_latent = tf.summary.image('heatmap_latent', heatmap_pl_latent) with tf.name_scope('validation_summary'): tf.summary.scalar('valid', rec_error_valid, ['v']) if dataset in IMAGES_DATASETS: with tf.name_scope('image_summary'): tf.summary.image('reconstruct', x_gen, 8, ['image']) tf.summary.image('input_images', x_pl, 8, ['image']) else: heatmap_pl_rec = tf.placeholder(tf.float32, shape=(1, 480, 640, 3), name="heatmap_pl_rec") with tf.name_scope('image_summary'): tf.summary.image('heatmap_rec', heatmap_pl_rec, 1, ['image']) sum_op_dis = tf.summary.merge_all('dis') sum_op_gen = tf.summary.merge_all('gen') sum_op = tf.summary.merge([sum_op_dis, sum_op_gen]) sum_op_im = tf.summary.merge_all('image') sum_op_valid = tf.summary.merge_all('v') logdir = create_logdir(dataset, weight, label, random_seed) sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=None, save_model_secs=None) logger.info('Start training...') with sv.managed_session(config=config) as sess: logger.info('Initialization done') writer = tf.summary.FileWriter(logdir, sess.graph) train_batch = 0 epoch = 0 best_valid_loss = 0 while not sv.should_stop() and epoch < nb_epochs: lr = starting_lr begin = time.time() trainx = trainx[rng.permutation( trainx.shape[0])] # shuffling unl dataset trainx_copy = trainx_copy[rng.permutation(trainx.shape[0])] train_loss_dis, train_loss_gen = [0, 0] # training for t in range(nr_batches_train): display_progression_epoch(t, nr_batches_train) # construct randomly permuted minibatches ran_from = t * batch_size ran_to = (t + 1) * batch_size # train discriminator feed_dict = { x_pl: trainx[ran_from:ran_to], is_training_pl: True, learning_rate: lr } _, ld, step = sess.run( [train_dis_op, loss_discriminator, global_step], feed_dict=feed_dict) train_loss_dis += ld # train generator feed_dict = { x_pl: trainx_copy[ran_from:ran_to], is_training_pl: True, learning_rate: lr } _, lg = sess.run([train_gen_op, loss_generator], feed_dict=feed_dict) train_loss_gen += lg if enable_sm: sm = sess.run(sum_op, feed_dict=feed_dict) writer.add_summary(sm, step) if t % FREQ_PRINT == 0: # inspect reconstruction # t = np.random.randint(0,400) # ran_from = t # ran_to = t + batch_size # sm = sess.run(sum_op_im, feed_dict={x_pl: trainx[ran_from:ran_to],is_training_pl: False}) # writer.add_summary(sm, train_batch) # data = sess.run(z_gen, feed_dict={ # x_pl: trainx[ran_from:ran_to], # is_training_pl: False}) # data = np.expand_dims(heatmap(data), axis=0) # sml = sess.run(sum_op_latent, feed_dict={ # heatmap_pl_latent: data, # is_training_pl: False}) # # writer.add_summary(sml, train_batch) if dataset in IMAGES_DATASETS: sm = sess.run(sum_op_im, feed_dict={ x_pl: trainx[ran_from:ran_to], is_training_pl: False }) # # else: # data = sess.run(z_gen, feed_dict={ # x_pl: trainx[ran_from:ran_to], # z_pl: np.random.normal( # size=[batch_size, latent_dim]), # is_training_pl: False}) # data = np.expand_dims(heatmap(data), axis=0) # sm = sess.run(sum_op_im, feed_dict={ # heatmap_pl_rec: data, # is_training_pl: False}) writer.add_summary(sm, step) #train_batch) train_batch += 1 train_loss_gen /= nr_batches_train train_loss_dis /= nr_batches_train # logger.info('Epoch terminated') # print("Epoch %d | time = %ds | loss gen = %.4f | loss dis = %.4f " # % (epoch, time.time() - begin, train_loss_gen, train_loss_dis)) ##EARLY STOPPING if (epoch + 1) % FREQ_EV == 0 and enable_early_stop: logger.info('Validation...') inds = rng.permutation(validx.shape[0]) validx = validx[inds] # shuffling dataset validy = validy[inds] # shuffling dataset valid_loss = 0 # Create scores for t in range(nr_batches_valid): # construct randomly permuted minibatches display_progression_epoch(t, nr_batches_valid) ran_from = t * batch_size ran_to = (t + 1) * batch_size feed_dict = { x_pl: validx[ran_from:ran_to], is_training_pl: False } for _ in range(STEPS_NUMBER): _ = sess.run(invert_op, feed_dict=feed_dict) vl = sess.run(rec_error_valid, feed_dict=feed_dict) valid_loss += vl sess.run(reinit_test_graph_op) valid_loss /= nr_batches_valid sess.run(reinit_test_graph_op) if enable_sm: sm = sess.run(sum_op_valid, feed_dict=feed_dict) writer.add_summary(sm, step) # train_batch) logger.info('Validation: valid loss {:.4f}'.format(valid_loss)) if valid_loss < best_valid_loss or epoch == FREQ_EV - 1: best_valid_loss = valid_loss logger.info( "Best model - valid loss = {:.4f} - saving...".format( best_valid_loss)) sv.saver.save(sess, logdir + '/model.ckpt', global_step=step) nb_without_improvements = 0 else: nb_without_improvements += FREQ_EV if nb_without_improvements > PATIENCE: sv.request_stop() logger.warning( "Early stopping at epoch {} with weights from epoch {}" .format(epoch, epoch - nb_without_improvements)) epoch += 1 logger.warn('Testing evaluation...') step = sess.run(global_step) sv.saver.save(sess, logdir + '/model.ckpt', global_step=step) rect_x, rec_error, latent, scores = [], [], [], [] inference_time = [] # Create scores for t in range(nr_batches_test): # construct randomly permuted minibatches display_progression_epoch(t, nr_batches_test) ran_from = t * batch_size ran_to = (t + 1) * batch_size begin_val_batch = time.time() feed_dict = {x_pl: testx[ran_from:ran_to], is_training_pl: False} for _ in range(STEPS_NUMBER): _ = sess.run(invert_op, feed_dict=feed_dict) brect_x, brec_error, bscores, blatent = sess.run( [rec_x_ema, reconstruction_score, loss_invert, z_optim], feed_dict=feed_dict) rect_x.append(brect_x) rec_error.append(brec_error) scores.append(bscores) latent.append(blatent) sess.run(reinit_test_graph_op) inference_time.append(time.time() - begin_val_batch) logger.info('Testing : mean inference time is %.4f' % (np.mean(inference_time))) if testx.shape[0] % batch_size != 0: batch, size = batch_fill(testx, batch_size) feed_dict = {x_pl: batch, is_training_pl: False} for _ in range(STEPS_NUMBER): _ = sess.run(invert_op, feed_dict=feed_dict) brect_x, brec_error, bscores, blatent = sess.run( [rec_x_ema, reconstruction_score, loss_invert, z_optim], feed_dict=feed_dict) rect_x.append(brect_x[:size]) rec_error.append(brec_error[:size]) scores.append(bscores[:size]) latent.append(blatent[:size]) sess.run(reinit_test_graph_op) rect_x = np.concatenate(rect_x, axis=0) rec_error = np.concatenate(rec_error, axis=0) scores = np.concatenate(scores, axis=0) latent = np.concatenate(latent, axis=0) save_results(scores, testy, 'anogan', dataset, score_method, weight, label, random_seed)
def train_and_test(dataset, nb_epochs, degree, random_seed, label, allow_zz, enable_sm, score_method, enable_early_stop, do_spectral_norm): """ Runs the AliCE on the specified dataset Note: Saves summaries on tensorboard. To display them, please use cmd line tensorboard --logdir=model.training_logdir() --port=number Args: dataset (str): name of the dataset nb_epochs (int): number of epochs degree (int): degree of the norm in the feature matching random_seed (int): trying different seeds for averaging the results label (int): label which is normal for image experiments allow_zz (bool): allow the d_zz discriminator or not for ablation study enable_sm (bool): allow TF summaries for monitoring the training score_method (str): which metric to use for the ablation study enable_early_stop (bool): allow early stopping for determining the number of epochs do_spectral_norm (bool): allow spectral norm or not for ablation study """ config = tf.ConfigProto() config.gpu_options.allow_growth = True logger = logging.getLogger("ALAD.run.{}.{}".format( dataset, label)) # Import model and data network = importlib.import_module('alad.{}_utilities'.format(dataset)) data = importlib.import_module("data.{}".format(dataset)) # Parameters starting_lr = network.learning_rate batch_size = network.batch_size latent_dim = network.latent_dim ema_decay = 0.999 global_step = tf.Variable(0, name='global_step', trainable=False) # Placeholders x_pl = tf.placeholder(tf.float32, shape=data.get_shape_input(), name="input_x") z_pl = tf.placeholder(tf.float32, shape=[None, latent_dim], name="input_z") is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl') learning_rate = tf.placeholder(tf.float32, shape=(), name="lr_pl") # Data # label 部分都已经转换为 0 normal数据 / 1 abnormal数据 logger.info('Data loading...') trainx, trainy = data.get_train(label) if enable_early_stop: validx, validy = data.get_valid(label) trainx_copy = trainx.copy() testx, testy = data.get_test(label) rng = np.random.RandomState(random_seed) nr_batches_train = int(trainx.shape[0] / batch_size) nr_batches_test = int(testx.shape[0] / batch_size) logger.info('Building graph...') logger.warn("ALAD is training with the following parameters:") display_parameters(batch_size, starting_lr, ema_decay, degree, label, allow_zz, score_method, do_spectral_norm) gen = network.decoder enc = network.encoder dis_xz = network.discriminator_xz dis_xx = network.discriminator_xx dis_zz = network.discriminator_zz with tf.variable_scope('encoder_model'): z_gen = enc(x_pl, is_training=is_training_pl, do_spectral_norm=do_spectral_norm) with tf.variable_scope('generator_model'): x_gen = gen(z_pl, is_training=is_training_pl) rec_x = gen(z_gen, is_training=is_training_pl, reuse=True) with tf.variable_scope('encoder_model'): rec_z = enc(x_gen, is_training=is_training_pl, reuse=True, do_spectral_norm=do_spectral_norm) with tf.variable_scope('discriminator_model_xz'): l_encoder, inter_layer_inp_xz = dis_xz(x_pl, z_gen, is_training=is_training_pl, do_spectral_norm=do_spectral_norm) l_generator, inter_layer_rct_xz = dis_xz(x_gen, z_pl, is_training=is_training_pl, reuse=True, do_spectral_norm=do_spectral_norm) with tf.variable_scope('discriminator_model_xx'): x_logit_real, inter_layer_inp_xx = dis_xx(x_pl, x_pl, is_training=is_training_pl, do_spectral_norm=do_spectral_norm) x_logit_fake, inter_layer_rct_xx = dis_xx(x_pl, rec_x, is_training=is_training_pl, reuse=True, do_spectral_norm=do_spectral_norm) with tf.variable_scope('discriminator_model_zz'): z_logit_real, _ = dis_zz(z_pl, z_pl, is_training=is_training_pl, do_spectral_norm=do_spectral_norm) z_logit_fake, _ = dis_zz(z_pl, rec_z, is_training=is_training_pl, reuse=True, do_spectral_norm=do_spectral_norm) with tf.name_scope('loss_functions'): # discriminator xz loss_dis_enc = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.ones_like(l_encoder),logits=l_encoder)) loss_dis_gen = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.zeros_like(l_generator),logits=l_generator)) dis_loss_xz = loss_dis_gen + loss_dis_enc # discriminator xx x_real_dis = tf.nn.sigmoid_cross_entropy_with_logits( logits=x_logit_real, labels=tf.ones_like(x_logit_real)) x_fake_dis = tf.nn.sigmoid_cross_entropy_with_logits( logits=x_logit_fake, labels=tf.zeros_like(x_logit_fake)) dis_loss_xx = tf.reduce_mean(x_real_dis + x_fake_dis) # discriminator zz z_real_dis = tf.nn.sigmoid_cross_entropy_with_logits( logits=z_logit_real, labels=tf.ones_like(z_logit_real)) z_fake_dis = tf.nn.sigmoid_cross_entropy_with_logits( logits=z_logit_fake, labels=tf.zeros_like(z_logit_fake)) dis_loss_zz = tf.reduce_mean(z_real_dis + z_fake_dis) loss_discriminator = dis_loss_xz + dis_loss_xx + dis_loss_zz if \ allow_zz else dis_loss_xz + dis_loss_xx # generator and encoder gen_loss_xz = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.ones_like(l_generator),logits=l_generator)) enc_loss_xz = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.zeros_like(l_encoder), logits=l_encoder)) x_real_gen = tf.nn.sigmoid_cross_entropy_with_logits( logits=x_logit_real, labels=tf.zeros_like(x_logit_real)) x_fake_gen = tf.nn.sigmoid_cross_entropy_with_logits( logits=x_logit_fake, labels=tf.ones_like(x_logit_fake)) z_real_gen = tf.nn.sigmoid_cross_entropy_with_logits( logits=z_logit_real, labels=tf.zeros_like(z_logit_real)) z_fake_gen = tf.nn.sigmoid_cross_entropy_with_logits( logits=z_logit_fake, labels=tf.ones_like(z_logit_fake)) cost_x = tf.reduce_mean(x_real_gen + x_fake_gen) cost_z = tf.reduce_mean(z_real_gen + z_fake_gen) cycle_consistency_loss = cost_x + cost_z if allow_zz else cost_x loss_generator = gen_loss_xz + cycle_consistency_loss loss_encoder = enc_loss_xz + cycle_consistency_loss with tf.name_scope('optimizers'): # control op dependencies for batch norm and trainable variables tvars = tf.trainable_variables() dxzvars = [var for var in tvars if 'discriminator_model_xz' in var.name] dxxvars = [var for var in tvars if 'discriminator_model_xx' in var.name] dzzvars = [var for var in tvars if 'discriminator_model_zz' in var.name] gvars = [var for var in tvars if 'generator_model' in var.name] evars = [var for var in tvars if 'encoder_model' in var.name] update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) update_ops_gen = [x for x in update_ops if ('generator_model' in x.name)] update_ops_enc = [x for x in update_ops if ('encoder_model' in x.name)] update_ops_dis_xz = [x for x in update_ops if ('discriminator_model_xz' in x.name)] update_ops_dis_xx = [x for x in update_ops if ('discriminator_model_xx' in x.name)] update_ops_dis_zz = [x for x in update_ops if ('discriminator_model_zz' in x.name)] optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.5) with tf.control_dependencies(update_ops_gen): gen_op = optimizer.minimize(loss_generator, var_list=gvars, global_step=global_step) with tf.control_dependencies(update_ops_enc): enc_op = optimizer.minimize(loss_encoder, var_list=evars) with tf.control_dependencies(update_ops_dis_xz): dis_op_xz = optimizer.minimize(dis_loss_xz, var_list=dxzvars) with tf.control_dependencies(update_ops_dis_xx): dis_op_xx = optimizer.minimize(dis_loss_xx, var_list=dxxvars) with tf.control_dependencies(update_ops_dis_zz): dis_op_zz = optimizer.minimize(dis_loss_zz, var_list=dzzvars) # Exponential Moving Average for inference def train_op_with_ema_dependency(vars, op): ema = tf.train.ExponentialMovingAverage(decay=ema_decay) maintain_averages_op = ema.apply(vars) with tf.control_dependencies([op]): train_op = tf.group(maintain_averages_op) return train_op, ema train_gen_op, gen_ema = train_op_with_ema_dependency(gvars, gen_op) train_enc_op, enc_ema = train_op_with_ema_dependency(evars, enc_op) train_dis_op_xz, xz_ema = train_op_with_ema_dependency(dxzvars, dis_op_xz) train_dis_op_xx, xx_ema = train_op_with_ema_dependency(dxxvars, dis_op_xx) train_dis_op_zz, zz_ema = train_op_with_ema_dependency(dzzvars, dis_op_zz) with tf.variable_scope('encoder_model'): z_gen_ema = enc(x_pl, is_training=is_training_pl, getter=get_getter(enc_ema), reuse=True, do_spectral_norm=do_spectral_norm) with tf.variable_scope('generator_model'): rec_x_ema = gen(z_gen_ema, is_training=is_training_pl, getter=get_getter(gen_ema), reuse=True) x_gen_ema = gen(z_pl, is_training=is_training_pl, getter=get_getter(gen_ema), reuse=True) with tf.variable_scope('discriminator_model_xx'): l_encoder_emaxx, inter_layer_inp_emaxx = dis_xx(x_pl, x_pl, is_training=is_training_pl, getter=get_getter(xx_ema), reuse=True, do_spectral_norm=do_spectral_norm) l_generator_emaxx, inter_layer_rct_emaxx = dis_xx(x_pl, rec_x_ema, is_training=is_training_pl, getter=get_getter( xx_ema), reuse=True, do_spectral_norm=do_spectral_norm) with tf.name_scope('Testing'): with tf.variable_scope('Scores'): score_ch = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.ones_like(l_generator_emaxx), logits=l_generator_emaxx) score_ch = tf.squeeze(score_ch) rec = x_pl - rec_x_ema rec = tf.contrib.layers.flatten(rec) score_l1 = tf.norm(rec, ord=1, axis=1, keep_dims=False, name='d_loss') score_l1 = tf.squeeze(score_l1) rec = x_pl - rec_x_ema rec = tf.contrib.layers.flatten(rec) score_l2 = tf.norm(rec, ord=2, axis=1, keep_dims=False, name='d_loss') score_l2 = tf.squeeze(score_l2) inter_layer_inp, inter_layer_rct = inter_layer_inp_emaxx, \ inter_layer_rct_emaxx fm = inter_layer_inp - inter_layer_rct fm = tf.contrib.layers.flatten(fm) score_fm = tf.norm(fm, ord=degree, axis=1, keep_dims=False, name='d_loss') score_fm = tf.squeeze(score_fm) if enable_early_stop: rec_error_valid = tf.reduce_mean(score_fm) if enable_sm: with tf.name_scope('summary'): with tf.name_scope('dis_summary'): tf.summary.scalar('loss_discriminator', loss_discriminator, ['dis']) tf.summary.scalar('loss_dis_encoder', loss_dis_enc, ['dis']) tf.summary.scalar('loss_dis_gen', loss_dis_gen, ['dis']) tf.summary.scalar('loss_dis_xz', dis_loss_xz, ['dis']) tf.summary.scalar('loss_dis_xx', dis_loss_xx, ['dis']) if allow_zz: tf.summary.scalar('loss_dis_zz', dis_loss_zz, ['dis']) with tf.name_scope('gen_summary'): tf.summary.scalar('loss_generator', loss_generator, ['gen']) tf.summary.scalar('loss_encoder', loss_encoder, ['gen']) tf.summary.scalar('loss_encgen_dxx', cost_x, ['gen']) if allow_zz: tf.summary.scalar('loss_encgen_dzz', cost_z, ['gen']) if enable_early_stop: with tf.name_scope('validation_summary'): tf.summary.scalar('valid', rec_error_valid, ['v']) with tf.name_scope('img_summary'): heatmap_pl_latent = tf.placeholder(tf.float32, shape=(1, 480, 640, 3), name="heatmap_pl_latent") sum_op_latent = tf.summary.image('heatmap_latent', heatmap_pl_latent) if dataset in IMAGES_DATASETS: with tf.name_scope('image_summary'): tf.summary.image('reconstruct', rec_x, 8, ['image']) tf.summary.image('input_images', x_pl, 8, ['image']) else: heatmap_pl_rec = tf.placeholder(tf.float32, shape=(1, 480, 640, 3), name="heatmap_pl_rec") with tf.name_scope('image_summary'): tf.summary.image('heatmap_rec', heatmap_pl_rec, 1, ['image']) sum_op_dis = tf.summary.merge_all('dis') sum_op_gen = tf.summary.merge_all('gen') sum_op = tf.summary.merge([sum_op_dis, sum_op_gen]) sum_op_im = tf.summary.merge_all('image') sum_op_valid = tf.summary.merge_all('v') logdir = create_logdir(dataset, label, random_seed, allow_zz, score_method, do_spectral_norm) saver = tf.train.Saver(max_to_keep=2) save_model_secs = None if enable_early_stop else 20 sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=None, saver=saver, save_model_secs=save_model_secs) logger.info('Start training...') with sv.managed_session(config=config) as sess: step = sess.run(global_step) logger.info('Initialization done at step {}'.format(step/nr_batches_train)) writer = tf.summary.FileWriter(logdir, sess.graph) train_batch = 0 epoch = 0 best_valid_loss = 0 request_stop = False while not sv.should_stop() and epoch < nb_epochs: lr = starting_lr begin = time.time() # construct randomly permuted minibatches trainx = trainx[rng.permutation(trainx.shape[0])] # shuffling dataset trainx_copy = trainx_copy[rng.permutation(trainx.shape[0])] train_loss_dis_xz, train_loss_dis_xx, train_loss_dis_zz, \ train_loss_dis, train_loss_gen, train_loss_enc = [0, 0, 0, 0, 0, 0] # Training for t in range(nr_batches_train): display_progression_epoch(t, nr_batches_train) ran_from = t * batch_size ran_to = (t + 1) * batch_size # train discriminator feed_dict = {x_pl: trainx[ran_from:ran_to], z_pl: np.random.normal(size=[batch_size, latent_dim]), is_training_pl: True, learning_rate:lr} _, _, _, ld, ldxz, ldxx, ldzz, step = sess.run([train_dis_op_xz, train_dis_op_xx, train_dis_op_zz, loss_discriminator, dis_loss_xz, dis_loss_xx, dis_loss_zz, global_step], feed_dict=feed_dict) train_loss_dis += ld train_loss_dis_xz += ldxz train_loss_dis_xx += ldxx train_loss_dis_zz += ldzz # train generator and encoder feed_dict = {x_pl: trainx_copy[ran_from:ran_to], z_pl: np.random.normal(size=[batch_size, latent_dim]), is_training_pl: True, learning_rate:lr} _,_, le, lg = sess.run([train_gen_op, train_enc_op, loss_encoder, loss_generator], feed_dict=feed_dict) train_loss_gen += lg train_loss_enc += le if enable_sm: sm = sess.run(sum_op, feed_dict=feed_dict) writer.add_summary(sm, step) if t % FREQ_PRINT == 0 and dataset in IMAGES_DATASETS: # inspect reconstruction t = np.random.randint(0, trainx.shape[0]-batch_size) ran_from = t ran_to = t + batch_size feed_dict = {x_pl: trainx[ran_from:ran_to], z_pl: np.random.normal( size=[batch_size, latent_dim]), is_training_pl: False} sm = sess.run(sum_op_im, feed_dict=feed_dict) writer.add_summary(sm, step)#train_batch) train_batch += 1 train_loss_gen /= nr_batches_train train_loss_enc /= nr_batches_train train_loss_dis /= nr_batches_train train_loss_dis_xz /= nr_batches_train train_loss_dis_xx /= nr_batches_train train_loss_dis_zz /= nr_batches_train logger.info('Epoch terminated') if allow_zz: print("Epoch %d | time = %ds | loss gen = %.4f | loss enc = %.4f | " "loss dis = %.4f | loss dis xz = %.4f | loss dis xx = %.4f | " "loss dis zz = %.4f" % (epoch, time.time() - begin, train_loss_gen, train_loss_enc, train_loss_dis, train_loss_dis_xz, train_loss_dis_xx, train_loss_dis_zz)) else: print("Epoch %d | time = %ds | loss gen = %.4f | loss enc = %.4f | " "loss dis = %.4f | loss dis xz = %.4f | loss dis xx = %.4f | " % (epoch, time.time() - begin, train_loss_gen, train_loss_enc, train_loss_dis, train_loss_dis_xz, train_loss_dis_xx)) ##EARLY STOPPING if (epoch + 1) % FREQ_EV == 0 and enable_early_stop: valid_loss = 0 feed_dict = {x_pl: validx, z_pl: np.random.normal(size=[validx.shape[0], latent_dim]), is_training_pl: False} vl, lat = sess.run([rec_error_valid, rec_z], feed_dict=feed_dict) valid_loss += vl if enable_sm: sm = sess.run(sum_op_valid, feed_dict=feed_dict) writer.add_summary(sm, step) # train_batch) logger.info('Validation: valid loss {:.4f}'.format(valid_loss)) if (valid_loss < best_valid_loss or epoch == FREQ_EV-1): best_valid_loss = valid_loss logger.info("Best model - valid loss = {:.4f} - saving...".format(best_valid_loss)) sv.saver.save(sess, logdir+'/model.ckpt', global_step=step) nb_without_improvements = 0 else: nb_without_improvements += FREQ_EV if nb_without_improvements > PATIENCE: sv.request_stop() logger.warning( "Early stopping at epoch {} with weights from epoch {}".format( epoch, epoch - nb_without_improvements)) epoch += 1 sv.saver.save(sess, logdir+'/model.ckpt', global_step=step) logger.warn('Testing evaluation...') scores_ch = [] scores_l1 = [] scores_l2 = [] scores_fm = [] inference_time = [] # Create scores for t in range(nr_batches_test): # construct randomly permuted minibatches ran_from = t * batch_size ran_to = (t + 1) * batch_size begin_test_time_batch = time.time() feed_dict = {x_pl: testx[ran_from:ran_to], z_pl: np.random.normal(size=[batch_size, latent_dim]), is_training_pl:False} scores_ch += sess.run(score_ch, feed_dict=feed_dict).tolist() scores_l1 += sess.run(score_l1, feed_dict=feed_dict).tolist() scores_l2 += sess.run(score_l2, feed_dict=feed_dict).tolist() scores_fm += sess.run(score_fm, feed_dict=feed_dict).tolist() inference_time.append(time.time() - begin_test_time_batch) inference_time = np.mean(inference_time) logger.info('Testing : mean inference time is %.4f' % (inference_time)) if testx.shape[0] % batch_size != 0: batch, size = batch_fill(testx, batch_size) feed_dict = {x_pl: batch, z_pl: np.random.normal(size=[batch_size, latent_dim]), is_training_pl: False} bscores_ch = sess.run(score_ch,feed_dict=feed_dict).tolist() bscores_l1 = sess.run(score_l1,feed_dict=feed_dict).tolist() bscores_l2 = sess.run(score_l2,feed_dict=feed_dict).tolist() bscores_fm = sess.run(score_fm,feed_dict=feed_dict).tolist() scores_ch += bscores_ch[:size] scores_l1 += bscores_l1[:size] scores_l2 += bscores_l2[:size] scores_fm += bscores_fm[:size] model = 'alad_sn{}_dzz{}'.format(do_spectral_norm, allow_zz) save_results(scores_ch, testy, model, dataset, 'ch', 'dzzenabled{}'.format(allow_zz), label, random_seed, step) save_results(scores_l1, testy, model, dataset, 'l1', 'dzzenabled{}'.format(allow_zz), label, random_seed, step) save_results(scores_l2, testy, model, dataset, 'l2', 'dzzenabled{}'.format(allow_zz), label, random_seed, step) save_results(scores_fm, testy, model, dataset, 'fm', 'dzzenabled{}'.format(allow_zz), label, random_seed, step)
def test(dataset, nb_epochs, degree, random_seed, label, allow_zz, enable_sm, score_method, enable_early_stop, do_spectral_norm): config = tf.ConfigProto() config.gpu_options.allow_growth = True logger = logging.getLogger("ALAD.run.{}.{}".format( dataset, label)) # Import model and data network = importlib.import_module('alad.{}_utilities'.format(dataset)) data = importlib.import_module("data.{}".format(dataset)) # Parameters starting_lr = network.learning_rate batch_size = network.batch_size latent_dim = network.latent_dim ema_decay = 0.999 global_step = tf.Variable(0, name='global_step', trainable=False) # Placeholders x_pl = tf.placeholder(tf.float32, shape=data.get_shape_input(), name="input_x") z_pl = tf.placeholder(tf.float32, shape=[None, latent_dim], name="input_z") is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl') learning_rate = tf.placeholder(tf.float32, shape=(), name="lr_pl") # Data logger.info('Data loading...') #testx, testy = data.get_test(label) df = pd.DataFrame(columns = ['Flow ID', 'Src IP', 'Src Port', 'Dst IP', 'Dst Port', 'Protocol', 'Timestamp', 'Flow Duration', 'Tot Fwd Pkts', 'Tot Bwd Pkts', 'TotLen Fwd Pkts', 'TotLen Bwd Pkts', 'Fwd Pkt Len Max', 'Fwd Pkt Len Min', 'Fwd Pkt Len Mean', 'Fwd Pkt Len Std', 'Bwd Pkt Len Max', 'Bwd Pkt Len Min', 'Bwd Pkt Len Mean', 'Bwd Pkt Len Std', 'Flow Byts/s', 'Flow Pkts/s', 'Flow IAT Mean', 'Flow IAT Std', 'Flow IAT Max', 'Flow IAT Min', 'Fwd IAT Tot', 'Fwd IAT Mean', 'Fwd IAT Std', 'Fwd IAT Max', 'Fwd IAT Min', 'Bwd IAT Tot', 'Bwd IAT Mean', 'Bwd IAT Std', 'Bwd IAT Max', 'Bwd IAT Min', 'Fwd PSH Flags', 'Bwd PSH Flags', 'Fwd URG Flags', 'Bwd URG Flags', 'Fwd Header Len', 'Bwd Header Len', 'Fwd Pkts/s', 'Bwd Pkts/s', 'Pkt Len Min', 'Pkt Len Max', 'Pkt Len Mean', 'Pkt Len Std', 'Pkt Len Var', 'FIN Flag Cnt', 'SYN Flag Cnt', 'RST Flag Cnt', 'PSH Flag Cnt', 'ACK Flag Cnt', 'URG Flag Cnt', 'CWE Flag Count', 'ECE Flag Cnt', 'Down/Up Ratio', 'Pkt Size Avg', 'Fwd Seg Size Avg', 'Bwd Seg Size Avg', 'Fwd Byts/b Avg', 'Fwd Pkts/b Avg', 'Fwd Blk Rate Avg', 'Bwd Byts/b Avg', 'Bwd Pkts/b Avg', 'Bwd Blk Rate Avg', 'Subflow Fwd Pkts', 'Subflow Fwd Byts', 'Subflow Bwd Pkts', 'Subflow Bwd Byts', 'Init Fwd Win Byts', 'Init Bwd Win Byts', 'Fwd Act Data Pkts', 'Fwd Seg Size Min', 'Active Mean', 'Active Std', 'Active Max', 'Active Min', 'Idle Mean', 'Idle Std', 'Idle Max', 'Idle Min', 'Label']) filenames = os.listdir("../package_file_temp_out/") os.chdir('../package_file_temp_out/') for x in filenames: df_temp = pd.read_csv(x) l = [df] l.append(df_temp) df = pd.concat(l, axis=0, ignore_index=True) #df.append(df_temp, ignore_index=True) os.chdir('../') result = [] for x in df.columns: if x != 'Label': result.append(x) record = df.as_matrix(result) df.drop('Flow ID', axis=1, inplace=True) df.drop('Src IP', axis=1, inplace=True) df.drop('Dst IP', axis=1, inplace=True) df.drop('Timestamp', axis=1, inplace=True) df.drop('Flow Byts/s', axis=1, inplace=True) df.drop('Flow Pkts/s', axis=1, inplace=True) element = ['Flow ID', 'Src IP', 'Dst IP', 'Timestamp', 'Flow Byts/s', 'Flow Pkts/s'] for x in element: result.remove(x) testx = df.as_matrix(result).astype(np.float32) rng = np.random.RandomState(random_seed) nr_batches_test = int(testx.shape[0] / batch_size) logger.info('Building graph...') gen = network.decoder enc = network.encoder dis_xz = network.discriminator_xz dis_xx = network.discriminator_xx dis_zz = network.discriminator_zz with tf.variable_scope('encoder_model'): z_gen = enc(x_pl, is_training=is_training_pl, do_spectral_norm=do_spectral_norm) with tf.variable_scope('generator_model'): x_gen = gen(z_pl, is_training=is_training_pl) rec_x = gen(z_gen, is_training=is_training_pl, reuse=True) with tf.variable_scope('encoder_model'): rec_z = enc(x_gen, is_training=is_training_pl, reuse=True, do_spectral_norm=do_spectral_norm) with tf.variable_scope('discriminator_model_xz'): l_encoder, inter_layer_inp_xz = dis_xz(x_pl, z_gen, is_training=is_training_pl, do_spectral_norm=do_spectral_norm) l_generator, inter_layer_rct_xz = dis_xz(x_gen, z_pl, is_training=is_training_pl, reuse=True, do_spectral_norm=do_spectral_norm) with tf.variable_scope('discriminator_model_xx'): x_logit_real, inter_layer_inp_xx = dis_xx(x_pl, x_pl, is_training=is_training_pl, do_spectral_norm=do_spectral_norm) x_logit_fake, inter_layer_rct_xx = dis_xx(x_pl, rec_x, is_training=is_training_pl, reuse=True, do_spectral_norm=do_spectral_norm) with tf.variable_scope('discriminator_model_zz'): z_logit_real, _ = dis_zz(z_pl, z_pl, is_training=is_training_pl, do_spectral_norm=do_spectral_norm) z_logit_fake, _ = dis_zz(z_pl, rec_z, is_training=is_training_pl, reuse=True, do_spectral_norm=do_spectral_norm) with tf.name_scope('loss_functions'): # discriminator xz loss_dis_enc = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.ones_like(l_encoder),logits=l_encoder)) loss_dis_gen = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.zeros_like(l_generator),logits=l_generator)) dis_loss_xz = loss_dis_gen + loss_dis_enc # discriminator xx x_real_dis = tf.nn.sigmoid_cross_entropy_with_logits( logits=x_logit_real, labels=tf.ones_like(x_logit_real)) x_fake_dis = tf.nn.sigmoid_cross_entropy_with_logits( logits=x_logit_fake, labels=tf.zeros_like(x_logit_fake)) dis_loss_xx = tf.reduce_mean(x_real_dis + x_fake_dis) # discriminator zz z_real_dis = tf.nn.sigmoid_cross_entropy_with_logits( logits=z_logit_real, labels=tf.ones_like(z_logit_real)) z_fake_dis = tf.nn.sigmoid_cross_entropy_with_logits( logits=z_logit_fake, labels=tf.zeros_like(z_logit_fake)) dis_loss_zz = tf.reduce_mean(z_real_dis + z_fake_dis) loss_discriminator = dis_loss_xz + dis_loss_xx + dis_loss_zz if \ allow_zz else dis_loss_xz + dis_loss_xx # generator and encoder gen_loss_xz = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.ones_like(l_generator),logits=l_generator)) enc_loss_xz = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.zeros_like(l_encoder), logits=l_encoder)) x_real_gen = tf.nn.sigmoid_cross_entropy_with_logits( logits=x_logit_real, labels=tf.zeros_like(x_logit_real)) x_fake_gen = tf.nn.sigmoid_cross_entropy_with_logits( logits=x_logit_fake, labels=tf.ones_like(x_logit_fake)) z_real_gen = tf.nn.sigmoid_cross_entropy_with_logits( logits=z_logit_real, labels=tf.zeros_like(z_logit_real)) z_fake_gen = tf.nn.sigmoid_cross_entropy_with_logits( logits=z_logit_fake, labels=tf.ones_like(z_logit_fake)) cost_x = tf.reduce_mean(x_real_gen + x_fake_gen) cost_z = tf.reduce_mean(z_real_gen + z_fake_gen) cycle_consistency_loss = cost_x + cost_z if allow_zz else cost_x loss_generator = gen_loss_xz + cycle_consistency_loss loss_encoder = enc_loss_xz + cycle_consistency_loss with tf.name_scope('optimizers'): # control op dependencies for batch norm and trainable variables tvars = tf.trainable_variables() dxzvars = [var for var in tvars if 'discriminator_model_xz' in var.name] dxxvars = [var for var in tvars if 'discriminator_model_xx' in var.name] dzzvars = [var for var in tvars if 'discriminator_model_zz' in var.name] gvars = [var for var in tvars if 'generator_model' in var.name] evars = [var for var in tvars if 'encoder_model' in var.name] update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) update_ops_gen = [x for x in update_ops if ('generator_model' in x.name)] update_ops_enc = [x for x in update_ops if ('encoder_model' in x.name)] update_ops_dis_xz = [x for x in update_ops if ('discriminator_model_xz' in x.name)] update_ops_dis_xx = [x for x in update_ops if ('discriminator_model_xx' in x.name)] update_ops_dis_zz = [x for x in update_ops if ('discriminator_model_zz' in x.name)] optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.5) with tf.control_dependencies(update_ops_gen): gen_op = optimizer.minimize(loss_generator, var_list=gvars, global_step=global_step) with tf.control_dependencies(update_ops_enc): enc_op = optimizer.minimize(loss_encoder, var_list=evars) with tf.control_dependencies(update_ops_dis_xz): dis_op_xz = optimizer.minimize(dis_loss_xz, var_list=dxzvars) with tf.control_dependencies(update_ops_dis_xx): dis_op_xx = optimizer.minimize(dis_loss_xx, var_list=dxxvars) with tf.control_dependencies(update_ops_dis_zz): dis_op_zz = optimizer.minimize(dis_loss_zz, var_list=dzzvars) # Exponential Moving Average for inference def train_op_with_ema_dependency(vars, op): ema = tf.train.ExponentialMovingAverage(decay=ema_decay) maintain_averages_op = ema.apply(vars) with tf.control_dependencies([op]): train_op = tf.group(maintain_averages_op) return train_op, ema train_gen_op, gen_ema = train_op_with_ema_dependency(gvars, gen_op) train_enc_op, enc_ema = train_op_with_ema_dependency(evars, enc_op) train_dis_op_xz, xz_ema = train_op_with_ema_dependency(dxzvars, dis_op_xz) train_dis_op_xx, xx_ema = train_op_with_ema_dependency(dxxvars, dis_op_xx) train_dis_op_zz, zz_ema = train_op_with_ema_dependency(dzzvars, dis_op_zz) with tf.variable_scope('encoder_model'): z_gen_ema = enc(x_pl, is_training=is_training_pl, getter=get_getter(enc_ema), reuse=True, do_spectral_norm=do_spectral_norm) with tf.variable_scope('generator_model'): rec_x_ema = gen(z_gen_ema, is_training=is_training_pl, getter=get_getter(gen_ema), reuse=True) x_gen_ema = gen(z_pl, is_training=is_training_pl, getter=get_getter(gen_ema), reuse=True) with tf.variable_scope('discriminator_model_xx'): l_encoder_emaxx, inter_layer_inp_emaxx = dis_xx(x_pl, x_pl, is_training=is_training_pl, getter=get_getter(xx_ema), reuse=True, do_spectral_norm=do_spectral_norm) l_generator_emaxx, inter_layer_rct_emaxx = dis_xx(x_pl, rec_x_ema, is_training=is_training_pl, getter=get_getter( xx_ema), reuse=True, do_spectral_norm=do_spectral_norm) with tf.name_scope('Testing'): with tf.variable_scope('Scores'): score_ch = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.ones_like(l_generator_emaxx), logits=l_generator_emaxx) score_ch = tf.squeeze(score_ch) rec = x_pl - rec_x_ema rec = tf.contrib.layers.flatten(rec) score_l1 = tf.norm(rec, ord=1, axis=1, keep_dims=False, name='d_loss') score_l1 = tf.squeeze(score_l1) rec = x_pl - rec_x_ema rec = tf.contrib.layers.flatten(rec) score_l2 = tf.norm(rec, ord=2, axis=1, keep_dims=False, name='d_loss') score_l2 = tf.squeeze(score_l2) inter_layer_inp, inter_layer_rct = inter_layer_inp_emaxx, \ inter_layer_rct_emaxx fm = inter_layer_inp - inter_layer_rct fm = tf.contrib.layers.flatten(fm) score_fm = tf.norm(fm, ord=degree, axis=1, keep_dims=False, name='d_loss') score_fm = tf.squeeze(score_fm) if enable_early_stop: rec_error_valid = tf.reduce_mean(score_fm) if enable_sm: with tf.name_scope('summary'): with tf.name_scope('dis_summary'): tf.summary.scalar('loss_discriminator', loss_discriminator, ['dis']) tf.summary.scalar('loss_dis_encoder', loss_dis_enc, ['dis']) tf.summary.scalar('loss_dis_gen', loss_dis_gen, ['dis']) tf.summary.scalar('loss_dis_xz', dis_loss_xz, ['dis']) tf.summary.scalar('loss_dis_xx', dis_loss_xx, ['dis']) if allow_zz: tf.summary.scalar('loss_dis_zz', dis_loss_zz, ['dis']) with tf.name_scope('gen_summary'): tf.summary.scalar('loss_generator', loss_generator, ['gen']) tf.summary.scalar('loss_encoder', loss_encoder, ['gen']) tf.summary.scalar('loss_encgen_dxx', cost_x, ['gen']) if allow_zz: tf.summary.scalar('loss_encgen_dzz', cost_z, ['gen']) if enable_early_stop: with tf.name_scope('validation_summary'): tf.summary.scalar('valid', rec_error_valid, ['v']) with tf.name_scope('img_summary'): heatmap_pl_latent = tf.placeholder(tf.float32, shape=(1, 480, 640, 3), name="heatmap_pl_latent") sum_op_latent = tf.summary.image('heatmap_latent', heatmap_pl_latent) if dataset in IMAGES_DATASETS: with tf.name_scope('image_summary'): tf.summary.image('reconstruct', rec_x, 8, ['image']) tf.summary.image('input_images', x_pl, 8, ['image']) else: heatmap_pl_rec = tf.placeholder(tf.float32, shape=(1, 480, 640, 3), name="heatmap_pl_rec") with tf.name_scope('image_summary'): tf.summary.image('heatmap_rec', heatmap_pl_rec, 1, ['image']) sum_op_dis = tf.summary.merge_all('dis') sum_op_gen = tf.summary.merge_all('gen') sum_op = tf.summary.merge([sum_op_dis, sum_op_gen]) sum_op_im = tf.summary.merge_all('image') sum_op_valid = tf.summary.merge_all('v') logdir = "ALAD/train_logs/cicids2017/alad_snTrue_dzzTrue/dzzenabledTrue/fm/label0/rd43" saver = tf.train.Saver(max_to_keep=1) save_model_secs = None if enable_early_stop else 20 sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=None, saver=saver, save_model_secs=save_model_secs) with sv.managed_session(config=config) as sess: ckpt = tf.train.get_checkpoint_state(logdir) if ckpt and ckpt.model_checkpoint_path: step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] logger.info(ckpt.model_checkpoint_path) file = os.path.basename(ckpt.model_checkpoint_path) checkpoint_path = os.path.join(logdir, file) saver.restore(sess, checkpoint_path) scores_ch = [] scores_l1 = [] scores_l2 = [] scores_fm = [] inference_time = [] # Create scores for t in range(nr_batches_test): # construct randomly permuted minibatches ran_from = t * batch_size ran_to = (t + 1) * batch_size begin_test_time_batch = time.time() feed_dict = {x_pl: testx[ran_from:ran_to], z_pl: np.random.normal(size=[batch_size, latent_dim]), is_training_pl:False} scores_ch += sess.run(score_ch, feed_dict=feed_dict).tolist() scores_l1 += sess.run(score_l1, feed_dict=feed_dict).tolist() scores_l2 += sess.run(score_l2, feed_dict=feed_dict).tolist() scores_fm += sess.run(score_fm, feed_dict=feed_dict).tolist() inference_time.append(time.time() - begin_test_time_batch) inference_time = np.mean(inference_time) logger.info('Testing : mean inference time is %.4f' % (inference_time)) if testx.shape[0] % batch_size != 0: batch, size = batch_fill(testx, batch_size) feed_dict = {x_pl: batch, z_pl: np.random.normal(size=[batch_size, latent_dim]), is_training_pl: False} bscores_ch = sess.run(score_ch,feed_dict=feed_dict).tolist() bscores_l1 = sess.run(score_l1,feed_dict=feed_dict).tolist() bscores_l2 = sess.run(score_l2,feed_dict=feed_dict).tolist() bscores_fm = sess.run(score_fm,feed_dict=feed_dict).tolist() scores_ch += bscores_ch[:size] scores_l1 += bscores_l1[:size] scores_l2 += bscores_l2[:size] scores_fm += bscores_fm[:size] model = 'alad_sn{}_dzz{}'.format(do_spectral_norm, allow_zz) save_results(scores_ch, record, model, dataset, 'ch', 'dzzenabled{}'.format(allow_zz), label, random_seed, int(step), False) save_results(scores_l1, record, model, dataset, 'l1', 'dzzenabled{}'.format(allow_zz), label, random_seed, int(step), False) save_results(scores_l2, record, model, dataset, 'l2', 'dzzenabled{}'.format(allow_zz), label, random_seed, int(step), False) save_results(scores_fm, record, model, dataset, 'fm', 'dzzenabled{}'.format(allow_zz), label, random_seed, int(step), False)
def train_and_test(dataset, nb_epochs, random_seed, label): """ Runs DSEBM on available datasets Note: Saves summaries on tensorboard. To display them, please use cmd line tensorboard --logdir=model.training_logdir() --port=number Args: dataset (string): dataset to run the model on nb_epochs (int): number of epochs random_seed (int): trying different seeds for averaging the results label (int): label which is normal for image experiments anomaly_type (string): "novelty" for 100% normal samples in the training set "outlier" for a contamined training set anomaly_proportion (float): if "outlier", anomaly proportion in the training set """ logger = logging.getLogger("DSEBM.run.{}.{}".format(dataset, label)) # Import model and data network = importlib.import_module('dsebm.{}_utilities'.format(dataset)) data = importlib.import_module("data.{}".format(dataset)) # Parameters starting_lr = network.learning_rate batch_size = network.batch_size # Placeholders x_pl = tf.placeholder(tf.float32, shape=data.get_shape_input(), name="input") is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl') learning_rate = tf.placeholder(tf.float32, shape=(), name="lr_pl") #test y_true = tf.placeholder(tf.int32, shape=[None], name="y_true") logger.info('Building training graph...') logger.warn("The DSEBM is training with the following parameters:") display_parameters(batch_size, starting_lr, label) net = network.network global_step = tf.train.get_or_create_global_step() noise = tf.random_normal(shape=tf.shape(x_pl), mean=0.0, stddev=1., dtype=tf.float32) x_noise = x_pl + noise with tf.variable_scope('network'): b_prime_shape = list(data.get_shape_input()) b_prime_shape[0] = batch_size b_prime = tf.get_variable(name='b_prime', shape=b_prime_shape) #tf.shape(x_pl)) net_out = net(x_pl, is_training=is_training_pl) net_out_noise = net(x_noise, is_training=is_training_pl, reuse=True) with tf.name_scope('energies'): energy = 0.5 * tf.reduce_sum(tf.square(x_pl - b_prime)) \ - tf.reduce_sum(net_out) energy_noise = 0.5 * tf.reduce_sum(tf.square(x_noise - b_prime)) \ - tf.reduce_sum(net_out_noise) with tf.name_scope('reconstructions'): # reconstruction grad = tf.gradients(energy, x_pl) fx = x_pl - tf.gradients(energy, x_pl) fx = tf.squeeze(fx, axis=0) fx_noise = x_noise - tf.gradients(energy_noise, x_noise) with tf.name_scope("loss_function"): # DSEBM for images if len(data.get_shape_input()) == 4: loss = tf.reduce_mean( tf.reduce_sum(tf.square(x_pl - fx_noise), axis=[1, 2, 3])) # DSEBM for tabular data else: loss = tf.reduce_mean(tf.square(x_pl - fx_noise)) with tf.name_scope('optimizers'): # control op dependencies for batch norm and trainable variables tvars = tf.trainable_variables() netvars = [var for var in tvars if 'network' in var.name] update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) update_ops_net = [x for x in update_ops if ('network' in x.name)] optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, name='optimizer') with tf.control_dependencies(update_ops_net): train_op = optimizer.minimize(loss, var_list=netvars, global_step=global_step) with tf.variable_scope('Scores'): with tf.name_scope('Energy_score'): flat = tf.layers.flatten(x_pl - b_prime) if len(data.get_shape_input()) == 4: list_scores_energy = 0.5 * tf.reduce_sum(tf.square(flat), axis=1) \ - tf.reduce_sum(net_out, axis=[1, 2, 3]) else: list_scores_energy = 0.5 * tf.reduce_sum(tf.square(flat), axis=1) \ - tf.reduce_sum(net_out, axis=1) with tf.name_scope('Reconstruction_score'): delta = x_pl - fx delta_flat = tf.layers.flatten(delta) list_scores_reconstruction = tf.norm(delta_flat, ord=2, axis=1, keepdims=False, name='reconstruction') # with tf.name_scope('predictions'): # # Highest 20% are anomalous # if dataset=="kdd": # per = tf.contrib.distributions.percentile(list_scores_energy, 80) # else: # per = tf.contrib.distributions.percentile(list_scores_energy, 95) # y_pred = tf.greater_equal(list_scores_energy, per) # # #y_test_true = tf.cast(y_test_true, tf.float32) # cm = tf.confusion_matrix(y_true, y_pred, num_classes=2) # recall = cm[1,1]/(cm[1,0]+cm[1,1]) # precision = cm[1,1]/(cm[0,1]+cm[1,1]) # f1 = 2*precision*recall/(precision + recall) with tf.name_scope('training_summary'): tf.summary.scalar('score_matching_loss', loss, ['net']) tf.summary.scalar('energy', energy, ['net']) if dataset in IMAGES_DATASETS: with tf.name_scope('image_summary'): tf.summary.image('reconstruct', fx, 6, ['image']) tf.summary.image('input_images', x_pl, 6, ['image']) sum_op_im = tf.summary.merge_all('image') sum_op_net = tf.summary.merge_all('net') logdir = create_logdir(dataset, label, random_seed) sv = tf.train.Supervisor(logdir=logdir + "/train", save_summaries_secs=None, save_model_secs=None) # Data logger.info('Data loading...') trainx, trainy = data.get_train(label) trainx_copy = trainx.copy() if dataset in IMAGES_DATASETS: validx, validy = data.get_valid(label) testx, testy = data.get_test(label) rng = np.random.RandomState(RANDOM_SEED) nr_batches_train = int(trainx.shape[0] / batch_size) if dataset in IMAGES_DATASETS: nr_batches_valid = int(validx.shape[0] / batch_size) nr_batches_test = int(testx.shape[0] / batch_size) logger.info("Train: {} samples in {} batches".format( trainx.shape[0], nr_batches_train)) if dataset in IMAGES_DATASETS: logger.info("Valid: {} samples in {} batches".format( validx.shape[0], nr_batches_valid)) logger.info("Test: {} samples in {} batches".format( testx.shape[0], nr_batches_test)) logger.info('Start training...') with sv.managed_session() as sess: logger.info('Initialization done') train_writer = tf.summary.FileWriter(logdir + "/train", sess.graph) valid_writer = tf.summary.FileWriter(logdir + "/valid", sess.graph) train_batch = 0 epoch = 0 best_valid_loss = 0 train_losses = [0] * STRIP_EV while not sv.should_stop() and epoch < nb_epochs: lr = starting_lr begin = time.time() trainx = trainx[rng.permutation( trainx.shape[0])] # shuffling unl dataset trainx_copy = trainx_copy[rng.permutation(trainx.shape[0])] losses, energies = [0, 0] # training for t in range(nr_batches_train): display_progression_epoch(t, nr_batches_train) # construct randomly permuted minibatches ran_from = t * batch_size ran_to = (t + 1) * batch_size # train the net feed_dict = { x_pl: trainx[ran_from:ran_to], is_training_pl: True, learning_rate: lr } _, ld, en, sm, step = sess.run( [train_op, loss, energy, sum_op_net, global_step], feed_dict=feed_dict) losses += ld energies += en train_writer.add_summary(sm, step) if t % FREQ_PRINT == 0 and dataset in IMAGES_DATASETS: # inspect reconstruction t = np.random.randint(0, 40) ran_from = t ran_to = t + batch_size sm = sess.run(sum_op_im, feed_dict={ x_pl: trainx[ran_from:ran_to], is_training_pl: False }) train_writer.add_summary(sm, step) train_batch += 1 losses /= nr_batches_train energies /= nr_batches_train # Remembering loss for early stopping train_losses[epoch % STRIP_EV] = losses logger.info('Epoch terminated') print("Epoch %d | time = %ds | loss = %.4f | energy = %.4f " % (epoch, time.time() - begin, losses, energies)) if (epoch + 1) % FREQ_SNAP == 0 and dataset in IMAGES_DATASETS: print("Take a snap of the reconstructions...") x = trainx[:batch_size] feed_dict = {x_pl: x, is_training_pl: False} rect_x = sess.run(fx, feed_dict=feed_dict) nama_e_wa = "dsebm/reconstructions/{}/{}/" \ "{}_epoch{}".format(dataset, label, random_seed, epoch) nb_imgs = 50 save_grid_plot(x[:nb_imgs], rect_x[:nb_imgs], nama_e_wa, nb_imgs) if (epoch + 1) % FREQ_EV == 0 and dataset in IMAGES_DATASETS: logger.info("Validation") inds = rng.permutation(validx.shape[0]) validx = validx[inds] # shuffling dataset validy = validy[inds] # shuffling dataset valid_loss = 0 for t in range(nr_batches_valid): display_progression_epoch(t, nr_batches_valid) # construct randomly permuted minibatches ran_from = t * batch_size ran_to = (t + 1) * batch_size # train the net feed_dict = { x_pl: validx[ran_from:ran_to], y_true: validy[ran_from:ran_to], is_training_pl: False } vl, sm, step = sess.run([loss, sum_op_net, global_step], feed_dict=feed_dict) valid_writer.add_summary(sm, step + t) #train_batch) valid_loss += vl valid_loss /= nr_batches_valid # train the net logger.info("Validation loss at step " + str(step) + ":" + str(valid_loss)) ##EARLY STOPPING #UPDATE WEIGHTS if valid_loss < best_valid_loss or epoch == FREQ_EV - 1: best_valid_loss = valid_loss logger.info("Best model - loss={} - saving...".format( best_valid_loss)) sv.saver.save(sess, logdir + '/train/model.ckpt', global_step=step) nb_without_improvements = 0 else: nb_without_improvements += FREQ_EV if nb_without_improvements > PATIENCE: sv.request_stop() logger.warning( "Early stopping at epoch {} with weights from epoch {}" .format(epoch, epoch - nb_without_improvements)) epoch += 1 logger.warn('Testing evaluation...') step = sess.run(global_step) scores_e = [] scores_r = [] inference_time = [] # Create scores for t in range(nr_batches_test): # construct randomly permuted minibatches ran_from = t * batch_size ran_to = (t + 1) * batch_size begin_val_batch = time.time() feed_dict = {x_pl: testx[ran_from:ran_to], is_training_pl: False} scores_e += sess.run(list_scores_energy, feed_dict=feed_dict).tolist() scores_r += sess.run(list_scores_reconstruction, feed_dict=feed_dict).tolist() inference_time.append(time.time() - begin_val_batch) logger.info('Testing : mean inference time is %.4f' % (np.mean(inference_time))) if testx.shape[0] % batch_size != 0: batch, size = batch_fill(testx, batch_size) feed_dict = {x_pl: batch, is_training_pl: False} batch_score_e = sess.run(list_scores_energy, feed_dict=feed_dict).tolist() batch_score_r = sess.run(list_scores_reconstruction, feed_dict=feed_dict).tolist() scores_e += batch_score_e[:size] scores_r += batch_score_r[:size] save_results(scores_e, testy, 'dsebm', dataset, 'energy', "test", label, random_seed, step) save_results(scores_r, testy, 'dsebm', dataset, 'reconstruction', "test", label, random_seed, step)