Beispiel #1
0
def estimator_train_and_save(estimator, model_params, save, train_dataset_fn,
                             val_dataset_fn, log_every_n_iter, train_max_steps,
                             eval_start_delay_secs, eval_throttle_secs,
                             save_checkpoints_steps, metric_names,
                             load_pretrained_model, model_meta):
    print("Start training using estimator model...")
    model_params["model_dir"] = save

    warm_start_from = save if load_pretrained_model else None
    if warm_start_from:
        load_pretrained_model_estimator(estimator, model_params)
    classifier = init_model(estimator, model_params)

    # do not add default Accuracy metric when using estimator to train, it will
    # fail when the estimator is a regressor, and estimator seems automatically
    # add some metrics. Only add additional metrics when user specified with
    # `WITH`.
    if tf_is_version2() and metric_names != ["Accuracy"]:
        classifier = tf.estimator.add_metrics(classifier,
                                              get_tf_metrics(metric_names))

    estimator_train_compiled(classifier, train_dataset_fn, val_dataset_fn,
                             log_every_n_iter, train_max_steps,
                             eval_start_delay_secs, eval_throttle_secs)
    estimator_save(classifier, save, model_params, model_meta)
Beispiel #2
0
def estimator_train_and_save_legacy(estimator, model_params, save, FLAGS,
                                    train_dataset_fn, val_dataset_fn,
                                    train_max_steps, eval_start_delay_secs,
                                    eval_throttle_secs, save_checkpoints_steps,
                                    metric_names, load_pretrained_model,
                                    model_meta):
    print("Start training using estimator model...")
    is_distributed = False
    if len(FLAGS.worker_hosts.split(",")) > 1:
        is_distributed = True
    model_params["config"] = make_estimator_distributed_runconfig(
        FLAGS,
        estimator,
        is_distributed,
        save_checkpoints_steps=save_checkpoints_steps)
    ckpt_dir = FLAGS.checkpointDir if FLAGS.checkpointDir else save
    print("Using checkpoint path: %s" % ckpt_dir)
    model_params["model_dir"] = ckpt_dir
    model_params["config"] = tf.estimator.RunConfig(
        tf_random_seed=get_tf_random_seed(),
        save_checkpoints_steps=save_checkpoints_steps)

    warm_start_from = save if load_pretrained_model else None
    if warm_start_from:
        load_pretrained_model_estimator(estimator, model_params)
    classifier = init_model(estimator, model_params)

    # do not add default Accuracy metric when using estimator to train, it will
    # fail when the estimator is a regressor, and estimator seems automatically
    # add some metrics. Only add additional metrics when user specified with
    # `WITH`.
    if tf_is_version2() and metric_names != ["Accuracy"]:
        classifier = tf.estimator.add_metrics(classifier,
                                              get_tf_metrics(metric_names))

    estimator_train_compiled(classifier, train_dataset_fn, val_dataset_fn,
                             train_max_steps, eval_start_delay_secs,
                             eval_throttle_secs)
    if FLAGS.task_index != 0:
        print("skip exporting model on worker != 0")
        return
    estimator_save(classifier, save, model_params, model_meta)