def run(method, output_dir, num_epochs, fake_data=False, fake_training=False): """Trains a model and records its predictions on configured datasets. Args: method: Modeling method to experiment with. output_dir: Directory to record the trained model and output stats. num_epochs: Number of training epochs. fake_data: If true, use fake data. fake_training: If true, train for a trivial number of steps. Returns: Trained Keras model. """ tf.io.gfile.makedirs(output_dir) data_config_train = data_lib.DataConfig(split='train', fake_data=fake_data) data_config_valid = data_lib.DataConfig(split='valid', fake_data=fake_data) hparams = hparams_lib.get_tuned_hparams(method, parameterization='C') model_opts = hparams_lib.model_opts_from_hparams(hparams, method, parameterization='C', fake_training=fake_training) experiment_utils.record_config(model_opts, output_dir+'/model_options.json') model = models_lib.build_and_train_model( model_opts, data_config_train, data_config_valid, output_dir=output_dir, num_epochs=num_epochs, fake_training=fake_training) logging.info('Saving model to output_dir.') model.save_weights(output_dir + '/model.ckpt') # TODO(yovadia): Looks like Keras save_model does not work with Python3. # (e.g. see b/129323565). # experiment_utils.save_model(model, output_dir) return model
def test_build_dataset(self): config = data_lib.DataConfig(split='train', fake_data=True) dataset = data_lib.build_dataset(config, batch_size=8, is_training=False, fake_training=False) # Check output_shapes. features_shapes, label_shape = dataset.output_shapes self.assertEqual([None], label_shape.as_list()) expected_keys = [ data_lib.feature_name(i) for i in range(1, data_lib.NUM_TOTAL_FEATURES + 1) ] self.assertSameElements(expected_keys, list(features_shapes.keys())) for key, shape in six.iteritems(features_shapes): self.assertEqual([None], shape.as_list(), 'Unexpected shape at key=' + key) # Check output_types. features_types, label_type = tf.compat.v1.data.get_output_types( dataset) self.assertEqual(tf.float32, label_type) for idx in data_lib.INT_FEATURE_INDICES: self.assertEqual(tf.float32, features_types[data_lib.feature_name(idx)]) for idx in data_lib.CAT_FEATURE_INDICES: self.assertEqual(tf.string, features_types[data_lib.feature_name(idx)])