Ejemplo n.º 1
0
def main():
    tf.logging.set_verbosity(FLAGS.verbosity)

    train_spec = tf.estimator.TrainSpec(
        input_fn=data.get_input_fn(
            FLAGS.train, shuffle=True, batch_size=FLAGS.batch_size
        ),
        max_steps=FLAGS.max_steps,
    )
    exporter = tf.estimator.FinalExporter(
        "estimator",
        data.json_serving_input_fn,
        as_text=False,  # change to true if you want to export the model as readable text
    )
    eval_spec = tf.estimator.EvalSpec(
        input_fn=data.get_input_fn(FLAGS.eval),
        steps=None,
        throttle_secs=FLAGS.throttle_secs,
        exporters=[exporter],
    )

    estimator = model.create_estimator()
    tf.logging.log(
        tf.logging.INFO,
        "About to start training and evaulating. To see results type\ntensorboard --logdir={}\nin another window.".format(
            FLAGS.job_dir
        ),
    )
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
    tf.logging.log(
        tf.logging.INFO,
        "Finished training and evaulating. To see results type\ntensorboard --logdir={}\nin another window.".format(
            FLAGS.job_dir
        ),
    )
Ejemplo n.º 2
0
def predict(override_cfg, model_dir):
  """Run model over a dataset and dump predictions to json file."""
  assert FLAGS.predict_path
  cfg = _load_config(model_dir)
  cfg = utils.merge(cfg, override_cfg)
  input_fn = data.get_input_fn(
      split=cfg.dataset.eval_split,
      max_length=None,
      repeat=False,
      shuffle=False,
      cache=False,
      limit=None,
      data_path=cfg.dataset.data_path,
      vocab_path=cfg.dataset.vocab_path,
      is_tpu=False,
      use_generator=True,
      is_training=False)
  estimator = model.get_estimator(**cfg)
  predictions = dict()
  for i, prediction in enumerate(estimator.predict(input_fn)):
    predictions[prediction["id"]] = prediction["answer"]
    if i % 100 == 0:
      tf.logging.info("Prediction %s | %s: %s" % (i, prediction["id"],
                                                  prediction["answer"]))

  # Dump results to a file
  with tf.gfile.GFile(FLAGS.predict_path, "w") as f:
    json.dump(predictions, f)
def make_experiment(data_dir,
                    batch_size,
                    model_dir,
                    learning_rate=1e-4,
                    save_checkpoints_secs=60,
                    repeat_and_shuffle_training_data=False):
    """
  Creates the experiment for training MNIST.

  Args:
    data_dir: Directory to save the MNIST data.
    batch_size: Size of the batches used for training and evaluation.
    model_dir: Directory to save the model data.
    learning_rate: Learning rate of the Adam optimizer.
    save_checkpoints_secs: How often to save the model parameters.
    repeat_and_shuffle_training_data: Make the data generation infinite and
        shuffle the data. Use for training only.

  Returns:
    The created tf.contrib.learn.Experiment.
  """
    config = tf.contrib.learn.RunConfig().replace(
        save_checkpoints_secs=save_checkpoints_secs, keep_checkpoint_max=None)
    params = {LEARNING_RATE_KEY: learning_rate}
    mnist_data = mnist.input_data.read_data_sets(data_dir, one_hot=True)
    estimator = tf.estimator.Estimator(_model_fn, model_dir, config, params)
    train_input_fn = data.get_input_fn(
        mnist_data.train.images,
        mnist_data.train.labels,
        batch_size,
        repeat=repeat_and_shuffle_training_data,
        shuffle=repeat_and_shuffle_training_data)
    eval_input_fn = data.get_input_fn(mnist_data.validation.images,
                                      mnist_data.validation.labels,
                                      batch_size,
                                      shuffle=False)
    return tf.contrib.learn.Experiment(
        estimator,
        train_input_fn,
        eval_input_fn,
        eval_steps=None,
    )
Ejemplo n.º 4
0
def get_input_fns(params):
    """Gets the train, eval, and eval on train data input fns."""
    dataset_dir = params.dataset_dir
    tf.logging.info('Using dataset_dir %r', dataset_dir)

    with tf.name_scope('data'):
        dataset_train_fn = functools.partial(data.get_holparam_dataset, TRAIN)
        dataset_eval_fn = functools.partial(data.get_holparam_dataset, EVAL)

        train_input_fn = data.get_input_fn(dataset_train_fn,
                                           TRAIN,
                                           params,
                                           shuffle_queue=params.shuffle_queue,
                                           parser=params.train_parser,
                                           filt=params.train_filter)
        eval_input_fn = data.get_input_fn(dataset_eval_fn,
                                          EVAL,
                                          params,
                                          parser=params.eval_parser,
                                          filt=params.eval_filter)

        return train_input_fn, eval_input_fn
Ejemplo n.º 5
0
def train(hparams):
    ##### Create input function

    ######read data#######
    if FLAGS.do_train:
        train_input_fn = data.get_input_fn(
            data_dir=FLAGS.data_dir,
            split="train",
            task_name=FLAGS.task_name,
            sup_size=FLAGS.sup_size,
            unsup_ratio=FLAGS.unsup_ratio,
            aug_copy=FLAGS.aug_copy,
        )

    if FLAGS.do_eval:
        eval_input_fn = data.get_input_fn(data_dir=FLAGS.data_dir,
                                          split="test",
                                          task_name=FLAGS.task_name,
                                          sup_size=-1,
                                          unsup_ratio=0,
                                          aug_copy=0)
        if FLAGS.task_name == "cifar10":
            eval_size = 10000
        elif FLAGS.task_name == "svhn":
            eval_size = 26032
        else:
            raise ValueError, "You need to specify the size of your test set."
        eval_steps = eval_size // FLAGS.eval_batch_size

    ##### Get model function
    model_fn = get_model_fn(hparams)
    estimator = utils.get_TPU_estimator(FLAGS, model_fn)

    #### Training
    if FLAGS.do_eval_along_training:
        tf.logging.info("***** Running training & evaluation *****")
        tf.logging.info("  Supervised batch size = %d", FLAGS.train_batch_size)
        tf.logging.info("  Unsupervised batch size = %d",
                        FLAGS.train_batch_size * FLAGS.unsup_ratio)
        tf.logging.info("  Num train steps = %d", FLAGS.train_steps)

        while True:
            if FLAGS.curr_step >= FLAGS.train_steps:
                break
            tf.logging.info("Current step {}".format(FLAGS.curr_step))
            train_step = min(FLAGS.save_steps,
                             FLAGS.train_steps - FLAGS.curr_step)
            estimator.train(input_fn=train_input_fn, steps=train_step)
            estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
            FLAGS.curr_step += FLAGS.save_steps
    else:
        if FLAGS.do_train:
            tf.logging.info("***** Running training *****")
            tf.logging.info("  Supervised batch size = %d",
                            FLAGS.train_batch_size)
            tf.logging.info("  Unsupervised batch size = %d",
                            FLAGS.train_batch_size * FLAGS.unsup_ratio)
            estimator.train(input_fn=train_input_fn,
                            max_steps=FLAGS.train_steps)
        if FLAGS.do_eval:
            tf.logging.info("***** Running evaluation *****")
            results = estimator.evaluate(input_fn=eval_input_fn,
                                         steps=eval_steps)
            tf.logging.info(">> Results:")
            for key in results.keys():
                tf.logging.info("  %s = %s", key, str(results[key]))
                results[key] = results[key].item()
            acc = results["eval/classify_accuracy"]
            with tf.gfile.Open("{}/results.txt".format(FLAGS.model_dir),
                               "w") as ouf:
                ouf.write(str(acc))
Ejemplo n.º 6
0
def train():
    """
    Training and evaluation.
    """
    with tf.io.gfile.GFile(os.path.join(FLAGS.data_dir, 'data_info.json'),
                           'r') as fp:
        data_info = json.load(fp)
    if FLAGS.unsup_ratio == 0:
        FLAGS.unsup_cut = 0.0

    # Create input functions
    train_input_fn = data.get_input_fn(data_dir=FLAGS.data_dir,
                                       split="train",
                                       data_size=data_info['train']['size'],
                                       batch_size=FLAGS.train_batch_size,
                                       sup_cut=FLAGS.sup_cut,
                                       unsup_cut=FLAGS.unsup_cut,
                                       unsup_ratio=FLAGS.unsup_ratio,
                                       shuffle_seed=FLAGS.shuffle_seed)

    eval_input_fn = data.get_input_fn(data_dir=FLAGS.data_dir,
                                      split="val",
                                      data_size=data_info['val']['size'],
                                      batch_size=FLAGS.eval_batch_size,
                                      sup_cut=1.0,
                                      unsup_cut=0.0,
                                      unsup_ratio=0)

    eval_size = data_info['val']['size']
    eval_steps = eval_size // FLAGS.eval_batch_size

    epoch_steps = int(
        (data_info['train']['size'] * FLAGS.sup_cut) / FLAGS.train_batch_size)

    # Get model function
    model_fn = get_model_fn()
    estimator = utils.get_estimator(FLAGS, model_fn, epoch_steps)

    # Training
    if FLAGS.do_eval_along_training:

        tf.compat.v1.logging.info("***** Running training & evaluation *****")
        tf.compat.v1.logging.info("  Supervised batch size = %d",
                                  FLAGS.train_batch_size)
        tf.compat.v1.logging.info("  Unsupervised batch size = %d",
                                  FLAGS.train_batch_size * FLAGS.unsup_ratio)
        tf.compat.v1.logging.info("  Num train steps = %d", FLAGS.train_steps)

        # 7 is arbitrary and will be eliminated
        serving_input_receiver_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(
            {'image': tf.zeros([7] + [224, 224, 4], tf.float32)})

        best_exporter = tf.estimator.BestExporter(
            name="best_exporter",
            serving_input_receiver_fn=serving_input_receiver_fn,
            exports_to_keep=1)

        best_ckpt = best_ckpt_copier.BestCheckpointCopier(
            name="best_checkpoint", checkpoints_to_keep=1, score_metric='loss')

        latest_exporter = tf.estimator.LatestExporter(
            name="latest_exporter",
            serving_input_receiver_fn=serving_input_receiver_fn,
            exports_to_keep=None)

        hooks = []

        if FLAGS.early_stop_steps != -1:
            # Hook to stop training if loss does not decrease in over FLAGS.early_stop_steps steps.
            hooks.append(
                tf.estimator.experimental.stop_if_no_decrease_hook(
                    estimator,
                    "loss",
                    FLAGS.early_stop_steps,
                    min_steps=FLAGS.min_step))

        train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn,
                                            max_steps=FLAGS.train_steps,
                                            hooks=hooks)
        eval_spec = tf.estimator.EvalSpec(
            input_fn=eval_input_fn,
            steps=eval_steps,
            exporters=[best_exporter, latest_exporter, best_ckpt],
            start_delay_secs=0,
            throttle_secs=10)
        tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

    if FLAGS.do_predict:
        tf.compat.v1.logging.info('***** Running prediction *****')

        with open('data/processed_data/aug_seed.json', 'r') as fp:
            aug_seeds_dict = json.load(fp)
        aug_seeds_list = list(
            itertools.chain.from_iterable(
                list(aug_seeds_dict[FLAGS.pred_dataset].values())))

        if FLAGS.pred_ckpt == 'best':
            out_path = os.path.join(
                FLAGS.model_dir,
                'best_{}_prediction'.format(FLAGS.pred_dataset))
            export_dir = sorted(
                glob.glob(
                    os.path.join(FLAGS.model_dir,
                                 'export/best_exporter/*')))[-1]
        else:
            out_path = os.path.join(
                FLAGS.model_dir,
                '{}_{}_prediction'.format(FLAGS.pred_ckpt, FLAGS.pred_dataset))
            export_dir = os.path.join(
                FLAGS.model_dir,
                'export/latest_exporter/{}'.format(FLAGS.pred_ckpt))

        if not os.path.exists(out_path):
            os.makedirs(out_path)

        model = tf.compat.v2.saved_model.load(export_dir)
        predict = model.signatures["serving_default"]

        tf.compat.v1.logging.info('Predicting images')

        test_input_fn = data.get_input_fn(
            data_dir=FLAGS.data_dir,
            split=FLAGS.pred_dataset,
            data_size=data_info[FLAGS.pred_dataset]['size'],
            batch_size=FLAGS.pred_batch_size,
            sup_cut=1.0,
            unsup_cut=0.0,
            unsup_ratio=0)

        dataset = test_input_fn()
        iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)

        preds = []
        aug_preds = []

        example_cnt = 1
        patient_cnt = 0
        img_cnt = 0

        def pad_and_export(patient_pred: np.ndarray, mode: str):
            """
            Pad and export predicted segmentation maps as .nii.gz-files.
            :param patient_pred: Predicted segmentation maps
            :param mode: Standard prediction, augmented prediction
             or predicted augmentation
            """
            # Make the prediction dims (155,240,240) again for BRATS evaluation
            patient_pred = np.pad(patient_pred, ((0, 0), (8, 8), (8, 8)))

            below_padding = np.zeros(
                (data_info[FLAGS.pred_dataset]['crop_idx'][patient_cnt][0],
                 240, 240))
            above_padding = np.zeros(
                (155 -
                 data_info[FLAGS.pred_dataset]['crop_idx'][patient_cnt][1],
                 240, 240))
            patient_pred = np.concatenate(
                [below_padding, patient_pred, above_padding])

            if not os.path.exists(os.path.join(out_path, mode)):
                os.makedirs(os.path.join(out_path, mode))

            patient_pred_nii = sitk.GetImageFromArray(patient_pred)
            sitk.WriteImage(
                patient_pred_nii,
                os.path.join(out_path, mode, '{}.nii.gz'.format(patient_id)))

            tf.compat.v1.logging.info('Exported patient {}'.format(patient_id))

        for sample in iterator:
            imgs = sample['image'].numpy()

            aug_seeds = aug_seeds_list[img_cnt:img_cnt + imgs.shape[0]]
            imgs_aug = data_aug(imgs, aug_seeds, is_seg_maps=False)

            pred_batch = predict(image=tf.constant(imgs))['prediction'].numpy()
            aug_pred_batch = predict(
                image=tf.constant(imgs_aug))['prediction'].numpy()

            for i in range(pred_batch.shape[0]):
                pred = pred_batch[i, ...]
                aug_pred = aug_pred_batch[i, ...]

                pred[np.where(pred == 3)] = 4
                aug_pred[np.where(aug_pred == 3)] = 4

                preds.append(pred)
                aug_preds.append(aug_pred)

                if example_cnt == data_info[
                        FLAGS.pred_dataset]['slices'][patient_cnt]:
                    patient_id = data_info[FLAGS.pred_dataset]['paths'][
                        patient_cnt].split('/')[-1]

                    patient_pred = np.stack(preds)
                    patient_aug_pred = np.stack(aug_preds)
                    patient_pred_aug = data_aug(
                        patient_pred,
                        aug_seeds_dict[FLAGS.pred_dataset][patient_id],
                        is_seg_maps=True)

                    pad_and_export(patient_pred, 'standard')
                    pad_and_export(patient_aug_pred, 'aug_pred')
                    pad_and_export(patient_pred_aug, 'pred_aug')

                    example_cnt = 1
                    patient_cnt += 1
                    preds = []
                    aug_preds = []
                else:
                    example_cnt += 1
            img_cnt += imgs.shape[0]

        if FLAGS.pred_dataset == 'val':
            tf.compat.v1.logging.info('Calculating standard Dice scores')
            calc_and_export_standard_dice(os.path.join(out_path, 'standard'))

        tf.compat.v1.logging.info('Calculating equivariance Dice scores')
        calc_and_export_equivariance_dice(os.path.join(out_path, 'pred_aug'),
                                          os.path.join(out_path, 'aug_pred'))
Ejemplo n.º 7
0
from estimator import get_estimator
from data import get_input_fn

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='training script')
    parser.add_argument('model_name', nargs='?', type=str, default='base')
    parser.add_argument('--batch_size', '-b', nargs='?', type=int, default=64)
    parser.add_argument('--max_steps', '-s', nargs='?', type=int, default=1e6)
    args = parser.parse_args()

    estimator = get_estimator(args.model_name)
    input_fn = get_input_fn(args.batch_size, shuffle=True)
    estimator.train(input_fn, max_steps=args.max_steps)