Example #1
0
def run(callbacks=None):
  # keras_utils.set_session_config(enable_xla=FLAGS.enable_xla)

  params = config_factory.config_generator(FLAGS.model)

  params = params_dict.override_params_dict(
      params, FLAGS.config_file, is_strict=True)

  params = params_dict.override_params_dict(
      params, FLAGS.params_override, is_strict=True)
  params.override(
      {
          'strategy_type': FLAGS.strategy_type,
          'model_dir': FLAGS.model_dir,
      },
      is_strict=False)
  params.validate()
  params.lock()
  pp = pprint.PrettyPrinter()
  params_str = pp.pformat(params.as_dict())
  logging.info('Model Parameters: {}'.format(params_str))

  train_input_fn = None
  eval_input_fn = None
  training_file_pattern = FLAGS.training_file_pattern or params.train.train_file_pattern
  eval_file_pattern = FLAGS.eval_file_pattern or params.eval.eval_file_pattern
  if not training_file_pattern and not eval_file_pattern:
    raise ValueError('Must provide at least one of training_file_pattern and '
                     'eval_file_pattern.')

  if training_file_pattern:
    # Use global batch size for single host.
    train_input_fn = input_reader.InputFn(
        file_pattern=training_file_pattern,
        params=params,
        mode=input_reader.ModeKeys.TRAIN,
        batch_size=params.train.batch_size)

  if eval_file_pattern:
    eval_input_fn = input_reader.InputFn(
        file_pattern=eval_file_pattern,
        params=params,
        mode=input_reader.ModeKeys.PREDICT_WITH_GT,
        batch_size=params.eval.batch_size,
        num_examples=params.eval.eval_samples)
  # estimator_run(params, train_input_fn)
  return run_executor(
      params,
      mode=ModeKeys.TRAIN,
      train_input_fn=train_input_fn,
      callbacks=callbacks)
Example #2
0
def main(argv):
    del argv  # Unused.

    params = params_dict.ParamsDict(retinanet_config.RETINANET_CFG,
                                    retinanet_config.RETINANET_RESTRICTIONS)
    params = params_dict.override_params_dict(params,
                                              FLAGS.params_overrides,
                                              is_strict=True)
    params.override(
        {
            'platform': {
                'eval_master': FLAGS.eval_master,
                'tpu': FLAGS.tpu,
                'tpu_zone': FLAGS.tpu_zone,
                'gcp_project': FLAGS.gcp_project,
            },
            'use_tpu': FLAGS.use_tpu,
            'model_dir': FLAGS.model_dir,
            'train': {
                'num_shards': FLAGS.num_cores,
            },
        },
        is_strict=False)
    params.validate()
    params.lock()
    pp = pprint.PrettyPrinter()
    params_str = pp.pformat(params.as_dict())
    tf.logging.info('Model Parameters: {}'.format(params_str))

    # Builds detection model on TPUs.
    model_fn = model_builder.ModelFn(params)
    executor = tpu_executor.TpuExecutor(model_fn, params)

    # Prepares input functions for train and eval.
    train_input_fn = input_reader.InputFn(params.train.train_file_pattern,
                                          params,
                                          mode=ModeKeys.TRAIN)
    eval_input_fn = input_reader.InputFn(params.eval.eval_file_pattern,
                                         params,
                                         mode=ModeKeys.PREDICT_WITH_GT)

    # Runs the model.
    if FLAGS.mode == 'train':
        save_config(params, params.model_dir)
        executor.train(train_input_fn, params.train.total_steps)
        if FLAGS.eval_after_training:
            executor.evaluate(
                eval_input_fn,
                params.eval.eval_samples // params.predict.predict_batch_size,
                params.train.total_steps)

    elif FLAGS.mode == 'eval':

        def terminate_eval():
            tf.logging.info(
                'Terminating eval after %d seconds of no checkpoints' %
                params.eval.eval_timeout)
            return True

        # Runs evaluation when there's a new checkpoint.
        for ckpt in tf.contrib.training.checkpoints_iterator(
                params.model_dir,
                min_interval_secs=params.eval.min_eval_interval,
                timeout=params.eval.eval_timeout,
                timeout_fn=terminate_eval):
            # Terminates eval job when final checkpoint is reached.
            current_step = int(os.path.basename(ckpt).split('-')[1])

            tf.logging.info('Starting to evaluate.')
            try:
                executor.evaluate(
                    eval_input_fn, params.eval.eval_samples //
                    params.predict.predict_batch_size, current_step)

                if current_step >= params.train.total_steps:
                    tf.logging.info(
                        'Evaluation finished after training step %d' %
                        current_step)
                    break
            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint' %
                    ckpt)

    elif FLAGS.mode == 'train_and_eval':
        save_config(params, params.model_dir)
        num_cycles = int(params.train.total_steps /
                         params.eval.num_steps_per_eval)
        for cycle in range(num_cycles):
            tf.logging.info('Start training cycle %d.' % cycle)
            current_step = (cycle + 1) * params.eval.num_steps_per_eval
            executor.train(train_input_fn, params.eval.num_steps_per_eval)
            executor.evaluate(
                eval_input_fn,
                params.eval.eval_samples // params.predict.predict_batch_size,
                current_step)
    else:
        tf.logging.info('Mode not found.')
def main(argv):
  del argv  # Unused.

  params = factory.config_generator(FLAGS.model)

  if FLAGS.config_file:
    params = params_dict.override_params_dict(
        params, FLAGS.config_file, is_strict=True)

  params = params_dict.override_params_dict(
      params, FLAGS.params_override, is_strict=True)
  if not FLAGS.use_tpu:
    params.override({
        'architecture': {
            'use_bfloat16': False,
        },
        'batch_norm_activation': {
            'use_sync_bn': False,
        },
    }, is_strict=True)
  params.override({
      'platform': {
          'eval_master': FLAGS.eval_master,
          'tpu': FLAGS.tpu,
          'tpu_zone': FLAGS.tpu_zone,
          'gcp_project': FLAGS.gcp_project,
      },
      'tpu_job_name': FLAGS.tpu_job_name,
      'use_tpu': FLAGS.use_tpu,
      'model_dir': FLAGS.model_dir,
      'train': {
          'num_shards': FLAGS.num_cores,
      },
  }, is_strict=False)
  # Only run spatial partitioning in training mode.
  if FLAGS.mode != 'train':
    params.train.input_partition_dims = None
    params.train.num_cores_per_replica = None

  params.validate()
  params.lock()
  pp = pprint.PrettyPrinter()
  params_str = pp.pformat(params.as_dict())
  logging.info('Model Parameters: %s', params_str)

  # Builds detection model on TPUs.
  model_fn = model_builder.ModelFn(params)
  executor = tpu_executor.TpuExecutor(model_fn, params)

  # Prepares input functions for train and eval.
  train_input_fn = input_reader.InputFn(
      params.train.train_file_pattern, params, mode=ModeKeys.TRAIN,
      dataset_type=params.train.train_dataset_type)
  if params.eval.type == 'customized':
    eval_input_fn = input_reader.InputFn(
        params.eval.eval_file_pattern, params, mode=ModeKeys.EVAL,
        dataset_type=params.eval.eval_dataset_type)
  else:
    eval_input_fn = input_reader.InputFn(
        params.eval.eval_file_pattern, params, mode=ModeKeys.PREDICT_WITH_GT,
        dataset_type=params.eval.eval_dataset_type)

  # Runs the model.
  if FLAGS.mode == 'train':
    config_utils.save_config(params, params.model_dir)
    executor.train(train_input_fn, params.train.total_steps)
    if FLAGS.eval_after_training:
      executor.evaluate(
          eval_input_fn,
          params.eval.eval_samples // params.eval.eval_batch_size)

  elif FLAGS.mode == 'eval':
    def terminate_eval():
      logging.info('Terminating eval after %d seconds of no checkpoints',
                   params.eval.eval_timeout)
      return True
    # Runs evaluation when there's a new checkpoint.
    for ckpt in tf.train.checkpoints_iterator(
        params.model_dir,
        min_interval_secs=params.eval.min_eval_interval,
        timeout=params.eval.eval_timeout,
        timeout_fn=terminate_eval):
      # Terminates eval job when final checkpoint is reached.
      current_step = int(os.path.basename(ckpt).split('-')[1])

      logging.info('Starting to evaluate.')
      try:
        executor.evaluate(
            eval_input_fn,
            params.eval.eval_samples // params.eval.eval_batch_size, ckpt)

        if current_step >= params.train.total_steps:
          logging.info('Evaluation finished after training step %d',
                       current_step)
          break
      except tf.errors.NotFoundError:
        # Since the coordinator is on a different job than the TPU worker,
        # sometimes the TPU worker does not finish initializing until long after
        # the CPU job tells it to start evaluating. In this case, the checkpoint
        # file could have been deleted already.
        logging.info('Checkpoint %s no longer exists, skipping checkpoint',
                     ckpt)

  elif FLAGS.mode == 'train_and_eval':
    config_utils.save_config(params, params.model_dir)
    num_cycles = int(params.train.total_steps / params.eval.num_steps_per_eval)
    for cycle in range(num_cycles):
      logging.info('Start training cycle %d.', cycle)
      current_cycle_last_train_step = ((cycle + 1)
                                       * params.eval.num_steps_per_eval)
      executor.train(train_input_fn, current_cycle_last_train_step)
      executor.evaluate(
          eval_input_fn,
          params.eval.eval_samples // params.eval.eval_batch_size)

  elif FLAGS.mode == 'predict':
    file_pattern = FLAGS.predict_file_pattern
    if not file_pattern:
        raise ValueError('"predict_file_pattern" parameter is required.')

    output_dir = FLAGS.predict_output_dir
    if not output_dir:
        raise ValueError('"predict_output_dir" parameter is required.')

    test_input_fn = input_reader.InputFn(
        file_pattern, params, mode=ModeKeys.PREDICT_WITH_GT,
        dataset_type=params.eval.eval_dataset_type)

    checkpoint_prefix = 'model.ckpt-' + FLAGS.predict_checkpoint_step
    checkpoint_path = os.path.join(FLAGS.model_dir, checkpoint_prefix)
    if not tf.train.checkpoint_exists(checkpoint_path):
        checkpoint_path = os.path.join(FLAGS.model_dir, 'best_checkpoints', checkpoint_prefix)
        if not tf.train.checkpoint_exists(checkpoint_path):
            raise ValueError('Checkpoint not found: %s/%s' % (FLAGS.model_dir, checkpoint_prefix))

    executor.predict(test_input_fn, checkpoint_path, output_dir=output_dir)

  else:
    logging.info('Mode not found.')
Example #4
0
def main(argv):
    del argv  # Unused.

    params = factory.config_generator(FLAGS.model)

    if FLAGS.config_file:
        params = params_dict.override_params_dict(params,
                                                  FLAGS.config_file,
                                                  is_strict=True)

    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=True)
    params.override({
        'use_tpu': FLAGS.use_tpu,
        'model_dir': FLAGS.model_dir,
    },
                    is_strict=True)
    if not FLAGS.use_tpu:
        params.override(
            {
                'architecture': {
                    'use_bfloat16': False,
                },
                'batch_norm_activation': {
                    'use_sync_bn': False,
                },
            },
            is_strict=True)
    # Only run spatial partitioning in training mode.
    if FLAGS.mode != 'train':
        params.train.input_partition_dims = None
        params.train.num_cores_per_replica = None
    params_to_save = params_dict.ParamsDict(params)
    params.override(
        {
            'platform': {
                'eval_master': FLAGS.eval_master,
                'tpu': FLAGS.tpu,
                'tpu_zone': FLAGS.tpu_zone,
                'gcp_project': FLAGS.gcp_project,
            },
            'tpu_job_name': FLAGS.tpu_job_name,
            'train': {
                'num_shards': FLAGS.num_cores,
            },
        },
        is_strict=False)

    params.validate()
    params.lock()
    pp = pprint.PrettyPrinter()
    params_str = pp.pformat(params.as_dict())
    logging.info('Model Parameters: %s', params_str)

    # Builds detection model on TPUs.
    model_fn = model_builder.ModelFn(params)
    executor = tpu_executor.TpuExecutor(model_fn, params)

    # Prepares input functions for train and eval.
    train_input_fn = input_reader.InputFn(
        params.train.train_file_pattern,
        params,
        mode=ModeKeys.TRAIN,
        dataset_type=params.train.train_dataset_type)
    if params.eval.type == 'customized':
        eval_input_fn = input_reader.InputFn(
            params.eval.eval_file_pattern,
            params,
            mode=ModeKeys.EVAL,
            dataset_type=params.eval.eval_dataset_type)
    else:
        eval_input_fn = input_reader.InputFn(
            params.eval.eval_file_pattern,
            params,
            mode=ModeKeys.PREDICT_WITH_GT,
            dataset_type=params.eval.eval_dataset_type)

    if params.eval.eval_samples:
        eval_times = params.eval.eval_samples // params.eval.eval_batch_size
    else:
        eval_times = None

    # Runs the model.
    if FLAGS.mode == 'train':
        config_utils.save_config(params_to_save, params.model_dir)
        executor.train(train_input_fn, params.train.total_steps)
        if FLAGS.eval_after_training:
            executor.evaluate(eval_input_fn, eval_times)

    elif FLAGS.mode == 'eval':

        def terminate_eval():
            logging.info('Terminating eval after %d seconds of no checkpoints',
                         params.eval.eval_timeout)
            return True

        # Runs evaluation when there's a new checkpoint.
        for ckpt in tf.train.checkpoints_iterator(
                params.model_dir,
                min_interval_secs=params.eval.min_eval_interval,
                timeout=params.eval.eval_timeout,
                timeout_fn=terminate_eval):
            # Terminates eval job when final checkpoint is reached.
            current_step = int(
                six.ensure_str(os.path.basename(ckpt)).split('-')[1])

            logging.info('Starting to evaluate.')
            try:
                executor.evaluate(eval_input_fn, eval_times, ckpt)

                if current_step >= params.train.total_steps:
                    logging.info('Evaluation finished after training step %d',
                                 current_step)
                    break
            except tf.errors.NotFoundError as e:
                logging.info(
                    'Erorr occurred during evaluation: NotFoundError: %s', e)

    elif FLAGS.mode == 'train_and_eval':
        config_utils.save_config(params_to_save, params.model_dir)
        num_cycles = int(params.train.total_steps /
                         params.eval.num_steps_per_eval)
        for cycle in range(num_cycles):
            logging.info('Start training cycle %d.', cycle)
            current_cycle_last_train_step = ((cycle + 1) *
                                             params.eval.num_steps_per_eval)
            executor.train(train_input_fn, current_cycle_last_train_step)
            executor.evaluate(eval_input_fn, eval_times)
    else:
        logging.info('Mode not found.')