def model_fn(features, labels, mode, params):
  """Defines how to train, evaluate and predict from the transformer model."""
  with tf.variable_scope("model"):
    inputs, targets = features, labels

    # Create model and get output logits.
    model = transformer.Transformer(params, mode == tf.estimator.ModeKeys.TRAIN)

    output = model(inputs, targets)

    # When in prediction mode, the labels/targets is None. The model output
    # is the prediction
    if mode == tf.estimator.ModeKeys.PREDICT:
      return tf.estimator.EstimatorSpec(
          tf.estimator.ModeKeys.PREDICT,
          predictions=output)

    logits = output

    # Calculate model loss.
    xentropy, weights = metrics.padded_cross_entropy_loss(
        logits, targets, params.label_smoothing, params.vocab_size)
    loss = tf.reduce_sum(xentropy * weights) / tf.reduce_sum(weights)

    if mode == tf.estimator.ModeKeys.EVAL:
      return tf.estimator.EstimatorSpec(
          mode=mode, loss=loss, predictions={"predictions": logits},
          eval_metric_ops=metrics.get_eval_metrics(logits, labels, params))
    else:
      train_op = get_train_op(loss, params)
      return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
Esempio n. 2
0
def model_fn(features, labels, mode, params):
    """Defines how to train, evaluate and predict from the transformer model."""
    with tf.variable_scope("model"):
        inputs, targets = features, labels

        # Create model and get output logits.
        model = transformer.Transformer(params,
                                        mode == tf.estimator.ModeKeys.TRAIN)

        output = model(inputs, targets)

        # When in prediction mode, the labels/targets is None. The model output
        # is the prediction
        if mode == tf.estimator.ModeKeys.PREDICT:
            return tf.estimator.EstimatorSpec(tf.estimator.ModeKeys.PREDICT,
                                              predictions=output)

        logits = output

        # Calculate model loss.
        # xentropy contains the cross entropy loss of every nonpadding token in the
        # targets.
        xentropy, weights = metrics.padded_cross_entropy_loss(
            logits, targets, params.label_smoothing, params.vocab_size)
        # Compute the weighted mean of the cross entropy losses
        loss = tf.reduce_sum(xentropy) / tf.reduce_sum(weights)

        # Save loss as named tensor that will be logged with the logging hook.
        tf.identity(loss, "cross_entropy")

        if mode == tf.estimator.ModeKeys.EVAL:
            return tf.estimator.EstimatorSpec(
                mode=mode,
                loss=loss,
                predictions={"predictions": logits},
                eval_metric_ops=metrics.get_eval_metrics(
                    logits, labels, params))
        else:
            train_op = get_train_op(loss, params)
            return tf.estimator.EstimatorSpec(mode=mode,
                                              loss=loss,
                                              train_op=train_op)
Esempio n. 3
0
def transformer_encoder(input, lengths):
    # Set up estimator and params
    params = model_params.BASE_PARAMS
    params["default_batch_size"] = K
    params["max_length"] = 500
    params["vocab_size"] = VOCABULARY_SIZE + 1
    params["filter_size"] = 256
    params["num_hidden_layers"] = 2
    params["num_heads"] = 2
    params["hidden_size"] = EMBEDDING_SIZE

    model = transformer.Transformer(params, tf.estimator.ModeKeys.TRAIN)
    initializer = tf.variance_scaling_initializer(
        model.params["initializer_gain"],
        mode="fan_avg",
        distribution="uniform")
    with tf.variable_scope("Transformer", initializer=initializer):
        # Calculate attention bias for encoder self-attention and decoder
        # multi-headed attention layers.
        attention_bias = model_utils.get_padding_bias(inputs)
        # Run the inputs through the encoder layer to map the symbol
        # representations to continuous representations.
        encoder = model.encode(inputs, attention_bias)
        return tf.reduce_mean(encoder, 1)
Esempio n. 4
0
def model_fn(features, labels, mode, params):
    """Defines how to train, evaluate and predict from the transformer model."""
    with tf.variable_scope("model"):
        inputs, targets = features, labels

        # Create model and get output logits.
        model = transformer.Transformer(params,
                                        mode == tf.estimator.ModeKeys.TRAIN)

        logits = model(inputs, targets)

        # When in prediction mode, the labels/targets is None. The model output
        # is the prediction
        if mode == tf.estimator.ModeKeys.PREDICT:
            if params["use_tpu"]:
                raise NotImplementedError(
                    "Prediction is not yet supported on TPUs.")
            return tf.estimator.EstimatorSpec(
                tf.estimator.ModeKeys.PREDICT,
                predictions=logits,
                export_outputs={
                    "translate": tf.estimator.export.PredictOutput(logits)
                })

        # Explicitly set the shape of the logits for XLA (TPU). This is needed
        # because the logits are passed back to the host VM CPU for metric
        # evaluation, and the shape of [?, ?, vocab_size] is too vague. However
        # it is known from Transformer that the first two dimensions of logits
        # are the dimensions of targets. Note that the ambiguous shape of logits is
        # not a problem when computing xentropy, because padded_cross_entropy_loss
        # resolves the shape on the TPU.
        logits.set_shape(targets.shape.as_list() + logits.shape.as_list()[2:])

        # Calculate model loss.
        # xentropy contains the cross entropy loss of every nonpadding token in the
        # targets.
        xentropy, weights = metrics.padded_cross_entropy_loss(
            logits, targets, params["label_smoothing"], params["vocab_size"])
        loss = tf.reduce_sum(xentropy) / tf.reduce_sum(weights)

        # Save loss as named tensor that will be logged with the logging hook.
        tf.identity(loss, "cross_entropy")

        if mode == tf.estimator.ModeKeys.EVAL:
            if params["use_tpu"]:
                # host call functions should only have tensors as arguments.
                # This lambda pre-populates params so that metric_fn is
                # TPUEstimator compliant.
                metric_fn = lambda logits, labels: (metrics.get_eval_metrics(
                    logits, labels, params=params))
                eval_metrics = (metric_fn, [logits, labels])
                return tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    predictions={"predictions": logits},
                    eval_metrics=eval_metrics)
            return tf.estimator.EstimatorSpec(
                mode=mode,
                loss=loss,
                predictions={"predictions": logits},
                eval_metric_ops=metrics.get_eval_metrics(
                    logits, labels, params))
        else:
            train_op, metric_dict = get_train_op_and_metrics(loss, params)

            # Epochs can be quite long. This gives some intermediate information
            # in TensorBoard.
            metric_dict["minibatch_loss"] = loss
            if params["use_tpu"]:
                return tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    host_call=tpu_util.construct_scalar_host_call(
                        metric_dict=metric_dict,
                        model_dir=params["model_dir"],
                        prefix="training/"))
            record_scalars(metric_dict)
            return tf.estimator.EstimatorSpec(mode=mode,
                                              loss=loss,
                                              train_op=train_op)
Esempio n. 5
0
    def model_fn(features, labels, mode, params):
        """Defines how to train, evaluate and predict from the transformer model."""
        num_devices = flags_core.get_num_gpus(flags_obj)
        consolidation_device = 'gpu:0'
        #    feature_shards, label_shards = replicate_model_fn._split_batch(features, labels, num_devices, device=consolidation_device)

        tower_losses = []
        tower_gradvars = []
        tower_preds = []
        for i in range(num_devices):
            worker_device = '/{}:{}'.format('gpu', i)
            device_setter = local_device_setter(
                ps_device_type='gpu',
                worker_device=worker_device,
                ps_strategy=tf.contrib.training.GreedyLoadBalancingStrategy(
                    num_devices, tf.contrib.training.byte_size_load_fn))
            with tf.variable_scope('model', reuse=bool(i != 0)):
                with tf.name_scope('tower_%d' % i) as name_scope:
                    with tf.device(device_setter):
                        # Create model and get output logits.
                        model = transformer.Transformer(
                            params, mode == tf.estimator.ModeKeys.TRAIN)
                        #logits = model(features, labels)
                        loss, gradvars, preds = _tower_fn(model,
                                                          features,
                                                          labels,
                                                          params=params)
                        tower_losses.append(loss)
                        tower_gradvars.append(gradvars)
                        tower_preds.append(preds)

        # Compute global loss and gradients
        gradvars = []
        with tf.name_scope('gradient_averaging'):
            all_grads = {}
            for grad, var in itertools.chain(*tower_gradvars):
                if grad is not None:
                    all_grads.setdefault(var, []).append(grad)
            for var, grads in six.iteritems(all_grads):
                with tf.device(var.device):
                    if len(grads) == 1:
                        avg_grad = grads[0]
                    else:
                        #            for a in range(len(grads)):
                        #              if len(grads[a]) > 1:
                        #                avg_grad = tf.multiply(tf.add_n(grads[a]), 1. / len(grads[a]))
                        #                gradvars.append((avg_grad, var))
                        avg_grad = tf.multiply(tf.add_n(grads),
                                               1. / len(grads))


#          print("AVG_GRAD: ", avg_grad, "VAR: ", var)
                    gradvars.append((avg_grad, var))

        with tf.device(consolidation_device):
            loss = tf.reduce_mean(tower_losses, name='loss')
            tf.identity(loss, "cross_entropy")
            logits = tf.reduce_mean(tower_preds, axis=0)
            #      logits = tf.concat([l for l in tower_preds], axis=0)
            if mode == tf.estimator.ModeKeys.PREDICT:
                return tf.estimator.EstimatorSpec(
                    tf.estimator.ModeKeys.PREDICT,
                    predictions=logits,
                    export_outputs={
                        "translate": tf.estimator.export.PredictOutput(logits)
                    })

            if mode == tf.estimator.ModeKeys.TRAIN:
                with tf.variable_scope("get_train_op"):
                    print("in get_train_op")
                    learning_rate = get_learning_rate(
                        learning_rate=params["learning_rate"],
                        hidden_size=params["hidden_size"],
                        learning_rate_warmup_steps=params[
                            "learning_rate_warmup_steps"])
                    optimizer = tf.contrib.opt.LazyAdamOptimizer(
                        learning_rate,
                        beta1=params["optimizer_adam_beta1"],
                        beta2=params["optimizer_adam_beta2"],
                        epsilon=params["optimizer_adam_epsilon"])
                    optimizer = tf.train.SyncReplicasOptimizer(
                        optimizer, replicas_to_aggregate=num_devices)
                    sync_hook = optimizer.make_session_run_hook(is_chief)
                    #          update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

                    global_step = tf.train.get_global_step()
                    update_ops = tf.assign(global_step,
                                           global_step + 1,
                                           name='update_global_step')
                    minimize_op = optimizer.apply_gradients(
                        gradvars, global_step=tf.train.get_global_step())
                    train_op = tf.group(minimize_op, update_ops)
                    #train_op = [optimizer.apply_gradients(gradvars, global_step=tf.train.get_global_step())]
                    metric_dict = {"learning_rate": learning_rate}
                    metric_dict["minibatch_loss"] = loss
                    record_scalars(metric_dict)
                    return tf.estimator.EstimatorSpec(
                        mode=mode,
                        loss=loss,
                        training_hooks=[sync_hook],
                        train_op=train_op)
            elif mode == tf.estimator.ModeKeys.EVAL:
                return tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=loss,
                    predictions={"predictions": logits},
                    eval_metric_ops=metrics.get_eval_metrics(
                        logits, labels, params))
    def model_fn(features, labels, mode, params):
        """Defines how to train, evaluate and predict from the transformer model."""
        num_gpus = flags_core.get_num_gpus(flags_obj)
        print("num_gpus: ", num_gpus)
        #    num_gpus=params["num_gpus"]

        learning_rate = get_learning_rate(
            learning_rate=params["learning_rate"],
            hidden_size=params["hidden_size"],
            learning_rate_warmup_steps=params["learning_rate_warmup_steps"])
        optimizers = [
            tf.contrib.opt.LazyAdamOptimizer(
                learning_rate,
                beta1=params["optimizer_adam_beta1"],
                beta2=params["optimizer_adam_beta2"],
                epsilon=params["optimizer_adam_epsilon"])
            for _ in range(num_gpus)
        ]

        if params["dtype"] == "fp16":
            optimizers = [
                tf.train.experimental.enable_mixed_precision_graph_rewrite(
                    optimizer) for optimizer in optimizers
            ]

#    feature_shards, label_shards = replicate_model_fn._split_batch(features, labels, num_gpus, device=consolidation_device)
#    feature_shards, label_shards = split_batch(features, labels, num_gpus)

        model = transformer.Transformer(params,
                                        mode == tf.estimator.ModeKeys.TRAIN)
        grad_list = []
        losses = []
        logits = []
        for gpu_idx in range(num_gpus):
            device_setter = local_device_setter(
                ps_device_type='cpu', worker_device='/gpu:{}'.format(gpu_idx))
            with tf.device(device_setter):
                #      with tf.device(tf.DeviceSpec(device_type="GPU", device_index=gpu_idx)), tf.variable_scope('tower%d'%gpu_idx):
                #with tf.device(tf.compat.v1.train.replica_device_setter(cluster=cluster_spec)):
                logit, loss = create_tower_network(model, params, features,
                                                   labels)
                #        feature_shard, label_shard = next(iterator)
                #        logit, loss = create_tower_network(model, params, features, labels)
                logits.append(logit)
                losses.append(loss)
                grad_list.append([
                    x for x in optimizers[gpu_idx].compute_gradients(loss)
                    if x[0] is not None
                ])

#    output_train = tf.concat(logits, axis=0)
        output_train = tf.reduce_mean(logits, axis=0)
        loss_train = tf.reduce_mean(losses, name='loss')

        #    grads = []
        #    all_vars= []
        sparse_grads = []
        sparse_vars = []
        dense_grads = []
        dense_vars = []
        for tower in grad_list:
            sp_grad = []
            sp_var = []
            dn_grad = []
            dn_var = []
            for x in tower:
                if isinstance(x[1], ops.IndexedSlices):
                    sp_grad.append(x[0])
                    sp_var.append(x[1])
                else:
                    dn_grad.append(x[0])
                    dn_var.append(x[1])

            if (len(sp_var) > 0):
                sparse_grads.append(sp_grad)
                sparse_vars.append(sp_var)
            if (len(dn_var) > 0):
                dense_grads.append(dn_grad)
                dense_vars.append(dn_var)

        #SPARSE


#    for var, grad in zip(sparse_vars, sparse_grads):
#      if len(grad) == 1:
#        avg_grad = grad
#      else:
#        avg_grad = tf.multiply(tf.add_n(grad), 1. /len(grad))
#      gradvars.append((avg_grad, var))
        if len(sparse_vars) > 0:
            if num_gpus == 1:
                reduced_grad = sparse_grads
            else:
                new_all_grads = []
                for grad in sparse_grads:
                    new_grads = []
                    for tower_grad in grad:
                        new_grads.append(tower_grad)
                    summed = tf.add_n(new_grads)
                    grads_for_devices = []
                    for g in summed:
                        with tf.device(g.device):
                            g = tf.multiply(g,
                                            1.0 / num_gpus,
                                            name='allreduce_avg')
                        grads_for_devices.append(g)
                    new_all_grads.append(grads_for_devices)
                reduced_grad = list(zip(*new_all_grads))
            gradvars = [
                list(zip(gs, vs)) for gs, vs in zip(reduced_grad, sparse_vars)
            ]

        #DENSE
        reduced_grad = []
        from tensorflow.python.ops import nccl_ops
        if num_gpus == 1:
            reduced_grad = dense_grads
        else:
            new_all_grads = []
            for grad in dense_grads:
                summed = nccl_ops.all_sum(grad)
                grads_for_devices = []
                for g in summed:
                    with tf.device(g.device):
                        g = tf.multiply(g,
                                        1.0 / num_gpus,
                                        name='allreduce_avg')
                    grads_for_devices.append(g)
                new_all_grads.append(grads_for_devices)
            reduced_grad = list(zip(*new_all_grads))

        grads = [list(zip(gs, vs)) for gs, vs in zip(reduced_grad, dense_vars)]

        #apply gradients to each GPU by broadcasting summed gradient
        train_ops = []
        for idx, grad_and_vars in enumerate(grads):
            with tf.name_scope('apply_gradients'), tf.device(
                    tf.DeviceSpec(device_type="GPU", device_index=idx)):
                global_step = tf.train.get_global_step()
                update_ops = tf.assign(global_step,
                                       global_step + 1,
                                       name='update_global_step')
                #update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='tower%d'%idx)
                #with tf.control_dependencies(update_ops):
                train_ops.append(optimizers[idx].apply_gradients(
                    grad_and_vars, name='apply_grad_{}'.format(idx)))

                #SPARSE
                if device_index == 0 and len(sparse_vars) > 0:
                    learning_rate = get_learning_rate(
                        learning_rate=params["learning_rate"],
                        hidden_size=params["hidden_size"],
                        learning_rate_warmup_steps=params[
                            "learning_rate_warmup_steps"])
                    optimizer = tf.contrib.opt.LazyAdamOptimizer(
                        learning_rate,
                        beta1=params["optimizer_adam_beta1"],
                        beta2=params["optimizer_adam_beta2"],
                        epsilon=params["optimizer_adam_epsilon"])
                    optimizer = tf.train.SyncReplicasOptimizer(
                        optimizer, replicas_to_aggregate=num_devices)
                    sync_hook = optimizer.make_session_run_hook(is_chief)

                    minimize_op = optimizer.apply_gradients(
                        gradvars, global_step=tf.train.get_global_step())
                    train_ops.append(minimize_op)

        optimize_op = tf.group(update_ops, *train_ops, name='train_op')
        train_metrics = {"learning_rate": learning_rate}

        tf.identity(loss_train, "cross_entropy")

        if mode == tf.estimator.ModeKeys.TRAIN:
            return tf.estimator.EstimatorSpec(mode=mode,
                                              loss=loss_train,
                                              train_op=optimize_op)
        if mode == tf.estimator.ModeKeys.EVAL:
            return tf.estimator.EstimatorSpec(
                mode=mode,
                loss=loss_train,
                predictions={"predictions": output_train},
                eval_metric_ops=metrics.get_eval_metrics(
                    output_train, labels, params))
        if mode == tf.estimator.ModeKeys.PREDICT:
            return tf.estimator.EstimatorSpec(
                mode=mode,
                predictions=output_train,
                export_outputs={
                    "translate":
                    tf.estimator.export.PredictOutput(output_train)
                })
Esempio n. 7
0
def model_fn(features, labels, mode, params):
  """Defines how to train, evaluate and predict from the transformer model."""
  #tf.set_random_seed(1367)
  with tf.variable_scope("model"):
    inputs, targets = features, labels
    concrete_loss = tf.constant(0)
    total_loss = tf.constant(0)
    concrete_reg = tf.constant(0)
    sparsity_rate = tf.constant(0)
    gate_values = tf.constant(0)
    # =================== For concrete gates ==================================
    print("**** concrete heads has this : {} ****".format(params["concrete_heads"]))
    if not params["concrete_coef"] == 0:
        tf.get_default_graph().clear_collection("CONCRETE")
        tf.get_default_graph().clear_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    # =========================================================================

    # Create model and get output logits.
    model = transformer.Transformer(params, mode == tf.estimator.ModeKeys.TRAIN)

    logits = model(inputs, targets)
    #print('logits')
    #print(len(logits))

    # When in prediction mode, the labels/targets is None. The model output
    # is the prediction
    if mode == tf.estimator.ModeKeys.PREDICT:
      if params["use_tpu"]:
        raise NotImplementedError("Prediction is not yet supported on TPUs.")
      print ("Logits", logits)
      #print (logits["attn_weights"], tf.transpose(tf.stack(logits["attn_weights"]).get_shape(), perm=[1,0,2,3,4]))
      return tf.estimator.EstimatorSpec(
          tf.estimator.ModeKeys.PREDICT,
          predictions={"outputs": logits["outputs"], "scores": logits["scores"]})
          #export_outputs={
          #    "translate": tf.estimator.export.PredictOutput(logits["outputs"])
          #})

    # Explicitly set the shape of the logits for XLA (TPU). This is needed
    # because the logits are passed back to the host VM CPU for metric
    # evaluation, and the shape of [?, ?, vocab_size] is too vague. However
    # it is known from Transformer that the first two dimensions of logits
    # are the dimensions of targets. Note that the ambiguous shape of logits is
    # not a problem when computing xentropy, because padded_cross_entropy_loss
    # resolves the shape on the TPU.
    logits.set_shape(targets.shape.as_list() + logits.shape.as_list()[2:])

    # Calculate model loss.
    # xentropy contains the cross entropy loss of every nonpadding token in the
    # targets.
    xentropy, weights = metrics.padded_cross_entropy_loss(
        logits, targets, params["label_smoothing"], params["vocab_size"])
    loss = tf.reduce_sum(xentropy) / tf.reduce_sum(weights)

    # Save loss as named tensor that will be logged with the logging hook.
    tf.identity(loss, "cross_entropy")

    # ============ Loss for concrete gates =================
    if not params["concrete_coef"] == 0:
        concrete_coef = params["concrete_coef"]
        sparsity_rate = tf.reduce_mean(tf.get_collection("CONCRETE"))
        concrete_reg = tf.reduce_mean(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        concrete_loss = concrete_coef * tf.reduce_mean(concrete_reg)
        
        total_loss = loss + concrete_loss

        gate_values = tf.get_collection("GATEVALUES")
        
        tf.identity(concrete_loss, "concrete_loss")
        tf.identity(total_loss, "total_loss")
        tf.identity(concrete_reg, "concrete_reg")
        tf.identity(sparsity_rate, "sparsity_rate")
        tf.identity(gate_values, "gate_values")
        loss = total_loss
    else:
        tf.identity(concrete_loss, "concrete_loss")
        tf.identity(total_loss, "total_loss")
        tf.identity(concrete_reg, "concrete_reg")
        tf.identity(sparsity_rate, "sparsity_rate")
        tf.identity(gate_values, "gate_values")
    # =======================================================
    if mode == tf.estimator.ModeKeys.EVAL:
      if params["use_tpu"]:
        # host call functions should only have tensors as arguments.
        # This lambda pre-populates params so that metric_fn is
        # TPUEstimator compliant.
        metric_fn = lambda logits, labels: (
            metrics.get_eval_metrics(logits, labels, params=params))
        eval_metrics = (metric_fn, [logits, labels])
        return tf.contrib.tpu.TPUEstimatorSpec(
            mode=mode, loss=loss, predictions={"predictions": logits},
            eval_metrics=eval_metrics)
      return tf.estimator.EstimatorSpec(
          mode=mode, loss=loss, predictions={"predictions": logits},
          eval_metric_ops=metrics.get_eval_metrics(logits, labels, params))
    else:
      train_op, metric_dict = get_train_op_and_metrics(loss, params)

      # Epochs can be quite long. This gives some intermediate information
      # in TensorBoard.
      metric_dict["minibatch_loss"] = loss
      if params["use_tpu"]:
        return tf.contrib.tpu.TPUEstimatorSpec(
            mode=mode, loss=loss, train_op=train_op,
            host_call=tpu_util.construct_scalar_host_call(
                metric_dict=metric_dict, model_dir=params["model_dir"],
                prefix="training/")
        )
      record_scalars(metric_dict)
      return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
  def model_fn(features, labels, mode, params):
    """Defines how to train, evaluate and predict from the transformer model."""  
    cluster_spec = cluster.as_dict()
#    num_gpus=len(cluster_spec["worker"])
    num_gpus=2 
    learning_rate = get_learning_rate(learning_rate=params["learning_rate"], hidden_size=params["hidden_size"], learning_rate_warmup_steps=params["learning_rate_warmup_steps"])
    optimizers = [tf.contrib.opt.LazyAdamOptimizer(learning_rate, beta1=params["optimizer_adam_beta1"], beta2=params["optimizer_adam_beta2"], epsilon=params["optimizer_adam_epsilon"]) for _ in range(num_gpus)]

    if params["dtype"] == "fp16":
      optimizers = [tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) for optimizer in optimizers]

    model = transformer.Transformer(params, mode == tf.estimator.ModeKeys.TRAIN)
    grad_list= []
    losses = []
    logits = []
    for gpu_idx in range(num_gpus):
#      device_setter = local_device_setter(cluster, worker_device="/job:worker/task:%d" % gpu_idx)
      device_setter = local_device_setter(cluster, worker_device="gpu:%d" % gpu_idx)
      with tf.device(device_setter):
#      with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % gpu_idx, cluster=cluster)):
#      with tf.device(tf.DeviceSpec(device_type="GPU", device_index=gpu_idx)), tf.variable_scope('tower%d'%gpu_idx):
#with tf.device(tf.compat.v1.train.replica_device_setter(cluster=cluster_spec)):
        logit, loss = create_tower_network(model, params, features, labels)
#        feature_shard, label_shard = next(iterator)
#        logit, loss = create_tower_network(model, params, features, labels)
        logits.append(logit)
        losses.append(loss)
        grad_list.append([x for x in optimizers[gpu_idx].compute_gradients(loss) if x[0] is not None])

#    output_train = tf.concat(logits, axis=0)
    output_train = tf.reduce_mean(logits, axis=0)
    loss_train = tf.reduce_mean(losses, name='loss')
   
    '''
    grads = []
    all_vars= []
    for tower in grad_list:
      grads.append([x[0] for x in tower])
      all_vars.append([x[1] for x in tower])

    reduced_grad = []
    if num_gpus==1:
      reduced_grad = grads
    else:
      new_all_grads = []
      for grad in zip(*grads):
        summed = nccl_ops.all_sum(grad)
        grads_for_devices = []
        for g in summed:
          with tf.device(g.device):
            g = tf.multiply(g, 1.0 / num_gpus, name='allreduce_avg')
          grads_for_devices.append(g)
        new_all_grads.append(grads_for_devices)
      reduced_grad = list(zip(*new_all_grads))
    grads = [list(zip(gs, vs)) for gs, vs in zip(reduced_grad, all_vars)]
    '''
    from tensorflow.python.distribute import cross_device_utils
    grads = cross_device_utils.aggregate_gradients_using_nccl(grad_list)
    #apply gradients to each GPU by broadcasting summed gradient
    train_ops = []
    for idx, grad_and_vars in enumerate(grads):
      with tf.name_scope('apply_gradients'), tf.device(tf.DeviceSpec(device_type="GPU", device_index=idx)):
#        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='tower%d'%idx)
        global_step = tf.train.get_global_step()
        update_ops = tf.assign(global_step, global_step+1, name='update_global_step')
#        with tf.control_dependencies(update_ops):
        train_ops.append(optimizers[idx].apply_gradients(grad_and_vars, name='apply_grad_{}'.format(idx)))
    optimize_op = tf.group(update_ops, *train_ops, name='train_op')
    train_metrics = {"learning_rate": learning_rate}

    tf.identity(loss_train, "cross_entropy")

    if mode == tf.estimator.ModeKeys.TRAIN:
      return tf.estimator.EstimatorSpec(mode=mode, loss=loss_train, train_op=optimize_op)
    if mode == tf.estimator.ModeKeys.EVAL:
      return tf.estimator.EstimatorSpec(mode=mode, loss=loss_train, predictions={"predictions": output_train}, eval_metric_ops=metrics.get_eval_metrics(output_train, labels, params))
    if mode == tf.estimator.ModeKeys.PREDICT:
      return tf.estimator.EstimatorSpec(mode=mode, predictions=output_train, export_outputs={"translate": tf.estimator.export.PredictOutput(output_train)})