Exemplo n.º 1
0
def estimator_train_and_save(estimator, model_params, save, train_dataset_fn,
                             val_dataset_fn, log_every_n_iter, train_max_steps,
                             eval_start_delay_secs, eval_throttle_secs,
                             save_checkpoints_steps, metric_names,
                             load_pretrained_model, model_meta):
    print("Start training using estimator model...")
    model_params["model_dir"] = save

    warm_start_from = save if load_pretrained_model else None
    if warm_start_from:
        load_pretrained_model_estimator(estimator, model_params)
    classifier = init_model(estimator, model_params)

    # do not add default Accuracy metric when using estimator to train, it will
    # fail when the estimator is a regressor, and estimator seems automatically
    # add some metrics. Only add additional metrics when user specified with
    # `WITH`.
    if tf_is_version2() and metric_names != ["Accuracy"]:
        classifier = tf.estimator.add_metrics(classifier,
                                              get_tf_metrics(metric_names))

    estimator_train_compiled(classifier, train_dataset_fn, val_dataset_fn,
                             log_every_n_iter, train_max_steps,
                             eval_start_delay_secs, eval_throttle_secs)
    estimator_save(classifier, save, model_params, model_meta)
Exemplo n.º 2
0
def estimator_train_and_save_legacy(estimator, model_params, save, FLAGS,
                                    train_dataset_fn, val_dataset_fn,
                                    train_max_steps, eval_start_delay_secs,
                                    eval_throttle_secs, save_checkpoints_steps,
                                    metric_names, load_pretrained_model,
                                    model_meta):
    print("Start training using estimator model...")
    is_distributed = False
    if len(FLAGS.worker_hosts.split(",")) > 1:
        is_distributed = True
    model_params["config"] = make_estimator_distributed_runconfig(
        FLAGS,
        estimator,
        is_distributed,
        save_checkpoints_steps=save_checkpoints_steps)
    ckpt_dir = FLAGS.checkpointDir if FLAGS.checkpointDir else save
    print("Using checkpoint path: %s" % ckpt_dir)
    model_params["model_dir"] = ckpt_dir
    model_params["config"] = tf.estimator.RunConfig(
        tf_random_seed=get_tf_random_seed(),
        save_checkpoints_steps=save_checkpoints_steps)

    warm_start_from = save if load_pretrained_model else None
    if warm_start_from:
        load_pretrained_model_estimator(estimator, model_params)
    classifier = init_model(estimator, model_params)

    # do not add default Accuracy metric when using estimator to train, it will
    # fail when the estimator is a regressor, and estimator seems automatically
    # add some metrics. Only add additional metrics when user specified with
    # `WITH`.
    if tf_is_version2() and metric_names != ["Accuracy"]:
        classifier = tf.estimator.add_metrics(classifier,
                                              get_tf_metrics(metric_names))

    estimator_train_compiled(classifier, train_dataset_fn, val_dataset_fn,
                             train_max_steps, eval_start_delay_secs,
                             eval_throttle_secs)
    if FLAGS.task_index != 0:
        print("skip exporting model on worker != 0")
        return
    estimator_save(classifier, save, model_params, model_meta)
def init_model_with_feature_column(estimator,
                                   model_params,
                                   has_none_optimizer=False):
    """Check if estimator have argument "feature_column" and initialize the model
       by wrapping the keras model if no "feature_column" argument detected.

       NOTE: initalize estimator model can also use this function since
       estimators all have "feature_column" argument.
    """
    if inspect.isclass(estimator):
        argspec = inspect.getargspec(estimator.__init__)
    else:
        argspec = inspect.getargspec(estimator)
    if "feature_columns" not in argspec.args and not has_none_optimizer:
        feature_columns = model_params["feature_columns"]
        del model_params["feature_columns"]
        classifier = WrappedKerasModel(estimator, model_params,
                                       feature_columns)
    else:
        classifier = init_model(estimator, model_params)
    return classifier