Example #1
0
  def _resnet_model_fn(features, labels, mode, params):
    """Resnet model body.

    Support single host, one or more GPU training. Parameter distribution can
    be either one of the following scheme.
    1. CPU is the parameter server and manages gradient updates.
    2. Parameters are distributed evenly across all GPUs, and the first GPU
       manages gradient updates.

    Args:
      features: a list of tensors, one for each tower
      labels: a list of tensors, one for each tower
      mode: ModeKeys.TRAIN or EVAL
      params: Hyperparameters suitable for tuning
    Returns:
      A EstimatorSpec object.
    """
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    weight_decay = params.weight_decay
    momentum = params.momentum

    tower_features = features
    tower_labels = labels
    tower_losses = []
    tower_gradvars = []
    tower_preds = []

    if num_gpus == 0:
      num_devices = 1
      device_type = 'cpu'
    else:
      num_devices = num_gpus
      device_type = 'gpu'

    for i in range(num_devices):
      worker_device = '/{}:{}'.format(device_type, i)
      if variable_strategy == 'CPU':
        device_setter = cifar10_utils.local_device_setter(
            worker_device=worker_device)
      elif variable_strategy == 'GPU':
        device_setter = cifar10_utils.local_device_setter(
            ps_device_type='gpu',
            worker_device=worker_device,
            ps_strategy=tf.contrib.training.GreedyLoadBalancingStrategy(
                num_gpus, tf.contrib.training.byte_size_load_fn))
      with tf.variable_scope('resnet', reuse=bool(i != 0)):
        with tf.name_scope('tower_%d' % i) as name_scope:
          with tf.device(device_setter):
            loss, gradvars, preds = _tower_fn(
<<<<<<< HEAD
                is_training, weight_decay, tower_features[i], tower_labels[i],
                data_format, params['num_layers'], params['batch_norm_decay'],
                params['batch_norm_epsilon'])
Example #2
0
    def _resnet_model_fn(features, labels, mode, params):
        """Resnet model body.

    Support single host, one or more GPU training. Parameter distribution can
    be either one of the following scheme.
    1. CPU is the parameter server and manages gradient updates.
    2. Parameters are distributed evenly across all GPUs, and the first GPU
       manages gradient updates.

    Args:
      features: a list of tensors, one for each tower
      labels: a list of tensors, one for each tower
      mode: ModeKeys.TRAIN or EVAL
      params: Hyperparameters suitable for tuning
    Returns:
      A EstimatorSpec object.
    """
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        weight_decay = params.weight_decay
        momentum = params.momentum

        tower_features = features
        tower_labels = labels
        tower_losses = []
        tower_gradvars = []
        tower_preds = []

        # channels first (NCHW) is normally optimal on GPU and channels last (NHWC)
        # on CPU. The exception is Intel MKL on CPU which is optimal with
        # channels_last.
        data_format = params.data_format
        if not data_format:
            if num_gpus == 0:
                data_format = 'channels_last'
            else:
                data_format = 'channels_first'

        print("data_format: ", data_format)
        print("num_gpus: ", data_format)

        if num_gpus == 0:
            num_devices = 1
            device_type = 'cpu'
        else:
            num_devices = num_gpus
            device_type = 'gpu'

        for i in range(num_devices):
            worker_device = '/{}:{}'.format(device_type, i)
            if variable_strategy == 'CPU':
                device_setter = cifar10_utils.local_device_setter(
                    worker_device=worker_device)
            elif variable_strategy == 'GPU':
                device_setter = cifar10_utils.local_device_setter(
                    ps_device_type='gpu',
                    worker_device=worker_device,
                    ps_strategy=tf.contrib.training.
                    GreedyLoadBalancingStrategy(
                        num_gpus, tf.contrib.training.byte_size_load_fn))
            with tf.variable_scope('resnet', reuse=bool(i != 0)):
                with tf.name_scope('tower_%d' % i) as name_scope:
                    with tf.device(device_setter):
                        loss, gradvars, preds = _tower_fn(
                            is_training, weight_decay, tower_features[i],
                            tower_labels[i], data_format, params.num_layers,
                            params.batch_norm_decay, params.batch_norm_epsilon)
                        tower_losses.append(loss)
                        tower_gradvars.append(gradvars)
                        tower_preds.append(preds)
                        if i == 0:
                            # Only trigger batch_norm moving mean and variance update from
                            # the 1st tower. Ideally, we should grab the updates from all
                            # towers but these stats accumulate extremely fast so we can
                            # ignore the other stats from the other towers without
                            # significant detriment.
                            update_ops = tf.get_collection(
                                tf.GraphKeys.UPDATE_OPS, name_scope)

        # Now compute global loss and gradients.
        gradvars = []
        with tf.name_scope('gradient_averaging'):
            all_grads = {}
            for grad, var in itertools.chain(*tower_gradvars):
                if grad is not None:
                    all_grads.setdefault(var, []).append(grad)
            for var, grads in six.iteritems(all_grads):
                # Average gradients on the same device as the variables
                # to which they apply.
                with tf.device(var.device):
                    if len(grads) == 1:
                        avg_grad = grads[0]
                    else:
                        avg_grad = tf.multiply(tf.add_n(grads),
                                               1. / len(grads))
                gradvars.append((avg_grad, var))

        # Device that runs the ops to apply global gradient updates.
        consolidation_device = '/gpu:0' if variable_strategy == 'GPU' else '/cpu:0'
        with tf.device(consolidation_device):
            # Suggested learning rate scheduling from
            # https://github.com/ppwwyyxx/tensorpack/blob/master/examples/ResNet/cifar10-resnet.py#L155
            num_batches_per_epoch = cifar10.Cifar10DataSet.num_examples_per_epoch(
                'train') // (params.train_batch_size * num_workers)
            boundaries = [
                num_batches_per_epoch * x
                for x in np.array([82, 123, 300], dtype=np.int64)
            ]
            staged_lr = [
                params.learning_rate * x for x in [1, 0.1, 0.01, 0.002]
            ]

            learning_rate = tf.train.piecewise_constant(
                tf.train.get_global_step(), boundaries, staged_lr)

            loss = tf.reduce_mean(tower_losses, name='loss')

            examples_sec_hook = cifar10_utils.ExamplesPerSecondHook(
                params.train_batch_size, every_n_steps=10)

            tensors_to_log = {'learning_rate': learning_rate, 'loss': loss}

            logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log,
                                                      every_n_iter=100)

            train_hooks = [logging_hook, examples_sec_hook]

            optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                                   momentum=momentum)

            if params.sync:
                optimizer = tf.train.SyncReplicasOptimizer(
                    optimizer, replicas_to_aggregate=num_workers)
                sync_replicas_hook = optimizer.make_session_run_hook(
                    params.is_chief)
                train_hooks.append(sync_replicas_hook)

            # Create single grouped train op
            train_op = [
                optimizer.apply_gradients(
                    gradvars, global_step=tf.train.get_global_step())
            ]
            train_op.extend(update_ops)
            train_op = tf.group(*train_op)

            predictions = {
                'classes':
                tf.concat([p['classes'] for p in tower_preds], axis=0),
                'probabilities':
                tf.concat([p['probabilities'] for p in tower_preds], axis=0)
            }
            stacked_labels = tf.concat(labels, axis=0)
            metrics = {
                'accuracy':
                tf.metrics.accuracy(stacked_labels, predictions['classes'])
            }

        return tf.estimator.EstimatorSpec(mode=mode,
                                          predictions=predictions,
                                          loss=loss,
                                          train_op=train_op,
                                          training_hooks=train_hooks,
                                          eval_metric_ops=metrics)
Example #3
0
  def _resnet_model_fn(features, labels, mode, params):
    """Resnet model body.

    Support single host, one or more GPU training. Parameter distribution can
    be either one of the following scheme.
    1. CPU is the parameter server and manages gradient updates.
    2. Parameters are distributed evenly across all GPUs, and the first GPU
       manages gradient updates.

    Args:
      features: a list of tensors, one for each tower
      labels: a list of tensors, one for each tower
      mode: ModeKeys.TRAIN or EVAL
      params: Hyperparameters suitable for tuning
    Returns:
      A EstimatorSpec object.
    """
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    weight_decay = params.weight_decay
    momentum = params.momentum

    tower_features = features
    tower_labels = labels
    tower_losses = []
    tower_gradvars = []
    tower_preds = []

    # channels first (NCHW) is normally optimal on GPU and channels last (NHWC)
    # on CPU. The exception is Intel MKL on CPU which is optimal with
    # channels_last.
    data_format = params.data_format
    if not data_format:
      if num_gpus == 0:
        data_format = 'channels_last'
      else:
        data_format = 'channels_first'

    if num_gpus == 0:
      num_devices = 1
      device_type = 'cpu'
    else:
      num_devices = num_gpus
      device_type = 'gpu'

    for i in range(num_devices):
      worker_device = '/{}:{}'.format(device_type, i)
      if variable_strategy == 'CPU':
        device_setter = cifar10_utils.local_device_setter(
            worker_device=worker_device)
      elif variable_strategy == 'GPU':
        device_setter = cifar10_utils.local_device_setter(
            ps_device_type='gpu',
            worker_device=worker_device,
            ps_strategy=tf.contrib.training.GreedyLoadBalancingStrategy(
                num_gpus, tf.contrib.training.byte_size_load_fn))
      with tf.variable_scope('resnet', reuse=bool(i != 0)):
        with tf.name_scope('tower_%d' % i) as name_scope:
          with tf.device(device_setter):
            loss, gradvars, preds = _tower_fn(
                is_training, weight_decay, tower_features[i], tower_labels[i],
                data_format, params.num_layers, params.batch_norm_decay,
                params.batch_norm_epsilon)
            tower_losses.append(loss)
            tower_gradvars.append(gradvars)
            tower_preds.append(preds)
            if i == 0:
              # Only trigger batch_norm moving mean and variance update from
              # the 1st tower. Ideally, we should grab the updates from all
              # towers but these stats accumulate extremely fast so we can
              # ignore the other stats from the other towers without
              # significant detriment.
              update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                             name_scope)

    # Now compute global loss and gradients.
    gradvars = []
    with tf.name_scope('gradient_averaging'):
      all_grads = {}
      for grad, var in itertools.chain(*tower_gradvars):
        if grad is not None:
          all_grads.setdefault(var, []).append(grad)
      for var, grads in six.iteritems(all_grads):
        # Average gradients on the same device as the variables
        # to which they apply.
        with tf.device(var.device):
          if len(grads) == 1:
            avg_grad = grads[0]
          else:
            avg_grad = tf.multiply(tf.add_n(grads), 1. / len(grads))
        gradvars.append((avg_grad, var))

    # Device that runs the ops to apply global gradient updates.
    consolidation_device = '/gpu:0' if variable_strategy == 'GPU' else '/cpu:0'
    with tf.device(consolidation_device):
      # Suggested learning rate scheduling from
      # https://github.com/ppwwyyxx/tensorpack/blob/master/examples/ResNet/cifar10-resnet.py#L155
      num_batches_per_epoch = cifar10.Cifar10DataSet.num_examples_per_epoch(
          'train') // (params.train_batch_size * num_workers)
      boundaries = [
          num_batches_per_epoch * x
          for x in np.array([82, 123, 300], dtype=np.int64)
      ]
      staged_lr = [params.learning_rate * x for x in [1, 0.1, 0.01, 0.002]]

      learning_rate = tf.train.piecewise_constant(tf.train.get_global_step(),
                                                  boundaries, staged_lr)

      loss = tf.reduce_mean(tower_losses, name='loss')

      examples_sec_hook = cifar10_utils.ExamplesPerSecondHook(
          params.train_batch_size, every_n_steps=10)

      tensors_to_log = {'learning_rate': learning_rate, 'loss': loss}

      logging_hook = tf.train.LoggingTensorHook(
          tensors=tensors_to_log, every_n_iter=100)

      train_hooks = [logging_hook, examples_sec_hook]

      optimizer = tf.train.MomentumOptimizer(
          learning_rate=learning_rate, momentum=momentum)

      if params.sync:
        optimizer = tf.train.SyncReplicasOptimizer(
            optimizer, replicas_to_aggregate=num_workers)
        sync_replicas_hook = optimizer.make_session_run_hook(params.is_chief)
        train_hooks.append(sync_replicas_hook)

      # Create single grouped train op
      train_op = [
          optimizer.apply_gradients(
              gradvars, global_step=tf.train.get_global_step())
      ]
      train_op.extend(update_ops)
      train_op = tf.group(*train_op)

      predictions = {
          'classes':
              tf.concat([p['classes'] for p in tower_preds], axis=0),
          'probabilities':
              tf.concat([p['probabilities'] for p in tower_preds], axis=0)
      }
      stacked_labels = tf.concat(labels, axis=0)
      metrics = {
          'accuracy':
              tf.metrics.accuracy(stacked_labels, predictions['classes'])
      }

    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        loss=loss,
        train_op=train_op,
        training_hooks=train_hooks,
        eval_metric_ops=metrics)
    def _resnet_model_fn(features, labels, mode, params):
        """Resnet model body.

    Support single host, one or more GPU training. Parameter distribution can
    be either one of the following scheme.
    1. CPU is the parameter server and manages gradient updates.
    2. Parameters are distributed evenly across all GPUs, and the first GPU
       manages gradient updates.

    Args:
      features: a list of tensors, one for each tower
      labels: a list of tensors, one for each tower
      mode: ModeKeys.TRAIN or EVAL
      params: Hyperparameters suitable for tuning
    Returns:
      A EstimatorSpec object.
    """
        tf.logging.info(
            'check to see if model fn not called on every switch input fn!!!!!!!!!'
        )
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        weight_decay = params.weight_decay
        momentum = params.momentum

        tower_features = features
        tower_labels = labels
        tower_losses = []
        tower_gradvars = []
        tower_preds = []

        os.environ['tensor_for_variance'] = 'gradient_variance:0'
        os.environ['tensor_for_b_simple'] = 'b_simple:0'

        # channels first (NCHW) is normally optimal on GPU and channels last (NHWC)
        # on CPU. The exception is Intel MKL on CPU which is optimal with
        # channels_last.
        data_format = params.data_format
        if not data_format:
            if num_gpus == 0:
                data_format = 'channels_last'
            else:
                data_format = 'channels_first'

        if num_gpus == 0:
            num_devices = 1
            device_type = 'cpu'
        else:
            num_devices = num_gpus
            device_type = 'gpu'

        tf_config = json.loads(os.environ['TF_CONFIG'])
        batchlist = tf_config['batch_size_list']
        tasktype = tf_config['task']['type']
        index = tf_config['task']['index']
        w_name = tasktype + '-' + str(index)

        combined_batch_size = 0
        for batchsize in batchlist:
            combined_batch_size = combined_batch_size + batchsize

        consolidation_device = '/gpu:0' if variable_strategy == 'GPU' else '/cpu:0'
        with tf.device(consolidation_device):

            num_batches_per_epoch = cifar10.Cifar10DataSet.num_examples_per_epoch(
                'train') // (combined_batch_size)
            boundaries = [
                num_batches_per_epoch * x
                for x in np.array([82, 123, 300], dtype=np.int64)
            ]
            staged_lr = [
                params.learning_rate * x for x in [1, 0.1, 0.01, 0.002]
            ]

            learning_rate = tf.train.piecewise_constant(
                tf.train.get_global_step(), boundaries, staged_lr)

            #learning_rate = tf.constant(float(0.1))

            optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                                   momentum=momentum,
                                                   use_locking=True)
            train_hooks = []

            # if params.sync:
            #   optimizer = tf.train.SyncReplicasOptimizer(optimizer, replicas_to_aggregate=num_workers)
            #   sync_replicas_hook = optimizer.make_session_run_hook(params.is_chief, num_tokens=0)
            #   train_hooks.append(sync_replicas_hook)

        for i in range(num_devices):
            worker_device = '/{}:{}'.format(device_type, i)
            if variable_strategy == 'CPU':
                device_setter = cifar10_utils.local_device_setter(
                    worker_device=worker_device)
            elif variable_strategy == 'GPU':
                device_setter = cifar10_utils.local_device_setter(
                    worker_device=worker_device)

            with tf.variable_scope('resnet', reuse=bool(i != 0)):
                with tf.name_scope('tower_%d' % i) as name_scope:
                    with tf.device(device_setter):
                        loss, gradvars, preds, compgrad_op = _tower_fn(
                            is_training, weight_decay, tower_features[i],
                            tower_labels[i], data_format, params.num_layers,
                            params.batch_norm_decay, params.batch_norm_epsilon,
                            optimizer, gradient_scale)

                        tower_losses.append(loss)
                        tower_gradvars.append(gradvars)
                        tower_preds.append(preds)
                        if i == 0:
                            # Only trigger batch_norm moving mean and variance update from
                            # the 1st tower. Ideally, we should grab the updates from all
                            # towers but these stats accumulate extremely fast so we can
                            # ignore the other stats from the other towers without
                            # significant detriment.
                            update_ops = tf.get_collection(
                                tf.GraphKeys.UPDATE_OPS, name_scope)

        # Now compute global loss and gradients.
        gradvars = []
        with tf.name_scope('gradient_averaging'):
            all_grads = {}
            for grad, var in itertools.chain(*tower_gradvars):
                if grad is not None:
                    all_grads.setdefault(var, []).append(grad)
            for var, grads in six.iteritems(all_grads):
                # Average gradients on the same device as the variables
                # to which they apply.
                with tf.device(var.device):
                    if len(grads) == 1:
                        avg_grad = grads[0]
                    else:
                        avg_grad = tf.multiply(tf.add_n(grads),
                                               1. / len(grads))
                gradvars.append((avg_grad, var))

        # Device that runs the ops to apply global gradient updates.
        consolidation_device = '/gpu:0' if variable_strategy == 'GPU' else '/cpu:0'
        with tf.device(consolidation_device):
            # Create single grouped train op
            loss = tf.reduce_mean(tower_losses, name='loss')

            examples_sec_hook = cifar10_utils.ExamplesPerSecondHook(
                params.train_batch_size, every_n_steps=10)

            train_hooks.append(examples_sec_hook)

            train_op = [
                optimizer.apply_gradients(
                    gradvars, global_step=tf.train.get_global_step())
            ]

            train_op.extend(update_ops)
            train_op = tf.group(*train_op)
            compgrad_op = tf.group(*compgrad_op)

            predictions = {
                'classes':
                tf.concat([p['classes'] for p in tower_preds], axis=0),
                'probabilities':
                tf.concat([p['probabilities'] for p in tower_preds], axis=0)
            }
            stacked_labels = tf.concat(labels, axis=0)
            accuracy = tf.metrics.accuracy(stacked_labels,
                                           predictions['classes'])
            metrics = {'accuracy': accuracy}

            tensors_to_log = {
                'worker_train_accuracy': accuracy[1],
                'learning_rate': learning_rate,
                'loss': loss,
                'global_step': tf.train.get_global_step()
            }

            logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log,
                                                      every_n_iter=50)

            train_hooks.append(logging_hook)

        # SYNCHRONOUS TRAINING SETTINGS
        # return tf.estimator.EstimatorSpec(
        #     mode=mode,
        #     predictions=predictions,
        #     loss=loss,
        #     train_op=train_op,
        #     reactive_adjustment_threshold=0.03,
        #     namescope='gradients',
        #     window_size=20,
        #     sync_mode='BSP',
        #     staleness=5,
        #     adjustment_mode='exponential_smoothing',
        #     training_hooks=train_hooks,
        #     eval_metric_ops=metrics)

        # ASYNCHRONOUS TRAINING
        # asp_adjust_strategy can be 'staleness' or 'iteration_time'
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            loss=loss,
            train_op=train_op,
            reactive_adjustment_threshold=0.08,
            namescope='gradients',
            window_size=250,
            sync_mode='ASP',
            staleness=1000,
            mini_batchsize_threshold=16,
            global_batch_size_value=512,
            asp_adjust_strategy='staleness',
            adjustment_mode='exponential_smoothing',
            training_hooks=train_hooks,
            eval_metric_ops=metrics)
Example #5
0
    def _resnet_model_fn(features, labels, mode, params):
        """Resnet model body.

        Support single host, one or more GPU training. Parameter distribution can
        be either one of the following scheme.
        1. CPU is the parameter server and manages gradient updates.
        2. Parameters are distributed evenly across all GPUs, and the first GPU
        manages gradient updates.

        Args:
        features: a list of tensors, one for each tower
        labels: a list of tensors, one for each tower
        mode: ModeKeys.TRAIN or EVAL
        params: Hyperparameters suitable for tuning
        Returns:
        A EstimatorSpec object.
        """
        is_training = mode == tf.estimator.ModeKeys.TRAIN
        weight_decay = params.weight_decay
        momentum = params.momentum

        tower_features = features
        tower_labels = labels
        tower_losses = []
        tower_gradvars = []
        tower_preds = []

        # channels first (NCHW) is normally optimal on GPU and channels last (NHWC)
        # on CPU. The exception is Intel MKL on CPU which is optimal with
        # channels_last.
        data_format = params.data_format
        if not data_format:
            if num_gpus == 0:
                data_format = "channels_last"
            else:
                data_format = "channels_first"

        if num_gpus == 0:
            num_devices = 1
            device_type = "cpu"
        else:
            num_devices = num_gpus
            device_type = "gpu"

        for i in range(num_devices):
            worker_device = "/{}:{}".format(device_type, i)
            if variable_strategy == "CPU":
                device_setter = cifar10_utils.local_device_setter(
                    worker_device=worker_device)
            elif variable_strategy == "GPU":
                device_setter = cifar10_utils.local_device_setter(
                    ps_device_type="gpu",
                    worker_device=worker_device,
                    ps_strategy=tf.contrib.training.
                    GreedyLoadBalancingStrategy(
                        num_gpus, tf.contrib.training.byte_size_load_fn),
                )
            with tf.variable_scope("resnet", reuse=bool(i != 0)):
                with tf.name_scope("tower_%d" % i) as name_scope:
                    with tf.device(device_setter):
                        loss, gradvars, preds = _tower_fn(
                            is_training,
                            weight_decay,
                            tower_features[i],
                            tower_labels[i],
                            data_format,
                            params.num_layers,
                            params.batch_norm_decay,
                            params.batch_norm_epsilon,
                        )
                        tower_losses.append(loss)
                        tower_gradvars.append(gradvars)
                        tower_preds.append(preds)
                        if i == 0:
                            # Only trigger batch_norm moving mean and variance update from
                            # the 1st tower. Ideally, we should grab the updates from all
                            # towers but these stats accumulate extremely fast so we can
                            # ignore the other stats from the other towers without
                            # significant detriment.
                            update_ops = tf.get_collection(
                                tf.GraphKeys.UPDATE_OPS, name_scope)

        # Now compute global loss and gradients.
        gradvars = []
        with tf.name_scope("gradient_averaging"):
            all_grads = {}
            for grad, var in itertools.chain(*tower_gradvars):
                if grad is not None:
                    all_grads.setdefault(var, []).append(grad)
            for var, grads in six.iteritems(all_grads):
                # Average gradients on the same device as the variables
                # to which they apply.
                with tf.device(var.device):
                    if len(grads) == 1:
                        avg_grad = grads[0]
                    else:
                        avg_grad = tf.multiply(tf.add_n(grads),
                                               1.0 / len(grads))
                gradvars.append((avg_grad, var))

        # Device that runs the ops to apply global gradient updates.
        consolidation_device = ("/gpu:0"
                                if variable_strategy == "GPU" else "/cpu:0")
        with tf.device(consolidation_device):
            # Suggested learning rate scheduling from
            # https://github.com/ppwwyyxx/tensorpack/blob/master/examples/ResNet/cifar10-resnet.py#L155
            num_batches_per_epoch = cifar10.Cifar10DataSet.num_examples_per_epoch(
                "train") // (params.train_batch_size * num_workers)
            boundaries = [
                num_batches_per_epoch * x
                for x in np.array([80, 120, 160], dtype=np.int64)
            ]
            staged_lr = [
                params.learning_rate * x for x in [1, 0.1, 0.01, 0.001]
            ]

            learning_rate = tf.train.piecewise_constant(
                tf.train.get_global_step(), boundaries, staged_lr)

            loss = tf.reduce_mean(tower_losses, name="loss")

            # examples_sec_hook = cifar10_utils.ExamplesPerSecondHook(
            #     params.train_batch_size, every_n_steps=10
            # )

            # tensors_to_log = {"learning_rate": learning_rate, "loss": loss}

            # logging_hook = tf.train.LoggingTensorHook(
            #     tensors=tensors_to_log, every_n_iter=100
            # )

            # train_hooks = [logging_hook, examples_sec_hook]
            train_hooks = []

            # Hyper-parameter "momentum" is only used for the Momentum Optimizer
            # Other optimizers use their default parameters.
            if params.optimizer == "momentum":
                optimizer = tf.train.MomentumOptimizer(
                    learning_rate=learning_rate, momentum=momentum)
            elif params.optimizer == "adam":
                optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
            elif params.optimizer == "adagrad":
                optimizer = tf.train.AdagradOptimizer(
                    learning_rate=learning_rate)
            elif params.optimizer == "adadelta":
                optimizer = tf.train.AdadeltaOptimizer(
                    learning_rate=learning_rate)
            elif params.optimizer == "sgd":
                optimizer = tf.train.GradientDescentOptimizer(
                    learning_rate=learning_rate)
            elif params.optimizer == "rmsprop":
                optimizer = tf.train.RMSPropOptimizer(
                    learning_rate=learning_rate)
            else:
                raise ValueError("unrecognized optimizer name")
            # TODO: RAdam is implemented in tensorflow-addons v0.6, which requires tf 2.0
            #       Upgrade code by removing tf.contrib modules.
            # optimizer = tfa.optimizers.RectifiedAdam(lr=learning_rate)

            if params.sync:
                optimizer = tf.train.SyncReplicasOptimizer(
                    optimizer, replicas_to_aggregate=num_workers)
                sync_replicas_hook = optimizer.make_session_run_hook(
                    params.is_chief)
                train_hooks.append(sync_replicas_hook)

            # Create single grouped train op
            train_op = [
                optimizer.apply_gradients(
                    gradvars, global_step=tf.train.get_global_step())
            ]
            train_op.extend(update_ops)
            train_op = tf.group(*train_op)

            predictions = {
                "classes":
                tf.concat([p["classes"] for p in tower_preds], axis=0),
                "probabilities":
                tf.concat([p["probabilities"] for p in tower_preds], axis=0),
            }
            stacked_labels = tf.concat(labels, axis=0)
            metrics = {
                "accuracy":
                tf.metrics.accuracy(stacked_labels, predictions["classes"])
            }

        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            loss=loss,
            train_op=train_op,
            training_hooks=train_hooks,
            eval_metric_ops=metrics,
        )
      if num_gpus == 0:
        data_format = 'channels_last'
      else:
        data_format = 'channels_first'

    if num_gpus == 0:
      num_devices = 1
      device_type = 'cpu'
    else:
      num_devices = num_gpus
      device_type = 'gpu'

    for i in range(num_devices):
      worker_device = '/{}:{}'.format(device_type, i)
      if variable_strategy == 'CPU':
        device_setter = cifar10_utils.local_device_setter(
            worker_device=worker_device)
      elif variable_strategy == 'GPU':
        device_setter = cifar10_utils.local_device_setter(
            ps_device_type='gpu',
            worker_device=worker_device,
            ps_strategy=tf.contrib.training.GreedyLoadBalancingStrategy(
                num_gpus, tf.contrib.training.byte_size_load_fn))
      with tf.variable_scope('resnet', reuse=bool(i != 0)):
        with tf.name_scope('tower_%d' % i) as name_scope:
          with tf.device(device_setter):
            loss, gradvars, preds = _tower_fn(
                is_training, weight_decay, tower_features[i], tower_labels[i],
                data_format, params.num_layers, params.batch_norm_decay,
                params.batch_norm_epsilon)
            tower_losses.append(loss)
            tower_gradvars.append(gradvars)
    def _resnet_model_fn(features, labels, mode, params):
        """Resnet model body.

    Support single host, one or more GPU training. Parameter distribution can
    be either one of the following scheme.
    1. CPU is the parameter server and manages gradient updates.
    2. Parameters are distributed evenly across all GPUs, and the first GPU
       manages gradient updates.

    Args:
      features: a list of tensors, one for each tower
      labels: a list of tensors, one for each tower
      mode: ModeKeys.TRAIN or EVAL
      params: Hyperparameters suitable for tuning
    Returns:
      A EstimatorSpec object.
    """
        tf.set_random_seed(7)
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        weight_decay = params.weight_decay
        momentum = params.momentum

        tower_features = features
        tower_labels = labels
        tower_losses = []
        tower_gradvars = []
        tower_preds = []

        # channels first (NCHW) is normally optimal on GPU and channels last (NHWC)
        # on CPU. The exception is Intel MKL on CPU which is optimal with
        # channels_last.
        data_format = params.data_format
        if not data_format:
            if num_gpus == 0:
                data_format = 'channels_last'
            else:
                data_format = 'channels_first'

        if num_gpus == 0:
            num_devices = 1
            device_type = 'cpu'
        else:
            num_devices = num_gpus
            device_type = 'gpu'

        tf_config = json.loads(os.environ['TF_CONFIG'])
        batchlist = tf_config['batch_size_list']
        tasktype = tf_config['task']['type']
        index = tf_config['task']['index']
        w_name = tasktype + '-' + str(index)

        combined_batch_size = 0
        for batchsize in batchlist:
            combined_batch_size = combined_batch_size + batchsize

        consolidation_device = '/gpu:0' if variable_strategy == 'GPU' else '/cpu:0'
        with tf.device(consolidation_device):
            # tf.logging.info('!!!!!!!!! pre-condition consolidation device........')
            # num_batches_per_epoch = cifar10.Cifar10DataSet.num_examples_per_epoch(
            #     'train') // (params.train_batch_size * num_workers)

            num_batches_per_epoch = cifar10.Cifar10DataSet.num_examples_per_epoch(
                'train') // (combined_batch_size)
            boundaries = [
                num_batches_per_epoch * x
                for x in np.array([82, 123, 300], dtype=np.int64)
            ]
            staged_lr = [
                params.learning_rate * x for x in [1, 0.1, 0.01, 0.002]
            ]

            learning_rate = tf.train.piecewise_constant(
                tf.train.get_global_step(), boundaries, staged_lr)

            optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                                   momentum=momentum)
            train_hooks = []

        for i in range(num_devices):
            #tf.logging.info('!!!!! in tower, value of i is ' + str(i))
            worker_device = '/{}:{}'.format(device_type, i)
            if variable_strategy == 'CPU':
                device_setter = cifar10_utils.local_device_setter(
                    worker_device=worker_device)
            elif variable_strategy == 'GPU':
                device_setter = cifar10_utils.local_device_setter(
                    worker_device=worker_device)

            with tf.variable_scope('resnet', reuse=bool(i != 0)):
                with tf.name_scope('tower_%d' % i) as name_scope:
                    with tf.device(device_setter):
                        loss, gradvars, preds, compgrad_op = _tower_fn(
                            is_training,
                            weight_decay,
                            tower_features[i],
                            tower_labels[i],
                            data_format,
                            params.num_layers,
                            params.batch_norm_decay,
                            params.batch_norm_epsilon,
                            optimizer,
                            gradient_scale,
                            w_name=w_name)

                        tower_losses.append(loss)
                        tower_gradvars.append(gradvars)
                        tower_preds.append(preds)
                        if i == 0:
                            # Only trigger batch_norm moving mean and variance update from
                            # the 1st tower. Ideally, we should grab the updates from all
                            # towers but these stats accumulate extremely fast so we can
                            # ignore the other stats from the other towers without
                            # significant detriment.
                            update_ops = tf.get_collection(
                                tf.GraphKeys.UPDATE_OPS, name_scope)

        # Now compute global loss and gradients.
        gradvars = []
        with tf.name_scope('gradient_averaging'):
            all_grads = {}
            for grad, var in itertools.chain(*tower_gradvars):
                if grad is not None:
                    all_grads.setdefault(var, []).append(grad)
            for var, grads in six.iteritems(all_grads):
                # Average gradients on the same device as the variables
                # to which they apply.
                with tf.device(var.device):
                    if len(grads) == 1:
                        #tf.logging.info('..............the length of grads is just 1!!!!')
                        avg_grad = grads[0]
                    else:
                        #tf.logging.info('#################### inside the condition for going to average gradients...')
                        avg_grad = tf.multiply(tf.add_n(grads),
                                               1. / len(grads))
                #tf.logging.info('!!!!!!!gradient averaged and appended to gradvars')
                gradvars.append((avg_grad, var))

        # Device that runs the ops to apply global gradient updates.
        consolidation_device = '/gpu:0' if variable_strategy == 'GPU' else '/cpu:0'
        with tf.device(consolidation_device):
            # Create single grouped train op
            loss = tf.reduce_mean(tower_losses, name='loss')

            examples_sec_hook = cifar10_utils.ExamplesPerSecondHook(
                params.train_batch_size, every_n_steps=10)

            tensors_to_log = {'learning_rate': learning_rate, 'loss': loss}

            logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log,
                                                      every_n_iter=100)

            train_hooks.append(logging_hook)
            train_hooks.append(examples_sec_hook)
            train_op = [
                optimizer.apply_gradients(
                    gradvars, global_step=tf.train.get_global_step())
            ]

            train_op.extend(update_ops)
            train_op = tf.group(*train_op)
            compgrad_op = tf.group(*compgrad_op)

            predictions = {
                'classes':
                tf.concat([p['classes'] for p in tower_preds], axis=0),
                'probabilities':
                tf.concat([p['probabilities'] for p in tower_preds], axis=0)
            }
            stacked_labels = tf.concat(labels, axis=0)
            metrics = {
                'accuracy':
                tf.metrics.accuracy(stacked_labels, predictions['classes'])
            }

        return tf.estimator.EstimatorSpec(mode=mode,
                                          predictions=predictions,
                                          loss=loss,
                                          train_op=train_op,
                                          namescope='gradients',
                                          window_size=None,
                                          sync_mode=None,
                                          training_hooks=train_hooks,
                                          eval_metric_ops=metrics)