def imagenet_model_fn(features, labels, mode, params):
  """Our model_fn for ResNet to be used with our Estimator."""

  # Warmup and higher lr may not be valid for fine tuning with small batches
  # and smaller numbers of training images.
  if params['fine_tune']:
    warmup = False
    base_lr = .1
  else:
    warmup = True
    base_lr = .128

  learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
      batch_size=params['batch_size'], batch_denom=256,
      num_images=NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90],
      decay_rates=[1, 0.1, 0.01, 0.001, 1e-4], warmup=warmup, base_lr=base_lr)

  return resnet_run_loop.resnet_model_fn(
      features=features,
      labels=labels,
      mode=mode,
      model_class=ImagenetModel,
      resnet_size=params['resnet_size'],
      weight_decay=1e-4,
      learning_rate_fn=learning_rate_fn,
      momentum=0.9,
      data_format=params['data_format'],
      resnet_version=params['resnet_version'],
      loss_scale=params['loss_scale'],
      loss_filter_fn=None,
      dtype=params['dtype'],
      fine_tune=params['fine_tune']
  )
示例#2
0
def cifar10_model_fn(features, labels, mode, params):
    """Model function for CIFAR-10."""
    features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS])

    learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
        batch_size=params['batch_size'],
        batch_denom=128,
        num_images=_NUM_IMAGES['train'],
        boundary_epochs=[100, 150, 200],
        decay_rates=[1, 0.1, 0.01, 0.001])

    # We use a weight decay of 0.0002, which performs better
    # than the 0.0001 that was originally suggested.
    weight_decay = 2e-4

    # Empirical testing showed that including batch_normalization variables
    # in the calculation of regularized loss helped validation accuracy
    # for the CIFAR-10 dataset, perhaps because the regularization prevents
    # overfitting on the small data set. We therefore include all vars when
    # regularizing and computing loss during training.
    def loss_filter_fn(_):
        return True

    return resnet_run_loop.resnet_model_fn(features,
                                           labels,
                                           mode,
                                           Cifar10Model,
                                           resnet_size=params['resnet_size'],
                                           weight_decay=weight_decay,
                                           learning_rate_fn=learning_rate_fn,
                                           momentum=0.9,
                                           data_format=params['data_format'],
                                           version=params['version'],
                                           loss_filter_fn=loss_filter_fn,
                                           multi_gpu=params['multi_gpu'])
示例#3
0
def cifar10_model_fn(features, labels, mode, params):
  """Model function for CIFAR-10."""
  features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS])

  learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
      batch_size=params['batch_size'], batch_denom=128,
      num_images=_NUM_IMAGES['train'], boundary_epochs=[100, 150, 200],
      decay_rates=[1, 0.1, 0.01, 0.001])

  # We use a weight decay of 0.0002, which performs better
  # than the 0.0001 that was originally suggested.
  weight_decay = 2e-4

  # Empirical testing showed that including batch_normalization variables
  # in the calculation of regularized loss helped validation accuracy
  # for the CIFAR-10 dataset, perhaps because the regularization prevents
  # overfitting on the small data set. We therefore include all vars when
  # regularizing and computing loss during training.
  def loss_filter_fn(_):
    return True

  return resnet_run_loop.resnet_model_fn(features, labels, mode, Cifar10Model,
                                         resnet_size=params['resnet_size'],
                                         weight_decay=weight_decay,
                                         learning_rate_fn=learning_rate_fn,
                                         momentum=0.9,
                                         data_format=params['data_format'],
                                         version=params['version'],
                                         loss_filter_fn=loss_filter_fn,
                                         multi_gpu=params['multi_gpu'])
示例#4
0
def imagenet_model_fn(features, labels, mode, params):
  """Our model_fn for ResNet to be used with our Estimator."""

  # Warmup and higher lr may not be valid for fine tuning with small batches
  # and smaller numbers of training images.
  if params['fine_tune']:
    warmup = False
    base_lr = .1
  else:
    warmup = True
    base_lr = .128

  learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
      batch_size=params['batch_size'] * params.get('num_workers', 1),
      batch_denom=256, num_images=NUM_IMAGES['train'],
      boundary_epochs=[30, 60, 80, 90], decay_rates=[1, 0.1, 0.01, 0.001, 1e-4],
      warmup=warmup, base_lr=base_lr)

  return resnet_run_loop.resnet_model_fn(
      features=features,
      labels=labels,
      mode=mode,
      model_class=ImagenetModel,
      resnet_size=params['resnet_size'],
      weight_decay=flags.FLAGS.weight_decay,
      learning_rate_fn=learning_rate_fn,
      momentum=0.9,
      data_format=params['data_format'],
      resnet_version=params['resnet_version'],
      loss_scale=params['loss_scale'],
      loss_filter_fn=None,
      dtype=params['dtype'],
      fine_tune=params['fine_tune'],
      label_smoothing=flags.FLAGS.label_smoothing
  )
def imagenet_model_fn(features, labels, mode, params):
    """Our model_fn for ResNet to be used with our Estimator."""
    mlcomm = params['mlcomm']
    wd = params['weight_decay']
    base_lr = params['base_lr']
    init_lr = params[
        'init_lr']  #learning rate used at beginning of warmup, reaches base_lr after warmup_epochs
    nwarmup = params['warmup_epochs']
    ndecay = params['train_epochs'] - nwarmup

    #base_lr = 18.
    #init_lr = 0.01

    # CRAY MODIFIED
    if mlcomm == 1:
        learning_rate_fn = resnet_run_loop.learning_rate_warmup_poly_decay(
            batch_size=params['batch_size'],
            num_images=_NUM_IMAGES['train'],
            learning_rate_0=init_lr,
            learning_rate_base=base_lr,
            decay_epochs=ndecay,
            warmup_epochs=nwarmup,
            mlcomm=mlcomm)

        #learning_rate_base=40 for 1024 workers, gbs=32k
        #learning_rate_base=4 for gbs=256
        #learning_rate_base=5 for gbs=512
        #learning_rate_base=9 for gbs=1024

        #learning_rate_base=14 for gbs=4096

    # END CRAY MODIFIED
    else:
        learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
            batch_size=params['batch_size'],
            batch_denom=256,
            num_images=_NUM_IMAGES['train'],
            boundary_epochs=[30, 60, 80, 90],
            decay_rates=[1, 0.1, 0.01, 0.001, 1e-4])

    return resnet_run_loop.resnet_model_fn(
        features=features,
        labels=labels,
        mode=mode,
        model_class=ImagenetModel,
        resnet_size=params['resnet_size'],
        weight_decay=wd,
        learning_rate_fn=learning_rate_fn,
        momentum=0.9,
        data_format=params['data_format'],
        version=params['version'],
        batch_size=params['batch_size'],
        log_freq=params['log_freq'],
        loss_scale=params['loss_scale'],
        loss_filter_fn=None,
        multi_gpu=params['multi_gpu'],
        dtype=params['dtype'],
        mlcomm=mlcomm,
    )
def cifar10_model_fn(features, labels, mode, params):
    """Model function for CIFAR-10."""
    features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS])
    # Learning rate schedule follows arXiv:1512.03385 for ResNet-56 and under.

    #################### My Changes #########################
    #"""
    # purpose -- change learning rate decay schedule
    #"""
    # learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
    #     batch_size=params['batch_size'], batch_denom=128,
    #     num_images=_NUM_IMAGES['train'], boundary_epochs=[91, 136, 182],
    #     decay_rates=[1, 0.1, 0.01, 0.001])
    learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
        batch_size=params['batch_size'],
        batch_denom=1280,
        num_images=_NUM_IMAGES['train'],
        boundary_epochs=[183, 184, 185],
        decay_rates=[1, 1, 1, 1])
    #########################################################

    # Weight decay of 2e-4 diverges from 1e-4 decay used in the ResNet paper
    # and seems more stable in testing. The difference was nominal for ResNet-56.
    weight_decay = 2e-4

    # Empirical testing showed that including batch_normalization variables
    # in the calculation of regularized loss helped validation accuracy
    # for the CIFAR-10 dataset, perhaps because the regularization prevents
    # overfitting on the small data set. We therefore include all vars when
    # regularizing and computing loss during training.
    def loss_filter_fn(_):
        return True

    return resnet_run_loop.resnet_model_fn(
        features=features,
        labels=labels,
        mode=mode,
        model_class=Cifar10Model,
        resnet_size=params['resnet_size'],
        weight_decay=weight_decay,
        learning_rate_fn=learning_rate_fn,
        momentum=0.9,
        data_format=params['data_format'],
        resnet_version=params['resnet_version'],
        loss_scale=params['loss_scale'],
        loss_filter_fn=loss_filter_fn,
        dtype=params['dtype'],
        fine_tune=params['fine_tune'])
示例#7
0
def imagenet_model_fn(features, labels, mode, params):
  """Our model_fn for ResNet to be used with our Estimator."""
  learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
      batch_size=params['batch_size'], batch_denom=256,
      num_images=_NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90],
      decay_rates=[1, 0.1, 0.01, 0.001, 1e-4])

  return resnet_run_loop.resnet_model_fn(features, labels, mode, ImagenetModel,
                                         resnet_size=params['resnet_size'],
                                         weight_decay=1e-4,
                                         learning_rate_fn=learning_rate_fn,
                                         momentum=0.9,
                                         data_format=params['data_format'],
                                         version=params['version'],
                                         loss_filter_fn=None,
                                         multi_gpu=params['multi_gpu'])
示例#8
0
def imagenet_model_fn(features, labels, mode, params):
    """Our model_fn for ResNet to be used with our Estimator."""

    # Warmup and higher lr may not be valid for fine tuning with small batches
    # and smaller numbers of training images.
    print("SSY imagenet_model_fn")
    if params['fine_tune']:
        base_lr = .1
    else:
        base_lr = .128

    learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
        batch_size=params['batch_size'],
        batch_denom=256,
        num_images=_NUM_IMAGES['train'],
        boundary_epochs=[30, 60, 80, 90],
        decay_rates=[1, 0.1, 0.01, 0.001, 1e-4],
        base_lr=_BASE_LR,
        enable_lars=params['enable_lars'])
    # SSY official/resnet/resnet_run_loop.py not related
    print("SSY creating class")
    return resnet_run_loop.resnet_model_fn(
        features=features,
        labels=labels,
        mode=mode,
        # SSY see above
        model_class=ImagenetModel,
        resnet_size=params['resnet_size'],
        weight_decay=params['weight_decay'],
        learning_rate_fn=learning_rate_fn,
        momentum=0.9,
        data_format=params['data_format'],
        version=params['version'],
        loss_scale=params['loss_scale'],
        loss_filter_fn=None,
        dtype=params['dtype'],
        label_smoothing=params['label_smoothing'],
        enable_lars=params['enable_lars'])
示例#9
0
def cifar10_model_fn(features, labels, mode, params):
  """Model function for CIFAR-10."""
  features = tf.reshape(features, [-1, HEIGHT, WIDTH, NUM_CHANNELS])
  # Learning rate schedule follows arXiv:1512.03385 for ResNet-56 and under.
  learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
      batch_size=params['batch_size'], batch_denom=128,
      num_images=NUM_IMAGES['train'], boundary_epochs=[91, 136, 182],
      decay_rates=[1, 0.1, 0.01, 0.001])

  # Weight decay of 2e-4 diverges from 1e-4 decay used in the ResNet paper
  # and seems more stable in testing. The difference was nominal for ResNet-56.
  weight_decay = 2e-4

  # Empirical testing showed that including batch_normalization variables
  # in the calculation of regularized loss helped validation accuracy
  # for the CIFAR-10 dataset, perhaps because the regularization prevents
  # overfitting on the small data set. We therefore include all vars when
  # regularizing and computing loss during training.
  def loss_filter_fn(_):
    return True

  return resnet_run_loop.resnet_model_fn(
      features=features,
      labels=labels,
      mode=mode,
      model_class=Cifar10Model,
      resnet_size=params['resnet_size'],
      weight_decay=weight_decay,
      learning_rate_fn=learning_rate_fn,
      momentum=0.9,
      data_format=params['data_format'],
      resnet_version=params['resnet_version'],
      loss_scale=params['loss_scale'],
      loss_filter_fn=loss_filter_fn,
      dtype=params['dtype'],
      fine_tune=params['fine_tune']
  )
示例#10
0
def cifar10_model_fn(features, labels, mode, params):
    """Model function for CIFAR-10."""
    features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS])
    print("BATCH SIZE", params['batch_size'])

    learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
        batch_size=params['batch_size'],
        batch_denom=BATCH_SIZE,
        num_images=_NUM_IMAGES['train'],
        boundary_epochs=[100, 150, 200],
        decay_rates=[1, 0.1, 0.01, 0.001])

    if sys.argv[2] == 'linear':
        dropout_prob_fn = resnet_run_loop.dropout_prob_with_decay_linear(
            batch_size=params['batch_size'],
            batch_denom=BATCH_SIZE,
            num_images=_NUM_IMAGES['train'],
            decay_epochs=200,
            base_prob=1,
            end_prob=0.95)
    else:
        dropout_prob_fn = resnet_run_loop.dropout_prob_with_decay_piece_wise(
            batch_size=params['batch_size'],
            batch_denom=BATCH_SIZE,
            num_images=_NUM_IMAGES['train'],
            boundary_epochs=[40, 70, 100, 150, 200],
            base_prob=1,
            decay_rates=[1, 0.98, 0.96, 0.94, 0.92, 0.90])

    print(dropout_prob_fn)
    # 1 - 4
    # 0.98 - 7
    # 0.96 - 100
    # 0.94 - 150
    # 0.92 - 200
    # 0.91 - 250
    # 0.90 - 300

    # We use a weight decay of 0.0002, which performs better
    # than the 0.0001 that was originally suggested.
    weight_decay = 2e-4

    # Empirical testing showed that including batch_normalization variables
    # in the calculation of regularized loss helped validation accuracy
    # for the CIFAR-10 dataset, perhaps because the regularization prevents
    # overfitting on the small data set. We therefore include all vars when
    # regularizing and computing loss during training.
    def loss_filter_fn(_):
        return True

    return resnet_run_loop.resnet_model_fn(
        features=features,
        labels=labels,
        mode=mode,
        model_class=Cifar10Model,
        resnet_size=params['resnet_size'],
        weight_decay=weight_decay,
        learning_rate_fn=learning_rate_fn,
        dropout_prob_fn=dropout_prob_fn,
        momentum=0.9,
        data_format=params['data_format'],
        resnet_version=params['resnet_version'],
        loss_scale=params['loss_scale'],
        loss_filter_fn=loss_filter_fn,
        dtype=params['dtype'],
        fine_tune=params['fine_tune'],
        dropout=params['dropout'])