Esempio n. 1
0
def cifar10_model_fn(features, labels, mode, params):
    """Model function for CIFAR-10."""
    features = tf.reshape(features, [-1, HEIGHT, WIDTH, NUM_CHANNELS])
    # Learning rate schedule follows arXiv:1512.03385 for ResNet-56 and under.
    learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
        batch_size=params['batch_size'] * params.get('num_workers', 1),
        batch_denom=128,
        num_images=NUM_IMAGES['train'],
        boundary_epochs=[91, 136, 182],
        decay_rates=[1, 0.1, 0.01, 0.001])

    # Weight decay of 2e-4 diverges from 1e-4 decay used in the ResNet paper
    # and seems more stable in testing. The difference was nominal for ResNet-56.
    weight_decay = 2e-4

    # Empirical testing showed that including batch_normalization variables
    # in the calculation of regularized loss helped validation accuracy
    # for the CIFAR-10 dataset, perhaps because the regularization prevents
    # overfitting on the small data set. We therefore include all vars when
    # regularizing and computing loss during training.
    def loss_filter_fn(_):
        return True

    return resnet_run_loop.resnet_model_fn(
        features=features,
        labels=labels,
        mode=mode,
        model_class=Cifar10Model,
        resnet_size=params['resnet_size'],
        weight_decay=weight_decay,
        learning_rate_fn=learning_rate_fn,
        momentum=0.9,
        data_format=params['data_format'],
        resnet_version=params['resnet_version'],
        loss_scale=params['loss_scale'],
        loss_filter_fn=loss_filter_fn,
        dtype=params['dtype'],
        fine_tune=params['fine_tune'])
Esempio n. 2
0
def imagenet_model_fn(features, labels, mode, params):
    """Our model_fn for ResNet to be used with our Estimator."""

    # Warmup and higher lr may not be valid for fine tuning with small batches
    # and smaller numbers of training images.
    if params['fine_tune']:
        warmup = False
        base_lr = .1
    else:
        warmup = True
        base_lr = .128

    learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
        batch_size=params['batch_size'] * params.get('num_workers', 1),
        batch_denom=256,
        num_images=NUM_IMAGES['train'],
        boundary_epochs=[30, 60, 80, 90],
        decay_rates=[1, 0.1, 0.01, 0.001, 1e-4],
        warmup=warmup,
        base_lr=base_lr)

    return resnet_run_loop.resnet_model_fn(
        features=features,
        labels=labels,
        mode=mode,
        model_class=ImagenetModel,
        resnet_size=params['resnet_size'],
        weight_decay=flags.FLAGS.weight_decay,
        learning_rate_fn=learning_rate_fn,
        momentum=0.9,
        data_format=params['data_format'],
        resnet_version=params['resnet_version'],
        loss_scale=params['loss_scale'],
        loss_filter_fn=None,
        dtype=params['dtype'],
        fine_tune=params['fine_tune'],
        label_smoothing=flags.FLAGS.label_smoothing)