Esempio n. 1
0
def _train():
    # Setup global tensorflow state
    sess, summary_writer = setup_tensorflow()

    # Prepare directories
    # all_filenames = 
    prepare_dirs(delete_train_dir=True, shuffle_filename=False)
    filenames_input = get_filenames(dir_file=FLAGS.dataset_input, shuffle_filename=False)
    # if not specify use the same as input
    if FLAGS.dataset_output == '':
        FLAGS.dataset_output = FLAGS.dataset_input
    filenames_output = get_filenames(dir_file=FLAGS.dataset_output, shuffle_filename=False)

    # Separate training and test sets
    # train_filenames = all_filenames[:-FLAGS.test_vectors]
    # test_filenames  = all_filenames[-FLAGS.test_vectors:]
    train_filenames_input = filenames_input[:-FLAGS.test_vectors]
    test_filenames_input  = filenames_input[-FLAGS.test_vectors:]
    train_filenames_output = filenames_output[:-FLAGS.test_vectors]
    test_filenames_output  = filenames_output[-FLAGS.test_vectors:]

    # TBD: Maybe download dataset here

    # Setup async input queues
    train_features, train_labels = srez_input.setup_inputs_one_sources(sess, train_filenames_input, train_filenames_output, 
                                                                        image_size=FLAGS.sample_size, axis_undersample=FLAGS.axis_undersample)
    test_features,  test_labels  = srez_input.setup_inputs_one_sources(sess, test_filenames_input, test_filenames_output,
                                                                        image_size=FLAGS.sample_size, axis_undersample=FLAGS.axis_undersample)
    
    # sample size
    num_sample_train = len(train_filenames_input)
    num_sample_test = len(test_filenames_input)
    print('train on {0} samples and test on {1} samples'.format(num_sample_train, num_sample_test))

    # Add some noise during training (think denoising autoencoders)
    noise_level = .00
    noisy_train_features = train_features + \
                           tf.random_normal(train_features.get_shape(), stddev=noise_level)

    # Create and initialize model
    [gene_minput, gene_moutput,
     gene_output, gene_var_list,
     disc_real_output, disc_fake_output, disc_var_list,
     gene_layers, gene_mlayers] = \
            srez_model.create_model(sess, noisy_train_features, train_labels)

    gene_loss = srez_model.create_generator_loss(disc_fake_output, gene_output, train_features, train_labels)
    disc_real_loss, disc_fake_loss = \
                     srez_model.create_discriminator_loss(disc_real_output, disc_fake_output)
    disc_loss = tf.add(disc_real_loss, disc_fake_loss, name='disc_loss')
    
    (global_step, learning_rate, gene_minimize, disc_minimize) = \
            srez_model.create_optimizers(gene_loss, gene_var_list,
                                         disc_loss, disc_var_list)

    # Train model
    train_data = TrainData(locals())
    srez_train.train_model(train_data, num_sample_train, num_sample_test)
Esempio n. 2
0
def _train():
    # Setup global tensorflow state
    sess = setup_tensorflow()

    # Prepare directories
    # all_filenames = prepare_dirs(delete_train_dir=True)
    all_filenames = prepare_dirs(delete_train_dir=False)

    # Separate training and test sets
    train_filenames = all_filenames[:-FLAGS.test_vectors]
    test_filenames = all_filenames[-FLAGS.test_vectors:]

    # TBD: Maybe download dataset here

    # Setup async input queues
    train_features, train_labels = srez_input.setup_inputs(
        sess, train_filenames)
    test_features, test_labels = srez_input.setup_inputs(sess, test_filenames)

    # Add some noise during training (think denoising autoencoders)
    noise_level = .03
    noisy_train_features = train_features + \
                           tf.random_normal(train_features.get_shape(), stddev=noise_level)

    # Create and initialize model
    [gene_minput, gene_moutput,
     gene_output, gene_var_list,
     disc_real_output, disc_fake_output, disc_var_list,
     dropout] = \
        srez_model.create_model(sess, noisy_train_features, train_labels)

    gene_loss = srez_model.create_generator_loss(disc_fake_output, gene_output,
                                                 train_features)
    disc_real_loss, disc_fake_loss = \
        srez_model.create_discriminator_loss(disc_real_output, disc_fake_output)
    disc_loss = tf.add(disc_real_loss, disc_fake_loss, name='disc_loss')

    (global_step, learning_rate, gene_minimize, disc_minimize) = \
        srez_model.create_optimizers(gene_loss, gene_var_list,
                                     disc_loss, disc_var_list)

    # Restore variables from checkpoint if EXISTS
    # if tf.gfile.IsDirectory(FLAGS.checkpoint_dir):
    #     filename = 'checkpoint_new.txt'
    #     filename = os.path.join(FLAGS.checkpoint_dir, filename)
    #     saver = tf.train.Saver()
    #     if tf.gfile.Exists(filename):
    #         saver.restore(tf.Session(), filename)
    #         print("Restored previous checkpoint. "
    #               "Warning, Batch number restarted.")

    # Train model
    train_data = TrainData(locals())
    srez_train.train_model(train_data)
Esempio n. 3
0
def _train():
    # Setup global tensorflow state
    sess, summary_writer = setup_tensorflow()

    # Prepare directories
    all_filenames = prepare_dirs(delete_train_dir=True)

    # Separate training and test sets
    rn.shuffle(all_filenames)
    train_filenames = all_filenames[:-FLAGS.test_vectors]
    test_filenames = all_filenames[-FLAGS.test_vectors:]

    # TBD: Maybe download dataset here

    # Setup async input queues
    train_features, train_labels = srez_input.setup_inputs(sess,
                                                           train_filenames,
                                                           image_size=32,
                                                           crop_size=128)
    test_features, test_labels = srez_input.setup_inputs(sess,
                                                         test_filenames,
                                                         image_size=32,
                                                         crop_size=128)

    # Add some noise during training (think denoising autoencoders)
    noise_level = .03
    noisy_train_features = train_features + \
                           tf.random_normal(train_features.get_shape(), stddev=noise_level)

    # Create and initialize model
    [gene_minput, gene_moutput,
     gene_output, gene_var_list,
     disc_real_output, disc_fake_output, disc_var_list] = \
            srez_model.create_model(sess, noisy_train_features, train_labels)

    gene_loss = srez_model.create_generator_loss(disc_fake_output, gene_output,
                                                 train_features)
    disc_real_loss, disc_fake_loss = \
                     srez_model.create_discriminator_loss(disc_real_output, disc_fake_output)
    disc_loss = tf.add(disc_real_loss, disc_fake_loss, name='disc_loss')

    (global_step, learning_rate, gene_minimize, disc_minimize) = \
            srez_model.create_optimizers(gene_loss, gene_var_list,
                                         disc_loss, disc_var_list)

    # Train model
    train_data = TrainData(locals())
    srez_train.train_model(train_data)
Esempio n. 4
0
def _train():
    # Prepare directories
    all_filenames = prepare_dirs(delete_train_dir=True)

    # Setup global tensorflow state
    sess, summary_writer = setup_tensorflow()

    #saver = tf.train.Saver()
    #filename = 'checkpoint'
    #filename = os.path.join(FLAGS.checkpoint_dir, filename)
    #saver.restore(sess,tf.train.latest_checkpoint("./checkpoint/"))
    #print("Model restored from file: %s" % FLAGS.checkpoint_dir)

    # Separate training and test sets
    train_filenames = all_filenames[:-FLAGS.test_vectors]
    test_filenames = all_filenames[-FLAGS.test_vectors:]

    # TBD: Maybe download dataset here

    # Setup async input queues
    train_features, train_labels = srez_input.setup_inputs(
        sess, train_filenames)
    test_features, test_labels = srez_input.setup_inputs(sess, test_filenames)

    # Add some noise during training (think denoising autoencoders)
    noise_level = .03
    noisy_train_features = train_features + \
                           tf.random_normal(train_features.get_shape(), stddev=noise_level)

    # Create and initialize model
    [gene_minput, gene_moutput,
     gene_output, gene_var_list,
     disc_real_output, disc_fake_output, disc_var_list] = \
            srez_model.create_model(sess, noisy_train_features, train_labels)

    gene_loss = srez_model.create_generator_loss(disc_fake_output, gene_output,
                                                 train_features)
    disc_real_loss, disc_fake_loss = \
                     srez_model.create_discriminator_loss(disc_real_output, disc_fake_output)
    disc_loss = tf.add(disc_real_loss, disc_fake_loss, name='disc_loss')

    (global_step, learning_rate, gene_minimize, disc_minimize) = \
            srez_model.create_optimizers(gene_loss, gene_var_list,
                                         disc_loss, disc_var_list)

    # Train model
    train_data = TrainData(locals())
    srez_train.train_model(train_data)
def _train():
    # Setup global tensorflow state
    sess, summary_writer = setup_tensorflow()

    # Prepare directories
    all_filenames = prepare_dirs(delete_train_dir=False)

    # Separate training and test sets

    if FLAGS.specific_test:
        train_filenames = all_filenames[:]
        test_filenames = prepare_test_dirs()[:]
    else:
        train_filenames = all_filenames[:-FLAGS.test_vectors]
        test_filenames = all_filenames[-FLAGS.test_vectors:]

    # Setup async input queues
    train_features, train_labels = srez_input.setup_inputs(
        sess, train_filenames)
    test_features, test_labels = srez_input.setup_inputs(sess, test_filenames)

    # Add some noise during training (think denoising autoencoders)
    noise_level = .03
    noisy_train_features = train_features + \
                           tf.random_normal(train_features.get_shape(), stddev=noise_level)

    # Create and initialize model
    [gene_minput, gene_moutput,
     gene_output, gene_var_list,
     disc_real_output, disc_fake_output, disc_var_list] = \
            srez_model.create_model(sess, noisy_train_features, train_labels)

    gene_loss, gene_l1_loss, gene_ce_loss = srez_model.create_generator_loss(
        disc_fake_output, gene_output, train_features, train_labels)
    disc_real_loss, disc_fake_loss = \
                     srez_model.create_discriminator_loss(disc_real_output, disc_fake_output)
    disc_loss = tf.add(2 * FLAGS.disc_real_factor * disc_real_loss,
                       2 * (1 - FLAGS.disc_real_factor) * disc_fake_loss,
                       name='disc_loss')

    (global_step, learning_rate, gene_minimize, disc_minimize) = \
            srez_model.create_optimizers(gene_loss, gene_var_list,
                                         disc_loss, disc_var_list)
    # Train model
    train_data = TrainData(locals())
    srez_train.train_model(train_data)
Esempio n. 6
0
def _train():
    # Setup global tensorflow state
    sess, summary_writer = setup_tensorflow()

    # Prepare directories
    all_filenames = prepare_dirs(delete_train_dir=True)

    # Separate training and test sets
    train_filenames = all_filenames[:-FLAGS.test_vectors]
    test_filenames  = all_filenames[-FLAGS.test_vectors:]

    # TBD: Maybe download dataset here

    # Setup async input queues
    train_features, train_labels = srez_input.setup_inputs(sess, train_filenames)
    test_features,  test_labels  = srez_input.setup_inputs(sess, test_filenames)

    # Add some noise during training (think denoising autoencoders)
    noise_level = .03
    noisy_train_features = train_features + \
                           tf.random_normal(train_features.get_shape(), stddev=noise_level)

    # Create and initialize model
    [gene_minput, gene_moutput,
     gene_output, gene_var_list,
     disc_real_output, disc_fake_output, disc_var_list] = \
            srez_model.create_model(sess, noisy_train_features, train_labels)

    gene_loss = srez_model.create_generator_loss(disc_fake_output, gene_output, train_features)
    disc_real_loss, disc_fake_loss = \
                     srez_model.create_discriminator_loss(disc_real_output, disc_fake_output)
    disc_loss = tf.add(disc_real_loss, disc_fake_loss, name='disc_loss')
    
    (global_step, learning_rate, gene_minimize, disc_minimize) = \
            srez_model.create_optimizers(gene_loss, gene_var_list,
                                         disc_loss, disc_var_list)

    # Train model
    train_data = TrainData(locals())
    srez_train.train_model(train_data)
Esempio n. 7
0
def _train():
    # Setup global tensorflow state
    sess, summary_writer = setup_tensorflow()

    # Prepare directories
    all_filenames = prepare_dirs(delete_train_dir=True)

    # Separate training and test sets
    # train_filenames = all_filenames[:-FLAGS.test_vectors]
    # test_filenames  = all_filenames[-FLAGS.test_vectors:]


    # We chose a pre-determined set of faces for the convenience of comparing results across models
    determined_test = [73883-1, 110251-1, 36510-1, 132301-1, 57264-1, 152931-1, 93861-1,
    124938-1, 79512-1, 106152-1, 127384-1, 134028-1, 67874-1,
    10613-1, 198694-1, 100990-1]
    all_filenames = np.array(all_filenames)
    train_filenames = list(np.delete(all_filenames, determined_test))

#     test_filenames = list(all_filenames[determined_test])

    # Setup async input queues
    train_features, train_labels = srez_input.setup_inputs(sess, train_filenames)
    
    # test_features,  test_labels  = srez_input.setup_inputs(sess, test_filenames)
    
    # Test sets are stored in 'testset_label.npy'
    test_labels = np.load('testset_label.npy')
    test_labels = tf.convert_to_tensor(test_labels, dtype = tf.float32)

    if FLAGS.input == 'scaled':
        test_features = tf.image.resize_area(test_labels, [16, 16])
    elif FLAGS.input == 'noise':
        test_features = tf.random_uniform(shape=[16, FLAGS.noise_dimension, FLAGS.noise_dimension, 3],minval= -1., maxval=1.)

    # Add some noise during training (think denoising autoencoders)
    noise_level = FLAGS.train_noise
    noisy_train_features = train_features + \
                           tf.random_normal(train_features.get_shape(), stddev=noise_level)

    [gene_minput, gene_moutput, gene_output, gene_var_list,
     disc_real_output, disc_fake_output, gradients, disc_var_list] = \
            srez_model.create_model(sess, noisy_train_features, train_labels)
  
    # >>> add summary scalars for test set
    max_samples = 10 # output 10 test images
    gene_output_clipped = tf.maximum(tf.minimum(gene_moutput, 1.0), 0.)
    
    # Calculate the L1 error between output samples and labels as a objective measure of image quality
    if FLAGS.input != 'noise':
      l1_quality  = tf.reduce_sum(tf.abs(gene_output_clipped - test_labels), [1,2,3])
      l1_quality = tf.reduce_mean(l1_quality[:max_samples])
      mse_quality  = tf.reduce_sum(tf.square(gene_output_clipped - test_labels), [1,2,3])
      mse_quality = tf.reduce_mean(mse_quality[:max_samples])
      tf.summary.scalar('l1_quality', l1_quality, collections=['test_scalars'])
      tf.summary.scalar('mse_quality', mse_quality, collections=['test_scalars'])


    gene_loss = srez_model.create_generator_loss(disc_fake_output, gene_output, train_features)
    disc_real_loss, disc_fake_loss = \
                     srez_model.create_discriminator_loss(disc_real_output, disc_fake_output)

    # Different training objectives
    if FLAGS.loss_func == 'dcgan':
        # for DCGAN
        disc_loss = tf.add(disc_real_loss, disc_fake_loss, name='disc_loss')
    elif FLAGS.loss_func == 'wgan':
        # for WGAN
        disc_loss = tf.subtract(disc_real_loss, disc_fake_loss, name='disc_loss')
    elif FLAGS.loss_func == 'wgangp':
        # for WGANGP
        disc_loss = tf.subtract(disc_real_loss, disc_fake_loss)
        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
        gradient_penalty = tf.reduce_mean((slopes-1.)**2)
        disc_loss = tf.add(disc_loss, FLAGS.LAMBDA*gradient_penalty, name='disc_loss')

    (global_step, learning_rate, gene_minimize, disc_minimize, d_clip) = \
            srez_model.create_optimizers(gene_loss, gene_var_list, disc_loss, disc_var_list)

    # For tensorboard
    tf.summary.scalar('generator_loss', gene_loss)
    tf.summary.scalar('discriminator_real_loss', disc_real_loss)
    tf.summary.scalar('discriminator_fake_loss', disc_fake_loss)
    tf.summary.scalar('discriminator_tot_loss', disc_loss)


    # Train model
    train_data = TrainData(locals())
    srez_train.train_model(train_data)
Esempio n. 8
0
def _train():
    # Setup global tensorflow state
    sess, _oldwriter = setup_tensorflow()

    # image_size
    if FLAGS.sample_size_y > 0:
        image_size = [FLAGS.sample_size, FLAGS.sample_size_y]
    else:
        image_size = [FLAGS.sample_size, FLAGS.sample_size]

    # Prepare train and test directories (SEPARATE FOLDER)
    prepare_dirs(delete_train_dir=True, shuffle_filename=False)
    filenames_input_train = get_filenames(dir_file=FLAGS.dataset_train,
                                          shuffle_filename=True)
    filenames_output_train = get_filenames(dir_file=FLAGS.dataset_train,
                                           shuffle_filename=True)
    filenames_input_test = get_filenames(dir_file=FLAGS.dataset_test,
                                         shuffle_filename=False)
    filenames_output_test = get_filenames(dir_file=FLAGS.dataset_test,
                                          shuffle_filename=False)

    ## Prepare directories (SAME FOLDER)
    #prepare_dirs(delete_train_dir=True, shuffle_filename=False)
    #filenames_input = get_filenames(dir_file=FLAGS.dataset_input, shuffle_filename=False)
    ## if not specify use the same as input
    #if FLAGS.dataset_output == '':
    #FLAGS.dataset_output = FLAGS.dataset_input
    #filenames_output = get_filenames(dir_file=FLAGS.dataset_output, shuffle_filename=False)

    # check input and output sample number matches (SEPARATE FOLDER)
    assert (len(filenames_input_train) == len(filenames_output_train))
    num_filename_train = len(filenames_input_train)
    assert (len(filenames_input_test) == len(filenames_output_test))
    num_filename_test = len(filenames_input_test)

    #print(num_filename_train)
    #print(num_filename_test)
    #print(filenames_output_test)

    # check input and output sample number matches (SAME FOLDER)
    #assert(len(filenames_input)==len(filenames_output))
    #num_filename_all = len(filenames_input)

    # Permutate train and test split (SEPARATE FOLDERS)
    if FLAGS.permutation_split:
        index_permutation_split = random.sample(num_filename_train,
                                                num_filename_train)
        filenames_input_train = [
            filenames_input_train[x] for x in index_permutation_split
        ]
        filenames_output_train = [
            filenames_output_train[x] for x in index_permutation_split
        ]
        #print(np.shape(filenames_input_train))

    if FLAGS.permutation_split:
        index_permutation_split = random.sample(num_filename_test,
                                                num_filename_test)
        filenames_input_test = [
            filenames_input_test[x] for x in index_permutation_split
        ]
        filenames_output_test = [
            filenames_output_test[x] for x in index_permutation_split
        ]
    #print('filenames_input[:20]',filenames_input[:20])

    # Permutate test split (SAME FOLDERS)
    #if FLAGS.permutation_split:
    #index_permutation_split = random.sample(num_filename_test, num_filename_test)
    #filenames_input_test = [filenames_input_test[x] for x in index_permutation_split]
    #filenames_output_test = [filenames_output_test[x] for x in index_permutation_split]
    #print('filenames_input[:20]',filenames_input[:20])

    # Separate training and test sets (SEPARATE FOLDERS)
    train_filenames_input = filenames_input_train[:FLAGS.sample_train]
    train_filenames_output = filenames_output_train[:FLAGS.sample_train]

    test_filenames_input = filenames_input_test[:FLAGS.sample_test]
    test_filenames_output = filenames_output_test[:FLAGS.sample_test]
    #print('test_filenames_input', test_filenames_input)
    #print('train_filenames_input', train_filenames_input)

    # Separate training and test sets (SAME FOLDERS)
    #train_filenames_input = filenames_input[:-FLAGS.sample_test]
    #train_filenames_output = filenames_output[:-FLAGS.sample_test]
    #test_filenames_input  = filenames_input[-FLAGS.sample_test:]
    #test_filenames_output  = filenames_output[-FLAGS.sample_test:]
    #print('test_filenames_input[:20]',test_filenames_input[:20])

    # randomly subsample for train
    if FLAGS.subsample_train > 0:

        index_sample_train_selected = random.sample(
            range(len(train_filenames_input)), FLAGS.subsample_train)
        if not FLAGS.permutation_train:
            index_sample_train_selected = sorted(index_sample_train_selected)
        train_filenames_input = [
            train_filenames_input[x] for x in index_sample_train_selected
        ]
        train_filenames_output = [
            train_filenames_output[x] for x in index_sample_train_selected
        ]
        print('randomly sampled {0} from {1} train samples'.format(
            len(train_filenames_input),
            len(filenames_input_train[:-FLAGS.sample_test])))

    # randomly sub-sample for test
    if FLAGS.subsample_test > 0:
        index_sample_test_selected = random.sample(
            range(len(test_filenames_input)), FLAGS.subsample_test)
        print(len(test_filenames_input))
        print(FLAGS.subsample_test)
        if not FLAGS.permutation_test:
            index_sample_test_selected = sorted(index_sample_test_selected)
        test_filenames_input = [
            test_filenames_input[x] for x in index_sample_test_selected
        ]
        test_filenames_output = [
            test_filenames_output[x] for x in index_sample_test_selected
        ]
        print('randomly sampled {0} from {1} test samples'.format(
            len(test_filenames_input),
            len(test_filenames_input[:-FLAGS.sample_test])))

    #print('test_filenames_input',test_filenames_input)

    # get undersample mask
    from scipy import io as sio
    try:
        content_mask = sio.loadmat(FLAGS.sampling_pattern)
        key_mask = [x for x in content_mask.keys() if not x.startswith('_')]
        mask = content_mask[key_mask[0]]
    except:
        mask = None

    print(len(train_filenames_input))
    print(len(train_filenames_output))
    print(len(test_filenames_input))
    print(len(test_filenames_output))

    # Setup async input queues
    train_features, train_labels, train_masks = srez_input.setup_inputs_one_sources(
        sess,
        train_filenames_input,
        train_filenames_output,
        image_size=image_size,
        # undersampling
        axis_undersample=FLAGS.axis_undersample,
        r_factor=FLAGS.R_factor,
        r_alpha=FLAGS.R_alpha,
        r_seed=FLAGS.R_seed,
        sampling_mask=mask)
    test_features, test_labels, test_masks = srez_input.setup_inputs_one_sources(
        sess,
        test_filenames_input,
        test_filenames_output,
        image_size=image_size,
        # undersampling
        axis_undersample=FLAGS.axis_undersample,
        r_factor=FLAGS.R_factor,
        r_alpha=FLAGS.R_alpha,
        r_seed=FLAGS.R_seed,
        sampling_mask=mask)

    print('features_size', train_features.get_shape())
    print('labels_size', train_labels.get_shape())
    print('masks_size', train_masks.get_shape())

    # sample train and test
    num_sample_train = len(train_filenames_input)
    num_sample_test = len(test_filenames_input)
    print('train on {0} samples and test on {1} samples'.format(
        num_sample_train, num_sample_test))

    # Add some noise during training (think denoising autoencoders)
    noise_level = .00
    noisy_train_features = train_features + \
                           tf.random_normal(train_features.get_shape(), stddev=noise_level)

    # Create and initialize model
    [gene_minput, label_minput, gene_moutput, gene_moutput_list, \
     gene_output, gene_output_list, gene_var_list, gene_layers_list, gene_mlayers_list, gene_mask_list, gene_mask_list_0, \
     disc_real_output, disc_fake_output, disc_var_list, train_phase,disc_layers, eta, nmse, kappa] = \
            srez_model.create_model(sess, noisy_train_features, train_labels, train_masks, architecture=FLAGS.architecture)

    #train_phase = tf.placeholder(tf.bool, [])

    gene_loss, gene_dc_loss, gene_ls_loss, gene_mse_loss, list_gene_losses, gene_mse_factor = srez_model.create_generator_loss(
        disc_fake_output, gene_output, gene_output_list, train_features,
        train_labels, train_masks)
    disc_real_loss, disc_fake_loss = \
                     srez_model.create_discriminator_loss(disc_real_output, disc_fake_output)
    disc_loss = tf.add(disc_real_loss, disc_fake_loss, name='disc_loss')

    (global_step, learning_rate, gene_minimize, disc_minimize) = \
            srez_model.create_optimizers(gene_loss, gene_var_list,
                                         disc_loss, disc_var_list)

    # tensorboard
    summary_op = tf.summary.merge_all()

    #restore variables from checkpoint
    filename = 'checkpoint_new.txt'
    filename = os.path.join(FLAGS.checkpoint_dir, filename)
    metafile = filename + '.meta'
    """
    if tf.gfile.Exists(metafile):
        saver = tf.train.Saver()
        print("Loading checkpoint from file `%s'" % (filename,))
        saver.restore(sess, filename)
    else:
        print("No checkpoint `%s', train from scratch" % (filename,))
        sess.run(tf.global_variables_initializer())
"""

    print("No checkpoint `%s', train from scratch" % (filename, ))
    print(
        np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ]))
    sess.run(tf.global_variables_initializer())

    # Train model
    train_data = TrainData(locals())
    srez_train.train_model(sess, train_data, num_sample_train, num_sample_test)
Esempio n. 9
0
def _train():
    time_start = time.strftime("%Y-%m-%d-%H-%M-%S")
    print("START. Time is {}".format(time_start))

    # Setup global tensorflow state
    sess, summary_writer = setup_tensorflow()

    # image_size
    if FLAGS.sample_size_y > 0:
        image_size = [FLAGS.sample_size, FLAGS.sample_size_y]
    else:
        image_size = [FLAGS.sample_size, FLAGS.sample_size]

    # Prepare train and test directories (SEPARATE FOLDER)
    prepare_dirs(delete_train_dir=True, shuffle_filename=False)
    if FLAGS.cv_index >= 0:
        # Cross-validation
        filenames_input_train = []
        filenames_output_train = []
        for i in range(FLAGS.cv_groups):
            if i == FLAGS.cv_index:
                continue
            train_dir = os.path.join(FLAGS.dataset, str(i))
            filenames = get_filenames(dir_file=train_dir,
                                      shuffle_filename=True)
            filenames_input_train.extend(filenames)
            filenames_output_train.extend(filenames)
        test_dir = os.path.join(FLAGS.dataset, str(FLAGS.cv_index))
        filenames_input_test = get_filenames(dir_file=test_dir,
                                             shuffle_filename=True)
        filenames_output_test = get_filenames(dir_file=test_dir,
                                              shuffle_filename=True)
    else:
        filenames_input_train = get_filenames(dir_file=FLAGS.dataset_train,
                                              shuffle_filename=True)
        filenames_output_train = get_filenames(dir_file=FLAGS.dataset_train,
                                               shuffle_filename=True)
        filenames_input_test = get_filenames(dir_file=FLAGS.dataset_test,
                                             shuffle_filename=False)
        filenames_output_test = get_filenames(dir_file=FLAGS.dataset_test,
                                              shuffle_filename=False)

    # Record parameters
    parameters = save_parameters(time_start=time_start)

    ## Prepare directories (SAME FOLDER)
    #prepare_dirs(delete_train_dir=True, shuffle_filename=False)
    #filenames_input = get_filenames(dir_file=FLAGS.dataset_input, shuffle_filename=False)
    ## if not specify use the same as input
    #if FLAGS.dataset_output == '':
    #FLAGS.dataset_output = FLAGS.dataset_input
    #filenames_output = get_filenames(dir_file=FLAGS.dataset_output, shuffle_filename=False)

    # check input and output sample number matches (SEPARATE FOLDER)
    assert (len(filenames_input_train) == len(filenames_output_train))
    num_filename_train = len(filenames_input_train)
    assert (len(filenames_input_test) == len(filenames_output_test))
    num_filename_test = len(filenames_input_test)

    print(num_filename_train)
    print(num_filename_test)

    # check input and output sample number matches (SAME FOLDER)
    #assert(len(filenames_input)==len(filenames_output))
    #num_filename_all = len(filenames_input)

    # Permutate train and test split (SEPARATE FOLDERS)
    if FLAGS.permutation_split:
        index_permutation_split = random.sample(num_filename_train,
                                                num_filename_train)
        filenames_input_train = [
            filenames_input_train[x] for x in index_permutation_split
        ]
        filenames_output_train = [
            filenames_output_train[x] for x in index_permutation_split
        ]
        #print(np.shape(filenames_input_train))

    if FLAGS.permutation_split:
        index_permutation_split = random.sample(num_filename_test,
                                                num_filename_test)
        filenames_input_test = [
            filenames_input_test[x] for x in index_permutation_split
        ]
        filenames_output_test = [
            filenames_output_test[x] for x in index_permutation_split
        ]
    #print('filenames_input[:20]',filenames_input[:20])

    # Permutate test split (SAME FOLDERS)
    #if FLAGS.permutation_split:
    #index_permutation_split = random.sample(num_filename_test, num_filename_test)
    #filenames_input_test = [filenames_input_test[x] for x in index_permutation_split]
    #filenames_output_test = [filenames_output_test[x] for x in index_permutation_split]
    #print('filenames_input[:20]',filenames_input[:20])

    # Separate training and test sets (SEPARATE FOLDERS)
    sample_train = len(filenames_input_train
                       ) if FLAGS.sample_train <= 0 else FLAGS.sample_train
    sample_test = len(
        filenames_input_test) if FLAGS.sample_test <= 0 else FLAGS.sample_test

    train_filenames_input = filenames_input_train[:sample_train]
    train_filenames_output = filenames_output_train[:sample_train]

    # TODO If separate folders, make the index `:sample_test`
    # Using index `-sample_test:` hacks it for a same-folder split.
    test_filenames_input = filenames_input_test[
        -sample_test:]  # filenames_input_test[:sample_test]
    test_filenames_output = filenames_output_test[
        -sample_test:]  # filenames_output_test[:sample_test]
    #print('test_filenames_input', test_filenames_input)
    #print('train_filenames_input', train_filenames_input)

    # Separate training and test sets (SAME FOLDERS)
    #train_filenames_input = filenames_input[:-FLAGS.sample_test]
    #train_filenames_output = filenames_output[:-FLAGS.sample_test]
    #test_filenames_input  = filenames_input[-FLAGS.sample_test:]
    #test_filenames_output  = filenames_output[-FLAGS.sample_test:]
    #print('test_filenames_input[:20]',test_filenames_input[:20])

    # randomly subsample for train
    if FLAGS.subsample_train > 0:
        index_sample_train_selected = random.sample(
            range(len(train_filenames_input)), FLAGS.subsample_train)
        if not FLAGS.permutation_train:
            index_sample_train_selected = sorted(index_sample_train_selected)
        train_filenames_input = [
            train_filenames_input[x] for x in index_sample_train_selected
        ]
        train_filenames_output = [
            train_filenames_output[x] for x in index_sample_train_selected
        ]
        print('randomly sampled {0} from {1} train samples'.format(
            len(train_filenames_input),
            len(train_filenames_input[:FLAGS.sample_train])))

    # randomly sub-sample for test
    if FLAGS.subsample_test > 0:
        index_sample_test_selected = random.sample(
            range(len(test_filenames_input)), FLAGS.subsample_test)
        if not FLAGS.permutation_test:
            index_sample_test_selected = sorted(index_sample_test_selected)
        test_filenames_input = [
            test_filenames_input[x] for x in index_sample_test_selected
        ]
        test_filenames_output = [
            test_filenames_output[x] for x in index_sample_test_selected
        ]
        #print('randomly sampled {0} from {1} test samples'.format(len(test_filenames_input), len(filenames_inp/.ut[:-FLAGS.sample_test])))

    #print('test_filenames_input',test_filenames_input)

    # get undersample mask
    from scipy import io as sio
    try:
        content_mask = sio.loadmat(FLAGS.sampling_pattern)
        key_mask = [x for x in content_mask.keys() if not x.startswith('_')]
        mask = content_mask[key_mask[0]]
    except:
        mask = None

    # Setup async input queues
    train_features, train_labels, train_masks = srez_input.setup_inputs_one_sources(
        sess,
        train_filenames_input,
        train_filenames_output,
        image_size=image_size,
        # undersampling
        axis_undersample=FLAGS.axis_undersample,
        r_factor=FLAGS.R_factor,
        r_alpha=FLAGS.R_alpha,
        r_seed=FLAGS.R_seed,
        sampling_mask=mask)
    test_features, test_labels, test_masks = srez_input.setup_inputs_one_sources(
        sess,
        test_filenames_input,
        test_filenames_output,
        image_size=image_size,
        # undersampling
        axis_undersample=FLAGS.axis_undersample,
        r_factor=FLAGS.R_factor,
        r_alpha=FLAGS.R_alpha,
        r_seed=FLAGS.R_seed,
        sampling_mask=mask)

    print('train_features_queue', train_features.get_shape())
    print('train_labels_queue', train_labels.get_shape())
    print('train_masks_queue', train_masks.get_shape())

    #train_masks = tf.cast(sess.run(train_masks), tf.float32)
    #test_masks = tf.cast(sess.run(test_masks), tf.float32)

    # sample train and test
    num_sample_train = len(train_filenames_input)
    num_sample_test = len(test_filenames_input)
    print('train on {0} samples and test on {1} samples'.format(
        num_sample_train, num_sample_test))

    # Add some noise during training (think denoising autoencoders)
    noise_level = .00
    noisy_train_features = train_features + \
                           tf.random_normal(train_features.get_shape(), stddev=noise_level)

    # Create and initialize model
    [gene_minput, gene_moutput, gene_moutput_complex, \
     gene_output, gene_output_complex, gene_var_list, gene_layers, gene_mlayers, \
     disc_real_output, disc_fake_output, disc_moutput, disc_var_list, disc_layers, disc_mlayers] = \
            srez_model.create_model(sess, noisy_train_features, train_labels, train_masks, architecture=FLAGS.architecture)

    gene_loss, gene_dc_loss, gene_ls_loss, list_gene_losses, gene_mse_factor = srez_model.create_generator_loss(
        disc_fake_output, gene_output, gene_output_complex, train_features,
        train_labels, train_masks)
    disc_real_loss, disc_fake_loss = \
                     srez_model.create_discriminator_loss(disc_real_output, disc_fake_output)
    disc_loss = tf.add(disc_real_loss, disc_fake_loss, name='disc_loss')

    # add gradient on disc loss
    disc_gradients = tf.gradients(
        disc_loss, [disc_fake_output, disc_real_output, gene_output])
    print('disc loss gradients:', [x.shape for x in disc_gradients])

    (global_step, learning_rate, gene_minimize, disc_minimize) = \
            srez_model.create_optimizers(gene_loss, gene_var_list,
                                         disc_loss, disc_var_list)

    summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

    # Train model
    train_data = TrainData(locals())
    srez_train.train_model(train_data, num_sample_train, num_sample_test)

    time_ended = time.strftime("%Y-%m-%d-%H-%M-%S")
    print("ENDED. Time is {}".format(time_ended))

    # Overwrite log file now that we are complete
    save_parameters(use_flags=False,
                    existing=parameters,
                    time_ended=time_ended)