def imagenet_model_fn(features, labels, mode, params): """Our model_fn for ResNet to be used with our Estimator.""" if params['flag_bias'] == True: ## lr for bias correction learning learning_rate_fn = resnet.learning_rate_with_decay( batch_size=params['batch_size'], batch_denom=256, num_images=params['num_train_images'], boundary_epochs=[30*_BIAS_EPOCHS, 60*_BIAS_EPOCHS, 80*_BIAS_EPOCHS, 90*_BIAS_EPOCHS], decay_rates=[1, 0.1, 0.01, 0.001, 1e-4]) else: ## normal train lr learning_rate_fn = resnet.learning_rate_with_decay( batch_size=params['batch_size'], batch_denom=256, num_images=params['num_train_images'], boundary_epochs=[30, 60, 80, 90], decay_rates=[1, 0.1, 0.01, 0.001, 1e-4]) ## increase the weight decay in initial stages ## This line of code is here and left unchanged for a long time. weight_decay = 1e-4 * nb_groups / (params['itera'] + 1) return resnet.resnet_model_fn(features, labels, mode, ImagenetModel, resnet_size=params['resnet_size'], weight_decay=weight_decay, learning_rate_fn=learning_rate_fn, momentum=0.9, data_format=params['data_format'], itera=params['itera'], nb_groups=params['nb_groups'], restore_model_dir=params['restore_model_dir'], flag_bias=params['flag_bias'], loss_filter_fn=None)
def cifar10_model_fn(features, labels, mode, params): """Model function for CIFAR-10.""" features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS]) learning_rate_fn = resnet.learning_rate_with_decay( batch_size=params['batch_size'], batch_denom=128, num_images=_NUM_IMAGES['train'], boundary_epochs=[100, 150, 200], decay_rates=[1, 0.1, 0.01, 0.001]) # We use a weight decay of 0.0002, which performs better # than the 0.0001 that was originally suggested. weight_decay = 2e-4 # Empirical testing showed that including batch_normalization variables # in the calculation of regularized loss helped validation accuracy # for the CIFAR-10 dataset, perhaps because the regularization prevents # overfitting on the small data set. We therefore include all vars when # regularizing and computing loss during training. def loss_filter_fn(name): return True return resnet.resnet_model_fn(features, labels, mode, Cifar10Model, resnet_size=params['resnet_size'], weight_decay=weight_decay, learning_rate_fn=learning_rate_fn, momentum=0.9, data_format=params['data_format'], loss_filter_fn=loss_filter_fn)
def cifar10_model_fn(features, labels, mode, params): """Model function for CIFAR-10.""" features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS]) learning_rate_fn = resnet.learning_rate_with_decay( batch_size=params['batch_size'], batch_denom=128, num_images=_NUM_IMAGES['train'], boundary_epochs=[100, 150, 200], decay_rates=[1, 0.1, 0.01, 0.001]) # We use a weight decay of 0.0002, which performs better # than the 0.0001 that was originally suggested. weight_decay = 2e-4 # Empirical testing showed that including batch_normalization variables # in the calculation of regularized loss helped validation accuracy # for the CIFAR-10 dataset, perhaps because the regularization prevents # overfitting on the small data set. We therefore include all vars when # regularizing and computing loss during training. def loss_filter_fn(name): return True return resnet.resnet_model_fn(features, labels, mode, Cifar10Model, resnet_size=params['resnet_size'], weight_decay=weight_decay, learning_rate_fn=learning_rate_fn, momentum=0.9, data_format=params['data_format'], loss_filter_fn=loss_filter_fn, multi_gpu=params['multi_gpu'])
def imagenet_model_fn(features, labels, mode, params): """Our model_fn for ResNet to be used with our Estimator.""" learning_rate_fn = resnet.learning_rate_with_decay( batch_size=params['batch_size'], batch_denom=256, num_images=_NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90], decay_rates=[1, 0.1, 0.01, 0.001, 1e-4]) return resnet.resnet_model_fn(features, labels, mode, ImagenetModel, resnet_size=params['resnet_size'], weight_decay=1e-4, learning_rate_fn=learning_rate_fn, momentum=0.9, data_format=params['data_format'], loss_filter_fn=None)
def cifar_model_fn(features, labels, mode, params): """ Returns the cifar-10 residual network model :param features: the image data :param labels: the labels of images :param mode: the mode, training, evaluation or prediction :param params: the parameters to be passed to the model :return: a resnet model fn """ features = tf.reshape(features, [-1, _WIDTH, _HEIGHT, _CHANNELS]) # the function that decays learning rate overtime learning_rate_fn = resnet.learning_rate_with_decay( num_images=_NUM_IMAGES['train'], batch_size=params['batch_size'], boundary_epochs=[100, 150, 200], decay_rates=[1, 0.1, 0.01, 0.001]) weight_decay = 2e-4 return resnet_model.resnet_model_fn(features, labels, mode, Cifar10Model, params['resnet_size'], weight_decay, learning_rate_fn, 0.9, params['data_format'], params['multi_gpu'])