def main(argv): del argv # Unused. params = params_dict.ParamsDict(unet_config.UNET_CONFIG, unet_config.UNET_RESTRICTIONS) params = params_dict.override_params_dict( params, FLAGS.config_file, is_strict=False) params = params_dict.override_params_dict( params, FLAGS.params_overrides, is_strict=False) params.override( { 'training_file_pattern': FLAGS.training_file_pattern, 'eval_file_pattern': FLAGS.eval_file_pattern, 'model_dir': FLAGS.model_dir, 'min_eval_interval': FLAGS.min_eval_interval, 'eval_timeout': FLAGS.eval_timeout, 'tpu_config': tpu_executor.get_tpu_flags() }, is_strict=False) params.validate() params.lock() train_input_fn = None eval_input_fn = None if FLAGS.mode in ('train', 'train_and_eval'): train_input_fn = input_reader.LiverInputFn( params.training_file_pattern, params, mode=tf.estimator.ModeKeys.TRAIN) if FLAGS.mode in ('eval', 'train_and_eval'): eval_input_fn = input_reader.LiverInputFn( params.eval_file_pattern, params, mode=tf.estimator.ModeKeys.EVAL) run_executer( params, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn)
def main(_): params = params_dict.ParamsDict(unet_config.UNET_CONFIG, unet_config.UNET_RESTRICTIONS) params = params_dict.override_params_dict( params, FLAGS.config_file, is_strict=False) params.train_batch_size = FLAGS.batch_size params.eval_batch_size = FLAGS.batch_size params.use_bfloat16 = False model_params = dict( params.as_dict(), use_tpu=FLAGS.use_tpu, mode=tf.estimator.ModeKeys.PREDICT, transpose_input=False) print(' - Setting up TPUEstimator...') estimator = tf.estimator.tpu.TPUEstimator( model_fn=serving_model_fn, model_dir=FLAGS.model_dir, config=tf.estimator.tpu.RunConfig( tpu_config=tf.estimator.tpu.TPUConfig( iterations_per_loop=FLAGS.iterations_per_loop), master='local', evaluation_master='local'), params=model_params, use_tpu=FLAGS.use_tpu, train_batch_size=FLAGS.batch_size, predict_batch_size=FLAGS.batch_size, export_to_tpu=FLAGS.use_tpu, export_to_cpu=True) print(' - Exporting the model...') input_type = FLAGS.input_type export_path = estimator.export_saved_model( export_dir_base=FLAGS.export_dir, serving_input_receiver_fn=functools.partial( serving_input_fn, batch_size=FLAGS.batch_size, input_type=input_type, params=params, input_name=FLAGS.input_name), checkpoint_path=FLAGS.checkpoint_path) print(' - Done! path: %s' % export_path)
def config_generator(model): """Model function generator.""" if model == 'classification': default_config = classification_config.CLASSIFICATION_CFG restrictions = classification_config.CLASSIFICATION_RESTRICTIONS elif model == 'retinanet': default_config = retinanet_config.RETINANET_CFG restrictions = retinanet_config.RETINANET_RESTRICTIONS elif model == 'mask_rcnn': default_config = maskrcnn_config.MASKRCNN_CFG restrictions = maskrcnn_config.MASKRCNN_RESTRICTIONS elif model == 'segmentation': default_config = segmentation_config.SEGMENTATION_CFG restrictions = segmentation_config.SEGMENTATION_RESTRICTIONS elif model == 'shapemask': default_config = shapemask_config.SHAPEMASK_CFG restrictions = shapemask_config.SHAPEMASK_RESTRICTIONS else: raise ValueError('Model %s is not supported.' % model) return params_dict.ParamsDict(default_config, restrictions)
"""Config template to train Mask R-CNN.""" from configs import detection_config import sys sys.path.insert(0, 'tpu/models') import os.path import sys sys.path.append( '/home/quocbao/Desktop/github_codes/tpu/models/hyperparameters/') from configs import base_config import params_dict # pylint: disable=line-too-long MASKRCNN_CFG = params_dict.ParamsDict(detection_config.DETECTION_CFG) MASKRCNN_CFG.override( { 'type': 'mask_rcnn', 'eval': { 'type': 'box_and_mask', }, 'architecture': { 'parser': 'maskrcnn_parser', 'backbone': 'resnet', 'min_level': 2, 'max_level': 6, 'multilevel_features': 'fpn', 'include_mask': True, 'mask_target_size': 28, },
# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Config template to train classification models.""" import os.path import sys sys.path.append('./models/hyperparameters/') from configs import base_config import params_dict # pylint: disable=line-too-long CLASSIFICATION_CFG = params_dict.ParamsDict(base_config.BASE_CFG) CLASSIFICATION_CFG.override({ 'type': 'classification', 'architecture': { 'parser': 'classification_parser', 'backbone': 'resnet', # Note that `num_classes` is the total number of classes including one # background class whose index is 0. 'num_classes': 1001, }, 'train': { 'iterations_per_loop': 1000, 'train_batch_size': 1024, # 2x2. 'total_steps': 112603, # total images 1281167, so ~90 epochs. 'learning_rate': { 'type': 'cosine',
from configs import base_config import params_dict # pylint: disable=line-too-long # For ResNet, this freezes the variables of the first conv1 and conv2_x # layers [1], which leads to higher training speed and slightly better testing # accuracy. The intuition is that the low-level architecture (e.g., ResNet-50) # is able to capture low-level features such as edges; therefore, it does not # need to be fine-tuned for the detection task. # Note that we need to trailing `/` to avoid the incorrect match. # [1]: RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/' DETECTION_CFG = params_dict.ParamsDict(base_config.BASE_CFG) DETECTION_CFG.override({ 'architecture': { # Note that `num_classes` is the total number of classes including # one background classes whose index is 0. 'num_classes': 91 }, 'eval': { 'type': 'box', # Setting `eval_samples` = None will exhaust all the samples in the eval # dataset once. This only works if `type` != customized. 'eval_samples': None, 'use_json_file': True, 'val_json_file': '', 'per_category_metrics': False, },
# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Config template to train Retinanet.""" from configs import detection_config import os.path import sys sys.path.append('/home/quocbao/Desktop/github_codes/tpu/models/hyperparameters/') from configs import base_config import params_dict # pylint: disable=line-too-long RETINANET_CFG = params_dict.ParamsDict(detection_config.DETECTION_CFG) RETINANET_CFG.override({ 'type': 'retinanet', 'architecture': { 'parser': 'retinanet_parser', 'backbone': 'resnet', 'multilevel_features': 'fpn', }, 'retinanet_parser': { 'output_size': [640, 640], 'match_threshold': 0.5, 'unmatched_threshold': 0.5, 'aug_rand_hflip': True, 'aug_scale_min': 1.0, 'aug_scale_max': 1.0, 'aug_policy': '',
# See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Config template to train Segmentation.""" import os.path import sys sys.path.append('tpu/models/hyperparameters/') from configs import base_config import params_dict # pylint: disable=line-too-long RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/' SEGMENTATION_CFG = params_dict.ParamsDict(base_config.BASE_CFG) SEGMENTATION_CFG.override( { 'type': 'segmentation', 'architecture': { 'parser': 'segmentation_parser', 'backbone': 'resnet', 'multilevel_features': 'fpn', 'use_aspp': False, 'use_pyramid_fusion': False, 'num_classes': 21, # Include background class 0. }, 'train': { 'train_batch_size': 64, 'total_steps': 10000, 'learning_rate': {