示例#1
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if FLAGS.input_width is None:
        FLAGS.input_width = FLAGS.input_height
    if FLAGS.output_width is None:
        FLAGS.output_width = FLAGS.output_height

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    #memory allocation
    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True
    #run_config.gpu_options.per_process_gpu_memory_fraction = 0.4

    #tf.Session: A class for running TensorFlow operations.
    #config=~ : tells to use specific configuration(setting)
    with tf.Session(config=run_config) as sess:
        #initialization
        if FLAGS.dataset == 'mnist':
            dcgan = DCGAN(sess,
                          input_width=FLAGS.input_width,
                          input_height=FLAGS.input_height,
                          output_width=FLAGS.output_width,
                          output_height=FLAGS.output_height,
                          batch_size=FLAGS.batch_size,
                          test_batch_size=FLAGS.test_batch_size,
                          sample_num=FLAGS.batch_size,
                          y_dim=10,
                          z_dim=FLAGS.generate_test_images,
                          dataset_name=FLAGS.dataset,
                          input_fname_pattern=FLAGS.input_fname_pattern,
                          crop=FLAGS.crop,
                          checkpoint_dir=FLAGS.checkpoint_dir,
                          sample_dir=FLAGS.sample_dir,
                          test_dir=FLAGS.test_dir)
        else:
            dcgan = DCGAN(sess,
                          input_width=FLAGS.input_width,
                          input_height=FLAGS.input_height,
                          output_width=FLAGS.output_width,
                          output_height=FLAGS.output_height,
                          batch_size=FLAGS.batch_size,
                          test_batch_size=FLAGS.test_batch_size,
                          sample_num=FLAGS.batch_size,
                          z_dim=FLAGS.generate_test_images,
                          dataset_name=FLAGS.dataset,
                          input_fname_pattern=FLAGS.input_fname_pattern,
                          crop=FLAGS.crop,
                          checkpoint_dir=FLAGS.checkpoint_dir,
                          sample_dir=FLAGS.sample_dir,
                          test_dir=FLAGS.test_dir)

        show_all_variables()

        if FLAGS.train:
            dcgan.train(FLAGS)
        else:
            if not dcgan.load(FLAGS.checkpoint_dir)[0]:
                raise Exception("[!] Train a model first, then run test mode")

        if FLAGS.anomaly_test:
            dcgan.anomaly_detector()
            assert len(dcgan.test_data_names) > 0
            fp = open('anomaly_score_record.txt', 'a+')
            #for idx in range(2):
            for idx in range(len(dcgan.test_data_names)):
                test_input = np.expand_dims(dcgan.test_data[idx],
                                            axis=0)  ###################
                test_name = dcgan.test_data_names[idx]
                dcgan.train_anomaly_detector(FLAGS, test_input, test_name)  ##
            #assert: Python evealuates the statement, if false, exception error will raise
            #image = np.expand_dims(image, <your desired dimension>)
        # Below is codes for visualization
        OPTION = 1
        visualize(sess, dcgan, FLAGS, OPTION)
示例#2
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if FLAGS.output_height is None:
        FLAGS.output_height = FLAGS.input_height
    if FLAGS.input_width is None:
        FLAGS.input_width = FLAGS.input_height
    if FLAGS.output_width is None:
        FLAGS.output_width = FLAGS.output_height

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    gpu_options = {
        'allow_growth': True,
        'per_process_gpu_memory_fraction': 0.9
    }

    run_config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=True,
                                gpu_options=gpu_options)

    with tf.Session(config=run_config) as sess:
        if FLAGS.dataset == 'mnist':
            dcgan = DCGAN(sess,
                          input_width=FLAGS.input_width,
                          input_height=FLAGS.input_height,
                          output_width=FLAGS.output_width,
                          output_height=FLAGS.output_height,
                          batch_size=FLAGS.batch_size,
                          test_batch_size=FLAGS.test_batch_size,
                          sample_num=FLAGS.batch_size,
                          y_dim=10,
                          data_dir=FLAGS.data_dir,
                          dataset_name=FLAGS.dataset,
                          input_fname_pattern=FLAGS.input_fname_pattern,
                          crop=FLAGS.crop,
                          max_to_keep=FLAGS.max_to_keep,
                          checkpoint_dir=FLAGS.checkpoint_dir,
                          sample_dir=FLAGS.sample_dir,
                          test_dir=FLAGS.test_dir)
        else:
            dcgan = DCGAN(sess,
                          input_width=FLAGS.input_width,
                          input_height=FLAGS.input_height,
                          output_width=FLAGS.output_width,
                          output_height=FLAGS.output_height,
                          batch_size=FLAGS.batch_size,
                          test_batch_size=FLAGS.test_batch_size,
                          sample_num=FLAGS.batch_size,
                          data_dir=FLAGS.data_dir,
                          dataset_name=FLAGS.dataset,
                          input_fname_pattern=FLAGS.input_fname_pattern,
                          crop=FLAGS.crop,
                          max_to_keep=FLAGS.max_to_keep,
                          checkpoint_dir=FLAGS.checkpoint_dir,
                          sample_dir=FLAGS.sample_dir,
                          test_dir=FLAGS.test_dir)

        show_all_variables()

        if FLAGS.train:
            dcgan.train(FLAGS)
        else:
            if not dcgan.load(FLAGS.checkpoint_dir)[0]:
                raise Exception("[!] Train a model first, then run test mode")

        if FLAGS.anomaly_test:
            dcgan.anomaly_detector()
            assert len(dcgan.test_data_names) > 0
            for idx in range(len(dcgan.test_data_names)):
                test_input = np.expand_dims(dcgan.test_data[idx], axis=0)
                test_name = dcgan.test_data_names[idx]
                dcgan.train_anomaly_detector(FLAGS, test_input, test_name)
示例#3
0
img_path = "temp"
if not os.path.isdir("temp"):
    os.mkdir("temp")

with tf.Session(config=sess_config) as sess:
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=1000)
    #summary_op = tf.summary.merge_all()
    init = tf.global_variables_initializer()
    sess.run(init)

    model_checkpoint_name = config.PATH_CHECKPOINT + "/model.ckpt"    
    print('Loaded latest model checkpoint')
    saver.restore(sess, model_checkpoint_name)

    model.anomaly_detector()

    #lr = tf.train.exponential_decay(0.01, model.global_step, decay_steps=8000, decay_rate=0.9)
    
    #with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='anomaly_detector')):
        #train_Z = tf.train.AdamOptimizer(learning_rate=lr).minimize(model.anomaly_score, global_step=model.global_step, var_list=model.z_vars)
    train_Z = tf.train.AdamOptimizer(0.0002, beta1=0.5).minimize(model.anomaly_score, var_list=model.z_vars)

    global_vars = tf.global_variables()
    is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
    not_initialized_vars = [v for (v,f) in zip(global_vars, is_not_initialized) if not f]
    print([str(i.name) for i in not_initialized_vars])
    sess.run(tf.variables_initializer(not_initialized_vars))

    for epoch in range(config.EPOCH):
        for idx in range(num_iters):