예제 #1
0
파일: main.py 프로젝트: grananqvist/tpu
def main(unused_argv):
    train_dataset = segmentation_dataset.get_dataset(
        FLAGS.dataset_name, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir)
    eval_dataset = segmentation_dataset.get_dataset(
        FLAGS.dataset_name, FLAGS.eval_split, dataset_dir=FLAGS.dataset_dir)

    num_train_images = train_dataset.num_samples
    num_classes = train_dataset.num_classes
    ignore_label = train_dataset.ignore_label

    num_batches_per_epoch = num_train_images / FLAGS.train_batch_size

    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
    config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_shards))

    params = get_params(ignore_label, num_classes, num_batches_per_epoch)

    deeplab_estimator = tf.contrib.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model.model_fn,
        config=config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        params=params)

    if FLAGS.mode == 'train':
        tf.logging.info(
            'Training for %d steps (%.2f epochs in total).' %
            (FLAGS.train_steps, FLAGS.train_steps / num_batches_per_epoch))
        train_input_fn = data_pipeline.InputReader(
            train_dataset,
            FLAGS.train_split,
            is_training=True,
            model_variant=FLAGS.model_variant)
        deeplab_estimator.train(input_fn=train_input_fn,
                                max_steps=FLAGS.train_steps)
    elif FLAGS.mode == 'train_and_eval':
        train_and_eval(deeplab_estimator, train_dataset, eval_dataset,
                       num_batches_per_epoch)
    elif FLAGS.mode == 'eval':

        eval_input_fn = data_pipeline.InputReader(
            eval_dataset,
            FLAGS.eval_split,
            is_training=False,
            model_variant=FLAGS.model_variant)

        # Run evaluation when there's a new checkpoint
        for ckpt in tf.contrib.training.checkpoints_iterator(
                FLAGS.model_dir, timeout=FLAGS.eval_timeout):

            tf.logging.info('Starting to evaluate.')
            try:
                eval_results = deeplab_estimator.evaluate(
                    input_fn=eval_input_fn,
                    steps=eval_dataset.num_samples // FLAGS.eval_batch_size)
                tf.logging.info('Eval results: %s' % eval_results)

                # Terminate eval job when final checkpoint is reached
                current_step = int(os.path.basename(ckpt).split('-')[1])
                if current_step >= FLAGS.train_steps:
                    tf.logging.info(
                        'Evaluation finished after training step %d' %
                        current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint' %
                    ckpt)
    else:
        tf.logging.error('Mode not found.')
예제 #2
0
                   'outputs_to_num_classes': outputs_to_num_classes})

    tf.logging.debug('Params: ')
    for k, v in sorted(params.items()):
        tf.logging.debug('%s: %s', k, v)
    return params


# %%
# def main(unused_argv):
# Get dataset
# ----------------------------------------------------------------------
dataset_name = 'xview2'
train_split = 'train'
dataset_dir = 'gs://lkk-xview2/xBD/spacenet_gt/images'
train_dataset = segmentation_dataset.get_dataset(
    dataset_name, train_split, dataset_dir=dataset_dir)
train_dataset
# %%
eval_split = 'val'
eval_dataset = segmentation_dataset.get_dataset(
    dataset_name, eval_split, dataset_dir=dataset_dir)

# %%
train_batch_size = 64

num_train_images = train_dataset.num_samples
num_classes = train_dataset.num_classes
ignore_label = train_dataset.ignore_label
num_batches_per_epoch = num_train_images / train_batch_size

print(num_train_images, num_classes, ignore_label, num_batches_per_epoch)