def test_configure_optimizer(self, mixed_precision_dtype, loss_scale): config = cfg.ExperimentConfig( runtime=cfg.RuntimeConfig( mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale), trainer=cfg.TrainerConfig( optimizer_config=cfg.OptimizationConfig({ 'optimizer': { 'type': 'sgd' }, 'learning_rate': { 'type': 'constant' } }))) trainer = self.create_test_trainer(config) if mixed_precision_dtype != 'float16': self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD) elif mixed_precision_dtype == 'float16' and loss_scale is None: self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD) else: self.assertIsInstance( trainer.optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer) metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32)) self.assertIn('training_loss', metrics)
def test_export_best_ckpt(self, distribution): config = cfg.ExperimentConfig( trainer=cfg.TrainerConfig( best_checkpoint_export_subdir='best_ckpt', best_checkpoint_eval_metric='acc', optimizer_config=cfg.OptimizationConfig({ 'optimizer': { 'type': 'sgd' }, 'learning_rate': { 'type': 'constant' } }))) model_dir = self.get_temp_dir() task = mock_task.MockTask(config.task, logging_dir=model_dir) ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir) trainer = trainer_lib.Trainer( config, task, model=task.build_model(), checkpoint_exporter=ckpt_exporter) trainer.train(tf.convert_to_tensor(1, dtype=tf.int32)) trainer.evaluate(tf.convert_to_tensor(1, dtype=tf.int32)) self.assertTrue( tf.io.gfile.exists(os.path.join(model_dir, 'best_ckpt', 'info.json')))
def basnet() -> cfg.ExperimentConfig: """BASNet general.""" return cfg.ExperimentConfig(task=BASNetModel(), trainer=cfg.TrainerConfig(), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ])
def image_classification() -> cfg.ExperimentConfig: """Image classification general.""" return cfg.ExperimentConfig(task=ImageClassificationTask(), trainer=cfg.TrainerConfig(), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ])
def semantic_segmentation() -> cfg.ExperimentConfig: """Semantic segmentation general.""" return cfg.ExperimentConfig(task=SemanticSegmentationModel(), trainer=cfg.TrainerConfig(), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ])
def seg_unet3d_test() -> cfg.ExperimentConfig: """Image segmentation on a dummy dataset with 3D UNet for testing purpose.""" train_batch_size = 2 eval_batch_size = 2 steps_per_epoch = 10 config = cfg.ExperimentConfig( task=SemanticSegmentation3DTask( model=SemanticSegmentationModel3D( num_classes=2, input_size=[32, 32, 32], num_channels=2, backbone=backbones.Backbone( type='unet_3d', unet_3d=backbones.UNet3D(model_id=2)), decoder=decoders.Decoder( type='unet_3d_decoder', unet_3d_decoder=decoders.UNet3DDecoder(model_id=2)), head=SegmentationHead3D(num_convs=0, num_classes=2), norm_activation=common.NormActivation( activation='relu', use_sync_bn=False)), train_data=DataConfig( input_path='train.tfrecord', num_classes=2, input_size=[32, 32, 32], num_channels=2, is_training=True, global_batch_size=train_batch_size), validation_data=DataConfig( input_path='val.tfrecord', num_classes=2, input_size=[32, 32, 32], num_channels=2, is_training=False, global_batch_size=eval_batch_size), losses=Losses(loss_type='adaptive')), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=10, validation_steps=10, validation_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', }, 'learning_rate': { 'type': 'constant', 'constant': { 'learning_rate': 0.000001 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config
def video_classification() -> cfg.ExperimentConfig: """Video classification general.""" return cfg.ExperimentConfig( task=VideoClassificationTask(), trainer=cfg.TrainerConfig(), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None', 'task.train_data.num_classes == task.validation_data.num_classes', ])
def video_classification() -> cfg.ExperimentConfig: """Video classification general.""" return cfg.ExperimentConfig( runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'), task=VideoClassificationTask(), trainer=cfg.TrainerConfig(), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None', 'task.train_data.num_classes == task.validation_data.num_classes', ])
def setUp(self): super().setUp() self._config = cfg.ExperimentConfig(trainer=cfg.TrainerConfig( optimizer_config=cfg.OptimizationConfig({ 'optimizer': { 'type': 'sgd' }, 'learning_rate': { 'type': 'constant' } })))
def add_trainer(experiment: cfg.ExperimentConfig, train_batch_size: int, eval_batch_size: int, learning_rate: float = 1.6, train_epochs: int = 44, warmup_epochs: int = 5): """Add and config a trainer to the experiment config.""" if experiment.task.train_data.num_examples <= 0: raise ValueError('Wrong train dataset size {!r}'.format( experiment.task.train_data)) if experiment.task.validation_data.num_examples <= 0: raise ValueError('Wrong validation dataset size {!r}'.format( experiment.task.validation_data)) experiment.task.train_data.global_batch_size = train_batch_size experiment.task.validation_data.global_batch_size = eval_batch_size steps_per_epoch = experiment.task.train_data.num_examples // train_batch_size experiment.trainer = cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=train_epochs * steps_per_epoch, validation_steps=experiment.task.validation_data.num_examples // eval_batch_size, validation_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', 'sgd': { 'momentum': 0.9, 'nesterov': True, } }, 'learning_rate': { 'type': 'cosine', 'cosine': { 'initial_learning_rate': learning_rate, 'decay_steps': train_epochs * steps_per_epoch, } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': warmup_epochs * steps_per_epoch, 'warmup_learning_rate': 0 } } })) return experiment
def maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig: """COCO object detection with Mask R-CNN.""" steps_per_epoch = 500 coco_val_samples = 5000 config = cfg.ExperimentConfig( runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'), task=MaskRCNNTask( init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/resnet50_imagenet/ckpt-28080', init_checkpoint_modules='backbone', annotation_file=os.path.join(COCO_INPUT_PATH_BASE, 'instances_val2017.json'), model=MaskRCNN( num_classes=91, input_size=[1024, 1024, 3], include_mask=True), losses=Losses(l2_weight_decay=0.00004), train_data=DataConfig( input_path=os.path.join(COCO_INPUT_PATH_BASE, 'train*'), is_training=True, global_batch_size=64, parser=Parser( aug_rand_hflip=True, aug_scale_min=0.8, aug_scale_max=1.25)), validation_data=DataConfig( input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'), is_training=False, global_batch_size=8)), trainer=cfg.TrainerConfig( train_steps=22500, validation_steps=coco_val_samples // 8, validation_interval=steps_per_epoch, steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', 'sgd': { 'momentum': 0.9 } }, 'learning_rate': { 'type': 'stepwise', 'stepwise': { 'boundaries': [15000, 20000], 'values': [0.12, 0.012, 0.0012], } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 500, 'warmup_learning_rate': 0.0067 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config
def seg_deeplabv3_pascal() -> cfg.ExperimentConfig: """Image segmentation on imagenet with resnet deeplabv3.""" train_batch_size = 16 eval_batch_size = 8 steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size output_stride = 16 aspp_dilation_rates = [12, 24, 36] # [6, 12, 18] if output_stride = 16 multigrid = [1, 2, 4] stem_type = 'v1' level = int(np.math.log2(output_stride)) config = cfg.ExperimentConfig( task=SemanticSegmentationTask( model=SemanticSegmentationModel( num_classes=21, input_size=[None, None, 3], backbone=backbones.Backbone( type='dilated_resnet', dilated_resnet=backbones.DilatedResNet( model_id=101, output_stride=output_stride, multigrid=multigrid, stem_type=stem_type)), decoder=decoders.Decoder( type='aspp', aspp=decoders.ASPP(level=level, dilation_rates=aspp_dilation_rates)), head=SegmentationHead(level=level, num_convs=0), norm_activation=common.NormActivation(activation='swish', norm_momentum=0.9997, norm_epsilon=1e-3, use_sync_bn=True)), losses=Losses(l2_weight_decay=1e-4), train_data=DataConfig( input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'), # TODO(arashwan): test changing size to 513 to match deeplab. output_size=[512, 512], is_training=True, global_batch_size=train_batch_size, aug_scale_min=0.5, aug_scale_max=2.0), validation_data=DataConfig(input_path=os.path.join( PASCAL_INPUT_PATH_BASE, 'val*'), output_size=[512, 512], is_training=False, global_batch_size=eval_batch_size, resize_eval_groundtruth=False, groundtruth_padded_size=[512, 512], drop_remainder=False), # resnet101 init_checkpoint= 'gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400', init_checkpoint_modules='backbone'), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=45 * steps_per_epoch, validation_steps=PASCAL_VAL_EXAMPLES // eval_batch_size, validation_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', 'sgd': { 'momentum': 0.9 } }, 'learning_rate': { 'type': 'polynomial', 'polynomial': { 'initial_learning_rate': 0.007, 'decay_steps': 45 * steps_per_epoch, 'end_learning_rate': 0.0, 'power': 0.9 } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 5 * steps_per_epoch, 'warmup_learning_rate': 0 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config
def maskrcnn_spinenet_coco() -> cfg.ExperimentConfig: """COCO object detection with Mask R-CNN with SpineNet backbone.""" steps_per_epoch = 463 coco_val_samples = 5000 config = cfg.ExperimentConfig( runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'), task=MaskRCNNTask( annotation_file=os.path.join(COCO_INPUT_PATH_BASE, 'instances_val2017.json'), model=MaskRCNN( backbone=backbones.Backbone( type='spinenet', spinenet=backbones.SpineNet(model_id='49')), decoder=decoders.Decoder( type='identity', identity=decoders.Identity()), anchor=Anchor(anchor_size=3), norm_activation=common.NormActivation(use_sync_bn=True), num_classes=91, input_size=[640, 640, 3], min_level=3, max_level=7, include_mask=True), losses=Losses(l2_weight_decay=0.00004), train_data=DataConfig( input_path=os.path.join(COCO_INPUT_PATH_BASE, 'train*'), is_training=True, global_batch_size=256, parser=Parser( aug_rand_hflip=True, aug_scale_min=0.5, aug_scale_max=2.0)), validation_data=DataConfig( input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'), is_training=False, global_batch_size=8)), trainer=cfg.TrainerConfig( train_steps=steps_per_epoch * 350, validation_steps=coco_val_samples // 8, validation_interval=steps_per_epoch, steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', 'sgd': { 'momentum': 0.9 } }, 'learning_rate': { 'type': 'stepwise', 'stepwise': { 'boundaries': [ steps_per_epoch * 320, steps_per_epoch * 340 ], 'values': [0.28, 0.028, 0.0028], } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 2000, 'warmup_learning_rate': 0.0067 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config
def detector_yolo() -> cfg.ExperimentConfig: """YOLO on custom datasets""" config = cfg.ExperimentConfig( task=YoloTask( model=YoloModel( num_classes=6, input_size=[256, 256, 3], backbone=backbones.Backbone( type='hardnet', hardnet=backbones.HardNet(model_id=70)), decoder=decoders.Decoder( type='pan', pan=decoders.PAN(levels=3)), head=YoloHead( anchor_per_scale=3, strides=[16, 32, 64], anchors=[12,16, 19,36, 40,28, 36,75, 76,55, 72,146, 142,110, 192,243, 459,401], xy_scale=[1.2, 1.1, 1.05] ), norm_activation=common.NormActivation( activation='relu', norm_momentum=0.9997, norm_epsilon=0.001, use_sync_bn=True)), losses=YoloLosses(l2_weight_decay=1e-4, iou_loss_thres=0.5), train_data=DataConfig( input_path='D:/data/whizz_tf/detect_env*', output_size=[256, 256], is_training=True, global_batch_size=1), validation_data=DataConfig( input_path='D:/data/whizz_tf/detect_env*', output_size=[256, 256], is_training=False, global_batch_size=1, drop_remainder=False), # init_checkpoint=None init_checkpoint_modules='backbone'), trainer=cfg.TrainerConfig( steps_per_loop=2, summary_interval=2, checkpoint_interval=2, train_steps=20, validation_steps=20, validation_interval=2, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', 'sgd': { 'momentum': 0.9 } }, 'learning_rate': { 'type': 'polynomial', 'polynomial': { 'initial_learning_rate': 0.007, 'decay_steps': 20, 'end_learning_rate': 0.0, 'power': 0.9 } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 2, 'warmup_learning_rate': 0 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config
def image_classification_imagenet() -> cfg.ExperimentConfig: """Image classification on imagenet with resnet.""" train_batch_size = 4096 eval_batch_size = 4096 steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size config = cfg.ExperimentConfig( task=ImageClassificationTask( model=ImageClassificationModel( num_classes=1001, input_size=[224, 224, 3], backbone=backbones.Backbone( type='resnet', resnet=backbones.ResNet(model_id=50)), norm_activation=common.NormActivation(norm_momentum=0.9, norm_epsilon=1e-5)), losses=Losses(l2_weight_decay=1e-4), train_data=DataConfig(input_path=os.path.join( IMAGENET_INPUT_PATH_BASE, 'train*'), is_training=True, global_batch_size=train_batch_size), validation_data=DataConfig(input_path=os.path.join( IMAGENET_INPUT_PATH_BASE, 'valid*'), is_training=False, global_batch_size=eval_batch_size)), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=90 * steps_per_epoch, validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size, validation_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', 'sgd': { 'momentum': 0.9 } }, 'learning_rate': { 'type': 'stepwise', 'stepwise': { 'boundaries': [ 30 * steps_per_epoch, 60 * steps_per_epoch, 80 * steps_per_epoch ], 'values': [ 0.1 * train_batch_size / 256, 0.01 * train_batch_size / 256, 0.001 * train_batch_size / 256, 0.0001 * train_batch_size / 256, ] } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 5 * steps_per_epoch, 'warmup_learning_rate': 0 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config
def mnv2_deeplabv3_cityscapes() -> cfg.ExperimentConfig: """Image segmentation on cityscapes with mobilenetv2 deeplabv3.""" train_batch_size = 16 eval_batch_size = 16 steps_per_epoch = CITYSCAPES_TRAIN_EXAMPLES // train_batch_size output_stride = 16 aspp_dilation_rates = [] pool_kernel_size = [512, 1024] level = int(np.math.log2(output_stride)) config = cfg.ExperimentConfig( task=SemanticSegmentationTask( model=SemanticSegmentationModel( # Cityscapes uses only 19 semantic classes for train/evaluation. # The void (background) class is ignored in train and evaluation. num_classes=19, input_size=[None, None, 3], backbone=backbones.Backbone(type='mobilenet', mobilenet=backbones.MobileNet( model_id='MobileNetV2', output_stride=output_stride)), decoder=decoders.Decoder( type='aspp', aspp=decoders.ASPP(level=level, dilation_rates=aspp_dilation_rates, pool_kernel_size=pool_kernel_size)), head=SegmentationHead(level=level, num_convs=0), norm_activation=common.NormActivation(activation='relu', norm_momentum=0.99, norm_epsilon=1e-3, use_sync_bn=True)), losses=Losses(l2_weight_decay=4e-5), train_data=DataConfig(input_path=os.path.join( CITYSCAPES_INPUT_PATH_BASE, 'train_fine**'), crop_size=[512, 1024], output_size=[1024, 2048], is_training=True, global_batch_size=train_batch_size, aug_scale_min=0.5, aug_scale_max=2.0), validation_data=DataConfig(input_path=os.path.join( CITYSCAPES_INPUT_PATH_BASE, 'val_fine*'), output_size=[1024, 2048], is_training=False, global_batch_size=eval_batch_size, resize_eval_groundtruth=True, drop_remainder=False), # Coco pre-trained mobilenetv2 checkpoint init_checkpoint= 'gs://tf_model_garden/cloud/vision-2.0/deeplab/deeplabv3_mobilenetv2_coco/best_ckpt-63', init_checkpoint_modules='backbone'), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=100000, validation_steps=CITYSCAPES_VAL_EXAMPLES // eval_batch_size, validation_interval=steps_per_epoch, best_checkpoint_eval_metric='mean_iou', best_checkpoint_export_subdir='best_ckpt', best_checkpoint_metric_comp='higher', optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', 'sgd': { 'momentum': 0.9 } }, 'learning_rate': { 'type': 'polynomial', 'polynomial': { 'initial_learning_rate': 0.01, 'decay_steps': 100000, 'end_learning_rate': 0.0, 'power': 0.9 } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 5 * steps_per_epoch, 'warmup_learning_rate': 0 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config
def basnet_duts() -> cfg.ExperimentConfig: """Image segmentation on duts with basnet.""" train_batch_size = 16 eval_batch_size = 16 steps_per_epoch = DUTS_TRAIN_EXAMPLES // train_batch_size config = cfg.ExperimentConfig( task=BASNetTask( model=BASNetModel( input_size=[224, 224, 3], # Resize to 256, 256 backbone=backbones.Backbone( type='basnet_en', basnet_en=backbones.BASNet_En( )), decoder=decoders.Decoder( type='basnet_de', basnet_de=decoders.BASNet_De( )), norm_activation=common.NormActivation( activation='relu', norm_momentum=0.99, norm_epsilon=1e-3, use_sync_bn=True)), losses=Losses(l2_weight_decay=0), train_data=DataConfig( input_path=os.path.join(DUTS_INPUT_PATH_BASE_TR, 'DUTS-TR-*'), is_training=True, global_batch_size=train_batch_size, ), validation_data=DataConfig( input_path=os.path.join(DUTS_INPUT_PATH_BASE_VAL, 'DUTS-TE-*'), is_training=False, global_batch_size=eval_batch_size, ), #init_checkpoint='', #init_checkpoint_modules='backbone' ), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=300 * steps_per_epoch, validation_steps=DUTS_VAL_EXAMPLES // eval_batch_size, validation_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'adam', 'adam': { 'beta_1': 0.9, 'beta_2': 0.999, 'epsilon': 1e-8, } }, 'learning_rate': { 'type': 'stepwise', 'stepwise': {'boundaries': [200*steps_per_epoch], 'values': [0.001, 0.001]} } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config
def mock_experiment() -> cfg.ExperimentConfig: config = cfg.ExperimentConfig(task=MockTaskConfig(), trainer=cfg.TrainerConfig()) return config
def seg_deeplabv2_pascal() -> cfg.ExperimentConfig: """Image segmentation on imagenet with vggnet & resnet deeplabv2.""" train_batch_size = 16 eval_batch_size = 8 steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size # for Large FOV fov_dilation_rates = 12 kernel_size = 3 # for ASPP aspp_dilation_rates = [6, 12, 18, 24] output_stride = 16 level = int(np.math.log2(output_stride)) config = cfg.ExperimentConfig( task=SemanticSegmentationTask( model=SemanticSegmentationModel( num_classes=21, input_size=[None, None, 3], backbone=backbones.Backbone( type='dilated_vggnet', dilated_vggnet=backbones.DilatedVGGNet(model_id=16)), decoder=decoders.Decoder( type='aspp', aspp=decoders.ASPP(level=level, dilation_rates=aspp_dilation_rates, stem_type='v2', num_filters=1024, use_sync_bn=True)), head=SegmentationHead(level=level, num_convs=0, low_level_num_filters=1024, feature_fusion='deeplabv2'), norm_activation=common.NormActivation(activation='swish', norm_momentum=0.9997, norm_epsilon=1e-3, use_sync_bn=True)), losses=Losses(l2_weight_decay=1e-4), train_data=DataConfig( input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'), # TODO(arashwan): test changing size to 513 to match deeplab. output_size=[512, 512], is_training=True, global_batch_size=train_batch_size, aug_scale_min=0.5, aug_scale_max=1.5), validation_data=DataConfig(input_path=os.path.join( PASCAL_INPUT_PATH_BASE, 'val*'), output_size=[512, 512], is_training=False, global_batch_size=eval_batch_size, resize_eval_groundtruth=False, groundtruth_padded_size=[512, 512], drop_remainder=False), # resnet101 init_checkpoint='/home/gunho1123/ckpt_vggnet16_deeplab/', init_checkpoint_modules='backbone'), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=45 * steps_per_epoch, validation_steps=PASCAL_VAL_EXAMPLES // eval_batch_size, validation_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', 'sgd': { 'momentum': 0.9 } }, 'learning_rate': { 'type': 'polynomial', 'polynomial': { 'initial_learning_rate': 0.007, 'decay_steps': 45 * steps_per_epoch, 'end_learning_rate': 0.0, 'power': 0.9 } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 5 * steps_per_epoch, 'warmup_learning_rate': 0 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config
def image_classification_imagenet_mobilenet() -> cfg.ExperimentConfig: """Image classification on imagenet with mobilenet.""" train_batch_size = 4096 eval_batch_size = 4096 steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size config = cfg.ExperimentConfig( task=ImageClassificationTask( model=ImageClassificationModel( num_classes=1001, dropout_rate=0.2, input_size=[224, 224, 3], backbone=backbones.Backbone(type='mobilenet', mobilenet=backbones.MobileNet( model_id='MobileNetV2', filter_size_scale=1.0)), norm_activation=common.NormActivation(norm_momentum=0.997, norm_epsilon=1e-3)), losses=Losses(l2_weight_decay=1e-5, label_smoothing=0.1), train_data=DataConfig(input_path=os.path.join( IMAGENET_INPUT_PATH_BASE, 'train*'), is_training=True, global_batch_size=train_batch_size), validation_data=DataConfig(input_path=os.path.join( IMAGENET_INPUT_PATH_BASE, 'valid*'), is_training=False, global_batch_size=eval_batch_size)), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=500 * steps_per_epoch, validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size, validation_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'rmsprop', 'rmsprop': { 'rho': 0.9, 'momentum': 0.9, 'epsilon': 0.002, } }, 'learning_rate': { 'type': 'exponential', 'exponential': { 'initial_learning_rate': 0.008 * (train_batch_size // 128), 'decay_steps': int(2.5 * steps_per_epoch), 'decay_rate': 0.98, 'staircase': True } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 5 * steps_per_epoch, 'warmup_learning_rate': 0 } }, })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config
def retinanet_resnetfpn_coco() -> cfg.ExperimentConfig: """COCO object detection with RetinaNet.""" train_batch_size = 256 eval_batch_size = 8 steps_per_epoch = COCO_TRIAN_EXAMPLES // train_batch_size config = cfg.ExperimentConfig( runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'), task=RetinaNetTask( init_checkpoint= 'gs://cloud-tpu-checkpoints/vision-2.0/resnet50_imagenet/ckpt-28080', init_checkpoint_modules='backbone', annotation_file=os.path.join(COCO_INPUT_PATH_BASE, 'instances_val2017.json'), model=RetinaNet(num_classes=91, input_size=[640, 640, 3], min_level=3, max_level=7), losses=Losses(l2_weight_decay=1e-4), train_data=DataConfig(input_path=os.path.join( COCO_INPUT_PATH_BASE, 'train*'), is_training=True, global_batch_size=train_batch_size, parser=Parser(aug_rand_hflip=True, aug_scale_min=0.5, aug_scale_max=2.0)), validation_data=DataConfig(input_path=os.path.join( COCO_INPUT_PATH_BASE, 'val*'), is_training=False, global_batch_size=eval_batch_size)), trainer=cfg.TrainerConfig( train_steps=72 * steps_per_epoch, validation_steps=COCO_VAL_EXAMPLES // eval_batch_size, validation_interval=steps_per_epoch, steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', 'sgd': { 'momentum': 0.9 } }, 'learning_rate': { 'type': 'stepwise', 'stepwise': { 'boundaries': [57 * steps_per_epoch, 67 * steps_per_epoch], 'values': [ 0.28 * train_batch_size / 256.0, 0.028 * train_batch_size / 256.0, 0.0028 * train_batch_size / 256.0 ], } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 500, 'warmup_learning_rate': 0.0067 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config
def seg_resnetfpn_pascal() -> cfg.ExperimentConfig: """Image segmentation on imagenet with resnet-fpn.""" train_batch_size = 256 eval_batch_size = 32 steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size config = cfg.ExperimentConfig( task=SemanticSegmentationTask( model=SemanticSegmentationModel( num_classes=21, input_size=[512, 512, 3], min_level=3, max_level=7, backbone=backbones.Backbone( type='resnet', resnet=backbones.ResNet(model_id=50)), decoder=decoders.Decoder(type='fpn', fpn=decoders.FPN()), head=SegmentationHead(level=3, num_convs=3), norm_activation=common.NormActivation(activation='swish', use_sync_bn=True)), losses=Losses(l2_weight_decay=1e-4), train_data=DataConfig(input_path=os.path.join( PASCAL_INPUT_PATH_BASE, 'train_aug*'), is_training=True, global_batch_size=train_batch_size, aug_scale_min=0.2, aug_scale_max=1.5), validation_data=DataConfig(input_path=os.path.join( PASCAL_INPUT_PATH_BASE, 'val*'), is_training=False, global_batch_size=eval_batch_size, resize_eval_groundtruth=False, groundtruth_padded_size=[512, 512], drop_remainder=False), ), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=450 * steps_per_epoch, validation_steps=PASCAL_VAL_EXAMPLES // eval_batch_size, validation_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', 'sgd': { 'momentum': 0.9 } }, 'learning_rate': { 'type': 'polynomial', 'polynomial': { 'initial_learning_rate': 0.007, 'decay_steps': 450 * steps_per_epoch, 'end_learning_rate': 0.0, 'power': 0.9 } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 5 * steps_per_epoch, 'warmup_learning_rate': 0 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config
def seg_deeplabv3plus_scooter() -> cfg.ExperimentConfig: """Image segmentation on scooter dataset with resnet deeplabv3+. Barebones config for testing purpose (modify batch size, initial lr, steps per epoch, train input path, val input path) """ scooter_path_glob = 'D:/data/test_data/val**' steps_per_epoch = 1 output_stride = 16 aspp_dilation_rates = [6, 12, 18] multigrid = [1, 2, 4] stem_type = 'v1' level = int(np.math.log2(output_stride)) config = cfg.ExperimentConfig( task=SemanticSegmentationTask( model=SemanticSegmentationModel( num_classes=19, input_size=[512, 512, 3], # specifying this speeds up model inference, no change in size backbone=backbones.Backbone( type='dilated_resnet', dilated_resnet=backbones.DilatedResNet( model_id=101, output_stride=output_stride, stem_type=stem_type, multigrid=multigrid)), decoder=decoders.Decoder( type='aspp', aspp=decoders.ASPP( level=level, dilation_rates=aspp_dilation_rates)), head=SegmentationHead( level=level, num_convs=2, feature_fusion='deeplabv3plus', low_level=2, low_level_num_filters=48), norm_activation=common.NormActivation( activation='swish', norm_momentum=0.99, norm_epsilon=1e-3, use_sync_bn=True)), losses=Losses( l2_weight_decay=1e-4, ignore_label=250), train_data=DataConfig( input_path=scooter_path_glob, output_size=[512, 512], is_training=True, global_batch_size=1, aug_scale_min=0.5, aug_scale_max=2.0), validation_data=DataConfig( input_path=scooter_path_glob, output_size=[512, 512], is_training=False, global_batch_size=1, resize_eval_groundtruth=True, drop_remainder=False)), # resnet101 # init_checkpoint='D:/repos/data_root/test_data/deeplab_cityscapes_pretrained/model.ckpt', # init_checkpoint_modules='all'), # init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400', # init_checkpoint_modules='backbone'), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=500 * steps_per_epoch, validation_steps=1021, validation_interval=steps_per_epoch, continuous_eval_timeout=1, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', 'sgd': { 'momentum': 0.9 } }, 'learning_rate': { 'type': 'polynomial', 'polynomial': { 'initial_learning_rate': 0.007, 'decay_steps': 500 * steps_per_epoch, 'end_learning_rate': 0.0, 'power': 0.9 } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 5 * steps_per_epoch, 'warmup_learning_rate': 0 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config
def seg_deeplabv3plus_cityscapes() -> cfg.ExperimentConfig: """Image segmentation on imagenet with resnet deeplabv3+.""" train_batch_size = 16 eval_batch_size = 16 steps_per_epoch = CITYSCAPES_TRAIN_EXAMPLES // train_batch_size output_stride = 16 aspp_dilation_rates = [6, 12, 18] multigrid = [1, 2, 4] stem_type = 'v1' level = int(np.math.log2(output_stride)) config = cfg.ExperimentConfig( task=SemanticSegmentationTask( model=SemanticSegmentationModel( # Cityscapes uses only 19 semantic classes for train/evaluation. # The void (background) class is ignored in train and evaluation. num_classes=19, input_size=[None, None, 3], backbone=backbones.Backbone( type='dilated_resnet', dilated_resnet=backbones.DilatedResNet( model_id=101, output_stride=output_stride, stem_type=stem_type, multigrid=multigrid)), decoder=decoders.Decoder( type='aspp', aspp=decoders.ASPP(level=level, dilation_rates=aspp_dilation_rates, pool_kernel_size=[512, 1024])), head=SegmentationHead(level=level, num_convs=2, feature_fusion='deeplabv3plus', low_level=2, low_level_num_filters=48), norm_activation=common.NormActivation(activation='swish', norm_momentum=0.99, norm_epsilon=1e-3, use_sync_bn=True)), losses=Losses(l2_weight_decay=1e-4), train_data=DataConfig(input_path=os.path.join( CITYSCAPES_INPUT_PATH_BASE, 'train_fine**'), crop_size=[512, 1024], output_size=[1024, 2048], is_training=True, global_batch_size=train_batch_size, aug_scale_min=0.5, aug_scale_max=2.0), validation_data=DataConfig(input_path=os.path.join( CITYSCAPES_INPUT_PATH_BASE, 'val_fine*'), output_size=[1024, 2048], is_training=False, global_batch_size=eval_batch_size, resize_eval_groundtruth=True, drop_remainder=False), # resnet101 init_checkpoint= 'gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400', init_checkpoint_modules='backbone'), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=500 * steps_per_epoch, validation_steps=CITYSCAPES_VAL_EXAMPLES // eval_batch_size, validation_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', 'sgd': { 'momentum': 0.9 } }, 'learning_rate': { 'type': 'polynomial', 'polynomial': { 'initial_learning_rate': 0.01, 'decay_steps': 500 * steps_per_epoch, 'end_learning_rate': 0.0, 'power': 0.9 } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 5 * steps_per_epoch, 'warmup_learning_rate': 0 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config
def seg_deeplabv3plus_ade20k_32(backbone: str, init_backbone: bool = True ) -> cfg.ExperimentConfig: """Semantic segmentation on ADE20K dataset with deeplabv3+.""" epochs = 200 train_batch_size = 128 eval_batch_size = 32 image_size = 512 steps_per_epoch = ADE20K_TRAIN_EXAMPLES // train_batch_size aspp_dilation_rates = [5, 10, 15] pretrained_checkpoint_path = BACKBONE_PRETRAINED_CHECKPOINT[ backbone] if init_backbone else None config = cfg.ExperimentConfig( task=CustomSemanticSegmentationTaskConfig( model=base_cfg.SemanticSegmentationModel( # ADE20K uses only 32 semantic classes for train/evaluation. # The void (background) class is ignored in train and evaluation. num_classes=32, input_size=[None, None, 3], backbone=Backbone( type='mobilenet_edgetpu', mobilenet_edgetpu=MobileNetEdgeTPU( model_id=backbone, pretrained_checkpoint_path=pretrained_checkpoint_path, freeze_large_filters=500, )), decoder=decoders.Decoder( type='aspp', aspp=decoders.ASPP( level=BACKBONE_HEADPOINT[backbone], use_depthwise_convolution=True, dilation_rates=aspp_dilation_rates, pool_kernel_size=[256, 256], num_filters=128, dropout_rate=0.3, )), head=base_cfg.SegmentationHead( level=BACKBONE_HEADPOINT[backbone], num_convs=2, num_filters=256, use_depthwise_convolution=True, feature_fusion='deeplabv3plus', low_level=BACKBONE_LOWER_FEATURES[backbone], low_level_num_filters=48), norm_activation=common.NormActivation( activation='relu', norm_momentum=0.99, norm_epsilon=2e-3, use_sync_bn=False)), train_data=base_cfg.DataConfig( input_path=os.path.join(ADE20K_INPUT_PATH_BASE, 'train-*'), output_size=[image_size, image_size], is_training=True, global_batch_size=train_batch_size), validation_data=base_cfg.DataConfig( input_path=os.path.join(ADE20K_INPUT_PATH_BASE, 'val-*'), output_size=[image_size, image_size], is_training=False, global_batch_size=eval_batch_size, resize_eval_groundtruth=True, drop_remainder=False), evaluation=base_cfg.Evaluation(report_train_mean_iou=False), ), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=epochs * steps_per_epoch, validation_steps=ADE20K_VAL_EXAMPLES // eval_batch_size, validation_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'adam', }, 'learning_rate': { 'type': 'polynomial', 'polynomial': { 'initial_learning_rate': 0.0001, 'decay_steps': epochs * steps_per_epoch, 'end_learning_rate': 0.0, 'power': 0.9 } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 4 * steps_per_epoch, 'warmup_learning_rate': 0 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config
def retinanet_spinenet_coco() -> cfg.ExperimentConfig: """COCO object detection with RetinaNet using SpineNet backbone.""" train_batch_size = 256 eval_batch_size = 8 steps_per_epoch = COCO_TRIAN_EXAMPLES // train_batch_size input_size = 640 config = cfg.ExperimentConfig( runtime=cfg.RuntimeConfig(mixed_precision_dtype='float32'), task=RetinaNetTask( annotation_file=os.path.join(COCO_INPUT_PATH_BASE, 'instances_val2017.json'), model=RetinaNet( backbone=backbones.Backbone( type='spinenet', spinenet=backbones.SpineNet(model_id='49')), decoder=decoders.Decoder(type='identity', identity=decoders.Identity()), anchor=Anchor(anchor_size=3), norm_activation=common.NormActivation(use_sync_bn=True), num_classes=91, input_size=[input_size, input_size, 3], min_level=3, max_level=7), losses=Losses(l2_weight_decay=4e-5), train_data=DataConfig(input_path=os.path.join( COCO_INPUT_PATH_BASE, 'train*'), is_training=True, global_batch_size=train_batch_size, parser=Parser(aug_rand_hflip=True, aug_scale_min=0.5, aug_scale_max=2.0)), validation_data=DataConfig(input_path=os.path.join( COCO_INPUT_PATH_BASE, 'val*'), is_training=False, global_batch_size=eval_batch_size)), trainer=cfg.TrainerConfig( train_steps=350 * steps_per_epoch, validation_steps=COCO_VAL_EXAMPLES // eval_batch_size, validation_interval=steps_per_epoch, steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', 'sgd': { 'momentum': 0.9 } }, 'learning_rate': { 'type': 'stepwise', 'stepwise': { 'boundaries': [320 * steps_per_epoch, 340 * steps_per_epoch], 'values': [ 0.28 * train_batch_size / 256.0, 0.028 * train_batch_size / 256.0, 0.0028 * train_batch_size / 256.0 ], } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 2000, 'warmup_learning_rate': 0.0067 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config
def yolo_v4_coco() -> cfg.ExperimentConfig: """COCO object detection with YOLO.""" train_batch_size = 1 eval_batch_size = 1 base_default = 1200000 num_batches = 1200000 * 64 / train_batch_size config = cfg.ExperimentConfig( runtime=cfg.RuntimeConfig( # mixed_precision_dtype='float16', # loss_scale='dynamic', num_gpus=2), task=YoloTask( model=Yolo(base='v4'), train_data=DataConfig( # input_path=os.path.join( # COCO_INPUT_PATH_BASE, 'train*'), is_training=True, global_batch_size=train_batch_size, parser=Parser(), shuffle_buffer_size=2), validation_data=DataConfig( # input_path=os.path.join(COCO_INPUT_PATH_BASE, # 'val*'), is_training=False, global_batch_size=eval_batch_size, shuffle_buffer_size=2)), trainer=cfg.TrainerConfig( steps_per_loop=2000, summary_interval=8000, checkpoint_interval=10000, train_steps=num_batches, validation_steps=1000, validation_interval=10, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', 'sgd': { 'momentum': 0.9 } }, 'learning_rate': { 'type': 'stepwise', 'stepwise': { 'boundaries': [ int(400000 / base_default * num_batches), int(450000 / base_default * num_batches) ], 'values': [ 0.00261 * train_batch_size / 64, 0.000261 * train_batch_size / 64, 0.0000261 * train_batch_size / 64 ] } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 1000 * 64 // num_batches, 'warmup_learning_rate': 0 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config
def basnet_duts() -> cfg.ExperimentConfig: """Image segmentation on duts with basnet.""" train_batch_size = 64 eval_batch_size = 16 steps_per_epoch = DUTS_TRAIN_EXAMPLES // train_batch_size config = cfg.ExperimentConfig( task=BASNetTask( model=BASNetModel(input_size=[None, None, 3], use_bias=True, norm_activation=common.NormActivation( activation='relu', norm_momentum=0.99, norm_epsilon=1e-3, use_sync_bn=True)), losses=Losses(l2_weight_decay=0), train_data=DataConfig( input_path=os.path.join(DUTS_INPUT_PATH_BASE_TR, 'tf_record_train'), file_type='tfrecord', crop_size=[224, 224], output_size=[256, 256], is_training=True, global_batch_size=train_batch_size, ), validation_data=DataConfig( input_path=os.path.join(DUTS_INPUT_PATH_BASE_VAL, 'tf_record_test'), file_type='tfrecord', output_size=[256, 256], is_training=False, global_batch_size=eval_batch_size, ), init_checkpoint= 'gs://cloud-basnet-checkpoints/basnet_encoder_imagenet/ckpt-340306', init_checkpoint_modules='backbone'), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=300 * steps_per_epoch, validation_steps=DUTS_VAL_EXAMPLES // eval_batch_size, validation_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'adam', 'adam': { 'beta_1': 0.9, 'beta_2': 0.999, 'epsilon': 1e-8, } }, 'learning_rate': { 'type': 'constant', 'constant': { 'learning_rate': 0.001 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config
def basnet_duts() -> cfg.ExperimentConfig: """Image segmentation on imagenet with resnet deeplabv3.""" train_batch_size = 16 eval_batch_size = 16 steps_per_epoch = DUTS_TRAIN_EXAMPLES // train_batch_size config = cfg.ExperimentConfig( task=BASNetTask( model=BASNetModel( #num_classes=21, # TODO(arashwan): test changing size to 513 to match deeplab. input_size=[224, 224, 3], # Resize to 256, 256 backbone=backbones.Backbone( type='basnet_en', basnet_en=backbones.BASNet_En( )), decoder=decoders.Decoder( type='basnet_de', basnet_de=decoders.BASNet_De( )), #head=BASNetHead(level=3, num_convs=0), norm_activation=common.NormActivation( activation='relu', norm_momentum=0.99, norm_epsilon=1e-3, use_sync_bn=True)), losses=Losses(l2_weight_decay=0), train_data=DataConfig( #input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'), # Dataset Path # input_path=os.path.join(DUTS_INPUT_PATH_BASE_TR, 'DUTS-TR-*'), is_training=True, global_batch_size=train_batch_size, #aug_scale_min=0.5, #aug_scale_max=2.0 ), validation_data=DataConfig( input_path=os.path.join(DUTS_INPUT_PATH_BASE_VAL, 'DUTS-TE-*'), is_training=False, global_batch_size=eval_batch_size, ), #init_checkpoint='', #init_checkpoint_modules='backbone' ), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=500 * steps_per_epoch, # (gunho) more epochs validation_steps=DUTS_VAL_EXAMPLES // eval_batch_size, # No validation in BASNet validation_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'adam', #BASNet 'adam': { 'beta_1': 0.9, 'beta_2': 0.999, 'epsilon': 1e-8, } }, 'learning_rate': { 'type': 'stepwise', 'stepwise': {'boundaries': [70*steps_per_epoch, 100*steps_per_epoch, 150*steps_per_epoch], 'values': [0.01, 0.001, 0.001, 0.0001]} } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config