def test_build_graph(self, dataset_name, regression):
   """Test whether build_graph works as expected."""
   data_x, data_y, _ = data_utils.load_dataset(dataset_name)
   data_gen = data_utils.split_training_dataset(
       data_x, data_y, n_splits=5, stratified=not regression)
   (x_train, y_train), (x_validation, y_validation) = next(data_gen)
   sess = tf.InteractiveSession()
   graph_tensors_and_ops, metric_scores = graph_builder.build_graph(
       x_train=x_train,
       y_train=y_train,
       x_test=x_validation,
       y_test=y_validation,
       activation='exu',
       learning_rate=1e-3,
       batch_size=256,
       shallow=True,
       regression=regression,
       output_regularization=0.1,
       dropout=0.1,
       decay_rate=0.999,
       name_scope='model',
       l2_regularization=0.1)
   # Run initializer ops
   sess.run(tf.global_variables_initializer())
   sess.run([
       graph_tensors_and_ops['iterator_initializer'],
       graph_tensors_and_ops['running_vars_initializer']
   ])
   for _ in range(2):
     sess.run(graph_tensors_and_ops['train_op'])
   self.assertIsInstance(metric_scores['train'](sess), float)
   sess.close()
def create_test_train_fold(fold_num):
    """Splits the dataset into training and held-out test set."""
    data_x, data_y, _ = data_utils.load_dataset(FLAGS.dataset_name)
    tf.logging.info('Dataset: %s, Size: %d', FLAGS.dataset_name, data_x.shape[0])
    tf.logging.info('Cross-val fold: %d/%d', FLAGS.fold_num, _N_FOLDS)
    # Get the training and test set based on the StratifiedKFold split
    (x_train_all, y_train_all), test_dataset = data_utils.get_train_test_fold(data_x,
                                                                              data_y,
                                                                              fold_num=fold_num,
                                                                              num_folds=_N_FOLDS,
                                                                              stratified=not FLAGS.regression)
    data_gen = data_utils.split_training_dataset(x_train_all,
                                                 y_train_all,
                                                 FLAGS.num_splits,
                                                 stratified=not FLAGS.regression)
    return data_gen, test_dataset