Ejemplo n.º 1
0
def _demo():
    # Load checkpoint
    if not tf.gfile.IsDirectory(FLAGS.checkpoint_dir):
        raise FileNotFoundError("Could not find folder `%s'" %
                                (FLAGS.checkpoint_dir, ))

    # Setup global tensorflow state
    sess, summary_writer = setup_tensorflow()

    # Prepare directories
    filenames = [random.choice(prepare_dirs(delete_train_dir=False))]

    # Setup async input queues
    features, labels = srez_input.setup_inputs(sess, filenames)

    # Create and initialize model
    [gene_minput, gene_moutput,
     gene_output, gene_var_list,
     disc_real_output, disc_fake_output, disc_var_list] = \
            srez_model.create_model(sess, features, labels)

    # Restore variables from checkpoint
    #     saver = tf.train.Saver()
    #     filename = 'checkpoint_new.txt'
    #     filename = os.path.join(FLAGS.checkpoint_dir, filename)
    #     saver.restore(sess, filename)

    filename = 'checkpoint_new.txt.meta'
    filename = os.path.join(FLAGS.checkpoint_dir, filename)
    saver = tf.train.import_meta_graph(filename)
    saver.restore(sess, tf.train.latest_checkpoint('./'))

    # Execute demo
    train_data = TrainData(locals())
    srez_demo.demo1(train_data)
Ejemplo n.º 2
0
    def _demo(self):
        # Load checkpoint
        if not tf.gfile.IsDirectory(FLAGS.checkpoint_dir):
            raise FileNotFoundError("Could not find folder `%s'" %
                                    (FLAGS.checkpoint_dir, ))

        # Setup global tensorflow state
        sess, summary_writer = setup_tensorflow()

        # Prepare directories
        filenames = prepare_dirs(delete_train_dir=False)

        # Setup async input queues
        features, labels = srez_input.setup_inputs(sess, filenames)

        # Create and initialize model
        [gene_minput, gene_moutput,
        gene_output, gene_var_list,
        disc_real_output, disc_fake_output, disc_var_list] = \
                srez_model.create_model(sess, features, labels)

        # Restore variables from checkpoint
        saver = tf.train.Saver()
        filename = 'checkpoint_new.txt'
        filename = os.path.join(FLAGS.checkpoint_dir, filename)
        saver.restore(sess, filename)

        # Execute demo
        srez_demo.demo1(sess)
Ejemplo n.º 3
0
def _demo():
    # Load checkpoint
    if not tf.gfile.IsDirectory(FLAGS.checkpoint_dir):
        raise FileNotFoundError("Could not find folder `%s'" % (FLAGS.checkpoint_dir,))

    # Setup global tensorflow state
    sess, summary_writer = setup_tensorflow()

    # Prepare directories
    filenames = prepare_dirs(delete_train_dir=False)

    # Setup async input queues
    features, labels = srez_input.setup_inputs(sess, filenames)

    # Create and initialize model
    [gene_minput, gene_moutput,
     gene_output, gene_var_list,
     disc_real_output, disc_fake_output, disc_var_list] = \
            srez_model.create_model(sess, features, labels)

    # Restore variables from checkpoint
    saver = tf.train.Saver()
    filename = 'checkpoint_new.txt'
    filename = os.path.join(FLAGS.checkpoint_dir, filename)
    saver.restore(sess, filename)

    # Execute demo
    srez_demo.demo1(sess)
Ejemplo n.º 4
0
def srez_output(input_fn, output_fn, checkpoint_file):
  input_image, sz, hackfn = get_input_feature(input_fn)
  # dummy files to satisfy input pipeline
  # TODO: remove input pipeline
  filenames = ['/input/data/sample.jpg']

  sess, summary_writer = setup_tensorflow()
  try:
    features, labels = srez_input.setup_inputs(sess, filenames, sz*4)
    [gene_minput, gene_moutput,
     gene_output, gene_var_list,
     disc_real_output, disc_fake_output, disc_var_list] = \
            srez_model.create_model(sess, features, labels)

    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_file)

    feed_dict = {gene_minput: input_image}
    gene_output = sess.run(gene_moutput, feed_dict=feed_dict)

    gene_output = gene_output.reshape(list(gene_output.shape)[1:])
    misc.toimage(gene_output, cmin=0., cmax=1.).save(output_fn)
  except tf.errors.CancelledError:
    pass
  finally:
    sess.close()
Ejemplo n.º 5
0
def _train():
    # Setup global tensorflow state
    sess, summary_writer = setup_tensorflow()

    # Prepare directories
    # all_filenames = 
    prepare_dirs(delete_train_dir=True, shuffle_filename=False)
    filenames_input = get_filenames(dir_file=FLAGS.dataset_input, shuffle_filename=False)
    # if not specify use the same as input
    if FLAGS.dataset_output == '':
        FLAGS.dataset_output = FLAGS.dataset_input
    filenames_output = get_filenames(dir_file=FLAGS.dataset_output, shuffle_filename=False)

    # Separate training and test sets
    # train_filenames = all_filenames[:-FLAGS.test_vectors]
    # test_filenames  = all_filenames[-FLAGS.test_vectors:]
    train_filenames_input = filenames_input[:-FLAGS.test_vectors]
    test_filenames_input  = filenames_input[-FLAGS.test_vectors:]
    train_filenames_output = filenames_output[:-FLAGS.test_vectors]
    test_filenames_output  = filenames_output[-FLAGS.test_vectors:]

    # TBD: Maybe download dataset here

    # Setup async input queues
    train_features, train_labels = srez_input.setup_inputs_one_sources(sess, train_filenames_input, train_filenames_output, 
                                                                        image_size=FLAGS.sample_size, axis_undersample=FLAGS.axis_undersample)
    test_features,  test_labels  = srez_input.setup_inputs_one_sources(sess, test_filenames_input, test_filenames_output,
                                                                        image_size=FLAGS.sample_size, axis_undersample=FLAGS.axis_undersample)
    
    # sample size
    num_sample_train = len(train_filenames_input)
    num_sample_test = len(test_filenames_input)
    print('train on {0} samples and test on {1} samples'.format(num_sample_train, num_sample_test))

    # Add some noise during training (think denoising autoencoders)
    noise_level = .00
    noisy_train_features = train_features + \
                           tf.random_normal(train_features.get_shape(), stddev=noise_level)

    # Create and initialize model
    [gene_minput, gene_moutput,
     gene_output, gene_var_list,
     disc_real_output, disc_fake_output, disc_var_list,
     gene_layers, gene_mlayers] = \
            srez_model.create_model(sess, noisy_train_features, train_labels)

    gene_loss = srez_model.create_generator_loss(disc_fake_output, gene_output, train_features, train_labels)
    disc_real_loss, disc_fake_loss = \
                     srez_model.create_discriminator_loss(disc_real_output, disc_fake_output)
    disc_loss = tf.add(disc_real_loss, disc_fake_loss, name='disc_loss')
    
    (global_step, learning_rate, gene_minimize, disc_minimize) = \
            srez_model.create_optimizers(gene_loss, gene_var_list,
                                         disc_loss, disc_var_list)

    # Train model
    train_data = TrainData(locals())
    srez_train.train_model(train_data, num_sample_train, num_sample_test)
Ejemplo n.º 6
0
def retro_infere(sess):
    """Generate image based on model"""

    # Setup async input queues
    filenames = tf.gfile.ListDirectory(FLAGS.source_fake_dir)
    filenames = sorted(filenames)
    filenames = [
        os.path.join(FLAGS.source_dir, f) for f in filenames
        if f[-4:] == '.jpg'
    ]

    labelnames = tf.gfile.ListDirectory(FLAGS.source_real_dir)
    labelnames = sorted(labelnames)
    labelnames = [
        os.path.join(FLAGS.source_real_dir, f) for f in labelnames
        if f[-4:] == '.jpg'
    ]
    features, labels = srez_input.setup_inputs(sess, filenames, labelnames)

    # Create and initialize model
    [gene_minput, gene_moutput,
     gene_output, gene_var_list,
     disc_real_output, disc_fake_output, disc_var_list] = \
            srez_model.create_model(sess, features, labels)

    # Restore variables from checkpoint
    saver = tf.train.Saver()
    chk_filename = 'checkpoint_new.txt'
    chk_filename = os.path.join(FLAGS.checkpoint_dir, chk_filename)
    saver.restore(sess, chk_filename)

    # Infere on an image
    feature, label = sess.run([features, labels])
    feed_dict = {gene_minput: feature}
    gene_output = sess.run(gene_moutput, feed_dict=feed_dict)

    size = [label.shape[1], label.shape[2]]

    nearest = tf.image.resize_nearest_neighbor(feature, size)
    nearest = tf.maximum(tf.minimum(nearest, 1.0), 0.0)

    bicubic = tf.image.resize_bicubic(feature, size)
    bicubic = tf.maximum(tf.minimum(bicubic, 1.0), 0.0)

    clipped = tf.maximum(tf.minimum(gene_output, 1.0), 0.0)

    image = tf.concat(2, [nearest, bicubic, clipped, label])
    #image   = tf.concat(0, [image[0,:,:,:]])
    print("There are ", len(filenames), "files")
    image = tf.concat(0, [image[i, :, :, :] for i in range(len(filenames))])

    image = sess.run(image)

    res_filename = 'check_results.png'
    res_filename = os.path.join(FLAGS.result_dir, res_filename)
    scipy.misc.toimage(image, cmin=0., cmax=1.).save(res_filename)
    print("    Saved result")
Ejemplo n.º 7
0
def _train():
    # Setup global tensorflow state
    sess = setup_tensorflow()

    # Prepare directories
    # all_filenames = prepare_dirs(delete_train_dir=True)
    all_filenames = prepare_dirs(delete_train_dir=False)

    # Separate training and test sets
    train_filenames = all_filenames[:-FLAGS.test_vectors]
    test_filenames = all_filenames[-FLAGS.test_vectors:]

    # TBD: Maybe download dataset here

    # Setup async input queues
    train_features, train_labels = srez_input.setup_inputs(
        sess, train_filenames)
    test_features, test_labels = srez_input.setup_inputs(sess, test_filenames)

    # Add some noise during training (think denoising autoencoders)
    noise_level = .03
    noisy_train_features = train_features + \
                           tf.random_normal(train_features.get_shape(), stddev=noise_level)

    # Create and initialize model
    [gene_minput, gene_moutput,
     gene_output, gene_var_list,
     disc_real_output, disc_fake_output, disc_var_list,
     dropout] = \
        srez_model.create_model(sess, noisy_train_features, train_labels)

    gene_loss = srez_model.create_generator_loss(disc_fake_output, gene_output,
                                                 train_features)
    disc_real_loss, disc_fake_loss = \
        srez_model.create_discriminator_loss(disc_real_output, disc_fake_output)
    disc_loss = tf.add(disc_real_loss, disc_fake_loss, name='disc_loss')

    (global_step, learning_rate, gene_minimize, disc_minimize) = \
        srez_model.create_optimizers(gene_loss, gene_var_list,
                                     disc_loss, disc_var_list)

    # Restore variables from checkpoint if EXISTS
    # if tf.gfile.IsDirectory(FLAGS.checkpoint_dir):
    #     filename = 'checkpoint_new.txt'
    #     filename = os.path.join(FLAGS.checkpoint_dir, filename)
    #     saver = tf.train.Saver()
    #     if tf.gfile.Exists(filename):
    #         saver.restore(tf.Session(), filename)
    #         print("Restored previous checkpoint. "
    #               "Warning, Batch number restarted.")

    # Train model
    train_data = TrainData(locals())
    srez_train.train_model(train_data)
Ejemplo n.º 8
0
def _test16(onefilename=False):
    # Load checkpoint
    if not tf.gfile.IsDirectory(FLAGS.checkpoint_dir):
        raise FileNotFoundError("Could not find folder `%s'" %
                                (FLAGS.checkpoint_dir, ))

    # Load test set
    if not tf.gfile.IsDirectory(FLAGS.test_dir):
        raise FileNotFoundError("Could not find folder `%s'" %
                                (FLAGS.test_dir, ))

    # Setup global tensorflow state
    sess = setup_tensorflow()

    # Prepare directories
    if os.path.isfile(onefilename):
        filenames = [onefilename]
    elif os.path.isdir(onefilename):
        filenames = [
            os.path.join(onefilename, f) for f in os.listdir(onefilename)
            if os.path.isfile(os.path.join(onefilename, f))
        ]

    # Was to check the size input.
    # im = Image.open(onefilename)
    # size = im.size

    # Setup async input queues
    test_features, test_labels = srez_input.test_inputs(sess, filenames)

    # Create and initialize model
    [gene_minput, gene_moutput,
     gene_output, gene_var_list,
     disc_real_output, disc_fake_output, disc_var_list,
     dropout] = \
        srez_model.create_model(sess, test_features, test_labels)

    # Restore variables from checkpoint
    saver = tf.train.Saver()
    filename = 'checkpoint_new.txt'
    filename = os.path.join(FLAGS.checkpoint_dir, filename)

    saver.restore(sess, filename)

    # Loop through and run/predict each picture.
    for file in filenames:
        test_features, test_labels = srez_input.test_inputs(sess, [file])

        test_feature, test_label = sess.run([test_features, test_labels])
        feed_dict = {gene_minput: test_label, dropout: 1.0}
        gene_output = sess.run(gene_moutput, feed_dict=feed_dict)

        srez_test.predict_one(sess, test_feature, test_label, gene_output,
                              file)
Ejemplo n.º 9
0
def _evaluate():
    # Load checkpoint
    if not tf.gfile.IsDirectory(FLAGS.checkpoint_dir):
        raise FileNotFoundError("Could not find folder `%s'" %
                                (FLAGS.checkpoint_dir, ))

    # Setup global tensorflow state
    sess, summary_writer = setup_tensorflow()

    # Prepare directories
    Tfilenames = prepare_dirs(delete_train_dir=False)

    features, labels = srez_input.setup_inputs(sess, Tfilenames)

    # Create and initialize model
    [gene_minput, gene_moutput,
    gene_output, gene_var_list,
    disc_real_output, disc_fake_output, disc_var_list] = \
            srez_model.create_model(sess, features, labels)

    # Restore variables from checkpoint_dir
    saver = tf.train.Saver()
    filename = 'checkpoint_new.txt'
    filename = os.path.join(FLAGS.checkpoint_dir, filename)
    saver.restore(sess, filename)

    print("RESTORE model")

    #region prediction test
    predict_restore = gene_moutput

    # Prepare directories
    filenames = prepare_dirs(delete_train_dir=False)

    # UPDATE: Evaluate 4 images. NEEDS to be 4.
    test_filenames = [
        'dataset/101287.jpg', 'dataset/101288.jpg', 'dataset/101289.jpg',
        'dataset/101290.jpg'
    ]

    test_features, test_labels = srez_input.setup_inputs(sess, test_filenames)
    test_img4_input, test_img4_original = sess.run(
        [test_features, test_labels])

    test_img5 = (tf.convert_to_tensor(test_img4_input)).eval(session=sess)
    feed_dict = {gene_minput: test_img5}
    prob = sess.run(predict_restore, feed_dict)
    #endregion prediction test

    td = TrainData(locals())
    srez_train._summarize_progress(td, test_img4_input, test_img4_original,
                                   prob, 69, 'out')

    print("Finish EVALUATING")
Ejemplo n.º 10
0
def _train():
    # Prepare directories
    all_filenames = prepare_dirs(delete_train_dir=True)

    # Setup global tensorflow state
    sess, summary_writer = setup_tensorflow()

    #saver = tf.train.Saver()
    #filename = 'checkpoint'
    #filename = os.path.join(FLAGS.checkpoint_dir, filename)
    #saver.restore(sess,tf.train.latest_checkpoint("./checkpoint/"))
    #print("Model restored from file: %s" % FLAGS.checkpoint_dir)

    # Separate training and test sets
    train_filenames = all_filenames[:-FLAGS.test_vectors]
    test_filenames = all_filenames[-FLAGS.test_vectors:]

    # TBD: Maybe download dataset here

    # Setup async input queues
    train_features, train_labels = srez_input.setup_inputs(
        sess, train_filenames)
    test_features, test_labels = srez_input.setup_inputs(sess, test_filenames)

    # Add some noise during training (think denoising autoencoders)
    noise_level = .03
    noisy_train_features = train_features + \
                           tf.random_normal(train_features.get_shape(), stddev=noise_level)

    # Create and initialize model
    [gene_minput, gene_moutput,
     gene_output, gene_var_list,
     disc_real_output, disc_fake_output, disc_var_list] = \
            srez_model.create_model(sess, noisy_train_features, train_labels)

    gene_loss = srez_model.create_generator_loss(disc_fake_output, gene_output,
                                                 train_features)
    disc_real_loss, disc_fake_loss = \
                     srez_model.create_discriminator_loss(disc_real_output, disc_fake_output)
    disc_loss = tf.add(disc_real_loss, disc_fake_loss, name='disc_loss')

    (global_step, learning_rate, gene_minimize, disc_minimize) = \
            srez_model.create_optimizers(gene_loss, gene_var_list,
                                         disc_loss, disc_var_list)

    # Train model
    train_data = TrainData(locals())
    srez_train.train_model(train_data)
Ejemplo n.º 11
0
def _train():
    # Setup global tensorflow state
    sess, summary_writer = setup_tensorflow()

    # Prepare directories
    all_filenames = prepare_dirs(delete_train_dir=True)

    # Separate training and test sets
    rn.shuffle(all_filenames)
    train_filenames = all_filenames[:-FLAGS.test_vectors]
    test_filenames = all_filenames[-FLAGS.test_vectors:]

    # TBD: Maybe download dataset here

    # Setup async input queues
    train_features, train_labels = srez_input.setup_inputs(sess,
                                                           train_filenames,
                                                           image_size=32,
                                                           crop_size=128)
    test_features, test_labels = srez_input.setup_inputs(sess,
                                                         test_filenames,
                                                         image_size=32,
                                                         crop_size=128)

    # Add some noise during training (think denoising autoencoders)
    noise_level = .03
    noisy_train_features = train_features + \
                           tf.random_normal(train_features.get_shape(), stddev=noise_level)

    # Create and initialize model
    [gene_minput, gene_moutput,
     gene_output, gene_var_list,
     disc_real_output, disc_fake_output, disc_var_list] = \
            srez_model.create_model(sess, noisy_train_features, train_labels)

    gene_loss = srez_model.create_generator_loss(disc_fake_output, gene_output,
                                                 train_features)
    disc_real_loss, disc_fake_loss = \
                     srez_model.create_discriminator_loss(disc_real_output, disc_fake_output)
    disc_loss = tf.add(disc_real_loss, disc_fake_loss, name='disc_loss')

    (global_step, learning_rate, gene_minimize, disc_minimize) = \
            srez_model.create_optimizers(gene_loss, gene_var_list,
                                         disc_loss, disc_var_list)

    # Train model
    train_data = TrainData(locals())
    srez_train.train_model(train_data)
Ejemplo n.º 12
0
def _test(onefilename=False):
    # Load checkpoint
    if not tf.gfile.IsDirectory(FLAGS.checkpoint_dir):
        raise FileNotFoundError("Could not find folder `%s'" %
                                (FLAGS.checkpoint_dir, ))

    # Prepare directories
    if onefilename:
        filenames = [onefilename]
    else:
        # Load test set
        if not tf.gfile.IsDirectory(FLAGS.test_dir):
            raise FileNotFoundError("Could not find folder `%s'" %
                                    (FLAGS.test_dir, ))
        filenames = prepare_test_dir()

    # Setup global tensorflow state
    sess = setup_tensorflow()

    # Setup async input queues
    test_features, test_labels = srez_input.test_inputs(sess, filenames)

    # Create and initialize model
    [gene_minput, gene_moutput,
     gene_output, gene_var_list,
     disc_real_output, disc_fake_output, disc_var_list,
     dropout] = \
        srez_model.create_model(sess, test_features, test_labels)
    dropout = tf.placeholder(tf.float32)

    # Restore variables from checkpoint
    saver = tf.train.Saver()
    filename = 'checkpoint_new.txt'
    filename = os.path.join(FLAGS.checkpoint_dir, filename)

    saver.restore(sess, filename)

    test_data = TrainData(locals())
    td = test_data
    test_feature, test_label = sess.run([test_features, test_labels])
    feed_dict = {gene_minput: test_feature, dropout: 1.0}
    gene_output = sess.run(gene_moutput, feed_dict=feed_dict)

    if onefilename:
        srez_test.predict_one(test_data, gene_output)
    else:
        srez_test.predict(test_data, test_feature, test_label, gene_output)
def _train():
    # Setup global tensorflow state
    sess, summary_writer = setup_tensorflow()

    # Prepare directories
    all_filenames = prepare_dirs(delete_train_dir=False)

    # Separate training and test sets

    if FLAGS.specific_test:
        train_filenames = all_filenames[:]
        test_filenames = prepare_test_dirs()[:]
    else:
        train_filenames = all_filenames[:-FLAGS.test_vectors]
        test_filenames = all_filenames[-FLAGS.test_vectors:]

    # Setup async input queues
    train_features, train_labels = srez_input.setup_inputs(
        sess, train_filenames)
    test_features, test_labels = srez_input.setup_inputs(sess, test_filenames)

    # Add some noise during training (think denoising autoencoders)
    noise_level = .03
    noisy_train_features = train_features + \
                           tf.random_normal(train_features.get_shape(), stddev=noise_level)

    # Create and initialize model
    [gene_minput, gene_moutput,
     gene_output, gene_var_list,
     disc_real_output, disc_fake_output, disc_var_list] = \
            srez_model.create_model(sess, noisy_train_features, train_labels)

    gene_loss, gene_l1_loss, gene_ce_loss = srez_model.create_generator_loss(
        disc_fake_output, gene_output, train_features, train_labels)
    disc_real_loss, disc_fake_loss = \
                     srez_model.create_discriminator_loss(disc_real_output, disc_fake_output)
    disc_loss = tf.add(2 * FLAGS.disc_real_factor * disc_real_loss,
                       2 * (1 - FLAGS.disc_real_factor) * disc_fake_loss,
                       name='disc_loss')

    (global_step, learning_rate, gene_minimize, disc_minimize) = \
            srez_model.create_optimizers(gene_loss, gene_var_list,
                                         disc_loss, disc_var_list)
    # Train model
    train_data = TrainData(locals())
    srez_train.train_model(train_data)
Ejemplo n.º 14
0
def _train():
    # Setup global tensorflow state
    sess, summary_writer = setup_tensorflow()

    # Prepare directories
    all_filenames = prepare_dirs(delete_train_dir=True)

    # Separate training and test sets
    train_filenames = all_filenames[:-FLAGS.test_vectors]
    test_filenames  = all_filenames[-FLAGS.test_vectors:]

    # TBD: Maybe download dataset here

    # Setup async input queues
    train_features, train_labels = srez_input.setup_inputs(sess, train_filenames)
    test_features,  test_labels  = srez_input.setup_inputs(sess, test_filenames)

    # Add some noise during training (think denoising autoencoders)
    noise_level = .03
    noisy_train_features = train_features + \
                           tf.random_normal(train_features.get_shape(), stddev=noise_level)

    # Create and initialize model
    [gene_minput, gene_moutput,
     gene_output, gene_var_list,
     disc_real_output, disc_fake_output, disc_var_list] = \
            srez_model.create_model(sess, noisy_train_features, train_labels)

    gene_loss = srez_model.create_generator_loss(disc_fake_output, gene_output, train_features)
    disc_real_loss, disc_fake_loss = \
                     srez_model.create_discriminator_loss(disc_real_output, disc_fake_output)
    disc_loss = tf.add(disc_real_loss, disc_fake_loss, name='disc_loss')
    
    (global_step, learning_rate, gene_minimize, disc_minimize) = \
            srez_model.create_optimizers(gene_loss, gene_var_list,
                                         disc_loss, disc_var_list)

    # Train model
    train_data = TrainData(locals())
    srez_train.train_model(train_data)
Ejemplo n.º 15
0
def _get_inference_data():
    # Setup global tensorflow state
    sess, summary_writer = setup_tensorflow()

    # Load single image to use for inference
    if FLAGS.infile is None:
        raise ValueError(
            'Must specify inference input file through `--infile <filename>` command line argument'
        )

    if not tf.gfile.Exists(FLAGS.infile) or tf.gfile.IsDirectory(FLAGS.infile):
        raise FileNotFoundError('File `%s` does not exist or is a directory' %
                                (FLAGS.infile, ))

    filenames = [FLAGS.infile]
    infer_images = srez_input.setup_inputs(sess, filenames)

    print('Loading model...')
    # Create inference model
    infer_model = srez_model.create_model(sess, infer_images)

    # Load model parameters from checkpoint
    checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
    try:
        saver = tf.train.Saver()
        saver.restore(sess, checkpoint.model_checkpoint_path)
        del saver
        del checkpoint
    except:
        raise RuntimeError('Unable to read checkpoint from `%s`' %
                           (FLAGS.checkpoint_dir, ))
    print('Done.')

    # Pack all for convenience
    infer_data = srez_utils.Container(locals())

    return infer_data
Ejemplo n.º 16
0
def _demo2(demo_func=srez_demo.demo2):
    time_start = time.strftime("%Y-%m-%d-%H-%M-%S")
    print("START. Time is {}".format(time_start))
    mkdirp(FLAGS.train_dir)
    parameters = save_parameters(time_start=time_start)

    # Load checkpoint
    if not tf.gfile.IsDirectory(FLAGS.checkpoint_dir):
        raise FileNotFoundError("Could not find folder '{}'".format(
            FLAGS.checkpoint_dir))

    # Setup global tensorflow state
    sess, summary_writer = setup_tensorflow()

    # data directories
    if FLAGS.preset_files:
        test_filenames_input = prepare_preset_files(FLAGS.dataset,
                                                    test_files,
                                                    shuffle=False)
        test_filenames_output = prepare_preset_files(FLAGS.dataset,
                                                     test_files,
                                                     shuffle=False)
    else:
        test_filenames_input = get_filenames(dir_file=FLAGS.dataset,
                                             shuffle_filename=False)
        test_filenames_output = get_filenames(dir_file=FLAGS.dataset,
                                              shuffle_filename=False)

    if FLAGS.subsample_test > 0:
        index_sample_test_selected = random.sample(
            range(len(test_filenames_input)), FLAGS.subsample_test)
        if not FLAGS.permutation_test:
            index_sample_test_selected = sorted(index_sample_test_selected)
        test_filenames_input = [
            test_filenames_input[x] for x in index_sample_test_selected
        ]
        test_filenames_output = [
            test_filenames_output[x] for x in index_sample_test_selected
        ]

    # image_size
    if FLAGS.sample_size_y > 0:
        image_size = [FLAGS.sample_size, FLAGS.sample_size_y]
    else:
        image_size = [FLAGS.sample_size, FLAGS.sample_size]

    # get undersample mask
    from scipy import io as sio
    try:
        content_mask = sio.loadmat(FLAGS.sampling_pattern)
        key_mask = [x for x in content_mask.keys() if not x.startswith('_')]
        mask = content_mask[key_mask[0]]
    except:
        mask = None

    # Setup async input queues
    features, labels, masks = srez_input.setup_inputs_one_sources(
        sess,
        test_filenames_input,
        test_filenames_output,
        image_size=image_size,
        # undersampling
        axis_undersample=FLAGS.axis_undersample,
        r_factor=FLAGS.R_factor,
        r_alpha=FLAGS.R_alpha,
        r_seed=FLAGS.R_seed,
        sampling_mask=mask)

    # Create and initialize model
    [gene_minput, gene_moutput, gene_moutput_complex, gene_output,
     gene_output_complex, gene_var_list, gene_layers, gene_mlayers,
     disc_real_output, disc_fake_output, disc_moutput,
     disc_var_list, disc_layers, disc_mlayers] = \
            srez_model.create_model(sess, features, labels, masks)

    # Restore variables from checkpoint
    saver = tf.train.Saver()
    filename = 'checkpoint_new.txt'
    filename = os.path.join(FLAGS.checkpoint_dir, filename)
    saver.restore(sess, filename)

    # Execute demo
    test_data = TrainData(locals())
    if FLAGS.subsample_test > 0:
        num_sample = FLAGS.subsample_test
    elif FLAGS.sample_test > 0:
        num_sample = FLAGS.sample_test
    else:
        num_sample = FLAGS.batch_size
    demo_func(test_data, num_sample)

    time_ended = time.strftime("%Y-%m-%d-%H-%M-%S")
    print("ENDED. Time is {}".format(time_ended))
    # Overwrite log file now that we are complete
    save_parameters(use_flags=False,
                    existing=parameters,
                    time_ended=time_ended)
Ejemplo n.º 17
0
def _train():
    time_start = time.strftime("%Y-%m-%d-%H-%M-%S")
    print("START. Time is {}".format(time_start))

    # Setup global tensorflow state
    sess, summary_writer = setup_tensorflow()

    # image_size
    if FLAGS.sample_size_y > 0:
        image_size = [FLAGS.sample_size, FLAGS.sample_size_y]
    else:
        image_size = [FLAGS.sample_size, FLAGS.sample_size]

    # Prepare train and test directories (SEPARATE FOLDER)
    prepare_dirs(delete_train_dir=True, shuffle_filename=False)
    if FLAGS.cv_index >= 0:
        # Cross-validation
        filenames_input_train = []
        filenames_output_train = []
        for i in range(FLAGS.cv_groups):
            if i == FLAGS.cv_index:
                continue
            train_dir = os.path.join(FLAGS.dataset, str(i))
            filenames = get_filenames(dir_file=train_dir,
                                      shuffle_filename=True)
            filenames_input_train.extend(filenames)
            filenames_output_train.extend(filenames)
        test_dir = os.path.join(FLAGS.dataset, str(FLAGS.cv_index))
        filenames_input_test = get_filenames(dir_file=test_dir,
                                             shuffle_filename=True)
        filenames_output_test = get_filenames(dir_file=test_dir,
                                              shuffle_filename=True)
    else:
        filenames_input_train = get_filenames(dir_file=FLAGS.dataset_train,
                                              shuffle_filename=True)
        filenames_output_train = get_filenames(dir_file=FLAGS.dataset_train,
                                               shuffle_filename=True)
        filenames_input_test = get_filenames(dir_file=FLAGS.dataset_test,
                                             shuffle_filename=False)
        filenames_output_test = get_filenames(dir_file=FLAGS.dataset_test,
                                              shuffle_filename=False)

    # Record parameters
    parameters = save_parameters(time_start=time_start)

    ## Prepare directories (SAME FOLDER)
    #prepare_dirs(delete_train_dir=True, shuffle_filename=False)
    #filenames_input = get_filenames(dir_file=FLAGS.dataset_input, shuffle_filename=False)
    ## if not specify use the same as input
    #if FLAGS.dataset_output == '':
    #FLAGS.dataset_output = FLAGS.dataset_input
    #filenames_output = get_filenames(dir_file=FLAGS.dataset_output, shuffle_filename=False)

    # check input and output sample number matches (SEPARATE FOLDER)
    assert (len(filenames_input_train) == len(filenames_output_train))
    num_filename_train = len(filenames_input_train)
    assert (len(filenames_input_test) == len(filenames_output_test))
    num_filename_test = len(filenames_input_test)

    print(num_filename_train)
    print(num_filename_test)

    # check input and output sample number matches (SAME FOLDER)
    #assert(len(filenames_input)==len(filenames_output))
    #num_filename_all = len(filenames_input)

    # Permutate train and test split (SEPARATE FOLDERS)
    if FLAGS.permutation_split:
        index_permutation_split = random.sample(num_filename_train,
                                                num_filename_train)
        filenames_input_train = [
            filenames_input_train[x] for x in index_permutation_split
        ]
        filenames_output_train = [
            filenames_output_train[x] for x in index_permutation_split
        ]
        #print(np.shape(filenames_input_train))

    if FLAGS.permutation_split:
        index_permutation_split = random.sample(num_filename_test,
                                                num_filename_test)
        filenames_input_test = [
            filenames_input_test[x] for x in index_permutation_split
        ]
        filenames_output_test = [
            filenames_output_test[x] for x in index_permutation_split
        ]
    #print('filenames_input[:20]',filenames_input[:20])

    # Permutate test split (SAME FOLDERS)
    #if FLAGS.permutation_split:
    #index_permutation_split = random.sample(num_filename_test, num_filename_test)
    #filenames_input_test = [filenames_input_test[x] for x in index_permutation_split]
    #filenames_output_test = [filenames_output_test[x] for x in index_permutation_split]
    #print('filenames_input[:20]',filenames_input[:20])

    # Separate training and test sets (SEPARATE FOLDERS)
    sample_train = len(filenames_input_train
                       ) if FLAGS.sample_train <= 0 else FLAGS.sample_train
    sample_test = len(
        filenames_input_test) if FLAGS.sample_test <= 0 else FLAGS.sample_test

    train_filenames_input = filenames_input_train[:sample_train]
    train_filenames_output = filenames_output_train[:sample_train]

    # TODO If separate folders, make the index `:sample_test`
    # Using index `-sample_test:` hacks it for a same-folder split.
    test_filenames_input = filenames_input_test[
        -sample_test:]  # filenames_input_test[:sample_test]
    test_filenames_output = filenames_output_test[
        -sample_test:]  # filenames_output_test[:sample_test]
    #print('test_filenames_input', test_filenames_input)
    #print('train_filenames_input', train_filenames_input)

    # Separate training and test sets (SAME FOLDERS)
    #train_filenames_input = filenames_input[:-FLAGS.sample_test]
    #train_filenames_output = filenames_output[:-FLAGS.sample_test]
    #test_filenames_input  = filenames_input[-FLAGS.sample_test:]
    #test_filenames_output  = filenames_output[-FLAGS.sample_test:]
    #print('test_filenames_input[:20]',test_filenames_input[:20])

    # randomly subsample for train
    if FLAGS.subsample_train > 0:
        index_sample_train_selected = random.sample(
            range(len(train_filenames_input)), FLAGS.subsample_train)
        if not FLAGS.permutation_train:
            index_sample_train_selected = sorted(index_sample_train_selected)
        train_filenames_input = [
            train_filenames_input[x] for x in index_sample_train_selected
        ]
        train_filenames_output = [
            train_filenames_output[x] for x in index_sample_train_selected
        ]
        print('randomly sampled {0} from {1} train samples'.format(
            len(train_filenames_input),
            len(train_filenames_input[:FLAGS.sample_train])))

    # randomly sub-sample for test
    if FLAGS.subsample_test > 0:
        index_sample_test_selected = random.sample(
            range(len(test_filenames_input)), FLAGS.subsample_test)
        if not FLAGS.permutation_test:
            index_sample_test_selected = sorted(index_sample_test_selected)
        test_filenames_input = [
            test_filenames_input[x] for x in index_sample_test_selected
        ]
        test_filenames_output = [
            test_filenames_output[x] for x in index_sample_test_selected
        ]
        #print('randomly sampled {0} from {1} test samples'.format(len(test_filenames_input), len(filenames_inp/.ut[:-FLAGS.sample_test])))

    #print('test_filenames_input',test_filenames_input)

    # get undersample mask
    from scipy import io as sio
    try:
        content_mask = sio.loadmat(FLAGS.sampling_pattern)
        key_mask = [x for x in content_mask.keys() if not x.startswith('_')]
        mask = content_mask[key_mask[0]]
    except:
        mask = None

    # Setup async input queues
    train_features, train_labels, train_masks = srez_input.setup_inputs_one_sources(
        sess,
        train_filenames_input,
        train_filenames_output,
        image_size=image_size,
        # undersampling
        axis_undersample=FLAGS.axis_undersample,
        r_factor=FLAGS.R_factor,
        r_alpha=FLAGS.R_alpha,
        r_seed=FLAGS.R_seed,
        sampling_mask=mask)
    test_features, test_labels, test_masks = srez_input.setup_inputs_one_sources(
        sess,
        test_filenames_input,
        test_filenames_output,
        image_size=image_size,
        # undersampling
        axis_undersample=FLAGS.axis_undersample,
        r_factor=FLAGS.R_factor,
        r_alpha=FLAGS.R_alpha,
        r_seed=FLAGS.R_seed,
        sampling_mask=mask)

    print('train_features_queue', train_features.get_shape())
    print('train_labels_queue', train_labels.get_shape())
    print('train_masks_queue', train_masks.get_shape())

    #train_masks = tf.cast(sess.run(train_masks), tf.float32)
    #test_masks = tf.cast(sess.run(test_masks), tf.float32)

    # sample train and test
    num_sample_train = len(train_filenames_input)
    num_sample_test = len(test_filenames_input)
    print('train on {0} samples and test on {1} samples'.format(
        num_sample_train, num_sample_test))

    # Add some noise during training (think denoising autoencoders)
    noise_level = .00
    noisy_train_features = train_features + \
                           tf.random_normal(train_features.get_shape(), stddev=noise_level)

    # Create and initialize model
    [gene_minput, gene_moutput, gene_moutput_complex, \
     gene_output, gene_output_complex, gene_var_list, gene_layers, gene_mlayers, \
     disc_real_output, disc_fake_output, disc_moutput, disc_var_list, disc_layers, disc_mlayers] = \
            srez_model.create_model(sess, noisy_train_features, train_labels, train_masks, architecture=FLAGS.architecture)

    gene_loss, gene_dc_loss, gene_ls_loss, list_gene_losses, gene_mse_factor = srez_model.create_generator_loss(
        disc_fake_output, gene_output, gene_output_complex, train_features,
        train_labels, train_masks)
    disc_real_loss, disc_fake_loss = \
                     srez_model.create_discriminator_loss(disc_real_output, disc_fake_output)
    disc_loss = tf.add(disc_real_loss, disc_fake_loss, name='disc_loss')

    # add gradient on disc loss
    disc_gradients = tf.gradients(
        disc_loss, [disc_fake_output, disc_real_output, gene_output])
    print('disc loss gradients:', [x.shape for x in disc_gradients])

    (global_step, learning_rate, gene_minimize, disc_minimize) = \
            srez_model.create_optimizers(gene_loss, gene_var_list,
                                         disc_loss, disc_var_list)

    summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

    # Train model
    train_data = TrainData(locals())
    srez_train.train_model(train_data, num_sample_train, num_sample_test)

    time_ended = time.strftime("%Y-%m-%d-%H-%M-%S")
    print("ENDED. Time is {}".format(time_ended))

    # Overwrite log file now that we are complete
    save_parameters(use_flags=False,
                    existing=parameters,
                    time_ended=time_ended)
# Setup global tensorflow state
sess, summary_writer = srez_main.setup_tensorflow()

# Prepare directories
filenames = srez_main.prepare_test_dirs()
assert len(filenames) >= FLAGS.max_samples, "Not enough test images"

# Setup async input queues
features, labels = srez_input.setup_inputs(sess, filenames)

# Create and initialize model
[gene_minput, gene_moutput,
 gene_output, gene_var_list,
 disc_real_output, disc_fake_output, disc_var_list] = \
    srez_model.create_model(sess, features, labels)

# Restore variables from checkpoint
saver = tf.train.Saver()
ckpt_filename = 'checkpoint_new.txt'
ckpt_filepath = os.path.join(FLAGS.checkpoint_dir, ckpt_filename)
saver.restore(sess, ckpt_filepath)

# Run inference using pretrained model
feature, label = sess.run([features, labels])

feed_dict = {gene_minput: feature}
gene_output = sess.run(gene_moutput, feed_dict=feed_dict)

size = [label.shape[1], label.shape[2]]
Ejemplo n.º 19
0
def _train():
    # Setup global tensorflow state
    sess, summary_writer = setup_tensorflow()

    # Prepare directories
    all_filenames = prepare_dirs(delete_train_dir=True)

    # Separate training and test sets
    # train_filenames = all_filenames[:-FLAGS.test_vectors]
    # test_filenames  = all_filenames[-FLAGS.test_vectors:]


    # We chose a pre-determined set of faces for the convenience of comparing results across models
    determined_test = [73883-1, 110251-1, 36510-1, 132301-1, 57264-1, 152931-1, 93861-1,
    124938-1, 79512-1, 106152-1, 127384-1, 134028-1, 67874-1,
    10613-1, 198694-1, 100990-1]
    all_filenames = np.array(all_filenames)
    train_filenames = list(np.delete(all_filenames, determined_test))

#     test_filenames = list(all_filenames[determined_test])

    # Setup async input queues
    train_features, train_labels = srez_input.setup_inputs(sess, train_filenames)
    
    # test_features,  test_labels  = srez_input.setup_inputs(sess, test_filenames)
    
    # Test sets are stored in 'testset_label.npy'
    test_labels = np.load('testset_label.npy')
    test_labels = tf.convert_to_tensor(test_labels, dtype = tf.float32)

    if FLAGS.input == 'scaled':
        test_features = tf.image.resize_area(test_labels, [16, 16])
    elif FLAGS.input == 'noise':
        test_features = tf.random_uniform(shape=[16, FLAGS.noise_dimension, FLAGS.noise_dimension, 3],minval= -1., maxval=1.)

    # Add some noise during training (think denoising autoencoders)
    noise_level = FLAGS.train_noise
    noisy_train_features = train_features + \
                           tf.random_normal(train_features.get_shape(), stddev=noise_level)

    [gene_minput, gene_moutput, gene_output, gene_var_list,
     disc_real_output, disc_fake_output, gradients, disc_var_list] = \
            srez_model.create_model(sess, noisy_train_features, train_labels)
  
    # >>> add summary scalars for test set
    max_samples = 10 # output 10 test images
    gene_output_clipped = tf.maximum(tf.minimum(gene_moutput, 1.0), 0.)
    
    # Calculate the L1 error between output samples and labels as a objective measure of image quality
    if FLAGS.input != 'noise':
      l1_quality  = tf.reduce_sum(tf.abs(gene_output_clipped - test_labels), [1,2,3])
      l1_quality = tf.reduce_mean(l1_quality[:max_samples])
      mse_quality  = tf.reduce_sum(tf.square(gene_output_clipped - test_labels), [1,2,3])
      mse_quality = tf.reduce_mean(mse_quality[:max_samples])
      tf.summary.scalar('l1_quality', l1_quality, collections=['test_scalars'])
      tf.summary.scalar('mse_quality', mse_quality, collections=['test_scalars'])


    gene_loss = srez_model.create_generator_loss(disc_fake_output, gene_output, train_features)
    disc_real_loss, disc_fake_loss = \
                     srez_model.create_discriminator_loss(disc_real_output, disc_fake_output)

    # Different training objectives
    if FLAGS.loss_func == 'dcgan':
        # for DCGAN
        disc_loss = tf.add(disc_real_loss, disc_fake_loss, name='disc_loss')
    elif FLAGS.loss_func == 'wgan':
        # for WGAN
        disc_loss = tf.subtract(disc_real_loss, disc_fake_loss, name='disc_loss')
    elif FLAGS.loss_func == 'wgangp':
        # for WGANGP
        disc_loss = tf.subtract(disc_real_loss, disc_fake_loss)
        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
        gradient_penalty = tf.reduce_mean((slopes-1.)**2)
        disc_loss = tf.add(disc_loss, FLAGS.LAMBDA*gradient_penalty, name='disc_loss')

    (global_step, learning_rate, gene_minimize, disc_minimize, d_clip) = \
            srez_model.create_optimizers(gene_loss, gene_var_list, disc_loss, disc_var_list)

    # For tensorboard
    tf.summary.scalar('generator_loss', gene_loss)
    tf.summary.scalar('discriminator_real_loss', disc_real_loss)
    tf.summary.scalar('discriminator_fake_loss', disc_fake_loss)
    tf.summary.scalar('discriminator_tot_loss', disc_loss)


    # Train model
    train_data = TrainData(locals())
    srez_train.train_model(train_data)
Ejemplo n.º 20
0
def _train():
    # Setup global tensorflow state
    sess, _oldwriter = setup_tensorflow()

    # image_size
    if FLAGS.sample_size_y > 0:
        image_size = [FLAGS.sample_size, FLAGS.sample_size_y]
    else:
        image_size = [FLAGS.sample_size, FLAGS.sample_size]

    # Prepare train and test directories (SEPARATE FOLDER)
    prepare_dirs(delete_train_dir=True, shuffle_filename=False)
    filenames_input_train = get_filenames(dir_file=FLAGS.dataset_train,
                                          shuffle_filename=True)
    filenames_output_train = get_filenames(dir_file=FLAGS.dataset_train,
                                           shuffle_filename=True)
    filenames_input_test = get_filenames(dir_file=FLAGS.dataset_test,
                                         shuffle_filename=False)
    filenames_output_test = get_filenames(dir_file=FLAGS.dataset_test,
                                          shuffle_filename=False)

    ## Prepare directories (SAME FOLDER)
    #prepare_dirs(delete_train_dir=True, shuffle_filename=False)
    #filenames_input = get_filenames(dir_file=FLAGS.dataset_input, shuffle_filename=False)
    ## if not specify use the same as input
    #if FLAGS.dataset_output == '':
    #FLAGS.dataset_output = FLAGS.dataset_input
    #filenames_output = get_filenames(dir_file=FLAGS.dataset_output, shuffle_filename=False)

    # check input and output sample number matches (SEPARATE FOLDER)
    assert (len(filenames_input_train) == len(filenames_output_train))
    num_filename_train = len(filenames_input_train)
    assert (len(filenames_input_test) == len(filenames_output_test))
    num_filename_test = len(filenames_input_test)

    #print(num_filename_train)
    #print(num_filename_test)
    #print(filenames_output_test)

    # check input and output sample number matches (SAME FOLDER)
    #assert(len(filenames_input)==len(filenames_output))
    #num_filename_all = len(filenames_input)

    # Permutate train and test split (SEPARATE FOLDERS)
    if FLAGS.permutation_split:
        index_permutation_split = random.sample(num_filename_train,
                                                num_filename_train)
        filenames_input_train = [
            filenames_input_train[x] for x in index_permutation_split
        ]
        filenames_output_train = [
            filenames_output_train[x] for x in index_permutation_split
        ]
        #print(np.shape(filenames_input_train))

    if FLAGS.permutation_split:
        index_permutation_split = random.sample(num_filename_test,
                                                num_filename_test)
        filenames_input_test = [
            filenames_input_test[x] for x in index_permutation_split
        ]
        filenames_output_test = [
            filenames_output_test[x] for x in index_permutation_split
        ]
    #print('filenames_input[:20]',filenames_input[:20])

    # Permutate test split (SAME FOLDERS)
    #if FLAGS.permutation_split:
    #index_permutation_split = random.sample(num_filename_test, num_filename_test)
    #filenames_input_test = [filenames_input_test[x] for x in index_permutation_split]
    #filenames_output_test = [filenames_output_test[x] for x in index_permutation_split]
    #print('filenames_input[:20]',filenames_input[:20])

    # Separate training and test sets (SEPARATE FOLDERS)
    train_filenames_input = filenames_input_train[:FLAGS.sample_train]
    train_filenames_output = filenames_output_train[:FLAGS.sample_train]

    test_filenames_input = filenames_input_test[:FLAGS.sample_test]
    test_filenames_output = filenames_output_test[:FLAGS.sample_test]
    #print('test_filenames_input', test_filenames_input)
    #print('train_filenames_input', train_filenames_input)

    # Separate training and test sets (SAME FOLDERS)
    #train_filenames_input = filenames_input[:-FLAGS.sample_test]
    #train_filenames_output = filenames_output[:-FLAGS.sample_test]
    #test_filenames_input  = filenames_input[-FLAGS.sample_test:]
    #test_filenames_output  = filenames_output[-FLAGS.sample_test:]
    #print('test_filenames_input[:20]',test_filenames_input[:20])

    # randomly subsample for train
    if FLAGS.subsample_train > 0:

        index_sample_train_selected = random.sample(
            range(len(train_filenames_input)), FLAGS.subsample_train)
        if not FLAGS.permutation_train:
            index_sample_train_selected = sorted(index_sample_train_selected)
        train_filenames_input = [
            train_filenames_input[x] for x in index_sample_train_selected
        ]
        train_filenames_output = [
            train_filenames_output[x] for x in index_sample_train_selected
        ]
        print('randomly sampled {0} from {1} train samples'.format(
            len(train_filenames_input),
            len(filenames_input_train[:-FLAGS.sample_test])))

    # randomly sub-sample for test
    if FLAGS.subsample_test > 0:
        index_sample_test_selected = random.sample(
            range(len(test_filenames_input)), FLAGS.subsample_test)
        print(len(test_filenames_input))
        print(FLAGS.subsample_test)
        if not FLAGS.permutation_test:
            index_sample_test_selected = sorted(index_sample_test_selected)
        test_filenames_input = [
            test_filenames_input[x] for x in index_sample_test_selected
        ]
        test_filenames_output = [
            test_filenames_output[x] for x in index_sample_test_selected
        ]
        print('randomly sampled {0} from {1} test samples'.format(
            len(test_filenames_input),
            len(test_filenames_input[:-FLAGS.sample_test])))

    #print('test_filenames_input',test_filenames_input)

    # get undersample mask
    from scipy import io as sio
    try:
        content_mask = sio.loadmat(FLAGS.sampling_pattern)
        key_mask = [x for x in content_mask.keys() if not x.startswith('_')]
        mask = content_mask[key_mask[0]]
    except:
        mask = None

    print(len(train_filenames_input))
    print(len(train_filenames_output))
    print(len(test_filenames_input))
    print(len(test_filenames_output))

    # Setup async input queues
    train_features, train_labels, train_masks = srez_input.setup_inputs_one_sources(
        sess,
        train_filenames_input,
        train_filenames_output,
        image_size=image_size,
        # undersampling
        axis_undersample=FLAGS.axis_undersample,
        r_factor=FLAGS.R_factor,
        r_alpha=FLAGS.R_alpha,
        r_seed=FLAGS.R_seed,
        sampling_mask=mask)
    test_features, test_labels, test_masks = srez_input.setup_inputs_one_sources(
        sess,
        test_filenames_input,
        test_filenames_output,
        image_size=image_size,
        # undersampling
        axis_undersample=FLAGS.axis_undersample,
        r_factor=FLAGS.R_factor,
        r_alpha=FLAGS.R_alpha,
        r_seed=FLAGS.R_seed,
        sampling_mask=mask)

    print('features_size', train_features.get_shape())
    print('labels_size', train_labels.get_shape())
    print('masks_size', train_masks.get_shape())

    # sample train and test
    num_sample_train = len(train_filenames_input)
    num_sample_test = len(test_filenames_input)
    print('train on {0} samples and test on {1} samples'.format(
        num_sample_train, num_sample_test))

    # Add some noise during training (think denoising autoencoders)
    noise_level = .00
    noisy_train_features = train_features + \
                           tf.random_normal(train_features.get_shape(), stddev=noise_level)

    # Create and initialize model
    [gene_minput, label_minput, gene_moutput, gene_moutput_list, \
     gene_output, gene_output_list, gene_var_list, gene_layers_list, gene_mlayers_list, gene_mask_list, gene_mask_list_0, \
     disc_real_output, disc_fake_output, disc_var_list, train_phase,disc_layers, eta, nmse, kappa] = \
            srez_model.create_model(sess, noisy_train_features, train_labels, train_masks, architecture=FLAGS.architecture)

    #train_phase = tf.placeholder(tf.bool, [])

    gene_loss, gene_dc_loss, gene_ls_loss, gene_mse_loss, list_gene_losses, gene_mse_factor = srez_model.create_generator_loss(
        disc_fake_output, gene_output, gene_output_list, train_features,
        train_labels, train_masks)
    disc_real_loss, disc_fake_loss = \
                     srez_model.create_discriminator_loss(disc_real_output, disc_fake_output)
    disc_loss = tf.add(disc_real_loss, disc_fake_loss, name='disc_loss')

    (global_step, learning_rate, gene_minimize, disc_minimize) = \
            srez_model.create_optimizers(gene_loss, gene_var_list,
                                         disc_loss, disc_var_list)

    # tensorboard
    summary_op = tf.summary.merge_all()

    #restore variables from checkpoint
    filename = 'checkpoint_new.txt'
    filename = os.path.join(FLAGS.checkpoint_dir, filename)
    metafile = filename + '.meta'
    """
    if tf.gfile.Exists(metafile):
        saver = tf.train.Saver()
        print("Loading checkpoint from file `%s'" % (filename,))
        saver.restore(sess, filename)
    else:
        print("No checkpoint `%s', train from scratch" % (filename,))
        sess.run(tf.global_variables_initializer())
"""

    print("No checkpoint `%s', train from scratch" % (filename, ))
    print(
        np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ]))
    sess.run(tf.global_variables_initializer())

    # Train model
    train_data = TrainData(locals())
    srez_train.train_model(sess, train_data, num_sample_train, num_sample_test)