예제 #1
0
def main(argv):
  del argv  # Unused.

  params = params_dict.ParamsDict(unet_config.UNET_CONFIG,
                                  unet_config.UNET_RESTRICTIONS)
  params = params_dict.override_params_dict(
      params, FLAGS.config_file, is_strict=False)
  params = params_dict.override_params_dict(
      params, FLAGS.params_overrides, is_strict=False)
  params.override(
      {
          'training_file_pattern': FLAGS.training_file_pattern,
          'eval_file_pattern': FLAGS.eval_file_pattern,
          'model_dir': FLAGS.model_dir,
          'min_eval_interval': FLAGS.min_eval_interval,
          'eval_timeout': FLAGS.eval_timeout,
          'tpu_config': tpu_executor.get_tpu_flags()
      },
      is_strict=False)
  params.validate()
  params.lock()

  train_input_fn = None
  eval_input_fn = None
  if FLAGS.mode in ('train', 'train_and_eval'):
    train_input_fn = input_reader.LiverInputFn(
        params.training_file_pattern, params, mode=tf.estimator.ModeKeys.TRAIN)
  if FLAGS.mode in ('eval', 'train_and_eval'):
    eval_input_fn = input_reader.LiverInputFn(
        params.eval_file_pattern, params, mode=tf.estimator.ModeKeys.EVAL)

  run_executer(
      params, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn)
예제 #2
0
def main(_):
  params = params_dict.ParamsDict(unet_config.UNET_CONFIG,
                                  unet_config.UNET_RESTRICTIONS)
  params = params_dict.override_params_dict(
      params, FLAGS.config_file, is_strict=False)
  params.train_batch_size = FLAGS.batch_size
  params.eval_batch_size = FLAGS.batch_size
  params.use_bfloat16 = False

  model_params = dict(
      params.as_dict(),
      use_tpu=FLAGS.use_tpu,
      mode=tf.estimator.ModeKeys.PREDICT,
      transpose_input=False)

  print(' - Setting up TPUEstimator...')
  estimator = tf.estimator.tpu.TPUEstimator(
      model_fn=serving_model_fn,
      model_dir=FLAGS.model_dir,
      config=tf.estimator.tpu.RunConfig(
          tpu_config=tf.estimator.tpu.TPUConfig(
              iterations_per_loop=FLAGS.iterations_per_loop),
          master='local',
          evaluation_master='local'),
      params=model_params,
      use_tpu=FLAGS.use_tpu,
      train_batch_size=FLAGS.batch_size,
      predict_batch_size=FLAGS.batch_size,
      export_to_tpu=FLAGS.use_tpu,
      export_to_cpu=True)

  print(' - Exporting the model...')
  input_type = FLAGS.input_type
  export_path = estimator.export_saved_model(
      export_dir_base=FLAGS.export_dir,
      serving_input_receiver_fn=functools.partial(
          serving_input_fn,
          batch_size=FLAGS.batch_size,
          input_type=input_type,
          params=params,
          input_name=FLAGS.input_name),
      checkpoint_path=FLAGS.checkpoint_path)

  print(' - Done! path: %s' % export_path)
예제 #3
0
def config_generator(model):
    """Model function generator."""
    if model == 'classification':
        default_config = classification_config.CLASSIFICATION_CFG
        restrictions = classification_config.CLASSIFICATION_RESTRICTIONS
    elif model == 'retinanet':
        default_config = retinanet_config.RETINANET_CFG
        restrictions = retinanet_config.RETINANET_RESTRICTIONS
    elif model == 'mask_rcnn':
        default_config = maskrcnn_config.MASKRCNN_CFG
        restrictions = maskrcnn_config.MASKRCNN_RESTRICTIONS
    elif model == 'segmentation':
        default_config = segmentation_config.SEGMENTATION_CFG
        restrictions = segmentation_config.SEGMENTATION_RESTRICTIONS
    elif model == 'shapemask':
        default_config = shapemask_config.SHAPEMASK_CFG
        restrictions = shapemask_config.SHAPEMASK_RESTRICTIONS
    else:
        raise ValueError('Model %s is not supported.' % model)

    return params_dict.ParamsDict(default_config, restrictions)
예제 #4
0
"""Config template to train Mask R-CNN."""

from configs import detection_config
import sys
sys.path.insert(0, 'tpu/models')
import os.path
import sys
sys.path.append(
    '/home/quocbao/Desktop/github_codes/tpu/models/hyperparameters/')

from configs import base_config
import params_dict

# pylint: disable=line-too-long

MASKRCNN_CFG = params_dict.ParamsDict(detection_config.DETECTION_CFG)
MASKRCNN_CFG.override(
    {
        'type': 'mask_rcnn',
        'eval': {
            'type': 'box_and_mask',
        },
        'architecture': {
            'parser': 'maskrcnn_parser',
            'backbone': 'resnet',
            'min_level': 2,
            'max_level': 6,
            'multilevel_features': 'fpn',
            'include_mask': True,
            'mask_target_size': 28,
        },
예제 #5
0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Config template to train classification models."""

import os.path
import sys
sys.path.append('./models/hyperparameters/')
from configs import base_config
import params_dict

# pylint: disable=line-too-long
CLASSIFICATION_CFG = params_dict.ParamsDict(base_config.BASE_CFG)
CLASSIFICATION_CFG.override({
    'type': 'classification',
    'architecture': {
        'parser': 'classification_parser',
        'backbone': 'resnet',
        # Note that `num_classes` is the total number of classes including one
        # background class whose index is 0.
        'num_classes': 1001,
    },
    'train': {
        'iterations_per_loop': 1000,
        'train_batch_size': 1024,  # 2x2.
        'total_steps': 112603,  # total images 1281167, so ~90 epochs.
        'learning_rate': {
            'type': 'cosine',
예제 #6
0
from configs import base_config
import params_dict

# pylint: disable=line-too-long

# For ResNet, this freezes the variables of the first conv1 and conv2_x
# layers [1], which leads to higher training speed and slightly better testing
# accuracy. The intuition is that the low-level architecture (e.g., ResNet-50)
# is able to capture low-level features such as edges; therefore, it does not
# need to be fine-tuned for the detection task.
# Note that we need to trailing `/` to avoid the incorrect match.
# [1]: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py#L198
RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'

DETECTION_CFG = params_dict.ParamsDict(base_config.BASE_CFG)
DETECTION_CFG.override({
    'architecture': {
        # Note that `num_classes` is the total number of classes including
        # one background classes whose index is 0.
        'num_classes': 91
    },
    'eval': {
        'type': 'box',
        # Setting `eval_samples` = None will exhaust all the samples in the eval
        # dataset once. This only works if `type` != customized.
        'eval_samples': None,
        'use_json_file': True,
        'val_json_file': '',
        'per_category_metrics': False,
    },
예제 #7
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Config template to train Retinanet."""

from configs import detection_config
import os.path
import sys
sys.path.append('/home/quocbao/Desktop/github_codes/tpu/models/hyperparameters/')

from configs import base_config
import params_dict
# pylint: disable=line-too-long
RETINANET_CFG = params_dict.ParamsDict(detection_config.DETECTION_CFG)
RETINANET_CFG.override({
    'type': 'retinanet',
    'architecture': {
        'parser': 'retinanet_parser',
        'backbone': 'resnet',
        'multilevel_features': 'fpn',
    },
    'retinanet_parser': {
        'output_size': [640, 640],
        'match_threshold': 0.5,
        'unmatched_threshold': 0.5,
        'aug_rand_hflip': True,
        'aug_scale_min': 1.0,
        'aug_scale_max': 1.0,
        'aug_policy': '',
예제 #8
0
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Config template to train Segmentation."""

import os.path
import sys
sys.path.append('tpu/models/hyperparameters/')

from configs import base_config
import params_dict

# pylint: disable=line-too-long
RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'

SEGMENTATION_CFG = params_dict.ParamsDict(base_config.BASE_CFG)
SEGMENTATION_CFG.override(
    {
        'type': 'segmentation',
        'architecture': {
            'parser': 'segmentation_parser',
            'backbone': 'resnet',
            'multilevel_features': 'fpn',
            'use_aspp': False,
            'use_pyramid_fusion': False,
            'num_classes': 21,  # Include background class 0.
        },
        'train': {
            'train_batch_size': 64,
            'total_steps': 10000,
            'learning_rate': {