def imagenet_model_fn(features, labels, mode, params): """Our model_fn for ResNet to be used with our Estimator.""" learning_rate_fn = resnet_run_loop.learning_rate_with_decay( batch_size=params['batch_size'], batch_denom=256, num_images=_NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90], decay_rates=[1, 0.1, 0.01, 0.001, 1e-4]) return resnet_model_fn(features, labels, mode, Imagenet16Model, resnet_size=params['resnet_size'], weight_decay=1e-4, learning_rate_fn=learning_rate_fn, momentum=0.9, data_format=params['data_format'], version=params['version'], loss_filter_fn=None, multi_gpu=params['multi_gpu'])
def cifar10_model_fn(features, labels, mode, params): """Model function for CIFAR-10.""" features = tf.reshape(features, [-1, HEIGHT, WIDTH, NUM_CHANNELS]) # Learning rate schedule follows arXiv:1512.03385 for ResNet-56 and under. learning_rate_fn = resnet_run_loop.learning_rate_with_decay( batch_size=params['batch_size'] * params.get('num_workers', 1), batch_denom=128, num_images=NUM_IMAGES['train'], boundary_epochs=[91, 136, 182], decay_rates=[1, 0.1, 0.01, 0.001]) # Weight decay of 2e-4 diverges from 1e-4 decay used in the ResNet paper # and seems more stable in testing. The difference was nominal for ResNet-56. weight_decay = 2e-4 # Empirical testing showed that including batch_normalization variables # in the calculation of regularized loss helped validation accuracy # for the CIFAR-10 dataset, perhaps because the regularization prevents # overfitting on the small data set. We therefore include all vars when # regularizing and computing loss during training. def loss_filter_fn(_): return True return resnet_run_loop.resnet_model_fn( features=features, labels=labels, mode=mode, model_class=Cifar10Model, resnet_size=params['resnet_size'], weight_decay=weight_decay, learning_rate_fn=learning_rate_fn, momentum=0.9, data_format=params['data_format'], resnet_version=params['resnet_version'], loss_scale=params['loss_scale'], loss_filter_fn=loss_filter_fn, dtype=params['dtype'], fine_tune=params['fine_tune'])
def imagenet_model_fn(features, labels, mode, params): """Our model_fn for ResNet to be used with our Estimator.""" # Warmup and higher lr may not be valid for fine tuning with small batches # and smaller numbers of training images. if params['fine_tune']: warmup = False base_lr = .1 else: warmup = True base_lr = .128 learning_rate_fn = resnet_run_loop.learning_rate_with_decay( batch_size=params['batch_size'] * params.get('num_workers', 1), batch_denom=256, num_images=NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90], decay_rates=[1, 0.1, 0.01, 0.001, 1e-4], warmup=warmup, base_lr=base_lr) return resnet_run_loop.resnet_model_fn( features=features, labels=labels, mode=mode, model_class=ImagenetModel, resnet_size=params['resnet_size'], weight_decay=flags.FLAGS.weight_decay, learning_rate_fn=learning_rate_fn, momentum=0.9, data_format=params['data_format'], resnet_version=params['resnet_version'], loss_scale=params['loss_scale'], loss_filter_fn=None, dtype=params['dtype'], fine_tune=params['fine_tune'], label_smoothing=flags.FLAGS.label_smoothing)