示例#1
0
def run_census(flags_obj):
    """Construct all necessary functions and call run_loop.

  Args:
    flags_obj: Object containing user specified flags.
  """
    if flags_obj.download_if_missing:
        census_dataset.download(flags_obj.data_dir)

    train_file = os.path.join(flags_obj.data_dir, census_dataset.TRAINING_FILE)
    test_file = os.path.join(flags_obj.data_dir, census_dataset.EVAL_FILE)

    # Train and evaluate the model every `flags.epochs_between_evals` epochs.
    def train_input_fn():
        return census_dataset.input_fn(train_file,
                                       flags_obj.epochs_between_evals, True,
                                       flags_obj.batch_size)

    def eval_input_fn():
        return census_dataset.input_fn(test_file, 1, False,
                                       flags_obj.batch_size)

    tensors_to_log = {
        'average_loss': '{loss_prefix}head/truediv',
        'loss': '{loss_prefix}head/weighted_loss/Sum'
    }

    wide_deep_run_loop.run_loop(
        name="Census Income",
        train_input_fn=train_input_fn,
        eval_input_fn=eval_input_fn,
        model_column_fn=census_dataset.build_model_columns,
        build_estimator_fn=build_estimator,
        flags_obj=flags_obj,
        tensors_to_log=tensors_to_log,
        early_stop=True)
    return child


def easy_input_function(df, label_key, num_epochs, shuffle, batch_size):
    label = df[label_key]
    ds = tf.data.Dataset.from_tensor_slices((dict(df), label))

    if shuffle:
        ds = ds.shuffle(10000)

    ds = ds.batch(batch_size).repeat(num_epochs)

    return ds


census_dataset.download("/tmp/census_data/")

if "PYTHONPATH" in os.environ:
    os.environ['PYTHONPATH'] += os.pathsep + models_path
else:
    os.environ['PYTHONPATH'] = models_path

train_file = "/tmp/census_data/adult.data"
test_file = "/tmp/census_data/adult.test"

train_df = pandas.read_csv(train_file,
                           header=None,
                           names=census_dataset._CSV_COLUMNS)
test_df = pandas.read_csv(test_file,
                          header=None,
                          names=census_dataset._CSV_COLUMNS)