Exemple #1
0
def run(mode, run_config):
    model = Model()
    estimator = tf.estimator.Estimator(
        model_fn=model.model_fn,
        model_dir=Config.train.model_dir,
        config=run_config)

    if mode == 'train':
        train_data = data_loader.get_tfrecord('train')
        val_data = data_loader.get_tfrecord('test')
        train_input_fn, train_input_hook = data_loader.get_dataset_batch(train_data, buffer_size=5000,
                                                                         batch_size=Config.train.batch_size,
                                                                         scope="val")
        val_input_fn, val_input_hook = data_loader.get_dataset_batch(val_data, batch_size=512,
                                                                     scope="val")

        while True:
            print('*' * 40)
            print("epoch", Config.train.epoch + 1, 'start')
            print('*' * 40)

            estimator.train(input_fn=train_input_fn, hooks=[train_input_hook])
            estimator.evaluate(input_fn=val_input_fn, hooks=[val_input_hook])

            Config.train.epoch += 1
            if Config.train.epoch == Config.train.max_epoch:
                break

    elif mode == 'eval':
        val_data = data_loader.get_tfrecord('test')

        val_input_fn, val_input_hook = data_loader.get_dataset_batch(val_data, batch_size=512,
                                                                     scope="val")

        estimator.evaluate(input_fn=val_input_fn, hooks=[val_input_hook])
Exemple #2
0
def run(mode, run_config):
    model = Model()
    estimator = tf.estimator.Estimator(model_fn=model.model_fn,
                                       model_dir=Config.train.model_dir,
                                       config=run_config)

    if mode == 'train':
        train_data = data_loader.get_tfrecord('train')
        val_data = data_loader.get_tfrecord('test')
        train_input_fn, train_input_hook = data_loader.get_dataset_batch(
            train_data,
            buffer_size=1000,
            batch_size=Config.train.batch_size,
            scope="val")

        val_input_fn, val_input_hook = data_loader.get_dataset_batch(
            val_data, batch_size=20, scope="val", shuffle=False)

        min_loss = 100
        min_loss_epoch = 0

        while True:
            print('*' * 40)
            print("epoch", Config.train.epoch + 1, 'start')
            print('*' * 40)

            estimator.train(input_fn=train_input_fn, hooks=[train_input_hook])
            eval_results = estimator.evaluate(input_fn=val_input_fn,
                                              hooks=[val_input_hook])

            Config.train.epoch += 1
            if Config.train.epoch == Config.train.max_epoch:
                break

            if eval_results['loss'] < min_loss:
                min_loss = eval_results['loss']
                min_loss_epoch = Config.train.epoch
            print('min loss:', min_loss, '  min_loss_epoch:', min_loss_epoch)
            if Config.train.switch_optimizer == 1 and (Config.train.epoch -
                                                       min_loss_epoch) >= 3:
                break
            if Config.train.switch_optimizer == 0 and (Config.train.epoch -
                                                       min_loss_epoch) >= 3:
                print('switch optimizer to SGD')
                Config.train.switch_optimizer = 1

    elif mode == 'eval':
        val_data = data_loader.get_tfrecord('test')
        val_input_fn, val_input_hook = data_loader.get_dataset_batch(
            val_data, batch_size=20, scope="val", shuffle=False)
        estimator.evaluate(input_fn=val_input_fn, hooks=[val_input_hook])
Exemple #3
0
def run(mode, run_config, params):
    model = Model()
    estimator = tf.estimator.Estimator(model_fn=model.model_fn,
                                       model_dir=Config.train.model_dir,
                                       params=params,
                                       config=run_config)

    if Config.train.debug:
        debug_hooks = tf_debug.LocalCLIDebugHook()
        hooks = [debug_hooks]
    else:
        hooks = []

    loss_hooks = tf.train.LoggingTensorHook(
        {
            'loss': 'loss/total_loss:0',
            'step': 'global_step:0'
        },
        every_n_iter=Config.train.check_hook_n_iter)

    train_data = data_loader.get_tfrecord(shuffle=True)

    train_input_fn, train_input_hook = data_loader.get_dataset_batch(
        train_data, batch_size=Config.model.batch_size, scope="train")

    hooks.extend([train_input_hook, loss_hooks])
    estimator.train(input_fn=train_input_fn,
                    hooks=hooks,
                    max_steps=Config.train.max_steps)
Exemple #4
0
def run(mode, run_config, params):
    model = Model()
    # ws = tf.estimator.WarmStartSettings(ckpt_to_initialize_from='logs/pretrained/vgg_16.ckpt',vars_to_warm_start='vgg_16.*')
    estimator = tf.estimator.Estimator(
        model_fn=model.model_fn,
        model_dir=Config.train.model_dir,
        params=params,
        # warm_start_from=ws,
        config=run_config)

    if Config.train.debug:
        debug_hooks = tf_debug.LocalCLIDebugHook()
        hooks = [debug_hooks]
    else:
        hooks = []

    loss_hooks = tf.train.LoggingTensorHook({'total_loss': 'loss/add_7',
                                             'content_loss': 'loss/mul_1',
                                             'style_loss': 'loss/mul_6:0',
                                             'step': 'global_step:0'}, every_n_iter=Config.train.check_hook_n_iter)

    if mode == 'train':
        train_data = data_loader.get_tfrecord(mode, shuffle=True)
        train_input_fn, train_input_hook = data_loader.get_dataset_batch(train_data, buffer_size=1000,
                                                                         batch_size=Config.model.batch_size,
                                                                         scope="train")
        hooks.extend([train_input_hook, loss_hooks])
        estimator.train(input_fn=train_input_fn, hooks=hooks, max_steps=Config.train.max_steps)

    else:
        raise ValueError('no %s mode' % (mode))
Exemple #5
0
def run(run_config):
    model = Model()
    estimator = tf.estimator.Estimator(
        model_fn=model.model_fn,
        model_dir=Config.train.model_dir,
        config=run_config)

    train_data = data_loader.get_tfrecord('train')
    train_input_fn, train_input_hook = data_loader.get_dataset_batch(train_data, buffer_size=5000,
                                                                     batch_size=Config.train.batch_size,
                                                                     scope="train")

    estimator.train(input_fn=train_input_fn, max_steps=Config.train.max_steps, hooks=[train_input_hook])
Exemple #6
0
def run(mode, run_config):
    model = Model()
    estimator = tf.estimator.Estimator(
        model_fn=model.model_fn,
        model_dir=Config.train.model_dir,
        config=run_config)

    if mode == 'train':
        train_data = data_loader.get_tfrecord('train')
        val_data = data_loader.get_tfrecord('test')
        train_input_fn, train_input_hook = data_loader.get_dataset_batch(train_data, buffer_size=5000,
                                                                         batch_size=Config.train.batch_size,
                                                                         scope="val")

        val_input_fn, val_input_hook = data_loader.get_dataset_batch(val_data, batch_size=128, scope="val")

        logginghook = tf.train.LoggingTensorHook({'arc_loss': "sparse_softmax_cross_entropy_loss/value:0",
                                                  'label_loss': "sparse_softmax_cross_entropy_loss_1/value:0",
                                                  'step': 'global_step:0'}, every_n_iter=100)

        while True:
            print('*' * 40)
            print("epoch", Config.train.epoch + 1, 'start')
            print('*' * 40)

            estimator.train(input_fn=train_input_fn, hooks=[logginghook, train_input_hook])
            estimator.evaluate(input_fn=val_input_fn, hooks=[val_input_hook, MSTHook()])

            Config.train.epoch += 1
            if Config.train.epoch == Config.train.max_epoch:
                break

    elif mode == 'eval':
        val_data = data_loader.get_tfrecord('test')
        val_input_fn, val_input_hook = data_loader.get_dataset_batch(val_data, batch_size=128, scope="val")
        estimator.evaluate(input_fn=val_input_fn, hooks=[val_input_hook, MSTHook()])
Exemple #7
0
def run(mode, run_config, params):
    model = Model()
    estimator = tf.estimator.Estimator(model_fn=model.model_fn,
                                       model_dir=Config.train.model_dir,
                                       params=params,
                                       config=run_config)

    if Config.train.debug:
        debug_hooks = tf_debug.LocalCLIDebugHook()
        hooks = [debug_hooks]
    else:
        hooks = []

    loss_hooks = tf.train.LoggingTensorHook(
        {
            'G_loss': 'loss/G_loss/add:0',
            'F_loss': 'loss/F_loss/add:0',
            'D_X_loss': 'loss/D_X_loss/Mean:0',
            'D_Y_loss': 'loss/D_Y_loss/Mean:0',
            'step': 'global_step:0'
        },
        every_n_iter=Config.train.check_hook_n_iter)

    if mode == 'train':
        train_dataA = data_loader.get_tfrecord('trainA')
        train_dataB = data_loader.get_tfrecord('trainB')
        train_input_fn, train_input_hook = data_loader.get_both_batch(
            train_dataA,
            train_dataB,
            buffer_size=1000,
            batch_size=Config.train.batch_size,
            scope="train")
        hooks.extend(train_input_hook + [loss_hooks])
        estimator.train(input_fn=train_input_fn,
                        hooks=hooks,
                        max_steps=Config.train.max_steps)
Exemple #8
0
def run(mode, run_config):
    model = Model()
    estimator = tf.estimator.Estimator(model_fn=model.model_fn,
                                       model_dir=Config.train.model_dir,
                                       config=run_config)

    logginghook = tf.train.LoggingTensorHook(
        {
            'for_loss': "Mean:0",
            'back_loss': "Mean_1:0",
            'step': 'global_step:0'
        },
        every_n_iter=100)

    if mode == 'train':
        for_train_data = data_loader.get_tfrecord('for-tfrecord', 'test')
        for_val_data = data_loader.get_tfrecord('for-tfrecord', 'test')
        back_train_data = data_loader.get_tfrecord('back-tfrecord', 'test')
        back_val_data = data_loader.get_tfrecord('back-tfrecord', 'test')

        train_input_fn, train_input_hook = data_loader.get_both_batch(
            for_train_data,
            back_train_data,
            buffer_size=5000,
            batch_size=Config.train.batch_size,
            scope='val')
        val_input_fn, val_input_hook = data_loader.get_both_batch(
            for_val_data, back_val_data, batch_size=20, scope="val")

        while True:
            print('*' * 40)
            print("epoch", Config.train.epoch + 1, 'start')
            print('*' * 40)

            estimator.train(input_fn=train_input_fn,
                            hooks=[logginghook] + train_input_hook)
            estimator.evaluate(input_fn=val_input_fn, hooks=val_input_hook)

            Config.train.epoch += 1
            if Config.train.epoch == Config.train.max_epoch:
                break

    elif mode == 'eval':
        for_val_data = data_loader.get_tfrecord('for-tfrecord', 'test')
        back_val_data = data_loader.get_tfrecord('back-tfrecord', 'test')
        val_input_fn, val_input_hook = data_loader.get_both_batch(
            for_val_data, back_val_data, batch_size=20, scope="val")
        estimator.evaluate(input_fn=val_input_fn, hooks=val_input_hook)
Exemple #9
0
def run(mode, run_config):
    model = Model()

    estimator = tfgan.estimator.GANEstimator(
        generator_fn=model.generator_fn,
        discriminator_fn=model.discriminator_fn,
        generator_loss_fn=model.generator_loss_fn,
        discriminator_loss_fn=model.discriminator_loss_fn,
        generator_optimizer=model.generator_optimizer,
        discriminator_optimizer=model.discriminator_optimizer,
        get_hooks_fn=tfgan.get_sequential_train_hooks(
            tfgan.GANTrainSteps(Config.train.G_step, 1)),
        config=run_config)

    if Config.train.debug:
        debug_hooks = tf_debug.LocalCLIDebugHook()
        hooks = [debug_hooks]
    else:
        hooks = []

    loss_hooks = tf.train.LoggingTensorHook(
        {
            'G_loss': 'GANHead/G_loss:0',
            'D_loss': 'GANHead/D_loss:0',
            'D_real_loss': 'GANHead/D_real_loss:0',
            'D_gen_loss': 'GANHead/D_gen_loss:0',
            'step': 'global_step:0'
        },
        every_n_iter=Config.train.check_hook_n_iter)

    if mode == 'train':
        train_data = data_loader.get_tfrecord('train')
        train_input_fn, train_input_hook = data_loader.get_dataset_batch(
            train_data,
            buffer_size=2000,
            batch_size=Config.train.batch_size,
            scope="train")
        hooks.extend([train_input_hook, loss_hooks])
        estimator.train(input_fn=train_input_fn, hooks=hooks)
Exemple #10
0
def run(mode, run_config):
    model = Model()
    estimator = tf.estimator.Estimator(model_fn=model.model_fn,
                                       model_dir=Config.train.model_dir,
                                       config=run_config)

    if mode == 'train':
        train_data = data_loader.get_tfrecord('train')
        val_data1 = data_loader.get_tfrecord('pku_test')
        val_data2 = data_loader.get_tfrecord('msr_test')
        val_data3 = data_loader.get_tfrecord('ctb_test')
        train_input_fn, train_input_hook = data_loader.get_dataset_batch(
            train_data,
            buffer_size=5000,
            batch_size=Config.train.batch_size,
            scope="val")
        val_input_fn1, val_input_hook1 = data_loader.get_dataset_batch(
            val_data1, batch_size=1024, scope="val")
        val_input_fn2, val_input_hook2 = data_loader.get_dataset_batch(
            val_data2, batch_size=1024, scope="val")
        val_input_fn3, val_input_hook3 = data_loader.get_dataset_batch(
            val_data3, batch_size=1024, scope="val")

        while True:
            print('*' * 40)
            print("epoch", Config.train.epoch + 1, 'start')
            print('*' * 40)

            estimator.train(input_fn=train_input_fn, hooks=[train_input_hook])
            estimator.evaluate(input_fn=val_input_fn1,
                               hooks=[val_input_hook1,
                                      PRFScoreHook('pku')],
                               name='pku')
            estimator.evaluate(input_fn=val_input_fn2,
                               hooks=[val_input_hook2,
                                      PRFScoreHook('msr')],
                               name='msr')
            estimator.evaluate(input_fn=val_input_fn3,
                               hooks=[val_input_hook3,
                                      PRFScoreHook('ctb')],
                               name='ctb')

            Config.train.epoch += 1
            if Config.train.epoch == Config.train.max_epoch:
                break

    elif mode == 'eval':
        val_data1 = data_loader.get_tfrecord('pku_test')
        val_data2 = data_loader.get_tfrecord('msr_test')
        val_data3 = data_loader.get_tfrecord('ctb_test')

        val_input_fn1, val_input_hook1 = data_loader.get_dataset_batch(
            val_data1, batch_size=1024, scope="val")
        val_input_fn2, val_input_hook2 = data_loader.get_dataset_batch(
            val_data2, batch_size=1024, scope="val")
        val_input_fn3, val_input_hook3 = data_loader.get_dataset_batch(
            val_data3, batch_size=1024, scope="val")

        estimator.evaluate(input_fn=val_input_fn1,
                           hooks=[val_input_hook1,
                                  PRFScoreHook('pku')],
                           name='pku')
        estimator.evaluate(input_fn=val_input_fn2,
                           hooks=[val_input_hook2,
                                  PRFScoreHook('msr')],
                           name='msr')
        estimator.evaluate(input_fn=val_input_fn3,
                           hooks=[val_input_hook3,
                                  PRFScoreHook('ctb')],
                           name='ctb')