Пример #1
0
def begin_training(config):
    create_training_ops()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    ops = TrainOps()
    ops.populate(sess)
    train(sess, ops, config)
Пример #2
0
def load_session(config):
    sess = tf.Session()

    # load stored graph into current graph
    graph_filename = str(tf.train.latest_checkpoint(config.checkpoint_dir)) + '.meta'
    saver = tf.train.import_meta_graph(graph_filename)

    # restore variables into graph
    saver.restore(sess, tf.train.latest_checkpoint(config.checkpoint_dir))
        
    # load operations 
    ops = TrainOps()
    ops.populate(sess)
    return sess, ops
def train(sess, dataset, config):

    # prepare train ops
    ops = TrainOps(sess.graph)
    logger = Logger(config, sess, ops)
    epoch = sess.run(ops.epoch)

    # loop through epochs
    while epoch < config.num_epochs:
        iterator = dataset.make_one_shot_iterator()  
        batch_var = iterator.get_next()   

        # loop through batches
        while True:
            try: 

                # get mini-batch 
                batch = sess.run(batch_var)
                x_images, y_images = DataLoader().split_images(batch)
                
                # train
                feed_dict = {ops.x_images_holder: x_images, ops.y_images_holder: y_images}
                sess.run([ops.train_g, ops.train_d], feed_dict=feed_dict)
                logger.log(feed_dict)
                logger.checkpoint(feed_dict)
                sess.run(tf.assign_add(ops.global_step, 1))
                
            except tf.errors.OutOfRangeError:
                break

        # increment epoch
        sess.run(tf.assign_add(ops.epoch, 1))
        epoch = sess.run(ops.epoch)
Пример #4
0
def main(_):

    GPU_ID = FLAGS.gpu
    os.environ[
        "CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152 on stackoverflow
    os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)

    RUN = FLAGS.run
    EXP_DIR = FLAGS.exp_dir
    SEED = int(FLAGS.seed)
    ARCH = FLAGS.arch

    model = Model(arch=ARCH)
    train_ops = TrainOps(model, EXP_DIR, RUN, ARCH)
    train_ops.load_exp_config()

    if FLAGS.mode == 'train':
        print 'Training'
        train_ops.train(seed=SEED)

    if FLAGS.mode == 'test':
        print 'Testing'
        train_ops.test()
Пример #5
0
    def sample(self, input_images_dir, output_dir, sess):
        ops = TrainOps(sess.graph)

        # input_images dataset
        dataset = DataLoader().load_images(input_images_dir, batch_size=100)
        iterator = dataset.make_one_shot_iterator()
        batch = iterator.get_next()
        x_images = sess.run(batch)

        # correct num channels if needed
        channel_axis = 3
        if x_images.shape[channel_axis] < A.input_channels:
            x_images = np.repeat(x_images, A.input_channels, axis=channel_axis)

        # create feed-dict
        feed_dict = {ops.x_images_holder: x_images}

        # run generated_image
        generated_images = sess.run(ops.generated_images, feed_dict=feed_dict)
        generated_images = generated_images + 1.
        generated_images = generated_images * 128.

        # save resulting image to sample_dir
        self.save_images(generated_images, output_dir)
Пример #6
0
def main(_):

    #npr.seed(int(FLAGS.seed))

    GPU_ID = FLAGS.gpu
    os.environ[
        "CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152 on stackoverflow
    os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)

    EXP_DIR = FLAGS.exp_dir

    model = Model()
    tr_ops = TrainOps(model, EXP_DIR)

    if 'train' in FLAGS.mode:
        npr.seed(int(FLAGS.seed))

    if FLAGS.mode == 'train_ERM':
        print 'Training model with standard ERM'
        tr_ops.load_exp_config()
        tr_ops.train()

    if FLAGS.mode == 'train_RDA':
        print 'Training model with RDA'
        tr_ops.load_exp_config()
        tr_ops.train(random_transf=True)

    if FLAGS.mode == 'train_RSDA':
        print 'Training model with RSDA'
        tr_ops.load_exp_config()
        tr_ops.train_search(search_algorithm='random_search')

    if FLAGS.mode == 'train_ESDA':
        print 'Training model with ESDA'
        tr_ops.load_exp_config()
        tr_ops.train_search(search_algorithm='evolution_search')

    elif FLAGS.mode == 'test_all':
        print 'Testing all'
        tr_ops.load_exp_config()
        tr_ops.test_all()

    elif FLAGS.mode == 'test_RS':
        print 'Random search'
        tr_ops.load_exp_config()
        tr_ops.test_random_search(run=str(FLAGS.run),
                                  seed=int(FLAGS.seed),
                                  no_iters=int(FLAGS.search_no_iters),
                                  string_length=int(
                                      FLAGS.transf_string_length))

    elif FLAGS.mode == 'test_ES':
        print 'Evolution search'
        tr_ops.load_exp_config()
        tr_ops.test_evolution_search(run=str(FLAGS.run),
                                     seed=int(FLAGS.seed),
                                     no_iters=int(FLAGS.search_no_iters),
                                     string_length=int(
                                         FLAGS.transf_string_length),
                                     pop_size=int(FLAGS.pop_size),
                                     mutation_rate=float(FLAGS.mutation_rate))