示例#1
0
def main(_):
    model_class = models.get_model_class(FLAGS.model)

    # Look up the model configuration.
    assert (FLAGS.config_name is None) != (FLAGS.config_json is None), (
        "Exactly one of --config_name or --config_json is required.")
    config = (models.get_model_config(FLAGS.model, FLAGS.config_name) if
              FLAGS.config_name else config_util.parse_json(FLAGS.config_json))

    config = configdict.ConfigDict(config)
    config_util.log_and_save_config(config, FLAGS.model_dir)

    # Create the estimator.
    run_config = tf.estimator.RunConfig(keep_checkpoint_max=1)
    estimator = estimator_util.create_estimator(model_class, config.hparams,
                                                run_config, FLAGS.model_dir)

    # Create an input function that reads the training dataset. We iterate through
    # the dataset once at a time if we are alternating with evaluation, otherwise
    # we iterate infinitely.
    train_input_fn = estimator_util.create_input_fn(
        file_pattern=FLAGS.train_files,
        input_config=config.inputs,
        mode=tf.estimator.ModeKeys.TRAIN,
        shuffle_values_buffer=FLAGS.shuffle_buffer_size,
        repeat=1 if FLAGS.eval_files else None)

    if not FLAGS.eval_files:
        estimator.train(train_input_fn, max_steps=FLAGS.train_steps)
    else:
        eval_input_fn = estimator_util.create_input_fn(
            file_pattern=FLAGS.eval_files,
            input_config=config.inputs,
            mode=tf.estimator.ModeKeys.EVAL)
        eval_args = {
            "val": (eval_input_fn, None)  # eval_name: (input_fn, eval_steps)
        }

        for _ in estimator_runner.continuous_train_and_eval(
                estimator=estimator,
                train_input_fn=train_input_fn,
                eval_args=eval_args,
                train_steps=FLAGS.train_steps):
            # continuous_train_and_eval() yields evaluation metrics after each
            # training epoch. We don't do anything here.
            pass
示例#2
0
def train(model, config):
    if FLAGS.model_dir:
        dir_name = "{}/{}_{}_{}".format(
            FLAGS.model_dir, FLAGS.model, FLAGS.config_name,
            datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))
        config_util.log_and_save_config(config, dir_name)

    ds = input_ds.build_dataset(
        file_pattern=FLAGS.train_files,
        input_config=config.inputs,
        batch_size=config.hparams.batch_size,
        include_labels=True,
        shuffle_filenames=True,
        shuffle_values_buffer=FLAGS.shuffle_buffer_size,
        repeat=None)

    if FLAGS.eval_files:
        eval_ds = input_ds.build_dataset(file_pattern=FLAGS.eval_files,
                                         input_config=config.inputs,
                                         batch_size=config.hparams.batch_size,
                                         include_labels=True,
                                         shuffle_filenames=False,
                                         repeat=1)
    else:
        eval_ds = None

    assert config.hparams.optimizer == 'adam'
    lr = config.hparams.learning_rate
    beta_1 = 1.0 - config.hparams.one_minus_adam_beta_1
    beta_2 = 1.0 - config.hparams.one_minus_adam_beta_2
    epsilon = config.hparams.adam_epsilon
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr,
                                         beta_1=beta_1,
                                         beta_2=beta_2,
                                         epsilon=epsilon)

    loss = tf.keras.losses.BinaryCrossentropy()

    metrics = [
        tf.keras.metrics.Recall(
            name='r',
            class_id=config.inputs.primary_class,
            thresholds=config.hparams.prediction_threshold,
        ),
        tf.keras.metrics.Precision(
            name='p',
            class_id=config.inputs.primary_class,
            thresholds=config.hparams.prediction_threshold,
        ),
    ]

    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

    if getattr(config.hparams, 'decreasing_lr', False):

        def scheduler(epoch, lr):
            if epoch > 1:
                return lr / 10
            else:
                return lr

        callbacks = [tf.keras.callbacks.LearningRateScheduler(scheduler)]
    else:
        callbacks = []

    history = model.fit(ds,
                        epochs=FLAGS.train_epochs,
                        steps_per_epoch=FLAGS.train_steps,
                        validation_data=eval_ds)

    if FLAGS.model_dir:
        model.save(dir_name)

    return history