Example #1
0
def cifar_10_model_fn(features, labels, mode, params):
    print(features)
    features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS])
    learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
        batch_size=params['batch_size'],
        batch_denom=128,
        num_images=_NUM_IMAGES['train'],
        boundary_epochs=[100, 150, 200],
        decay_rates=[1, 0.1, 0.01, 0.001])
    weight_decay = 2e-4

    def loss_filter_fn(_):
        return True

    return resnet_run_loop.resnet_model_fn(
        features=features,
        labels=labels,
        mode=mode,
        model_class=Cifar10Model,
        resnet_size=params['resnet_size'],
        weight_decay=weight_decay,
        learning_rate_fn=learning_rate_fn,
        momentum=0.9,
        data_format=params['data_format'],
        resnet_version=params['resnet_version'],
        loss_scale=params['loss_scale'],
        loss_filter_fn=loss_filter_fn,
        dtype=params['dtype'])
Example #2
0
def cifar10_model_fn(features, labels, mode, params):
  """Model function for CIFAR-10."""
  features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS])

  learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
      batch_size=params['batch_size'], batch_denom=128,
      num_images=_NUM_IMAGES['train'], boundary_epochs=[100, 150, 200],
      decay_rates=[1, 0.1, 0.01, 0.001])

  # We use a weight decay of 0.0002, which performs better
  # than the 0.0001 that was originally suggested.
  weight_decay = 2e-4

  # Empirical testing showed that including batch_normalization variables
  # in the calculation of regularized loss helped validation accuracy
  # for the CIFAR-10 dataset, perhaps because the regularization prevents
  # overfitting on the small data set. We therefore include all vars when
  # regularizing and computing loss during training.
  def loss_filter_fn(_):
    return True

  return resnet_run_loop.resnet_model_fn(features, labels, mode, Cifar10Model,
                                         resnet_size=params['resnet_size'],
                                         weight_decay=weight_decay,
                                         learning_rate_fn=learning_rate_fn,
                                         momentum=0.9,
                                         data_format=params['data_format'],
                                         version=params['version'],
                                         loss_filter_fn=loss_filter_fn,
                                         multi_gpu=params['multi_gpu'],
                                         select_device=params['select_device'])
Example #3
0
def imagenet_model_fn(features, labels, mode, params):
	"""Our model_fn for ResNet to be used with our Estimator."""

	# Warmup and higher lr may not be valid for fine tuning with small batches
	# and smaller numbers of training images.
	if params['fine_tune']:
		warmup = False
		if params['no_dense_init']:
			base_lr = 0.1
			# boundary_ep = [2, 4, 6, 8]
			boundary_ep = [4, 8, 10, 12]
			# boundary_ep = [5, 10, 13, 16]
		else:
			base_lr = 0.01
			boundary_ep = [4, 8, 10, 12]
			# boundary_ep = [5, 10, 13, 16]
	else:
		warmup = True
		base_lr = .128
		boundary_ep = [30, 60, 80, 90]

	learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
			batch_size=params['batch_size'], batch_denom=256,
			num_images=_NUM_IMAGES['train'], boundary_epochs=boundary_ep,
			decay_rates=[1, 0.1, 0.01, 0.001, 1e-4], warmup=warmup, base_lr=base_lr)

	return resnet_run_loop.resnet_model_fn(
			features=features,
			labels=labels,
			mode=mode,
			model_class=ImagenetModel,
			resnet_size=params['resnet_size'],
			weight_decay=1e-4,
			learning_rate_fn=learning_rate_fn,
			momentum=0.9,
			data_format=params['data_format'],
			resnet_version=params['resnet_version'],
			loss_scale=params['loss_scale'],
			loss_filter_fn=None,
			dtype=params['dtype'],
			fine_tune=params['fine_tune'],
			reconst_loss_scale=params['reconst_loss_scale'],
			use_ce=params['use_ce'],
			opt_chos=params['optimizer'],
			clip_grad=params['clip_grad'],
			spectral_norm=params['spectral_norm'],
			ce_scale=params['ce_scale'],
			sep_grad_nrom=params['sep_grad_nrom'],
			norm_teach_feature=params['norm_teach_feature'],
			compress_ratio=params['compress_ratio']
	)
def cifar10_model_fn(features, labels, mode, params):
  """Model function for CIFAR-10."""
  features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS])

  # Xinyi add, by default turnning off learning rate decay
  boundary_epochs=[250] # Default 250 epoch training
  decay_rates=[1,1]

  if flags.FLAGS.decay_steps !=0 and  flags.FLAGS.decay_steps != 100: # Overwrite
    boundary_epochs_length = int(ceil(100 / flags.FLAGS.decay_steps)) - 1
    boundary_epochs = []
    decay_rates = [1]
    decay_epochs = 250 * flags.FLAGS.decay_steps / 100.0
    for i in xrange(boundary_epochs_length):
      decay_rates.append(flags.FLAGS.decay_rate * decay_rates[i])
      boundary_epochs.append(decay_epochs*(i+1))
  
  learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
      batch_size=params['batch_size'], batch_denom=128,
      num_images=_NUM_IMAGES['train'], boundary_epochs=boundary_epochs,
      decay_rates=decay_rates) # Xinyi modified

  # We use a weight decay of 0.0002, which performs better
  # than the 0.0001 that was originally suggested.
  weight_decay = 2e-4 # Xinyi modified it inside resnet_model.py

  # Empirical testing showed that including batch_normalization variables
  # in the calculation of regularized loss helped validation accuracy
  # for the CIFAR-10 dataset, perhaps because the regularization prevents
  # overfitting on the small data set. We therefore include all vars when
  # regularizing and computing loss during training.
  def loss_filter_fn(_):
    return True

  return resnet_run_loop.resnet_model_fn(
      features=features,
      labels=labels,
      mode=mode,
      model_class=Cifar10Model,
      resnet_size=params['resnet_size'],
      weight_decay=weight_decay,
      learning_rate_fn=learning_rate_fn,
      momentum=0.9,
      data_format=params['data_format'],
      resnet_version=params['resnet_version'],
      loss_scale=params['loss_scale'],
      loss_filter_fn=loss_filter_fn,
      dtype=params['dtype']
  )
Example #5
0
def model_fn(features, labels, mode, params):
    learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
        batch_size=params['batch_size'],
        batch_denom=256,
        num_images=_NUM_IMAGES['train'],
        boundary_epochs=[30, 60, 80, 90],
        decay_rates=[1, 0.1, 0.01, 0.001, 1e-4])
    return resnet_run_loop.resnet_model_fn(features,
                                           labels,
                                           mode,
                                           Model,
                                           resnet_size=params['resnet_size'],
                                           weight_decay=1e-4,
                                           learning_rate_fn=learning_rate_fn,
                                           momentum=0.9,
                                           data_format=params['data_format'],
                                           version=params['version'],
                                           loss_filter_fn=None,
                                           multi_gpu=params['multi_gpu'])
Example #6
0
def imagenet_model_fn(features, labels, mode, params):
  """Our model_fn for ResNet to be used with our Estimator."""
  learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
      batch_size=params['batch_size'], batch_denom=256,
      num_images=_NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90],
      decay_rates=[1, 0.1, 0.01, 0.001, 1e-4])

  return resnet_run_loop.resnet_model_fn(
      features=features,
      labels=labels,
      mode=mode,
      model_class=ImagenetModel,
      resnet_size=params['resnet_size'],
      weight_decay=1e-4,
      learning_rate_fn=learning_rate_fn,
      momentum=0.9,
      data_format=params['data_format'],
      resnet_version=params['resnet_version'],
      loss_scale=params['loss_scale'],
      loss_filter_fn=None,
      dtype=params['dtype']
  )
Example #7
0
def cifar10_model_fn(features, labels, mode, params):
    """Model function for CIFAR-10."""
    features = tf.reshape(features, [-1, HEIGHT, WIDTH, NUM_CHANNELS])
    # Learning rate schedule follows arXiv:1512.03385 for ResNet-56 and under.
    learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
        batch_size=params['batch_size'] * params.get('num_workers', 1),
        batch_denom=128,
        num_images=NUM_IMAGES['train'],
        boundary_epochs=[91, 136, 182],
        decay_rates=[1, 0.1, 0.01, 0.001])

    # Weight decay of 2e-4 diverges from 1e-4 decay used in the ResNet paper
    # and seems more stable in testing. The difference was nominal for ResNet-56.
    # weight_decay = 2e-4  # *SC*

    # Empirical testing showed that including batch_normalization variables
    # in the calculation of regularized loss helped validation accuracy
    # for the CIFAR-10 dataset, perhaps because the regularization prevents
    # overfitting on the small data set. We therefore include all vars when
    # regularizing and computing loss during training.
    def loss_filter_fn(_):
        return True

    return resnet_run_loop.resnet_model_fn(
        features=features,
        labels=labels,
        mode=mode,
        size_factor=params['size_factor'],  # *SC*
        model_class=Cifar10Model,
        resnet_size=params['resnet_size'],
        weight_decay=params['weight_decay'],
        learning_rate_fn=learning_rate_fn,
        momentum=0.9,
        data_format=params['data_format'],
        resnet_version=params['resnet_version'],
        loss_scale=params['loss_scale'],
        loss_filter_fn=loss_filter_fn,
        dtype=params['dtype'],
        fine_tune=params['fine_tune'])
def imagenet_model_fn(features, labels, mode, params):
    """Our model_fn for ResNet to be used with our Estimator."""

    # Warmup and higher lr may not be valid for fine tuning with small batches
    # and smaller numbers of training images.
    if params['fine_tune']:
        warmup = False
        base_lr = .1
    else:
        warmup = True
        base_lr = 1e-4

    learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
        batch_size=params['batch_size'] * params.get('num_workers', 1),
        batch_denom=256,
        num_images=NUM_IMAGES['train'],
        boundary_epochs=[10, 20, 25, 50],
        decay_rates=[1, 0.1, 0.01, 0.001, 1e-4],
        warmup=warmup,
        base_lr=base_lr)

    return resnet_run_loop.resnet_model_fn(
        features=features,
        labels=labels,
        mode=mode,
        model_class=ImagenetModel,
        resnet_size=params['resnet_size'],
        weight_decay=flags.FLAGS.weight_decay,
        learning_rate_fn=learning_rate_fn,
        momentum=0.9,
        data_format=params['data_format'],
        resnet_version=params['resnet_version'],
        loss_scale=params['loss_scale'],
        loss_filter_fn=None,
        dtype=params['dtype'],
        fine_tune=params['fine_tune'],
        label_smoothing=flags.FLAGS.label_smoothing)
Example #9
0
def imagenet_model_fn(features, labels, mode, params):
    """Our model_fn for ResNet to be used with our Estimator."""

    # Warmup and higher lr may not be valid for fine tuning with small batches
    # and smaller numbers of training images.
    base_lr = params['learn_rate']  # 64-0.128, 100-0.15, 128-0.18

    boundary_epochs = [30, 60, 80, 90]
    decay_rates = [1, 0.1, 0.01, 0.001, 1e-4]
    if params['enable_quantize'] and params['online_quantize']:
        boundary_epochs = [10, 20]
        decay_rates = [1, 0.1, 1e-2]

    if params['pickle_model'].startswith("oss://"):
        pickle_file = os.path.split(params['pickle_model'])
        params[
            'pickle_model'] = oss_bucket_root + _BUCKET_DIR + '/' + pickle_file[
                -1]

    # [1, 0.1, 0.01, 0.001, 1e-4], [30, 60, 80, 90]
    learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
        batch_size=params['batch_size'],
        batch_denom=params['batch_size'],
        num_images=_NUM_IMAGES['train'],
        boundary_epochs=boundary_epochs,
        decay_rates=decay_rates,
        base_lr=base_lr,
        train_epochs=params['train_epochs'],
        enable_lars=params['enable_lars'],
        enable_cos=params['enable_cos'],
        cos_alpha=params['cos_alpha'],
        warm_up=params['warm_up'])

    return resnet_run_loop.resnet_model_fn(
        features=features,
        labels=labels,
        mode=mode,
        model_class=ImagenetModelGap if params['gap_ft'] else ImagenetModel,
        model_teach=ImagenetModelTeach,
        resnet_size=params['resnet_size'],
        random_init=params['random_init'],
        final_size=params['final_size'],
        pickle_model=params['pickle_model'],
        weight_decay=params['weight_decay'],
        learning_rate_fn=learning_rate_fn,
        momentum=0.9,
        data_format=params['data_format'],
        version=params['version'],
        version_t=params['version_t'],
        loss_scale=params['loss_scale'],
        gap_train=params['gap_train'],
        gap_lambda=params['gap_lambda'],
        gap_ft=params['gap_ft'],
        loss_filter_fn=None,
        dtype=params['dtype'],
        label_smoothing=params['label_smoothing'],
        enable_lars=params['enable_lars'],
        enable_kd=params['enable_kd'],
        kd_size=params['kd_size'],
        temp_dst=params['temp_dst'],
        w_dst=params['w_dst'],
        mix_up=params['mix_up'],
        mx_mode=params['mx_mode'],
        enable_at=params['enable_at'],
        w_at=params['w_at'])