Ejemplo n.º 1
0
def run(method, output_dir, num_epochs, fake_data=False, fake_training=False):
  """Trains a model and records its predictions on configured datasets.

  Args:
    method: Modeling method to experiment with.
    output_dir: Directory to record the trained model and output stats.
    num_epochs: Number of training epochs.
    fake_data: If true, use fake data.
    fake_training: If true, train for a trivial number of steps.
  Returns:
    Trained Keras model.
  """
  tf.io.gfile.makedirs(output_dir)
  data_config_train = data_lib.DataConfig(split='train', fake_data=fake_data)
  data_config_valid = data_lib.DataConfig(split='valid', fake_data=fake_data)

  hparams = hparams_lib.get_tuned_hparams(method, parameterization='C')
  model_opts = hparams_lib.model_opts_from_hparams(hparams, method,
                                                   parameterization='C',
                                                   fake_training=fake_training)

  experiment_utils.record_config(model_opts, output_dir+'/model_options.json')

  model = models_lib.build_and_train_model(
      model_opts, data_config_train, data_config_valid,
      output_dir=output_dir,
      num_epochs=num_epochs,
      fake_training=fake_training)

  logging.info('Saving model to output_dir.')
  model.save_weights(output_dir + '/model.ckpt')
  # TODO(yovadia): Looks like Keras save_model does not work with Python3.
  # (e.g. see b/129323565).
  # experiment_utils.save_model(model, output_dir)
  return model
Ejemplo n.º 2
0
    def test_build_dataset(self):
        config = data_lib.DataConfig(split='train', fake_data=True)
        dataset = data_lib.build_dataset(config,
                                         batch_size=8,
                                         is_training=False,
                                         fake_training=False)

        # Check output_shapes.
        features_shapes, label_shape = dataset.output_shapes
        self.assertEqual([None], label_shape.as_list())
        expected_keys = [
            data_lib.feature_name(i)
            for i in range(1, data_lib.NUM_TOTAL_FEATURES + 1)
        ]
        self.assertSameElements(expected_keys, list(features_shapes.keys()))
        for key, shape in six.iteritems(features_shapes):
            self.assertEqual([None], shape.as_list(),
                             'Unexpected shape at key=' + key)

        # Check output_types.
        features_types, label_type = tf.compat.v1.data.get_output_types(
            dataset)
        self.assertEqual(tf.float32, label_type)
        for idx in data_lib.INT_FEATURE_INDICES:
            self.assertEqual(tf.float32,
                             features_types[data_lib.feature_name(idx)])
        for idx in data_lib.CAT_FEATURE_INDICES:
            self.assertEqual(tf.string,
                             features_types[data_lib.feature_name(idx)])