def test_archive_ckpt(self):
        model_dir = os.path.join(tf.test.get_temp_dir(), 'model_dir')
        ckpt_path = os.path.join(model_dir, 'ckpt')
        self.build_model()
        saver = tf.train.Saver()
        with self.session() as sess:
            sess.run(tf.global_variables_initializer())
            saver.save(sess, ckpt_path)

        # Save checkpoint if the new objective is better.
        self.assertTrue(utils.archive_ckpt('eval1', 0.1, ckpt_path))
        logging.info(os.listdir(model_dir))
        self.assertTrue(tf.io.gfile.exists(os.path.join(model_dir, 'archive')))
        self.assertFalse(tf.io.gfile.exists(os.path.join(model_dir, 'backup')))

        # Save checkpoint if the new objective is better.
        self.assertTrue(utils.archive_ckpt('eval2', 0.2, ckpt_path))
        self.assertTrue(tf.io.gfile.exists(os.path.join(model_dir, 'archive')))
        self.assertTrue(tf.io.gfile.exists(os.path.join(model_dir, 'backup')))

        # Skip checkpoint if the new objective is worse.
        self.assertFalse(utils.archive_ckpt('eval3', 0.1, ckpt_path))

        # Save checkpoint if the new objective is better.
        self.assertTrue(utils.archive_ckpt('eval4', 0.3, ckpt_path))

        # Save checkpoint if the new objective is equal.
        self.assertTrue(utils.archive_ckpt('eval5', 0.3, ckpt_path))
        self.assertTrue(tf.io.gfile.exists(os.path.join(model_dir, 'archive')))
        self.assertTrue(tf.io.gfile.exists(os.path.join(model_dir, 'backup')))
示例#2
0
 def run_train_and_eval(e):
   print('\n   =====> Starting training, epoch: %d.' % e)
   train_est.train(
       input_fn=train_input_fn,
       max_steps=e * FLAGS.num_examples_per_epoch // FLAGS.train_batch_size)
   print('\n   =====> Starting evaluation, epoch: %d.' % e)
   eval_results = eval_est.evaluate(input_fn=eval_input_fn, steps=eval_steps)
   ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
   utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)
示例#3
0
 def run_train_and_eval(e):
     print('-----------------------------------------------------\n'
           '=====> Starting training, epoch: %d.' % e)
     _train(e * FLAGS.num_examples_per_epoch // FLAGS.train_batch_size)
     print('-----------------------------------------------------\n'
           '=====> Starting evaluation, epoch: %d.' % e)
     eval_results = _eval(eval_steps)
     ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
     utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)
示例#4
0
def main(argv):
    del argv  # Unused.

    hvd_try_init()
    set_env(use_amp=FLAGS.use_amp)

    # Check data path
    if FLAGS.mode in (
            'train', 'train_and_eval') and FLAGS.training_file_pattern is None:
        raise RuntimeError(
            'You must specify --training_file_pattern for training.')
    if FLAGS.mode in ('eval', 'train_and_eval'):
        if FLAGS.validation_file_pattern is None:
            raise RuntimeError('You must specify --validation_file_pattern '
                               'for evaluation.')
        if not FLAGS.val_json_file and not FLAGS.testdev_dir:
            raise RuntimeError(
                'You must specify --val_json_file or --testdev for evaluation.'
            )

    # Parse and override hparams
    config = hparams_config.get_detection_config(FLAGS.model_name)
    config.override(FLAGS.hparams)

    # The following is for spatial partitioning. `features` has one tensor while
    # `labels` had 4 + (`max_level` - `min_level` + 1) * 2 tensors. The input
    # partition is performed on `features` and all partitionable tensors of
    # `labels`, see the partition logic below.
    # In the TPUEstimator context, the meaning of `shard` and `replica` is the
    # same; follwing the API, here has mixed use of both.
    if FLAGS.use_spatial_partition:
        # Checks input_partition_dims agrees with num_cores_per_replica.
        if FLAGS.num_cores_per_replica != np.prod(FLAGS.input_partition_dims):
            raise RuntimeError(
                '--num_cores_per_replica must be a product of array'
                'elements in --input_partition_dims.')

        labels_partition_dims = {
            'mean_num_positives': None,
            'source_ids': None,
            'groundtruth_data': None,
            'image_scales': None,
        }
        # The Input Partition Logic: We partition only the partition-able tensors.
        # Spatial partition requires that the to-be-partitioned tensors must have a
        # dimension that is a multiple of `partition_dims`. Depending on the
        # `partition_dims` and the `image_size` and the `max_level` in config, some
        # high-level anchor labels (i.e., `cls_targets` and `box_targets`) cannot
        # be partitioned. For example, when `partition_dims` is [1, 4, 2, 1], image
        # size is 1536, `max_level` is 9, `cls_targets_8` has a shape of
        # [batch_size, 6, 6, 9], which cannot be partitioned (6 % 4 != 0). In this
        # case, the level-8 and level-9 target tensors are not partition-able, and
        # the highest partition-able level is 7.
        image_size = config.get('image_size')
        for level in range(config.get('min_level'),
                           config.get('max_level') + 1):

            def _can_partition(spatial_dim):
                partitionable_index = np.where(
                    spatial_dim % np.array(FLAGS.input_partition_dims) == 0)
                return len(partitionable_index[0]) == len(
                    FLAGS.input_partition_dims)

            spatial_dim = image_size // (2**level)
            if _can_partition(spatial_dim):
                labels_partition_dims['box_targets_%d' %
                                      level] = FLAGS.input_partition_dims
                labels_partition_dims['cls_targets_%d' %
                                      level] = FLAGS.input_partition_dims
            else:
                labels_partition_dims['box_targets_%d' % level] = None
                labels_partition_dims['cls_targets_%d' % level] = None
        num_cores_per_replica = FLAGS.num_cores_per_replica
        input_partition_dims = [
            FLAGS.input_partition_dims, labels_partition_dims
        ]
        num_shards = FLAGS.num_cores // num_cores_per_replica
    else:
        num_cores_per_replica = None
        input_partition_dims = None
        num_shards = FLAGS.num_cores
        if hvd is not None:
            num_shards = hvd.size()

    params = dict(
        config.as_dict(),
        model_name=FLAGS.model_name,
        num_epochs=FLAGS.num_epochs,
        iterations_per_loop=FLAGS.iterations_per_loop,
        model_dir=FLAGS.model_dir,
        num_shards=num_shards,
        num_examples_per_epoch=FLAGS.num_examples_per_epoch,
        use_tpu=FLAGS.use_tpu,
        backbone_ckpt=FLAGS.backbone_ckpt,
        ckpt=FLAGS.ckpt,
        val_json_file=FLAGS.val_json_file,
        testdev_dir=FLAGS.testdev_dir,
        mode=FLAGS.mode,
    )

    run_config = tf.estimator.RunConfig(
        session_config=get_session_config(use_xla=FLAGS.use_xla),
        save_checkpoints_steps=600)

    model_fn_instance = det_model_fn.get_model_fn(FLAGS.model_name)

    # TPU Estimator
    logging.info(params)
    if FLAGS.mode == 'train':
        # train_estimator = tf.estimator.tpu.TPUEstimator(
        #     model_fn=model_fn_instance,
        #     use_tpu=FLAGS.use_tpu,
        #     train_batch_size=FLAGS.train_batch_size,
        #     config=run_config,
        #     params=params)
        params['batch_size'] = FLAGS.train_batch_size
        train_estimator = HorovodEstimator(model_fn=model_fn_instance,
                                           model_dir=FLAGS.model_dir,
                                           config=run_config,
                                           params=params)

        input_fn = dataloader.InputReader(FLAGS.training_file_pattern,
                                          is_training=True,
                                          params=params,
                                          use_fake_data=FLAGS.use_fake_data)
        max_steps = int((FLAGS.num_epochs * FLAGS.num_examples_per_epoch) /
                        (FLAGS.train_batch_size * num_shards)) + 1

        train_estimator.train(input_fn=input_fn, max_steps=max_steps)

        # if FLAGS.eval_after_training:
        #   # Run evaluation after training finishes.
        #   eval_params = dict(
        #       params,
        #       use_tpu=False,
        #       input_rand_hflip=False,
        #       is_training_bn=False,
        #       use_bfloat16=False,
        #   )
        #   eval_estimator = tf.estimator.tpu.TPUEstimator(
        #       model_fn=model_fn_instance,
        #       use_tpu=False,
        #       train_batch_size=FLAGS.train_batch_size,
        #       eval_batch_size=FLAGS.eval_batch_size,
        #       config=run_config,
        #       params=eval_params)
        #   eval_results = eval_estimator.evaluate(
        #       input_fn=dataloader.InputReader(FLAGS.validation_file_pattern,
        #                                       is_training=False),
        #       steps=FLAGS.eval_samples//FLAGS.eval_batch_size)
        #   logging.info('Eval results: %s', eval_results)
        #   ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
        #   utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)

    elif FLAGS.mode == 'eval':
        config_proto = tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=False)
        if FLAGS.use_xla and not FLAGS.use_tpu:
            config_proto.graph_options.optimizer_options.global_jit_level = (
                tf.OptimizerOptions.ON_1)

        tpu_config = tf.estimator.tpu.TPUConfig(
            FLAGS.iterations_per_loop,
            num_shards=num_shards,
            num_cores_per_replica=num_cores_per_replica,
            input_partition_dims=input_partition_dims,
            per_host_input_for_training=tf.estimator.tpu.InputPipelineConfig.
            PER_HOST_V2)

        run_config = tf.estimator.tpu.RunConfig(
            cluster=None,
            evaluation_master=FLAGS.eval_master,
            model_dir=FLAGS.model_dir,
            log_step_count_steps=FLAGS.iterations_per_loop,
            session_config=config_proto,
            tpu_config=tpu_config,
        )

        # Eval only runs on CPU or GPU host with batch_size = 1.
        # Override the default options: disable randomization in the input pipeline
        # and don't run on the TPU.
        # Also, disable use_bfloat16 for eval on CPU/GPU.
        eval_params = dict(
            params,
            use_tpu=False,
            input_rand_hflip=False,
            is_training_bn=False,
            use_bfloat16=False,
        )

        eval_estimator = tf.estimator.tpu.TPUEstimator(
            model_fn=model_fn_instance,
            use_tpu=False,
            train_batch_size=FLAGS.train_batch_size,
            eval_batch_size=FLAGS.eval_batch_size,
            config=run_config,
            params=eval_params)

        def terminate_eval():
            logging.info('Terminating eval after %d seconds of no checkpoints',
                         FLAGS.eval_timeout)
            return True

        # Run evaluation when there's a new checkpoint
        for ckpt in tf.train.checkpoints_iterator(
                FLAGS.model_dir,
                min_interval_secs=FLAGS.min_eval_interval,
                timeout=FLAGS.eval_timeout,
                timeout_fn=terminate_eval):

            logging.info('Starting to evaluate.')
            try:
                eval_results = eval_estimator.evaluate(
                    input_fn=dataloader.InputReader(
                        FLAGS.validation_file_pattern, is_training=False),
                    steps=FLAGS.eval_samples // FLAGS.eval_batch_size)
                logging.info('Eval results: %s', eval_results)

                # Terminate eval job when final checkpoint is reached.
                try:
                    current_step = int(os.path.basename(ckpt).split('-')[1])
                except IndexError:
                    logging.info('%s has no global step info: stop!', ckpt)
                    break

                utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)
                total_step = int(
                    (FLAGS.num_epochs * FLAGS.num_examples_per_epoch) /
                    FLAGS.train_batch_size)
                if current_step >= total_step:
                    logging.info('Evaluation finished after training step %d',
                                 current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint',
                    ckpt)

    elif FLAGS.mode == 'train_and_eval':
        for cycle in range(FLAGS.num_epochs):
            logging.info('Starting training cycle, epoch: %d.', cycle)
            train_estimator = tf.estimator.tpu.TPUEstimator(
                model_fn=model_fn_instance,
                use_tpu=FLAGS.use_tpu,
                train_batch_size=FLAGS.train_batch_size,
                config=run_config,
                params=params)
            train_estimator.train(input_fn=dataloader.InputReader(
                FLAGS.training_file_pattern,
                is_training=True,
                use_fake_data=FLAGS.use_fake_data),
                                  steps=int(FLAGS.num_examples_per_epoch /
                                            FLAGS.train_batch_size))

            logging.info('Starting evaluation cycle, epoch: %d.', cycle)
            # Run evaluation after every epoch.
            eval_params = dict(
                params,
                use_tpu=False,
                input_rand_hflip=False,
                is_training_bn=False,
            )

            eval_estimator = tf.estimator.tpu.TPUEstimator(
                model_fn=model_fn_instance,
                use_tpu=False,
                train_batch_size=FLAGS.train_batch_size,
                eval_batch_size=FLAGS.eval_batch_size,
                config=run_config,
                params=eval_params)
            eval_results = eval_estimator.evaluate(
                input_fn=dataloader.InputReader(FLAGS.validation_file_pattern,
                                                is_training=False),
                steps=FLAGS.eval_samples // FLAGS.eval_batch_size)
            logging.info('Evaluation results: %s', eval_results)
            ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
            utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)
        pass

    else:
        logging.info('Mode not found.')
示例#5
0
def main(unused_argv):

    input_image_size = FLAGS.input_image_size
    if not input_image_size:
        input_image_size = model_builder_factory.get_model_input_size(
            FLAGS.model_name)

    if FLAGS.holdout_shards:
        holdout_images = int(FLAGS.num_train_images * FLAGS.holdout_shards /
                             1024.0)
        FLAGS.num_train_images -= holdout_images
        FLAGS.num_eval_images = holdout_images

    # For imagenet dataset, include background label if number of output classes
    # is 1001
    include_background_label = (FLAGS.num_label_classes == 1001)

    if FLAGS.tpu or FLAGS.use_tpu:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
    else:
        tpu_cluster_resolver = None

    if FLAGS.use_async_checkpointing:
        save_checkpoints_steps = None
    else:
        save_checkpoints_steps = max(100, FLAGS.iterations_per_loop)
    config = tf.estimator.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=save_checkpoints_steps,
        log_step_count_steps=FLAGS.log_step_count_steps,
        session_config=tf.ConfigProto(
            graph_options=tf.GraphOptions(
                rewrite_options=rewriter_config_pb2.RewriterConfig(
                    disable_meta_optimizer=True))),
        tpu_config=tf.estimator.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            per_host_input_for_training=tf.estimator.tpu.InputPipelineConfig
            .PER_HOST_V2))  # pylint: disable=line-too-long
    # Initializes model parameters.
    params = dict(steps_per_epoch=FLAGS.num_train_images /
                  FLAGS.train_batch_size,
                  use_bfloat16=FLAGS.use_bfloat16)
    est = tf.estimator.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        export_to_tpu=FLAGS.export_to_tpu,
        params=params)

    if (FLAGS.model_name.startswith('efficientnet-lite')
            or FLAGS.model_name.startswith('efficientnet-edgetpu')):
        # lite or edgetpu use binlinear for easier post-quantization.
        resize_method = tf.image.ResizeMethod.BILINEAR
    else:
        resize_method = None
    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    def build_imagenet_input(is_training):
        """Generate ImageNetInput for training and eval."""
        if FLAGS.bigtable_instance:
            logging.info('Using Bigtable dataset, table %s',
                         FLAGS.bigtable_table)
            select_train, select_eval = _select_tables_from_flags()
            return imagenet_input.ImageNetBigtableInput(
                is_training=is_training,
                use_bfloat16=FLAGS.use_bfloat16,
                transpose_input=FLAGS.transpose_input,
                selection=select_train if is_training else select_eval,
                num_label_classes=FLAGS.num_label_classes,
                include_background_label=include_background_label,
                augment_name=FLAGS.augment_name,
                mixup_alpha=FLAGS.mixup_alpha,
                randaug_num_layers=FLAGS.randaug_num_layers,
                randaug_magnitude=FLAGS.randaug_magnitude,
                resize_method=resize_method)
        else:
            if FLAGS.data_dir == FAKE_DATA_DIR:
                logging.info('Using fake dataset.')
            else:
                logging.info('Using dataset: %s', FLAGS.data_dir)

            return imagenet_input.ImageNetInput(
                is_training=is_training,
                data_dir=FLAGS.data_dir,
                transpose_input=FLAGS.transpose_input,
                cache=FLAGS.use_cache and is_training,
                image_size=input_image_size,
                num_parallel_calls=FLAGS.num_parallel_calls,
                use_bfloat16=FLAGS.use_bfloat16,
                num_label_classes=FLAGS.num_label_classes,
                include_background_label=include_background_label,
                augment_name=FLAGS.augment_name,
                mixup_alpha=FLAGS.mixup_alpha,
                randaug_num_layers=FLAGS.randaug_num_layers,
                randaug_magnitude=FLAGS.randaug_magnitude,
                resize_method=resize_method,
                holdout_shards=FLAGS.holdout_shards)

    imagenet_train = build_imagenet_input(is_training=True)
    imagenet_eval = build_imagenet_input(is_training=False)

    if FLAGS.mode == 'eval':
        eval_steps = FLAGS.num_eval_images // FLAGS.eval_batch_size
        # Run evaluation when there's a new checkpoint
        for ckpt in tf.train.checkpoints_iterator(FLAGS.model_dir,
                                                  timeout=FLAGS.eval_timeout):
            logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = est.evaluate(input_fn=imagenet_eval.input_fn,
                                            steps=eval_steps,
                                            checkpoint_path=ckpt,
                                            name=FLAGS.eval_name)
                elapsed_time = int(time.time() - start_timestamp)
                logging.info('Eval results: %s. Elapsed seconds: %d',
                             eval_results, elapsed_time)
                if FLAGS.archive_ckpt:
                    utils.archive_ckpt(eval_results,
                                       eval_results['top_1_accuracy'], ckpt)

                # Terminate eval job when final checkpoint is reached
                try:
                    current_step = int(os.path.basename(ckpt).split('-')[1])
                except IndexError:
                    logging.info('%s has no global step info: stop!', ckpt)
                    break

                if current_step >= FLAGS.train_steps:
                    logging.info('Evaluation finished after training step %d',
                                 current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint',
                    ckpt)
    else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
        current_step = estimator._load_global_step_from_checkpoint_dir(
            FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long

        logging.info(
            'Training for %d steps (%.2f epochs in total). Current'
            ' step %d.', FLAGS.train_steps,
            FLAGS.train_steps / params['steps_per_epoch'], current_step)

        start_timestamp = time.time(
        )  # This time will include compilation time

        if FLAGS.mode == 'train':
            hooks = []
            if FLAGS.use_async_checkpointing:
                try:
                    from tensorflow.contrib.tpu.python.tpu import async_checkpoint  # pylint: disable=g-import-not-at-top
                except ImportError as e:
                    logging.exception(
                        'Async checkpointing is not supported in TensorFlow 2.x'
                    )
                    raise e

                hooks.append(
                    async_checkpoint.AsyncCheckpointSaverHook(
                        checkpoint_dir=FLAGS.model_dir,
                        save_steps=max(100, FLAGS.iterations_per_loop)))
            est.train(input_fn=imagenet_train.input_fn,
                      max_steps=FLAGS.train_steps,
                      hooks=hooks)

        else:
            assert FLAGS.mode == 'train_and_eval'
            while current_step < FLAGS.train_steps:
                # Train for up to steps_per_eval number of steps.
                # At the end of training, a checkpoint will be written to --model_dir.
                next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                                      FLAGS.train_steps)
                est.train(input_fn=imagenet_train.input_fn,
                          max_steps=next_checkpoint)
                current_step = next_checkpoint

                logging.info(
                    'Finished training up to step %d. Elapsed seconds %d.',
                    next_checkpoint, int(time.time() - start_timestamp))

                # Evaluate the model on the most recent model in --model_dir.
                # Since evaluation happens in batches of --eval_batch_size, some images
                # may be excluded modulo the batch size. As long as the batch size is
                # consistent, the evaluated images are also consistent.
                logging.info('Starting to evaluate.')
                eval_results = est.evaluate(input_fn=imagenet_eval.input_fn,
                                            steps=FLAGS.num_eval_images //
                                            FLAGS.eval_batch_size,
                                            name=FLAGS.eval_name)
                logging.info('Eval results at step %d: %s', next_checkpoint,
                             eval_results)
                ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
                if FLAGS.archive_ckpt:
                    utils.archive_ckpt(eval_results,
                                       eval_results['top_1_accuracy'], ckpt)

            elapsed_time = int(time.time() - start_timestamp)
            logging.info(
                'Finished training up to step %d. Elapsed seconds %d.',
                FLAGS.train_steps, elapsed_time)
    if FLAGS.export_dir:
        export(est, FLAGS.export_dir, input_image_size)
示例#6
0
def main(_):
    # Parse and override hparams
    config = hparams_config.get_detection_config(FLAGS.model_name)
    config.override(FLAGS.hparams)
    if FLAGS.num_epochs:  # NOTE: remove this flag after updating all docs.
        config.num_epochs = FLAGS.num_epochs

    # Parse image size in case it is in string format.
    config.image_size = utils.parse_image_size(config.image_size)

    if FLAGS.use_xla and FLAGS.strategy != 'tpu':
        tf.config.optimizer.set_jit(True)
        for gpu in tf.config.list_physical_devices('GPU'):
            tf.config.experimental.set_memory_growth(gpu, True)

    if FLAGS.debug:
        tf.debugging.set_log_device_placement(True)
        os.environ['TF_DETERMINISTIC_OPS'] = '1'
        tf.random.set_seed(FLAGS.tf_random_seed)
        logging.set_verbosity(logging.DEBUG)

    if FLAGS.strategy == 'tpu':
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
        tf.config.experimental_connect_to_cluster(tpu_cluster_resolver)
        tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver)
        ds_strategy = tf.distribute.TPUStrategy(tpu_cluster_resolver)
        logging.info('All devices: %s', tf.config.list_logical_devices('TPU'))
    elif FLAGS.strategy == 'gpus':
        gpus = tf.config.list_physical_devices('GPU')
        if FLAGS.batch_size % len(gpus):
            raise ValueError(
                'Batch size divide gpus number must be interger, but got %f' %
                (FLAGS.batch_size / len(gpus)))
        if platform.system() == 'Windows':
            # Windows doesn't support nccl use HierarchicalCopyAllReduce instead
            # TODO(fsx950223): investigate HierarchicalCopyAllReduce performance issue
            cross_device_ops = tf.distribute.HierarchicalCopyAllReduce()
        else:
            cross_device_ops = None
        ds_strategy = tf.distribute.MirroredStrategy(
            cross_device_ops=cross_device_ops)
        logging.info('All devices: %s', gpus)
    else:
        if tf.config.list_physical_devices('GPU'):
            ds_strategy = tf.distribute.OneDeviceStrategy('device:GPU:0')
        else:
            ds_strategy = tf.distribute.OneDeviceStrategy('device:CPU:0')

    steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.batch_size
    params = dict(profile=FLAGS.profile,
                  model_name=FLAGS.model_name,
                  steps_per_execution=FLAGS.steps_per_execution,
                  model_dir=FLAGS.model_dir,
                  steps_per_epoch=steps_per_epoch,
                  strategy=FLAGS.strategy,
                  batch_size=FLAGS.batch_size,
                  tf_random_seed=FLAGS.tf_random_seed,
                  debug=FLAGS.debug,
                  val_json_file=FLAGS.val_json_file,
                  eval_samples=FLAGS.eval_samples,
                  num_shards=ds_strategy.num_replicas_in_sync)
    config.override(params, True)
    # set mixed precision policy by keras api.
    precision = utils.get_precision(config.strategy, config.mixed_precision)
    policy = tf.keras.mixed_precision.Policy(precision)
    tf.keras.mixed_precision.set_global_policy(policy)

    def get_dataset(is_training, config):
        file_pattern = (FLAGS.train_file_pattern
                        if is_training else FLAGS.val_file_pattern)
        if not file_pattern:
            raise ValueError('No matching files.')

        return dataloader.InputReader(
            file_pattern,
            is_training=is_training,
            use_fake_data=FLAGS.use_fake_data,
            max_instances_per_image=config.max_instances_per_image,
            debug=FLAGS.debug)(config.as_dict())

    with ds_strategy.scope():
        if config.model_optimizations:
            tfmot.set_config(config.model_optimizations.as_dict())
        if FLAGS.hub_module_url:
            model = train_lib.EfficientDetNetTrainHub(
                config=config, hub_module_url=FLAGS.hub_module_url)
        else:
            model = train_lib.EfficientDetNetTrain(config=config)
        model = setup_model(model, config)
        if FLAGS.debug:
            tf.config.run_functions_eagerly(True)
        if FLAGS.pretrained_ckpt and not FLAGS.hub_module_url:
            ckpt_path = tf.train.latest_checkpoint(FLAGS.pretrained_ckpt)
            util_keras.restore_ckpt(model,
                                    ckpt_path,
                                    config.moving_average_decay,
                                    exclude_layers=['class_net'])
        init_experimental(config)
        if 'train' in FLAGS.mode:
            val_dataset = get_dataset(False,
                                      config) if 'eval' in FLAGS.mode else None
            model.fit(
                get_dataset(True, config),
                epochs=config.num_epochs,
                steps_per_epoch=steps_per_epoch,
                callbacks=train_lib.get_callbacks(config.as_dict(),
                                                  val_dataset),
                validation_data=val_dataset,
                validation_steps=(FLAGS.eval_samples // FLAGS.batch_size))
        else:
            # Continuous eval.
            for ckpt in tf.train.checkpoints_iterator(FLAGS.model_dir,
                                                      min_interval_secs=180):
                logging.info('Starting to evaluate.')
                # Terminate eval job when final checkpoint is reached.
                try:
                    current_epoch = int(os.path.basename(ckpt).split('-')[1])
                except IndexError:
                    current_epoch = 0

                val_dataset = get_dataset(False, config)
                logging.info('start loading model.')
                model.load_weights(tf.train.latest_checkpoint(FLAGS.model_dir))
                logging.info('finish loading model.')
                coco_eval = train_lib.COCOCallback(val_dataset, 1)
                coco_eval.set_model(model)
                eval_results = coco_eval.on_epoch_end(current_epoch)
                logging.info('eval results for %s: %s', ckpt, eval_results)

                try:
                    utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)
                except tf.errors.NotFoundError:
                    # Checkpoint might be not already deleted by the time eval finished.
                    logging.info('Checkpoint %s no longer exists, skipping.',
                                 ckpt)

                if current_epoch >= config.num_epochs or not current_epoch:
                    logging.info('Eval epoch %d / %d', current_epoch,
                                 config.num_epochs)
                    break
示例#7
0
def main(_):
    if FLAGS.strategy == 'tpu':
        tf.disable_eager_execution()
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
        tpu_grpc_url = tpu_cluster_resolver.get_master()
        tf.Session.reset(tpu_grpc_url)
    else:
        tpu_cluster_resolver = None

    # Check data path
    if FLAGS.mode in ('train', 'train_and_eval'):
        if FLAGS.training_file_pattern is None:
            raise RuntimeError(
                'Must specify --training_file_pattern for train.')
    if FLAGS.mode in ('eval', 'train_and_eval'):
        if FLAGS.validation_file_pattern is None:
            raise RuntimeError(
                'Must specify --validation_file_pattern for eval.')

    # Parse and override hparams
    config = hparams_config.get_detection_config(FLAGS.model_name)
    config.override(FLAGS.hparams)
    if FLAGS.num_epochs:  # NOTE: remove this flag after updating all docs.
        config.num_epochs = FLAGS.num_epochs

    # Parse image size in case it is in string format.
    config.image_size = utils.parse_image_size(config.image_size)

    # The following is for spatial partitioning. `features` has one tensor while
    # `labels` had 4 + (`max_level` - `min_level` + 1) * 2 tensors. The input
    # partition is performed on `features` and all partitionable tensors of
    # `labels`, see the partition logic below.
    # In the TPUEstimator context, the meaning of `shard` and `replica` is the
    # same; follwing the API, here has mixed use of both.
    if FLAGS.use_spatial_partition:
        # Checks input_partition_dims agrees with num_cores_per_replica.
        if FLAGS.num_cores_per_replica != np.prod(FLAGS.input_partition_dims):
            raise RuntimeError(
                '--num_cores_per_replica must be a product of array'
                'elements in --input_partition_dims.')

        labels_partition_dims = {
            'mean_num_positives': None,
            'source_ids': None,
            'groundtruth_data': None,
            'image_scales': None,
            'image_masks': None,
        }
        # The Input Partition Logic: We partition only the partition-able tensors.
        feat_sizes = utils.get_feat_sizes(config.get('image_size'),
                                          config.get('max_level'))
        for level in range(config.get('min_level'),
                           config.get('max_level') + 1):

            def _can_partition(spatial_dim):
                partitionable_index = np.where(
                    spatial_dim % np.array(FLAGS.input_partition_dims) == 0)
                return len(partitionable_index[0]) == len(
                    FLAGS.input_partition_dims)

            spatial_dim = feat_sizes[level]
            if _can_partition(spatial_dim['height']) and _can_partition(
                    spatial_dim['width']):
                labels_partition_dims['box_targets_%d' %
                                      level] = FLAGS.input_partition_dims
                labels_partition_dims['cls_targets_%d' %
                                      level] = FLAGS.input_partition_dims
            else:
                labels_partition_dims['box_targets_%d' % level] = None
                labels_partition_dims['cls_targets_%d' % level] = None
        num_cores_per_replica = FLAGS.num_cores_per_replica
        input_partition_dims = [
            FLAGS.input_partition_dims, labels_partition_dims
        ]
        num_shards = FLAGS.num_cores // num_cores_per_replica
    else:
        num_cores_per_replica = None
        input_partition_dims = None
        num_shards = FLAGS.num_cores

    params = dict(config.as_dict(),
                  model_name=FLAGS.model_name,
                  iterations_per_loop=FLAGS.iterations_per_loop,
                  model_dir=FLAGS.model_dir,
                  num_shards=num_shards,
                  num_examples_per_epoch=FLAGS.num_examples_per_epoch,
                  strategy=FLAGS.strategy,
                  backbone_ckpt=FLAGS.backbone_ckpt,
                  ckpt=FLAGS.ckpt,
                  val_json_file=FLAGS.val_json_file,
                  testdev_dir=FLAGS.testdev_dir,
                  mode=FLAGS.mode)
    config_proto = tf.ConfigProto(allow_soft_placement=True,
                                  log_device_placement=False)
    if FLAGS.strategy != 'tpu':
        if FLAGS.use_xla:
            config_proto.graph_options.optimizer_options.global_jit_level = (
                tf.OptimizerOptions.ON_1)
        config_proto.gpu_options.allow_growth = True

    model_dir = FLAGS.model_dir
    strategy = None
    if FLAGS.strategy == 'tpu':
        tpu_config = tf.estimator.tpu.TPUConfig(
            FLAGS.iterations_per_loop if FLAGS.strategy == 'tpu' else 1,
            num_cores_per_replica=num_cores_per_replica,
            input_partition_dims=input_partition_dims,
            per_host_input_for_training=tf.estimator.tpu.InputPipelineConfig.
            PER_HOST_V2)
        run_config = tf.estimator.tpu.RunConfig(
            cluster=tpu_cluster_resolver,
            model_dir=model_dir,
            log_step_count_steps=FLAGS.iterations_per_loop,
            session_config=config_proto,
            tpu_config=tpu_config,
            save_checkpoints_steps=FLAGS.save_checkpoints_steps,
            tf_random_seed=FLAGS.tf_random_seed,
        )
    else:
        if FLAGS.strategy == 'gpus':
            strategy = tf.distribute.MirroredStrategy()
        run_config = tf.estimator.RunConfig(
            model_dir=model_dir,
            train_distribute=strategy,
            log_step_count_steps=FLAGS.iterations_per_loop,
            session_config=config_proto,
            save_checkpoints_steps=FLAGS.save_checkpoints_steps,
            tf_random_seed=FLAGS.tf_random_seed,
        )

    model_fn_instance = det_model_fn.get_model_fn(FLAGS.model_name)
    max_instances_per_image = config.max_instances_per_image
    eval_steps = int(FLAGS.eval_samples // FLAGS.eval_batch_size)
    total_examples = int(config.num_epochs * FLAGS.num_examples_per_epoch)
    train_steps = total_examples // FLAGS.train_batch_size
    logging.info(params)
    with tf.io.gfile.GFile(os.path.join(model_dir, 'config.yaml'), 'w') as f:
        f.write(str(config))

    train_input_fn = dataloader.InputReader(
        FLAGS.training_file_pattern,
        is_training=True,
        use_fake_data=FLAGS.use_fake_data,
        max_instances_per_image=max_instances_per_image)
    eval_input_fn = dataloader.InputReader(
        FLAGS.validation_file_pattern,
        is_training=False,
        use_fake_data=FLAGS.use_fake_data,
        max_instances_per_image=max_instances_per_image)

    if FLAGS.strategy == 'tpu':
        estimator = tf.estimator.tpu.TPUEstimator(
            model_fn=model_fn_instance,
            train_batch_size=FLAGS.train_batch_size,
            eval_batch_size=FLAGS.eval_batch_size,
            config=run_config,
            params=params)
    else:
        params['batch_size'] = (FLAGS.train_batch_size //
                                getattr(strategy, 'num_replicas_in_sync', 1))
        params['num_shards'] = getattr(strategy, 'num_replicas_in_sync', 1)
        estimator = tf.estimator.Estimator(model_fn=model_fn_instance,
                                           config=run_config,
                                           params=params)

    # start train/eval flow.
    if FLAGS.mode == 'train':
        estimator.train(input_fn=train_input_fn, max_steps=train_steps)
        if FLAGS.eval_after_training:
            estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)

    elif FLAGS.mode == 'eval':
        # Run evaluation when there's a new checkpoint
        for ckpt in tf.train.checkpoints_iterator(
                FLAGS.model_dir,
                min_interval_secs=FLAGS.min_eval_interval,
                timeout=FLAGS.eval_timeout):

            logging.info('Starting to evaluate.')
            try:
                eval_results = estimator.evaluate(eval_input_fn,
                                                  steps=eval_steps)
                # Terminate eval job when final checkpoint is reached.
                try:
                    current_step = int(os.path.basename(ckpt).split('-')[1])
                except IndexError:
                    logging.info('%s has no global step info: stop!', ckpt)
                    break

                utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)
                if current_step >= train_steps:
                    logging.info('Eval finished step %d/%d', current_step,
                                 train_steps)
                    break

            except tf.errors.NotFoundError:
                # Checkpoint might be not already deleted by the time eval finished.
                # We simply skip ssuch case.
                logging.info('Checkpoint %s no longer exists, skipping.', ckpt)

    elif FLAGS.mode == 'train_and_eval':
        train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn,
                                            max_steps=train_steps)
        eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn,
                                          steps=eval_steps,
                                          throttle_secs=600)
        tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
    else:
        logging.info('Invalid mode: %s', FLAGS.mode)
示例#8
0
def main(_):

    if FLAGS.strategy == 'horovod':
        import horovod.tensorflow as hvd  # pylint: disable=g-import-not-at-top
        logging.info('Use horovod with multi gpus')
        hvd.init()
        os.environ['CUDA_VISIBLE_DEVICES'] = str(hvd.local_rank())
    import tensorflow.compat.v1 as tf  # pylint: disable=g-import-not-at-top
    tf.enable_v2_tensorshape()
    tf.disable_eager_execution()

    if FLAGS.strategy == 'tpu':
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
        tpu_grpc_url = tpu_cluster_resolver.get_master()
        tf.Session.reset(tpu_grpc_url)
    else:
        tpu_cluster_resolver = None

    # Check data path
    if FLAGS.mode in (
            'train', 'train_and_eval') and FLAGS.training_file_pattern is None:
        raise RuntimeError(
            'You must specify --training_file_pattern for training.')
    if FLAGS.mode in ('eval', 'train_and_eval'):
        if FLAGS.validation_file_pattern is None:
            raise RuntimeError('You must specify --validation_file_pattern '
                               'for evaluation.')

    # Parse and override hparams
    config = hparams_config.get_detection_config(FLAGS.model_name)
    config.override(FLAGS.hparams)
    if FLAGS.num_epochs:  # NOTE: remove this flag after updating all docs.
        config.num_epochs = FLAGS.num_epochs

    # Parse image size in case it is in string format.
    config.image_size = utils.parse_image_size(config.image_size)

    # The following is for spatial partitioning. `features` has one tensor while
    # `labels` had 4 + (`max_level` - `min_level` + 1) * 2 tensors. The input
    # partition is performed on `features` and all partitionable tensors of
    # `labels`, see the partition logic below.
    # In the TPUEstimator context, the meaning of `shard` and `replica` is the
    # same; follwing the API, here has mixed use of both.
    if FLAGS.use_spatial_partition:
        # Checks input_partition_dims agrees with num_cores_per_replica.
        if FLAGS.num_cores_per_replica != np.prod(FLAGS.input_partition_dims):
            raise RuntimeError(
                '--num_cores_per_replica must be a product of array'
                'elements in --input_partition_dims.')

        labels_partition_dims = {
            'mean_num_positives': None,
            'source_ids': None,
            'groundtruth_data': None,
            'image_scales': None,
        }
        # The Input Partition Logic: We partition only the partition-able tensors.
        # Spatial partition requires that the to-be-partitioned tensors must have a
        # dimension that is a multiple of `partition_dims`. Depending on the
        # `partition_dims` and the `image_size` and the `max_level` in config, some
        # high-level anchor labels (i.e., `cls_targets` and `box_targets`) cannot
        # be partitioned. For example, when `partition_dims` is [1, 4, 2, 1], image
        # size is 1536, `max_level` is 9, `cls_targets_8` has a shape of
        # [batch_size, 6, 6, 9], which cannot be partitioned (6 % 4 != 0). In this
        # case, the level-8 and level-9 target tensors are not partition-able, and
        # the highest partition-able level is 7.
        feat_sizes = utils.get_feat_sizes(config.get('image_size'),
                                          config.get('max_level'))
        for level in range(config.get('min_level'),
                           config.get('max_level') + 1):

            def _can_partition(spatial_dim):
                partitionable_index = np.where(
                    spatial_dim % np.array(FLAGS.input_partition_dims) == 0)
                return len(partitionable_index[0]) == len(
                    FLAGS.input_partition_dims)

            spatial_dim = feat_sizes[level]
            if _can_partition(spatial_dim['height']) and _can_partition(
                    spatial_dim['width']):
                labels_partition_dims['box_targets_%d' %
                                      level] = FLAGS.input_partition_dims
                labels_partition_dims['cls_targets_%d' %
                                      level] = FLAGS.input_partition_dims
            else:
                labels_partition_dims['box_targets_%d' % level] = None
                labels_partition_dims['cls_targets_%d' % level] = None
        num_cores_per_replica = FLAGS.num_cores_per_replica
        input_partition_dims = [
            FLAGS.input_partition_dims, labels_partition_dims
        ]
        num_shards = FLAGS.num_cores // num_cores_per_replica
    else:
        num_cores_per_replica = None
        input_partition_dims = None
        num_shards = FLAGS.num_cores

    params = dict(config.as_dict(),
                  model_name=FLAGS.model_name,
                  iterations_per_loop=FLAGS.iterations_per_loop,
                  model_dir=FLAGS.model_dir,
                  num_shards=num_shards,
                  num_examples_per_epoch=FLAGS.num_examples_per_epoch,
                  strategy=FLAGS.strategy,
                  backbone_ckpt=FLAGS.backbone_ckpt,
                  ckpt=FLAGS.ckpt,
                  val_json_file=FLAGS.val_json_file,
                  testdev_dir=FLAGS.testdev_dir,
                  mode=FLAGS.mode)
    config_proto = tf.ConfigProto(allow_soft_placement=True,
                                  log_device_placement=False)
    if FLAGS.strategy != 'tpu':
        if FLAGS.use_xla:
            config_proto.graph_options.optimizer_options.global_jit_level = (
                tf.OptimizerOptions.ON_1)
        config_proto.gpu_options.allow_growth = True

    tpu_config = tf.estimator.tpu.TPUConfig(
        FLAGS.iterations_per_loop if FLAGS.strategy == 'tpu' else 1,
        num_cores_per_replica=num_cores_per_replica,
        input_partition_dims=input_partition_dims,
        per_host_input_for_training=tf.estimator.tpu.InputPipelineConfig.
        PER_HOST_V2)

    if FLAGS.strategy == 'horovod':
        model_dir = FLAGS.model_dir if hvd.rank() == 0 else None
    else:
        model_dir = FLAGS.model_dir

    run_config = tf.estimator.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=model_dir,
        log_step_count_steps=FLAGS.iterations_per_loop,
        session_config=config_proto,
        tpu_config=tpu_config,
        tf_random_seed=FLAGS.tf_random_seed,
    )

    model_fn_instance = det_model_fn.get_model_fn(FLAGS.model_name)
    max_instances_per_image = config.max_instances_per_image
    eval_steps = int(FLAGS.eval_samples // FLAGS.eval_batch_size)
    use_tpu = (FLAGS.strategy == 'tpu')
    logging.info(params)

    def _train(steps):
        """Build train estimator and run training if steps > 0."""
        train_estimator = tf.estimator.tpu.TPUEstimator(
            model_fn=model_fn_instance,
            use_tpu=use_tpu,
            train_batch_size=FLAGS.train_batch_size,
            config=run_config,
            params=params)
        train_estimator.train(input_fn=dataloader.InputReader(
            FLAGS.training_file_pattern,
            is_training=True,
            use_fake_data=FLAGS.use_fake_data,
            max_instances_per_image=max_instances_per_image),
                              max_steps=steps)

    def _eval(steps):
        """Build estimator and eval the latest checkpoint if steps > 0."""
        eval_params = dict(
            params,
            strategy=FLAGS.strategy,
            input_rand_hflip=False,
            is_training_bn=False,
        )
        eval_estimator = tf.estimator.tpu.TPUEstimator(
            model_fn=model_fn_instance,
            use_tpu=use_tpu,
            train_batch_size=FLAGS.train_batch_size,
            eval_batch_size=FLAGS.eval_batch_size,
            config=run_config,
            params=eval_params)
        eval_results = eval_estimator.evaluate(input_fn=dataloader.InputReader(
            FLAGS.validation_file_pattern,
            is_training=False,
            max_instances_per_image=max_instances_per_image),
                                               steps=steps,
                                               name=FLAGS.eval_name)
        logging.info('Evaluation results: %s', eval_results)
        return eval_results

    # start train/eval flow.
    if FLAGS.mode == 'train':
        total_examples = int(config.num_epochs * FLAGS.num_examples_per_epoch)
        _train(total_examples // FLAGS.train_batch_size)
        if FLAGS.eval_after_training:
            _eval(eval_steps)

    elif FLAGS.mode == 'eval':
        # Run evaluation when there's a new checkpoint
        for ckpt in tf.train.checkpoints_iterator(
                FLAGS.model_dir,
                min_interval_secs=FLAGS.min_eval_interval,
                timeout=FLAGS.eval_timeout):

            logging.info('Starting to evaluate.')
            try:
                eval_results = _eval(eval_steps)
                # Terminate eval job when final checkpoint is reached.
                try:
                    current_step = int(os.path.basename(ckpt).split('-')[1])
                except IndexError:
                    logging.info('%s has no global step info: stop!', ckpt)
                    break

                utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)
                total_step = int(
                    (config.num_epochs * FLAGS.num_examples_per_epoch) /
                    FLAGS.train_batch_size)
                if current_step >= total_step:
                    logging.info('Evaluation finished after training step %d',
                                 current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                logging.info('Checkpoint %s no longer exists, skipping.', ckpt)

    elif FLAGS.mode == 'train_and_eval':
        ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
        try:
            step = int(os.path.basename(ckpt).split("-")[1])
            current_epoch = (step * FLAGS.train_batch_size //
                             FLAGS.num_examples_per_epoch)
            logging.info('found ckpt at step %d (epoch %d)', step,
                         current_epoch)
        except (IndexError, TypeError):
            logging.info("Folder has no ckpt with valid step.",
                         FLAGS.model_dir)
            current_epoch = 0

        epochs_per_cycle = 1  # higher number has less graph construction overhead.
        for e in range(current_epoch + 1, config.num_epochs + 1,
                       epochs_per_cycle):
            print('-----------------------------------------------------\n'
                  '=====> Starting training, epoch: %d.' % e)
            _train(e * FLAGS.num_examples_per_epoch // FLAGS.train_batch_size)
            print('-----------------------------------------------------\n'
                  '=====> Starting evaluation, epoch: %d.' % e)
            eval_results = _eval(eval_steps)
            ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
            utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)

    else:
        logging.info('Invalid mode: %s', FLAGS.mode)
示例#9
0
def main(unused_argv):
    params = params_dict.ParamsDict(mnasnet_config.MNASNET_CFG,
                                    mnasnet_config.MNASNET_RESTRICTIONS)
    params = params_dict.override_params_dict(params,
                                              FLAGS.config_file,
                                              is_strict=True)
    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=True)

    params = flags_to_params.override_params_from_input_flags(params, FLAGS)

    additional_params = {
        'steps_per_epoch': params.num_train_images / params.train_batch_size,
        'quantized_training': FLAGS.quantized_training,
    }

    params = params_dict.override_params_dict(params,
                                              additional_params,
                                              is_strict=False)

    params.validate()
    params.lock()

    if FLAGS.tpu or params.use_tpu:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
    else:
        tpu_cluster_resolver = None

    if params.use_async_checkpointing:
        save_checkpoints_steps = None
    else:
        save_checkpoints_steps = max(100, params.iterations_per_loop)
    config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=save_checkpoints_steps,
        log_step_count_steps=FLAGS.log_step_count_steps,
        session_config=tf.ConfigProto(
            graph_options=tf.GraphOptions(
                rewrite_options=rewriter_config_pb2.RewriterConfig(
                    disable_meta_optimizer=True))),
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=params.iterations_per_loop,
            per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig
            .PER_HOST_V2))  # pylint: disable=line-too-long

    # Validates Flags.
    if params.precision == 'bfloat16' and params.use_keras:
        raise ValueError(
            'Keras layers do not have full support to bfloat16 activation training.'
            ' You have set precision as %s and use_keras as %s' %
            (params.precision, params.use_keras))

    # Initializes model parameters.
    mnasnet_est = tf.contrib.tpu.TPUEstimator(
        use_tpu=params.use_tpu,
        model_fn=mnasnet_model_fn,
        config=config,
        train_batch_size=params.train_batch_size,
        eval_batch_size=params.eval_batch_size,
        export_to_tpu=FLAGS.export_to_tpu,
        params=params.as_dict())

    if FLAGS.mode == 'export_only':
        export(mnasnet_est, FLAGS.export_dir, params, FLAGS.post_quantize)
        return

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    if FLAGS.bigtable_instance:
        tf.logging.info('Using Bigtable dataset, table %s',
                        FLAGS.bigtable_table)
        select_train, select_eval = _select_tables_from_flags()
        imagenet_train, imagenet_eval = [
            imagenet_input.ImageNetBigtableInput(
                is_training=is_training,
                use_bfloat16=False,
                transpose_input=params.transpose_input,
                selection=selection)
            for (is_training,
                 selection) in [(True, select_train), (False, select_eval)]
        ]
    else:
        if FLAGS.data_dir == FAKE_DATA_DIR:
            tf.logging.info('Using fake dataset.')
        else:
            tf.logging.info('Using dataset: %s', FLAGS.data_dir)
        imagenet_train, imagenet_eval = [
            imagenet_input.ImageNetInput(
                is_training=is_training,
                data_dir=FLAGS.data_dir,
                transpose_input=params.transpose_input,
                cache=params.use_cache and is_training,
                image_size=params.input_image_size,
                num_parallel_calls=params.num_parallel_calls,
                use_bfloat16=(params.precision == 'bfloat16'))
            for is_training in [True, False]
        ]

    if FLAGS.mode == 'eval':
        eval_steps = params.num_eval_images // params.eval_batch_size
        # Run evaluation when there's a new checkpoint
        for ckpt in evaluation.checkpoints_iterator(
                FLAGS.model_dir, timeout=FLAGS.eval_timeout):
            tf.logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = mnasnet_est.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=eval_steps,
                    checkpoint_path=ckpt)
                elapsed_time = int(time.time() - start_timestamp)
                tf.logging.info('Eval results: %s. Elapsed seconds: %d',
                                eval_results, elapsed_time)
                utils.archive_ckpt(eval_results,
                                   eval_results['top_1_accuracy'], ckpt)

                # Terminate eval job when final checkpoint is reached
                current_step = int(os.path.basename(ckpt).split('-')[1])
                if current_step >= params.train_steps:
                    tf.logging.info(
                        'Evaluation finished after training step %d',
                        current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint',
                    ckpt)

        if FLAGS.export_dir:
            export(mnasnet_est, FLAGS.export_dir, params, FLAGS.post_quantize)
    else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
        current_step = estimator._load_global_step_from_checkpoint_dir(  # pylint: disable=protected-access
            FLAGS.model_dir)

        tf.logging.info(
            'Training for %d steps (%.2f epochs in total). Current'
            ' step %d.', params.train_steps,
            params.train_steps / params.steps_per_epoch, current_step)

        start_timestamp = time.time(
        )  # This time will include compilation time

        if FLAGS.mode == 'train':
            hooks = []
            if params.use_async_checkpointing:
                hooks.append(
                    async_checkpoint.AsyncCheckpointSaverHook(
                        checkpoint_dir=FLAGS.model_dir,
                        save_steps=max(100, params.iterations_per_loop)))
            mnasnet_est.train(input_fn=imagenet_train.input_fn,
                              max_steps=params.train_steps,
                              hooks=hooks)

        else:
            assert FLAGS.mode == 'train_and_eval'
            while current_step < params.train_steps:
                # Train for up to steps_per_eval number of steps.
                # At the end of training, a checkpoint will be written to --model_dir.
                next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                                      params.train_steps)
                mnasnet_est.train(input_fn=imagenet_train.input_fn,
                                  max_steps=next_checkpoint)
                current_step = next_checkpoint

                tf.logging.info(
                    'Finished training up to step %d. Elapsed seconds %d.',
                    next_checkpoint, int(time.time() - start_timestamp))

                # Evaluate the model on the most recent model in --model_dir.
                # Since evaluation happens in batches of --eval_batch_size, some images
                # may be excluded modulo the batch size. As long as the batch size is
                # consistent, the evaluated images are also consistent.
                tf.logging.info('Starting to evaluate.')
                eval_results = mnasnet_est.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=params.num_eval_images // params.eval_batch_size)
                tf.logging.info('Eval results at step %d: %s', next_checkpoint,
                                eval_results)
                ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
                utils.archive_ckpt(eval_results,
                                   eval_results['top_1_accuracy'], ckpt)

            elapsed_time = int(time.time() - start_timestamp)
            tf.logging.info(
                'Finished training up to step %d. Elapsed seconds %d.',
                params.train_steps, elapsed_time)
            if FLAGS.export_dir:
                export(mnasnet_est, FLAGS.export_dir, params,
                       FLAGS.post_quantize)
示例#10
0
文件: train.py 项目: yichenj/tpu
def main(unused_argv):
    input_image_size = FLAGS.input_image_size
    if not input_image_size:
        if FLAGS.model_name.startswith('efficientnet'):
            _, _, input_image_size, _ = efficientnet_builder.efficientnet_params(
                FLAGS.model_name)
        else:
            raise ValueError(
                'input_image_size must be set except for EfficientNet')

    config = tf.estimator.RunConfig(
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        log_step_count_steps=FLAGS.log_step_count_steps,
        session_config=tf.ConfigProto(graph_options=tf.GraphOptions(
            rewrite_options=rewriter_config_pb2.RewriterConfig(
                disable_meta_optimizer=True))))
    # Initializes model parameters.
    params = dict(steps_per_epoch=FLAGS.num_train_images /
                  FLAGS.train_batch_size)
    est = tf.estimator.Estimator(model_fn=model_fn,
                                 config=config,
                                 params=params)

    def build_input(is_training):
        """Input for training and eval."""
        tf.logging.info('Using dataset: %s', FLAGS.data_dir)
        return egg_candler_input.EggCandlerInput(
            is_training=is_training,
            data_dir=FLAGS.data_dir,
            train_batch_size=FLAGS.train_batch_size,
            eval_batch_size=FLAGS.eval_batch_size,
            image_size=input_image_size)

    image_train = build_input(is_training=True)
    image_eval = build_input(is_training=False)

    if FLAGS.mode == 'eval':
        eval_steps = FLAGS.num_eval_images // FLAGS.eval_batch_size
        # Run evaluation when there's a new checkpoint
        for ckpt in tf.train.checkpoints_iterator(FLAGS.model_dir,
                                                  timeout=FLAGS.eval_timeout):
            tf.logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = est.evaluate(input_fn=image_eval.input_fn,
                                            steps=eval_steps,
                                            checkpoint_path=ckpt)
                elapsed_time = int(time.time() - start_timestamp)
                tf.logging.info('Eval results: %s. Elapsed seconds: %d',
                                eval_results, elapsed_time)
                utils.archive_ckpt(eval_results, eval_results['val_accuracy'],
                                   ckpt)

                # Terminate eval job when final checkpoint is reached
                current_step = int(os.path.basename(ckpt).split('-')[1])
                if current_step >= FLAGS.train_steps:
                    tf.logging.info(
                        'Evaluation finished after training step %d',
                        current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint',
                    ckpt)
    else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
        current_step = estimator._load_global_step_from_checkpoint_dir(
            FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long

        tf.logging.info(
            'Training for %d steps (%.2f epochs in total). Current'
            ' step %d.', FLAGS.train_steps,
            FLAGS.train_steps / params['steps_per_epoch'], current_step)

        start_timestamp = time.time(
        )  # This time will include compilation time

        if FLAGS.mode == 'train':
            est.train(input_fn=image_train.input_fn,
                      max_steps=FLAGS.train_steps,
                      hooks=[])
        else:
            assert FLAGS.mode == 'train_and_eval'
            while current_step < FLAGS.train_steps:
                # Train for up to steps_per_eval number of steps.
                # At the end of training, a checkpoint will be written to --model_dir.
                next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                                      FLAGS.train_steps)
                est.train(input_fn=image_train.input_fn,
                          max_steps=next_checkpoint,
                          hooks=[])
                current_step = next_checkpoint

                tf.logging.info(
                    'Finished training up to step %d. Elapsed seconds %d.',
                    next_checkpoint, int(time.time() - start_timestamp))

                # Evaluate the model on the most recent model in --model_dir.
                # Since evaluation happens in batches of --eval_batch_size, some images
                # may be excluded modulo the batch size. As long as the batch size is
                # consistent, the evaluated images are also consistent.
                tf.logging.info('Starting to evaluate.')
                eval_results = est.evaluate(input_fn=image_eval.input_fn,
                                            steps=FLAGS.num_eval_images //
                                            FLAGS.eval_batch_size)
                tf.logging.info('Eval results at step %d: %s', next_checkpoint,
                                eval_results)
                ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
                utils.archive_ckpt(eval_results, eval_results['val_accuracy'],
                                   ckpt)

            elapsed_time = int(time.time() - start_timestamp)
            tf.logging.info(
                'Finished training up to step %d. Elapsed seconds %d.',
                FLAGS.train_steps, elapsed_time)
    if FLAGS.export_dir:
        export(est, FLAGS.export_dir, input_image_size)
示例#11
0
def main(unused_argv):

    input_image_size = FLAGS.input_image_size
    if not input_image_size:
        if FLAGS.model_name.startswith('efficientnet'):
            _, _, input_image_size, _ = efficientnet_builder.efficientnet_params(
                FLAGS.model_name)
        else:
            raise ValueError(
                'input_image_size must be set expect for EfficientNet.')

    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu if (FLAGS.tpu or FLAGS.use_tpu) else '',
        zone=FLAGS.tpu_zone,
        project=FLAGS.gcp_project)

    if FLAGS.use_async_checkpointing:
        save_checkpoints_steps = None
    else:
        save_checkpoints_steps = max(100, FLAGS.iterations_per_loop)
    config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=save_checkpoints_steps,
        log_step_count_steps=FLAGS.log_step_count_steps,
        session_config=tf.ConfigProto(
            graph_options=tf.GraphOptions(
                rewrite_options=rewriter_config_pb2.RewriterConfig(
                    disable_meta_optimizer=True))),
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig
            .PER_HOST_V2))  # pylint: disable=line-too-long
    # Initializes model parameters.
    params = dict(steps_per_epoch=FLAGS.num_train_images /
                  FLAGS.train_batch_size,
                  use_bfloat16=FLAGS.use_bfloat16)
    est = tf.contrib.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        export_to_tpu=FLAGS.export_to_tpu,
        params=params,
        warm_start_from='./weights_efficientnet-b0/model.ckpt')  #update
    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    if FLAGS.bigtable_instance:
        tf.logging.info('Using Bigtable dataset, table %s',
                        FLAGS.bigtable_table)
        select_train, select_eval = _select_tables_from_flags()
        imagenet_train, imagenet_eval = [
            imagenet_input.ImageNetBigtableInput(
                is_training=is_training,
                use_bfloat16=FLAGS.use_bfloat16,
                transpose_input=FLAGS.transpose_input,
                selection=selection)
            for (is_training,
                 selection) in [(True, select_train), (False, select_eval)]
        ]
    else:
        if FLAGS.data_dir == FAKE_DATA_DIR:
            tf.logging.info('Using fake dataset.')
        else:
            tf.logging.info('Using dataset: %s', FLAGS.data_dir)
        imagenet_train, imagenet_eval = [
            imagenet_input.ImageNetInput(
                is_training=is_training,
                data_dir=FLAGS.data_dir,
                transpose_input=FLAGS.transpose_input,
                cache=FLAGS.use_cache and is_training,
                image_size=input_image_size,
                num_parallel_calls=FLAGS.num_parallel_calls,
                use_bfloat16=FLAGS.use_bfloat16)
            for is_training in [True, False]
        ]

    if FLAGS.mode == 'eval':
        eval_steps = FLAGS.num_eval_images // FLAGS.eval_batch_size
        # Run evaluation when there's a new checkpoint
        for ckpt in evaluation.checkpoints_iterator(
                FLAGS.model_dir, timeout=FLAGS.eval_timeout):
            tf.logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = est.evaluate(input_fn=imagenet_eval.input_fn,
                                            steps=eval_steps,
                                            checkpoint_path=ckpt)
                elapsed_time = int(time.time() - start_timestamp)
                tf.logging.info('Eval results: %s. Elapsed seconds: %d',
                                eval_results, elapsed_time)
                utils.archive_ckpt(eval_results,
                                   eval_results['top_1_accuracy'], ckpt)

                # Terminate eval job when final checkpoint is reached
                current_step = int(os.path.basename(ckpt).split('-')[1])
                if current_step >= FLAGS.train_steps:
                    tf.logging.info(
                        'Evaluation finished after training step %d',
                        current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint',
                    ckpt)

        if FLAGS.export_dir:
            export(est, FLAGS.export_dir, input_image_size)
    else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
        current_step = estimator._load_global_step_from_checkpoint_dir(
            FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long

        tf.logging.info(
            'Training for %d steps (%.2f epochs in total). Current'
            ' step %d.', FLAGS.train_steps,
            FLAGS.train_steps / params['steps_per_epoch'], current_step)

        start_timestamp = time.time(
        )  # This time will include compilation time

        if FLAGS.mode == 'train':
            hooks = []
            if FLAGS.use_async_checkpointing:
                hooks.append(
                    async_checkpoint.AsyncCheckpointSaverHook(
                        checkpoint_dir=FLAGS.model_dir,
                        save_steps=max(100, FLAGS.iterations_per_loop)))
            est.train(input_fn=imagenet_train.input_fn,
                      max_steps=FLAGS.train_steps,
                      hooks=hooks)

        else:
            assert FLAGS.mode == 'train_and_eval'
            while current_step < FLAGS.train_steps:
                # Train for up to steps_per_eval number of steps.
                # At the end of training, a checkpoint will be written to --model_dir.
                next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                                      FLAGS.train_steps)
                est.train(input_fn=imagenet_train.input_fn,
                          max_steps=next_checkpoint)
                current_step = next_checkpoint

                tf.logging.info(
                    'Finished training up to step %d. Elapsed seconds %d.',
                    next_checkpoint, int(time.time() - start_timestamp))

                # Evaluate the model on the most recent model in --model_dir.
                # Since evaluation happens in batches of --eval_batch_size, some images
                # may be excluded modulo the batch size. As long as the batch size is
                # consistent, the evaluated images are also consistent.
                tf.logging.info('Starting to evaluate.')
                eval_results = est.evaluate(input_fn=imagenet_eval.input_fn,
                                            steps=FLAGS.num_eval_images //
                                            FLAGS.eval_batch_size)
                tf.logging.info('Eval results at step %d: %s', next_checkpoint,
                                eval_results)
                # ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
                # utils.archive_ckpt(eval_results, eval_results['top_1_accuracy'], ckpt)

            elapsed_time = int(time.time() - start_timestamp)
            tf.logging.info(
                'Finished training up to step %d. Elapsed seconds %d.',
                FLAGS.train_steps, elapsed_time)
            if FLAGS.export_dir:
                export(est, FLAGS.export_dir, input_image_size)
示例#12
0
def main(unused_argv):

    input_image_size = FLAGS.input_image_size
    if not input_image_size:
        if FLAGS.model_name.startswith('efficientnet'):
            _, _, input_image_size, _ = efficientnet_builder.efficientnet_params(
                FLAGS.model_name)
        else:
            raise ValueError(
                'input_image_size must be set expect for EfficientNet.')

    save_checkpoints_steps = max(100, FLAGS.steps_per_eval)

    config = tf.estimator.RunConfig(
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=save_checkpoints_steps,
        log_step_count_steps=FLAGS.log_step_count_steps,
    )

    params = dict(steps_per_epoch=FLAGS.num_train_images /
                  FLAGS.train_batch_size,
                  use_bfloat16=FLAGS.use_bfloat16,
                  batch_size=FLAGS.train_batch_size)
    est = tf.estimator.Estimator(model_fn=model_fn,
                                 config=config,
                                 params=params)

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.

    if FLAGS.data_dir == FAKE_DATA_DIR:
        tf.logging.info('Using fake dataset.')
    else:
        tf.logging.info('Using dataset: %s', FLAGS.data_dir)
    data_train, data_eval = [
        mnist_input.ImageNetInput(is_training=is_training,
                                  data_dir=FLAGS.data_dir,
                                  transpose_input=FLAGS.transpose_input,
                                  cache=FLAGS.use_cache and is_training,
                                  image_size=input_image_size,
                                  use_bfloat16=FLAGS.use_bfloat16)
        for is_training in [True, False]
    ]

    if FLAGS.mode == 'eval':
        eval_steps = FLAGS.num_eval_images // FLAGS.eval_batch_size
        # Run evaluation when there's a new checkpoint
        for ckpt in evaluation.checkpoints_iterator(
                FLAGS.model_dir, timeout=FLAGS.eval_timeout):
            tf.logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = est.evaluate(input_fn=data_eval.input_fn,
                                            steps=eval_steps,
                                            checkpoint_path=ckpt)
                elapsed_time = int(time.time() - start_timestamp)
                tf.logging.info('Eval results: %s. Elapsed seconds: %d',
                                eval_results, elapsed_time)
                utils.archive_ckpt(eval_results,
                                   eval_results['top_1_accuracy'], ckpt)

                # Terminate eval job when final checkpoint is reached
                current_step = int(os.path.basename(ckpt).split('-')[1])
                if current_step >= FLAGS.train_steps:
                    tf.logging.info(
                        'Evaluation finished after training step %d',
                        current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint',
                    ckpt)

        if FLAGS.export_dir:
            export(est, FLAGS.export_dir, input_image_size)
    else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
        current_step = estimator._load_global_step_from_checkpoint_dir(
            FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long

        tf.logging.info(
            'Training for %d steps (%.2f epochs in total). Current'
            ' step %d.', FLAGS.train_steps,
            FLAGS.train_steps / params['steps_per_epoch'], current_step)

        start_timestamp = time.time(
        )  # This time will include compilation time

        if FLAGS.mode == 'train':
            hooks = []
            if FLAGS.use_async_checkpointing:
                hooks.append(
                    async_checkpoint.AsyncCheckpointSaverHook(
                        checkpoint_dir=FLAGS.model_dir,
                        save_steps=max(100, FLAGS.iterations_per_loop)))
            est.train(input_fn=data_train.input_fn,
                      max_steps=FLAGS.train_steps,
                      hooks=hooks)

        else:
            assert FLAGS.mode == 'train_and_eval'
            while current_step < FLAGS.train_steps:
                # Train for up to steps_per_eval number of steps.
                # At the end of training, a checkpoint will be written to --model_dir.
                next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                                      FLAGS.train_steps)
                est.train(input_fn=data_train.input_fn,
                          max_steps=next_checkpoint)
                current_step = next_checkpoint

                tf.logging.info(
                    'Finished training up to step %d. Elapsed seconds %d.',
                    next_checkpoint, int(time.time() - start_timestamp))

                # Evaluate the model on the most recent model in --model_dir.
                # Since evaluation happens in batches of --eval_batch_size, some images
                # may be excluded modulo the batch size. As long as the batch size is
                # consistent, the evaluated images are also consistent.
                tf.logging.info('Starting to evaluate.')
                eval_results = est.evaluate(input_fn=data_eval.input_fn,
                                            steps=FLAGS.num_eval_images //
                                            FLAGS.eval_batch_size)
                tf.logging.info('Eval results at step %d: %s', next_checkpoint,
                                eval_results)
                ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
                utils.archive_ckpt(eval_results,
                                   eval_results['top_1_accuracy'], ckpt)

            elapsed_time = int(time.time() - start_timestamp)
            tf.logging.info(
                'Finished training up to step %d. Elapsed seconds %d.',
                FLAGS.train_steps, elapsed_time)
            if FLAGS.export_dir:
                export(est, FLAGS.export_dir, input_image_size)
示例#13
0
def main(unused_argv):
    config = copy.deepcopy(hparams.base_config)
    config.override(effnetv2_configs.get_model_config(FLAGS.model_name))
    config.override(datasets.get_dataset_config(FLAGS.dataset_cfg))
    config.override(FLAGS.hparam_str)
    config.override(FLAGS.sweeps)

    train_size = config.train.isize
    eval_size = config.eval.isize
    if train_size <= 16.:
        train_size = int(eval_size * train_size) // 16 * 16
    input_image_size = eval_size if FLAGS.mode == 'eval' else train_size

    if FLAGS.mode == 'train':
        if not tf.io.gfile.exists(FLAGS.model_dir):
            tf.io.gfile.makedirs(FLAGS.model_dir)
        config.save_to_yaml(os.path.join(FLAGS.model_dir, 'config.yaml'))

    train_split = config.train.split or 'train'
    eval_split = config.eval.split or 'eval'
    num_train_images = config.data.splits[train_split].num_images
    num_eval_images = config.data.splits[eval_split].num_images

    if FLAGS.tpu or FLAGS.use_tpu:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
    else:
        tpu_cluster_resolver = None

    run_config = tf.estimator.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=max(100, config.runtime.iterations_per_loop),
        keep_checkpoint_max=config.runtime.keep_checkpoint_max,
        keep_checkpoint_every_n_hours=(
            config.runtime.keep_checkpoint_every_n_hours),
        log_step_count_steps=config.runtime.log_step_count_steps,
        session_config=tf.ConfigProto(isolate_session_state=True,
                                      log_device_placement=False),
        tpu_config=tf.estimator.tpu.TPUConfig(
            iterations_per_loop=config.runtime.iterations_per_loop,
            tpu_job_name=FLAGS.tpu_job_name,
            per_host_input_for_training=(
                tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2)))
    # Initializes model parameters.
    params = dict(steps_per_epoch=num_train_images / config.train.batch_size,
                  image_size=input_image_size,
                  config=config)

    est = tf.estimator.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=config.train.batch_size,
        eval_batch_size=config.eval.batch_size,
        export_to_tpu=FLAGS.export_to_tpu,
        params=params)

    image_dtype = None
    if config.runtime.mixed_precision:
        image_dtype = 'bfloat16' if FLAGS.use_tpu else 'float16'

    train_steps = max(
        config.train.min_steps,
        config.train.epochs * num_train_images // config.train.batch_size)
    dataset_eval = datasets.build_dataset_input(False, input_image_size,
                                                image_dtype, FLAGS.data_dir,
                                                eval_split, config.data)

    if FLAGS.mode == 'eval':
        eval_steps = num_eval_images // config.eval.batch_size
        # Run evaluation when there's a new checkpoint
        for ckpt in tf.train.checkpoints_iterator(FLAGS.model_dir,
                                                  timeout=60 * 60 * 24):
            logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = est.evaluate(input_fn=dataset_eval.input_fn,
                                            steps=eval_steps,
                                            checkpoint_path=ckpt,
                                            name=FLAGS.eval_name)
                elapsed_time = int(time.time() - start_timestamp)
                logging.info('Eval results: %s. Elapsed seconds: %d',
                             eval_results, elapsed_time)
                if FLAGS.archive_ckpt:
                    utils.archive_ckpt(eval_results,
                                       eval_results['eval/acc_top1'], ckpt)

                # Terminate eval job when final checkpoint is reached
                try:
                    current_step = int(os.path.basename(ckpt).split('-')[1])
                except IndexError:
                    logging.info('%s has no global step info: stop!', ckpt)
                    break

                logging.info('Finished step: %d, total %d', current_step,
                             train_steps)
                if current_step >= train_steps:
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                logging.info('Checkpoint %s no longer exists, skip saving.',
                             ckpt)
    else:  # FLAGS.mode == 'train'
        try:
            checkpoint_reader = tf.compat.v1.train.NewCheckpointReader(
                tf.train.latest_checkpoint(FLAGS.model_dir))
            current_step = checkpoint_reader.get_tensor(
                tf.compat.v1.GraphKeys.GLOBAL_STEP)
        except:  # pylint: disable=bare-except
            current_step = 0

        logging.info(
            'Training for %d steps (%.2f epochs in total). Current'
            ' step %d.', train_steps, config.train.epochs, current_step)

        start_timestamp = time.time(
        )  # This time will include compilation time

        if FLAGS.mode == 'train':
            hooks = []  # add hooks if needed.
            if not config.train.stages:
                dataset_train = datasets.build_dataset_input(
                    True, input_image_size, image_dtype, FLAGS.data_dir,
                    train_split, config.data)
                est.train(input_fn=dataset_train.input_fn,
                          max_steps=train_steps,
                          hooks=hooks)
            else:
                curr_step = 0
                try:
                    ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
                    curr_step = int(os.path.basename(ckpt).split('-')[1])
                except (IndexError, TypeError):
                    logging.info('%s has no ckpt with valid step.',
                                 FLAGS.model_dir)

                total_stages = config.train.stages
                if config.train.sched:
                    if config.model.dropout_rate:
                        dp_list = np.linspace(0, config.model.dropout_rate,
                                              total_stages)
                    else:
                        dp_list = [None] * total_stages

                    del dp_list
                    ram_list = np.linspace(5, config.data.ram, total_stages)
                    mixup_list = np.linspace(0, config.data.mixup_alpha,
                                             total_stages)
                    cutmix_list = np.linspace(0, config.data.cutmix_alpha,
                                              total_stages)

                ibase = config.data.ibase or (input_image_size / 2)
                # isize_list = np.linspace(ibase, input_image_size, total_stages)
                for stage in range(curr_step // train_steps, total_stages):
                    tf.compat.v1.reset_default_graph()
                    ratio = float(stage + 1) / float(total_stages)
                    max_steps = int(ratio * train_steps)
                    image_size = int(ibase +
                                     (input_image_size - ibase) * ratio)
                    params['image_size'] = image_size

                    if config.train.sched:
                        config.data.ram = ram_list[stage]
                        config.data.mixup_alpha = mixup_list[stage]
                        config.data.cutmix_alpha = cutmix_list[stage]
                        # config.model.dropout_rate = dp_list[stage]

                    ds_lab_cls = datasets.build_dataset_input(
                        True, image_size, image_dtype, FLAGS.data_dir,
                        train_split, config.data)

                    est = tf.estimator.tpu.TPUEstimator(
                        use_tpu=FLAGS.use_tpu,
                        model_fn=model_fn,
                        config=run_config,
                        train_batch_size=config.train.batch_size,
                        eval_batch_size=config.eval.batch_size,
                        export_to_tpu=FLAGS.export_to_tpu,
                        params=params)
                    est.train(input_fn=ds_lab_cls.input_fn,
                              max_steps=max_steps,
                              hooks=hooks)
        else:
            raise ValueError('Unknown mode %s' % FLAGS.mode)
示例#14
0
def main(argv):
    del argv  # Unused.

    # if given an efficentdet ckpt don't use default backbone ckpt
    if FLAGS.backbone_ckpt == BACKBONE_CKPT_DEFAULT_DIR and FLAGS.ckpt is not None:
        print("Using ckpt flag: {}, ignoring default backbone_ckpt: {}".format(
            FLAGS.ckpt, FLAGS.backbone_ckpt))
        FLAGS.backbone_ckpt = None

    if FLAGS.use_horovod is not None:
        if FLAGS.dump_all_ranks:
            FLAGS.model_dir += "/worker_" + str(hvd.rank())
        if not 'HOROVOD_CYCLE_TIME' in os.environ:
            os.environ['HOROVOD_CYCLE_TIME'] = '0.5'
        if not 'HABANA_HCCL_COMM_API' in os.environ:
            os.environ['HABANA_HCCL_COMM_API'] = '0'
        hvd_init()

    if not FLAGS.no_hpu:
        from habana_frameworks.tensorflow import load_habana_module
        load_habana_module()

        if FLAGS.use_horovod:
            assert (horovod_enabled())

    set_env(use_amp=FLAGS.use_amp)

    # deterministic setting
    if FLAGS.sbs_test or FLAGS.deterministic:
        set_deterministic()

    # Check data path
    if FLAGS.mode in (
            'train', 'train_and_eval') and FLAGS.training_file_pattern is None:
        raise RuntimeError(
            'You must specify --training_file_pattern for training.')
    if FLAGS.mode in ('eval', 'train_and_eval'):
        if FLAGS.validation_file_pattern is None:
            raise RuntimeError('You must specify --validation_file_pattern '
                               'for evaluation.')
        if not FLAGS.val_json_file and not FLAGS.testdev_dir:
            raise RuntimeError(
                'You must specify --val_json_file or --testdev for evaluation.'
            )

    # Parse and override hparams
    config = hparams_config.get_detection_config(FLAGS.model_name)
    config.override(FLAGS.hparams)

    # The following is for spatial partitioning. `features` has one tensor while
    # `labels` had 4 + (`max_level` - `min_level` + 1) * 2 tensors. The input
    # partition is performed on `features` and all partitionable tensors of
    # `labels`, see the partition logic below.
    # In the TPUEstimator context, the meaning of `shard` and `replica` is the
    # same; follwing the API, here has mixed use of both.
    if FLAGS.use_spatial_partition:
        # Checks input_partition_dims agrees with num_cores_per_replica.
        if FLAGS.num_cores_per_replica != np.prod(FLAGS.input_partition_dims):
            raise RuntimeError(
                '--num_cores_per_replica must be a product of array'
                'elements in --input_partition_dims.')

        labels_partition_dims = {
            'mean_num_positives': None,
            'source_ids': None,
            'groundtruth_data': None,
            'image_scales': None,
        }
        # The Input Partition Logic: We partition only the partition-able tensors.
        # Spatial partition requires that the to-be-partitioned tensors must have a
        # dimension that is a multiple of `partition_dims`. Depending on the
        # `partition_dims` and the `image_size` and the `max_level` in config, some
        # high-level anchor labels (i.e., `cls_targets` and `box_targets`) cannot
        # be partitioned. For example, when `partition_dims` is [1, 4, 2, 1], image
        # size is 1536, `max_level` is 9, `cls_targets_8` has a shape of
        # [batch_size, 6, 6, 9], which cannot be partitioned (6 % 4 != 0). In this
        # case, the level-8 and level-9 target tensors are not partition-able, and
        # the highest partition-able level is 7.
        image_size = config.get('image_size')
        for level in range(config.get('min_level'),
                           config.get('max_level') + 1):

            def _can_partition(spatial_dim):
                partitionable_index = np.where(
                    spatial_dim % np.array(FLAGS.input_partition_dims) == 0)
                return len(partitionable_index[0]) == len(
                    FLAGS.input_partition_dims)

            spatial_dim = image_size // (2**level)
            if _can_partition(spatial_dim):
                labels_partition_dims['box_targets_%d' %
                                      level] = FLAGS.input_partition_dims
                labels_partition_dims['cls_targets_%d' %
                                      level] = FLAGS.input_partition_dims
            else:
                labels_partition_dims['box_targets_%d' % level] = None
                labels_partition_dims['cls_targets_%d' % level] = None
        num_cores_per_replica = FLAGS.num_cores_per_replica
        input_partition_dims = [
            FLAGS.input_partition_dims, labels_partition_dims
        ]
        num_shards = FLAGS.num_cores // num_cores_per_replica
    else:
        num_cores_per_replica = None
        input_partition_dims = None
        num_shards = FLAGS.num_cores
        if horovod_enabled():
            num_shards = hvd.size()
        else:
            num_shards = 1

    params = build_estimator_params('train', config, num_shards)
    # disabling input data scaling/flip manipulations.
    if FLAGS.sbs_test:
        sbs_params = dict(input_rand_hflip=False,
                          train_scale_min=1,
                          train_scale_max=1,
                          dropout_rate=0.0)
        params.update(sbs_params)

    tf_random_seed = 0 if FLAGS.deterministic else None
    run_config = build_estimator_config('train', config, num_shards,
                                        num_cores_per_replica,
                                        input_partition_dims)
    write_hparams_v1(FLAGS.model_dir, {
        'batch_size': FLAGS.train_batch_size,
        **FLAGS.flag_values_dict()
    })

    model_fn_instance = det_model_fn.get_model_fn(FLAGS.model_name)

    # TPU Estimator
    logging.info(params)

    if FLAGS.mode == 'train':
        train_estimator = HorovodEstimator(model_fn=model_fn_instance,
                                           model_dir=FLAGS.model_dir,
                                           config=run_config,
                                           params=params)

        # for deterministic input, we pass to dataloader False for not manipulating input data
        is_training = not FLAGS.deterministic
        use_fake_data = FLAGS.use_fake_data or FLAGS.deterministic

        input_fn = dataloader.InputReader(FLAGS.training_file_pattern,
                                          is_training=is_training,
                                          params=params,
                                          use_fake_data=use_fake_data,
                                          is_deterministic=FLAGS.deterministic)
        max_steps = int((FLAGS.num_epochs * FLAGS.num_examples_per_epoch) /
                        (FLAGS.train_batch_size * num_shards)) + 1

        # for sbs test, train under sbs callbacks
        if FLAGS.sbs_test:
            from TensorFlow.common.debug import dump_callback
            SBS_TEST_CONFIG = os.path.join(
                os.environ['TF_TESTS_ROOT'],
                "tests/tf_training_tests/side_by_side/topologies/efficientdet/dump_config.json"
            )
            with dump_callback(SBS_TEST_CONFIG):
                train_estimator.train(input_fn=input_fn, max_steps=max_steps)
        else:
            if FLAGS.ckpt is not None:
                train_estimator.train(input_fn=input_fn, steps=max_steps)
            else:
                train_estimator.train(input_fn=input_fn, max_steps=max_steps)

    elif FLAGS.mode == 'eval':
        eval_params = build_estimator_params('eval', config, num_shards)
        eval_config = build_estimator_config('eval', config, num_shards,
                                             num_cores_per_replica,
                                             input_partition_dims)

        # Eval only runs on CPU or GPU host with batch_size = 1.
        # Override the default options: disable randomization in the input pipeline
        # and don't run on the TPU.
        # Also, disable use_bfloat16 for eval on CPU/GPU.

        eval_estimator = tf.estimator.tpu.TPUEstimator(
            model_fn=model_fn_instance,
            use_tpu=False,
            train_batch_size=FLAGS.train_batch_size,
            eval_batch_size=FLAGS.eval_batch_size,
            config=eval_config,
            params=eval_params)

        def terminate_eval():
            logging.info('Terminating eval after %d seconds of no checkpoints',
                         FLAGS.eval_timeout)
            return True

        # Run evaluation when there's a new checkpoint
        for ckpt in tf.train.checkpoints_iterator(
                FLAGS.model_dir,
                min_interval_secs=FLAGS.min_eval_interval,
                timeout=FLAGS.eval_timeout,
                timeout_fn=terminate_eval):

            logging.info('Starting to evaluate.')
            try:
                eval_results = eval_estimator.evaluate(
                    input_fn=dataloader.InputReader(
                        FLAGS.validation_file_pattern, is_training=False),
                    steps=FLAGS.eval_samples // FLAGS.eval_batch_size)
                logging.info('Eval results: %s', eval_results)

                # Terminate eval job when final checkpoint is reached.
                try:
                    current_step = int(os.path.basename(ckpt).split('-')[1])
                except IndexError:
                    logging.info('%s has no global step info: stop!', ckpt)
                    break

                write_summary(eval_results, ckpt, current_step)

                utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)
                total_step = int(
                    (FLAGS.num_epochs * FLAGS.num_examples_per_epoch) /
                    FLAGS.train_batch_size)
                if current_step >= total_step:
                    logging.info('Evaluation finished after training step %d',
                                 current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint',
                    ckpt)

    elif FLAGS.mode == 'train_and_eval':
        train_params = build_estimator_params('train', config, num_shards)
        train_config = build_estimator_config('train', config, num_shards,
                                              num_cores_per_replica,
                                              input_partition_dims)
        train_estimator = HorovodEstimator(model_fn=model_fn_instance,
                                           model_dir=FLAGS.model_dir,
                                           config=train_config,
                                           params=train_params)

        eval_estimator = None

        for cycle in range(FLAGS.num_epochs):
            logging.info('Starting training cycle, epoch: %d.', cycle)

            train_estimator.train(
                input_fn=dataloader.InputReader(
                    FLAGS.training_file_pattern,
                    is_training=True,
                    use_fake_data=FLAGS.use_fake_data),
                max_steps=(cycle + 1) *
                int(FLAGS.num_examples_per_epoch / FLAGS.train_batch_size))

            # synchronization point for all ranks
            if horovod_enabled():
                hvd.allreduce(tf.constant(0))

            logging.info('Starting evaluation cycle, epoch: %d.', cycle)
            # Run evaluation after every epoch.

            if eval_estimator is None:
                eval_params = build_estimator_params('eval', config,
                                                     num_shards)
                eval_config = build_estimator_config('eval', config,
                                                     num_shards,
                                                     num_cores_per_replica,
                                                     input_partition_dims)
                eval_estimator = tf.estimator.tpu.TPUEstimator(
                    model_fn=model_fn_instance,
                    use_tpu=False,
                    train_batch_size=FLAGS.train_batch_size,
                    eval_batch_size=FLAGS.eval_batch_size,
                    config=eval_config,
                    params=eval_params)

            if is_rank0():
                eval_results = eval_estimator.evaluate(
                    input_fn=dataloader.InputReader(
                        FLAGS.validation_file_pattern, is_training=False),
                    steps=FLAGS.eval_samples // FLAGS.eval_batch_size)

                checkpoint_path = Path(FLAGS.model_dir)
                last_ckpt = tf.train.latest_checkpoint(str(checkpoint_path),
                                                       latest_filename=None)
                current_step = int(os.path.basename(last_ckpt).split('-')[1])
                write_summary(eval_results, FLAGS.model_dir, current_step)
                logging.info('Evaluation results: %s', eval_results)

                ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
                utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)
        pass

    else:
        logging.info('Mode not found.')
示例#15
0
def main(argv):
  assert len(argv) >= 1
  if len(argv) > 1:  # Do not accept unknown args.
    raise ValueError('Received unknown arguments: {}'.format(argv[1:]))

  if FLAGS.use_tpu:
    tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu,
        zone=FLAGS.tpu_zone,
        project=FLAGS.gcp_project)
    tpu_grpc_url = tpu_cluster_resolver.get_master()
    tf.Session.reset(tpu_grpc_url)
  else:
    tpu_cluster_resolver = None

  # Check data path
  if FLAGS.mode in ('train',
                    'train_and_eval') and FLAGS.training_file_pattern is None:
    raise RuntimeError('You must specify --training_file_pattern for training.')
  if FLAGS.mode in ('eval', 'train_and_eval'):
    if FLAGS.validation_file_pattern is None:
      raise RuntimeError('You must specify --validation_file_pattern '
                         'for evaluation.')

  # Parse and override hparams
  config = hparams_config.get_detection_config(FLAGS.model_name)
  config.override(FLAGS.hparams)
  if FLAGS.num_epochs:  # NOTE: remove this flag after updating all docs.
    config.num_epochs = FLAGS.num_epochs

  # Parse image size in case it is in string format.
  config.image_size = utils.parse_image_size(config.image_size)

  # The following is for spatial partitioning. `features` has one tensor while
  # `labels` had 4 + (`max_level` - `min_level` + 1) * 2 tensors. The input
  # partition is performed on `features` and all partitionable tensors of
  # `labels`, see the partition logic below.
  # In the TPUEstimator context, the meaning of `shard` and `replica` is the
  # same; follwing the API, here has mixed use of both.
  if FLAGS.use_spatial_partition:
    # Checks input_partition_dims agrees with num_cores_per_replica.
    if FLAGS.num_cores_per_replica != np.prod(FLAGS.input_partition_dims):
      raise RuntimeError('--num_cores_per_replica must be a product of array'
                         'elements in --input_partition_dims.')

    labels_partition_dims = {
        'mean_num_positives': None,
        'source_ids': None,
        'groundtruth_data': None,
        'image_scales': None,
    }
    # The Input Partition Logic: We partition only the partition-able tensors.
    # Spatial partition requires that the to-be-partitioned tensors must have a
    # dimension that is a multiple of `partition_dims`. Depending on the
    # `partition_dims` and the `image_size` and the `max_level` in config, some
    # high-level anchor labels (i.e., `cls_targets` and `box_targets`) cannot
    # be partitioned. For example, when `partition_dims` is [1, 4, 2, 1], image
    # size is 1536, `max_level` is 9, `cls_targets_8` has a shape of
    # [batch_size, 6, 6, 9], which cannot be partitioned (6 % 4 != 0). In this
    # case, the level-8 and level-9 target tensors are not partition-able, and
    # the highest partition-able level is 7.
    feat_sizes = utils.get_feat_sizes(
        config.get('image_size'), config.get('max_level'))
    for level in range(config.get('min_level'), config.get('max_level') + 1):

      def _can_partition(spatial_dim):
        partitionable_index = np.where(
            spatial_dim % np.array(FLAGS.input_partition_dims) == 0)
        return len(partitionable_index[0]) == len(FLAGS.input_partition_dims)

      spatial_dim = feat_sizes[level]
      if _can_partition(spatial_dim['height']) and _can_partition(
          spatial_dim['width']):
        labels_partition_dims['box_targets_%d' %
                              level] = FLAGS.input_partition_dims
        labels_partition_dims['cls_targets_%d' %
                              level] = FLAGS.input_partition_dims
      else:
        labels_partition_dims['box_targets_%d' % level] = None
        labels_partition_dims['cls_targets_%d' % level] = None
    num_cores_per_replica = FLAGS.num_cores_per_replica
    input_partition_dims = [
        FLAGS.input_partition_dims, labels_partition_dims]
    num_shards = FLAGS.num_cores // num_cores_per_replica
  else:
    num_cores_per_replica = None
    input_partition_dims = None
    num_shards = FLAGS.num_cores

  params = dict(
      config.as_dict(),
      model_name=FLAGS.model_name,

      iterations_per_loop=FLAGS.iterations_per_loop,
      model_dir=FLAGS.model_dir,
      num_shards=num_shards,
      num_examples_per_epoch=FLAGS.num_examples_per_epoch,
      use_tpu=FLAGS.use_tpu,
      backbone_ckpt=FLAGS.backbone_ckpt,
      ckpt=FLAGS.ckpt,
      val_json_file=FLAGS.val_json_file,
      testdev_dir=FLAGS.testdev_dir,
      mode=FLAGS.mode,
  )
  config_proto = tf.ConfigProto(
      allow_soft_placement=True, log_device_placement=False)
  if FLAGS.use_xla and not FLAGS.use_tpu:
    config_proto.graph_options.optimizer_options.global_jit_level = (
        tf.OptimizerOptions.ON_1)

  tpu_config = tf.estimator.tpu.TPUConfig(
      FLAGS.iterations_per_loop,
      num_shards=num_shards,
      num_cores_per_replica=num_cores_per_replica,
      input_partition_dims=input_partition_dims,
      per_host_input_for_training=tf.estimator.tpu.InputPipelineConfig
      .PER_HOST_V2)

  run_config = tf.estimator.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      evaluation_master=FLAGS.eval_master,
      model_dir=FLAGS.model_dir,
      log_step_count_steps=FLAGS.iterations_per_loop,
      session_config=config_proto,
      tpu_config=tpu_config,
      tf_random_seed=FLAGS.tf_random_seed,
  )

  model_fn_instance = det_model_fn.get_model_fn(FLAGS.model_name)

  # TPU Estimator
  logging.info(params)
  if FLAGS.mode == 'train':
    train_estimator = tf.estimator.tpu.TPUEstimator(
        model_fn=model_fn_instance,
        use_tpu=FLAGS.use_tpu,
        train_batch_size=FLAGS.train_batch_size,
        config=run_config,
        params=params)
    train_estimator.train(
        input_fn=dataloader.InputReader(FLAGS.training_file_pattern,
                                        is_training=True,
                                        use_fake_data=FLAGS.use_fake_data),
        max_steps=int((config.num_epochs * FLAGS.num_examples_per_epoch) /
                      FLAGS.train_batch_size))

    if FLAGS.eval_after_training:
      # Run evaluation after training finishes.
      eval_params = dict(
          params,
          use_tpu=FLAGS.use_tpu,
          input_rand_hflip=False,
          is_training_bn=False,
          precision=None,
      )
      eval_estimator = tf.estimator.tpu.TPUEstimator(
          model_fn=model_fn_instance,
          use_tpu=FLAGS.use_tpu,
          train_batch_size=FLAGS.train_batch_size,
          eval_batch_size=FLAGS.eval_batch_size,
          config=run_config,
          params=eval_params)
      eval_results = eval_estimator.evaluate(
          input_fn=dataloader.InputReader(FLAGS.validation_file_pattern,
                                          is_training=False),
          steps=FLAGS.eval_samples//FLAGS.eval_batch_size)
      logging.info('Eval results: %s', eval_results)
      ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
      utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)

  elif FLAGS.mode == 'eval':
    # Eval only runs on CPU or GPU host with batch_size = 1.
    # Override the default options: disable randomization in the input pipeline
    # and don't run on the TPU.
    eval_params = dict(
        params,
        use_tpu=FLAGS.use_tpu,
        input_rand_hflip=False,
        is_training_bn=False,
        precision=None,
    )

    eval_estimator = tf.estimator.tpu.TPUEstimator(
        model_fn=model_fn_instance,
        use_tpu=FLAGS.use_tpu,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        config=run_config,
        params=eval_params)

    def terminate_eval():
      logging.info('Terminating eval after %d seconds of no checkpoints',
                   FLAGS.eval_timeout)
      return True

    # Run evaluation when there's a new checkpoint
    for ckpt in tf.train.checkpoints_iterator(
        FLAGS.model_dir,
        min_interval_secs=FLAGS.min_eval_interval,
        timeout=FLAGS.eval_timeout,
        timeout_fn=terminate_eval):

      logging.info('Starting to evaluate.')
      try:
        eval_results = eval_estimator.evaluate(
            input_fn=dataloader.InputReader(FLAGS.validation_file_pattern,
                                            is_training=False),
            steps=FLAGS.eval_samples//FLAGS.eval_batch_size)
        logging.info('Eval results: %s', eval_results)

        # Terminate eval job when final checkpoint is reached.
        try:
          current_step = int(os.path.basename(ckpt).split('-')[1])
        except IndexError:
          logging.info('%s has no global step info: stop!', ckpt)
          break

        utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)
        total_step = int((config.num_epochs * FLAGS.num_examples_per_epoch) /
                         FLAGS.train_batch_size)
        if current_step >= total_step:
          logging.info('Evaluation finished after training step %d',
                       current_step)
          break

      except tf.errors.NotFoundError:
        # Since the coordinator is on a different job than the TPU worker,
        # sometimes the TPU worker does not finish initializing until long after
        # the CPU job tells it to start evaluating. In this case, the checkpoint
        # file could have been deleted already.
        logging.info('Checkpoint %s no longer exists, skipping checkpoint',
                     ckpt)

  elif FLAGS.mode == 'train_and_eval':
    # resumeした場合、global_stepにstep数が入っていて、cycleは関係がない
    for cycle in range(config.num_epochs):
      logging.info('Starting training cycle, epoch: %d.', cycle)
      train_estimator = tf.estimator.tpu.TPUEstimator(
          model_fn=model_fn_instance,
          use_tpu=FLAGS.use_tpu,
          train_batch_size=FLAGS.train_batch_size,
          config=run_config,
          params=params)
      train_estimator.train(
          input_fn=dataloader.InputReader(FLAGS.training_file_pattern,
                                          is_training=True,
                                          use_fake_data=FLAGS.use_fake_data),
          steps=int(FLAGS.num_examples_per_epoch / FLAGS.train_batch_size))

      logging.info('Starting evaluation cycle, epoch: %d.', cycle)
      # Run evaluation after every epoch.
      eval_params = dict(
          params,
          use_tpu=FLAGS.use_tpu,
          input_rand_hflip=False,
          is_training_bn=False,
      )

      eval_estimator = tf.estimator.tpu.TPUEstimator(
          model_fn=model_fn_instance,
          use_tpu=FLAGS.use_tpu,
          train_batch_size=FLAGS.train_batch_size,
          eval_batch_size=FLAGS.eval_batch_size,
          config=run_config,
          params=eval_params)
      eval_results = eval_estimator.evaluate(
          input_fn=dataloader.InputReader(FLAGS.validation_file_pattern,
                                          is_training=False),
          steps=FLAGS.eval_samples//FLAGS.eval_batch_size)
      logging.info('Evaluation results: %s', eval_results)
      ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
      logging.info(f'save_checkpoint_start for epoch: {cycle}')
      now = datetime.now().strftime('%Y%m%d%H%M%S')
      with open('/tmp/main_inner.log', 'a') as f:
        f.write(f'{now}: save_checkpoint_start {cycle}\n')
      utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)
      with open('/tmp/main_inner.log', 'a') as f:
        f.write(f'{now}: save_checkpoint_end {cycle}\n')
      logging.info(f'save_checkpoint_end for epoch: {cycle}')

  else:
    logging.info('Mode not found.')