def compare_dense_emb_sok_with_tf(args):
    if (args.global_batch_size % args.local_gpu_num != 0):
        raise ValueError("global_batch_size: %d is not divisible by local_gpu_num: %d"
                        %(args.global_batch_size, args.local_gpu_num))
    if (args.global_batch_size % args.worker_num != 0):
        raise ValueError("global_batch_size: %d is not divisible by worker_num: %d"
                        %(args.global_batch_size, args.worker_num))

    if args.mixed_precision:
        policy = tf.keras.mixed_precision.Policy("mixed_float16")
        tf.keras.mixed_precision.set_global_policy(policy)

    #each worker generate different dataset
    if args.generate_new_datas:
        if args.use_hashtable:
            vocabulary_size = args.local_gpu_num * args.max_vocabulary_size_per_gpu * args.worker_num
        else:
            vocabulary_size = args.max_vocabulary_size_per_gpu

        worker_batch_size = args.global_batch_size // args.worker_num
        random_samples_local = utils.generate_random_samples(num_of_samples=worker_batch_size * args.iter_num,
                                                             vocabulary_size=vocabulary_size,
                                                             slot_num=args.slot_num,
                                                             max_nnz=args.nnz_per_slot,
                                                             use_sparse_mask=False)
        utils.save_to_file(r"./random_samples_" + str(args.task_id) + r".file", *random_samples_local)
    else:
        random_samples_local = utils.restore_from_file(r"./random_samples_" + str(args.task_id) + r".file")

    if 0 == args.restore_params:
        # each worker generate same init tensors, because each worker will do the filtering by itself
        init_tensors = utils.get_ones_tensor(max_vocab_size_per_gpu=args.max_vocabulary_size_per_gpu,
                                            embedding_vec_size=args.embedding_vec_size,
                                            num=args.local_gpu_num * args.worker_num)
    else:
        filepath = r"./embedding_variables"
        tf_values_filename = os.path.join(filepath, r"tf_variable.file")
        init_tensors = utils.restore_from_file(tf_values_filename)

    sok_results_local, embedding_variable_name = test_sok_dense_demo(args, init_tensors, *random_samples_local)
    # save the forward embedding vector from different worker to file
    utils.save_to_file(r"./sok_embedding_vectors_" + str(args.task_id) + r".file", *sok_results_local)

    # only 1 process needs to do tf computation
    if args.task_id != 0:
        return

    # aggregate dataset from different worker
    dataset_filenames = [r"./random_samples_" + str(task_id) + r".file"
                         for task_id in range(args.worker_num)]
    random_samples_total = [list() for _ in range(args.iter_num)]
    random_labels_total = [list() for _ in range(args.iter_num)]
    local_batch_size = args.global_batch_size // args.worker_num
    for worker_id in range(args.worker_num):
        samples, labels = utils.restore_from_file(dataset_filenames[worker_id])
        for i in range(args.iter_num):
            random_samples_total[i].extend(samples[i * local_batch_size : (i + 1) * local_batch_size])
            random_labels_total[i].extend(labels[i * local_batch_size : (i + 1) * local_batch_size])
    random_samples_total = np.concatenate(random_samples_total, axis=0)
    random_labels_total = np.concatenate(random_labels_total, axis=0)

    tf_results = test_tf_dense_model(args, init_tensors, random_samples_total, random_labels_total)

    # aggregate forward embedding vector from different worker
    sok_results_filenames = [r"./sok_embedding_vectors_" + str(task_id) + r".file"
                             for task_id in range(args.worker_num)]
    sok_results_total = list()
    for file_name in sok_results_filenames:
        sok_results_local = utils.restore_from_file(file_name)
        sok_results_total.append(sok_results_local)
    
    if (len(sok_results_total[0]) != len(tf_results)):
        raise ValueError("The length of results obtained from sok: %d is not equal to that obtained from TF: %d"
                         %(len(sok_results_total[0]), len(tf_results)))
    if (len(tf_results) != args.iter_num):
        raise ValueError("The length of embedding vectors: %d is not equal to iteration number: %d."
                         %(len(tf_results), args.iter_num))
    
    if 1 == args.restore_params or args.mixed_precision:
        tolerance = 1e-2
    else:
        tolerance = 1e-4

    for i in range(args.iter_num):
        if args.local_gpu_num != 1:
            sok_vector = tf.concat([tf.concat(sok_results_total[task_id][i].values, axis=0)
                                    for task_id in range(args.worker_num)], axis=0)
        else:
            sok_vector = tf.concat([sok_results_total[task_id][i]
                                    for task_id in range(args.worker_num)],
                                    axis=0)
        tf.debugging.assert_near(tf.reshape(sok_vector,
                                            shape=[-1, tf.shape(sok_vector)[-1]]),
                                tf_results[i],
                                atol=tolerance,
                                rtol=tolerance)

    print("\n[INFO]: For Dense Embedding Layer, with MultiWorkerMirroredStrategy, the embedding vectors "+\
          "obtained from sparse operation kit and TensorFlow are consistent for %d iterations"
          ", with mixed_precision = %s"
          %(args.iter_num, args.mixed_precision))

    if 1 == args.save_params:
        check_saved_embedding_variables(args, embedding_variable_name, 
                                        use_hashtable=args.use_hashtable, 
                                        gpu_num=args.worker_num * args.local_gpu_num,
                                        atol=tolerance, rtol=tolerance)
Пример #2
0
def test_sok_multi_dense_emb(args):
    comm_options = tf.distribute.experimental.CommunicationOptions(
        bytes_per_pack=0,
        timeout_seconds=None,
        implementation=tf.distribute.experimental.CommunicationImplementation.
        NCCL)

    if args.worker_num == 1:
        strategy = tf.distribute.MirroredStrategy()
    else:
        port = 12345
        os.environ["TF_CONFIG"] = json.dumps({
            "cluster": {
                "worker": [
                    "localhost" + ":" + str(port + i)
                    for i in range(args.worker_num)
                ]
            },
            "task": {
                "type": "worker",
                "index": args.task_id
            }
        })
        strategy = tf.distribute.MultiWorkerMirroredStrategy(
            communication_options=comm_options)

    replica_batch_size = args.global_batch_size // (args.worker_num * 1)

    dataset = utility.TFDataset(filename=args.file_prefix + str(args.task_id) +
                                ".file",
                                batchsize=replica_batch_size,
                                as_sparse_tensor=False,
                                repeat=1)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    dynamic_input = True if args.dynamic_input == 1 else False

    with strategy.scope():
        sok.Init(global_batch_size=args.global_batch_size)

        model = SOKDenseModel(
            max_vocabulary_size_per_gpu=args.max_vocabulary_size_per_gpu,
            embedding_vec_size_list=args.embedding_vec_size_list,
            slot_num_list=args.slot_num_list,
            nnz_per_slot_list=[
                args.nnz_per_slot for _ in range(len(args.slot_num_list))
            ],
            num_dense_layers=args.num_dense_layers,
            dynamic_input=dynamic_input)

        emb_opt = utils.get_embedding_optimizer(
            args.optimizer)(learning_rate=0.1)
        dense_opt = utils.get_dense_optimizer(
            args.optimizer)(learning_rate=0.1)
        if args.mixed_precision:
            emb_opt = tf.keras.mixed_precision.LossScaleOptimizer(
                emb_opt, initial_scale=1024)

    # set initial value to embedding variables.
    sok_saver = sok.Saver()
    for i, layer in enumerate(model.embedding_layers):
        init_tensors = utils.get_ones_tensor(
            max_vocab_size_per_gpu=args.max_vocabulary_size_per_gpu,
            embedding_vec_size=args.embedding_vec_size_list[i],
            num=args.worker_num)
        sok_saver.load_embedding_values(layer.embedding_variable, init_tensors)

    loss_fn = tf.keras.losses.BinaryCrossentropy(
        from_logits=True, reduction=tf.keras.losses.Reduction.NONE)

    def _replica_loss(labels, logits):
        loss = loss_fn(labels, logits)
        _dtype = loss.dtype
        loss = tf.cast(loss, tf.float32)
        loss = tf.nn.compute_average_loss(
            loss, global_batch_size=args.global_batch_size)
        return tf.cast(loss, _dtype)

    @tf.function
    def _train_step(inputs, labels):
        with tf.GradientTape() as tape:
            logit, all_vectors = model(inputs, training=True)
            loss = _replica_loss(labels, logit)
            if args.mixed_precision:
                _loss = emb_opt.get_scaled_loss(loss)
            else:
                _loss = loss
        emb_variable, other_variable = sok.split_embedding_variable_from_others(
            model.trainable_variables)
        grads, emb_grads = tape.gradient(_loss, [other_variable, emb_variable])
        if args.mixed_precision:
            grads = emb_opt.get_unscaled_gradients(grads)
            emb_grads = emb_opt.get_unscaled_gradients(emb_grads)

        if "plugin" not in args.optimizer:
            with sok.OptimizerScope(emb_variable):
                emb_opt.apply_gradients(zip(emb_grads, emb_variable),
                                        experimental_aggregate_gradients=False)
        else:
            emb_opt.apply_gradients(zip(emb_grads, emb_variable),
                                    experimental_aggregate_gradients=False)

        with tf.control_dependencies(emb_grads):
            # mannually all-reduce dense gradients
            replica_context = tf.distribute.get_replica_context()
            grads = replica_context.all_reduce("sum",
                                               grads,
                                               options=comm_options)
            dense_opt.apply_gradients(zip(grads, other_variable),
                                      experimental_aggregate_gradients=False)

            # manually all-reduce loss, it is ok, because replica_loss has already been used to
            # update local variables.
            loss = replica_context.all_reduce(tf.distribute.ReduceOp.SUM,
                                              loss,
                                              options=comm_options)
        return loss, all_vectors, logit

    # save its results
    sok_results = list()
    for i, (inputs, labels) in enumerate(dataset):
        if args.stop_iter >= 0 and i >= args.stop_iter:
            break

        total_loss, all_vectors, logit = strategy.run(_train_step,
                                                      args=(inputs, labels))
        print("[INFO]: Iteration: {}, loss={}".format(i, total_loss))

        with tf.device("CPU:0"):
            sok_results.append(all_vectors)

    return sok_results
Пример #3
0
def test_tf_multi_dense_emb(args):
    dataset_filenames = [
        args.file_prefix + str(task_id) + ".file"
        for task_id in range(args.worker_num)
    ]

    samples_total = [list() for _ in range(args.dataset_iter_num)]
    labels_total = [list() for _ in range(args.dataset_iter_num)]
    replica_batch_size = args.global_batch_size // args.worker_num
    for worker_id in range(args.worker_num):
        samples, labels = utils.restore_from_file(dataset_filenames[worker_id])
        for i in range(args.dataset_iter_num):
            samples_total[i].extend(samples[i * replica_batch_size:(i + 1) *
                                            replica_batch_size])
            labels_total[i].extend(labels[i * replica_batch_size:(i + 1) *
                                          replica_batch_size])
    samples_total = np.concatenate(samples_total, axis=0)
    labels_total = np.concatenate(labels_total, axis=0)

    dataset = utils.tf_dataset(samples_total,
                               labels_total,
                               batchsize=args.global_batch_size,
                               to_sparse_tensor=False,
                               repeat=1)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    model = TFDenseModel(
        vocabulary_size=args.max_vocabulary_size_per_gpu * args.worker_num,
        embedding_vec_size_list=args.embedding_vec_size_list,
        slot_num_list=args.slot_num_list,
        nnz_per_slot_list=[
            args.nnz_per_slot for _ in range(len(args.slot_num_list))
        ],
        num_dense_layers=args.num_dense_layers)

    optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)
    if args.mixed_precision:
        optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
            optimizer, initial_scale=1024)

    # set initial value to embedding variables
    for i, param in enumerate(model.embedding_params):
        init_tensors = utils.get_ones_tensor(
            max_vocab_size_per_gpu=args.max_vocabulary_size_per_gpu *
            args.worker_num,
            embedding_vec_size=args.embedding_vec_size_list[i],
            num=1)
        param.assign(init_tensors[0])

    loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)

    @tf.function
    def _train_step(inputs, labels):
        with tf.GradientTape() as tape:
            logit, all_vectors = model(inputs, training=True)
            loss = loss_fn(labels, logit)
            if args.mixed_precision:
                _loss = optimizer.get_scaled_loss(loss)
            else:
                _loss = loss
        grads = tape.gradient(_loss, model.trainable_variables)
        if args.mixed_precision:
            grads = optimizer.get_unscaled_gradients(grads)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        return loss, all_vectors

    # save its results
    tf_results = list()
    for i, (inputs, labels) in enumerate(dataset):
        if args.stop_iter >= 0 and i >= args.stop_iter:
            break

        loss, all_vectors = _train_step(inputs, labels)
        print("[INFO]: Iteration: {}, loss={}".format(i, loss))

        with tf.device("CPU:0"):
            tf_results.append(all_vectors)
    return tf_results
Пример #4
0
def compare_sok_with_tf(args):
    if (args.global_batch_size % args.gpu_num != 0):
        raise ValueError(
            "global_batch_size: %d is not divisible by gpu_num: %d" %
            (args.global_batch_size, args.gpu_num))

    if args.use_hashtable:
        vocabulary_size = args.max_vocabulary_size_per_gpu * args.gpu_num
    else:
        vocabulary_size = args.max_vocabulary_size_per_gpu

    if args.generate_new_datas:
        random_samples = utils.generate_random_samples(
            num_of_samples=args.global_batch_size * args.iter_num,
            vocabulary_size=vocabulary_size,
            slot_num=args.slot_num,
            max_nnz=args.max_nnz)
        utils.save_to_file(r"./random_samples.file", *random_samples)
    else:
        random_samples = utils.restore_from_file(r"./random_samples.file")

    if (1 == args.restore_params):  # initialize using trained params
        filepath = r"./embedding_variables"

        # because we already checked the Variable consistency when saving.
        # so that here we can directly use TensorFlow Variable file to initialize
        # tf's variable.
        # FIXME: what if not all TensorFlow embedding vectors are used??
        tf_values_filename = os.path.join(filepath, r"tf_variable.file")
        init_tensors = utils.restore_from_file(tf_values_filename)

    else:  # initialize using random initial value
        init_tensors = utils.get_ones_tensor(
            max_vocab_size_per_gpu=args.max_vocabulary_size_per_gpu,
            embedding_vec_size=args.embedding_vec_size,
            num=args.gpu_num)

    sok_results, embedding_variable_name = test_sok_demo(
        args, init_tensors, *random_samples)
    tf_results = test_tf_demo(args, init_tensors, *random_samples)

    if (len(sok_results) != len(tf_results)):
        raise ValueError(
            "The length of plugin results is not equal to that of tensorflow.")
    if (len(tf_results) != args.iter_num):
        raise ValueError(
            "The length of embedding vectors: %d is not equal to iteration number: %d."
            % (len(tf_results), args.iter_num))

    tolerance = 1e-4
    if args.mixed_precision:
        tolerance = 1e-3

    for i, sok_vector in enumerate(sok_results):
        if args.gpu_num != 1:
            sok_vector = tf.stack(sok_vector.values, axis=0)
        tf.debugging.assert_near(tf.reshape(
            sok_vector, shape=[-1, tf.shape(sok_vector)[-1]]),
                                 tf_results[i],
                                 atol=tolerance,
                                 rtol=tolerance)
    print("\n[INFO]: With MirroredStrategy, the embedding vector obtained from " +\
          "sparse operation kit and tensorflow are consistent for %d iterations."
          " With mixed_precision = %s, and key_dtype = %s, and use_tf_initializer = %s"
          %(args.iter_num, args.mixed_precision, args.key_dtype, args.use_tf_initializer))

    if (1 == args.save_params):
        check_saved_embedding_variables(args,
                                        embedding_variable_name,
                                        use_hashtable=args.use_hashtable,
                                        gpu_num=args.gpu_num,
                                        atol=tolerance,
                                        rtol=tolerance)
Пример #5
0
def compare_dense_emb_sok_with_tf(args):
    if args.global_batch_size % args.gpu_num != 0:
        raise ValueError(
            f"global_batch_size: {args.global_batch_size} is not divisible by"
            f" gpu_num: {args.gpu_num}")

    if args.use_hashtable:
        vocabulary_size = args.max_vocabulary_size_per_gpu * args.gpu_num
    else:
        vocabulary_size = args.max_vocabulary_size_per_gpu

    if args.generate_new_datas:
        replica_batch_size = args.global_batch_size // args.gpu_num
        random_samples = utils.generate_random_samples(
            num_of_samples=replica_batch_size * args.iter_num,
            vocabulary_size=vocabulary_size,
            slot_num=sum(args.slot_num),
            max_nnz=args.nnz_per_slot,
            use_sparse_mask=False)
        utils.save_to_file(
            r"./random_samples_" + str(args.rank_idx) + r".file",
            *random_samples)
    else:
        random_samples = utils.restore_from_file(r"./random_samples_" +
                                                 str(args.rank_idx) + r".file")

    if args.restore_params:
        filepath = r"./embedding_variables"
        # because we already checked the Variable consistency when saving
        # so that we can directly use TensorFlow Variable file to initialize
        # TF's Variable
        init_tensors = list()
        for i in range(len(args.slot_num)):
            tf_values_filename = os.path.join(
                filepath, r"tf_variable_" + str(i) + r".file")
            init_tensors.append(utils.restore_from_file(tf_values_filename))
    else:
        init_tensors = list()
        for i in range(len(args.slot_num)):
            init_tensors.append(
                utils.get_ones_tensor(
                    max_vocab_size_per_gpu=args.max_vocabulary_size_per_gpu,
                    embedding_vec_size=args.embedding_vec_size[i],
                    num=args.gpu_num))

    sok_results, embedding_variable_name = get_sok_results(
        args, init_tensors, *random_samples)
    utils.save_to_file(
        r"./sok_embedding_vectors_" + str(args.rank_idx) + r".file",
        *sok_results)

    if args.rank_idx != 0:
        return

    # aggregate dataset from different worker
    dataset_filenames = [
        r"./random_samples_" + str(rank_idx) + r".file"
        for rank_idx in range(args.rank_size)
    ]
    random_samples_total = [list() for _ in range(args.iter_num)]
    random_labels_total = [list() for _ in range(args.iter_num)]
    local_batch_size = args.global_batch_size // args.gpu_num
    for rank_idx in range(args.rank_size):
        samples, labels = utils.restore_from_file(dataset_filenames[rank_idx])
        for i in range(args.iter_num):
            random_samples_total[i].extend(
                samples[i * local_batch_size:(i + 1) * local_batch_size])
            random_labels_total[i].extend(labels[i * local_batch_size:(i + 1) *
                                                 local_batch_size])
    random_samples_total = np.concatenate(random_samples_total, axis=0)
    random_labels_total = np.concatenate(random_labels_total, axis=0)

    tf_results, _ = get_tf_results(args, init_tensors, random_samples_total,
                                   random_labels_total)

    # aggregate sok forward results from different worker
    sok_results_filenames = [
        r"./sok_embedding_vectors_" + str(rank_idx) + r".file"
        for rank_idx in range(args.rank_size)
    ]
    sok_results_total = list()
    for filename in sok_results_filenames:
        sok_results = utils.restore_from_file(filename)
        sok_results_total.append(sok_results)

    if len(sok_results_total[0]) != len(tf_results):
        raise ValueError(
            "The length of sok results is not equal to that of tensorflow.")
    if len(sok_results) != args.iter_num:
        raise ValueError(
            "The length of embedding vectors: %d is not equal to iteration number: %d."
            % (len(sok_results), args.iter_num))

    rtol = 1e-4
    atol = 1e-4
    if args.restore_params:
        rtol, atol = 1e-3, 1e-3
    elif args.distributed_tool == "horovod":
        rtol, atol = rtol * 10, atol * 10
    elif args.mixed_precision:
        rtol, atol = 1e-2, 1e-2

    for i in range(args.iter_num):
        sok_vector = np.concatenate([
            sok_results_total[rank_idx][i]
            for rank_idx in range(args.rank_size)
        ],
                                    axis=0)
        allclose = np.allclose(sok_vector, tf_results[i], rtol=rtol, atol=atol)
        if not allclose:
            raise ValueError(
                f"\n{sok_vector} \nis not near to \n{tf_results[i]} \nat rtol={rtol}, atol={atol}"
            )

        # TODO: add an verbose option
        if False:
            print("--------------- step: {}---------------------".format(i))
            print("sok_embedding_vector:\n{}".format(sok_vector))
            print("tf_embedding_vector:\n{}".format(tf_results[i]))

    print(
        f"\n[INFO]: For {len(args.slot_num)} Dense Embedding layer, using {args.gpu_num} GPUs + {args.optimizer} optimizer, "
        f"using hashtable? {args.use_hashtable}, dynamic_input? {args.dynamic_input}, "
        "the embedding vectors"
        f" obtained from sok and tf are consistent for {args.iter_num} iterations,"
        f" with mixed_precision = {args.mixed_precision}, key_dtype = {args.key_dtype}",
        f" use_tf_initializer = {args.use_tf_initializer}")

    if args.save_params:
        check_saved_embedding_variables(args,
                                        embedding_variable_name,
                                        use_hashtable=args.use_hashtable,
                                        gpu_num=args.gpu_num,
                                        atol=atol,
                                        rtol=rtol)
Пример #6
0
def test_sok_multi_dense_emb(args):
    assert (args.global_batch_size % args.worker_num == 0)
    replica_batch_size = args.global_batch_size // (args.worker_num)

    dataset = utility.TFDataset(filename=args.file_prefix + str(args.task_id) +
                                ".file",
                                batchsize=replica_batch_size,
                                as_sparse_tensor=False,
                                repeat=1)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    dynamic_input = True if args.dynamic_input == 1 else False

    # SOK initialize
    sok.Init(global_batch_size=args.global_batch_size)

    model = SOKDenseModel(
        max_vocabulary_size_per_gpu=args.max_vocabulary_size_per_gpu,
        embedding_vec_size_list=args.embedding_vec_size_list,
        slot_num_list=args.slot_num_list,
        nnz_per_slot_list=[
            args.nnz_per_slot for _ in range(len(args.slot_num_list))
        ],
        num_dense_layers=args.num_dense_layers,
        dynamic_input=dynamic_input,
        use_hashtable=args.use_hashtable)

    emb_opt = utils.get_embedding_optimizer(args.optimizer)(learning_rate=0.1)
    dense_opt = utils.get_dense_optimizer(args.optimizer)(learning_rate=0.1)
    if args.mixed_precision:
        emb_opt = tf.keras.mixed_precision.LossScaleOptimizer(
            emb_opt, initial_scale=1024)

    sok_saver = sok.Saver()
    for i, layer in enumerate(model.embedding_layers):
        init_tensors = utils.get_ones_tensor(
            max_vocab_size_per_gpu=args.max_vocabulary_size_per_gpu,
            embedding_vec_size=args.embedding_vec_size_list[i],
            num=args.worker_num)

        sok_saver.load_embedding_values(layer.embedding_variable, init_tensors)

    loss_fn = tf.keras.losses.BinaryCrossentropy(
        from_logits=True, reduction=tf.keras.losses.Reduction.NONE)

    def _replica_loss(labels, logits):
        loss = loss_fn(labels, logits)
        _dtype = loss.dtype
        loss = tf.cast(loss, tf.float32)
        loss = tf.nn.compute_average_loss(
            loss, global_batch_size=args.global_batch_size)
        return tf.cast(loss, _dtype)

    @tf.function
    def _train_step(inputs, labels, first_batch):
        with tf.GradientTape() as tape:
            logit, all_vectors = model(inputs, training=True)
            replica_loss = _replica_loss(labels, logit)
            if args.mixed_precision:
                _loss = emb_opt.get_scaled_loss(replica_loss)
            else:
                _loss = replica_loss

        emb_var, other_var = sok.split_embedding_variable_from_others(
            model.trainable_variables)
        emb_grads, grads = tape.gradient(_loss, [emb_var, other_var])
        if args.mixed_precision:
            emb_grads = emb_opt.get_unscaled_gradients(emb_grads)
            grads = emb_opt.get_unscaled_gradients(grads)

        if "plugin" not in args.optimizer:
            with sok.OptimizerScope(emb_var):
                emb_opt.apply_gradients(zip(emb_grads, emb_var),
                                        experimental_aggregate_gradients=False)
        else:
            emb_opt.apply_gradients(zip(emb_grads, emb_var),
                                    experimental_aggregate_gradients=False)

        with tf.control_dependencies(emb_grads):

            grads = [hvd.allreduce(grad) for grad in grads]
            dense_opt.apply_gradients(zip(grads, other_var))

            if first_batch:
                hvd.broadcast_variables(other_var, root_rank=0)
                hvd.broadcast_variables(dense_opt.variables(), root_rank=0)

            total_loss = hvd.allreduce(replica_loss)
        return total_loss, all_vectors

    sok_results = list()
    for i, (inputs, labels) in enumerate(dataset):
        if args.stop_iter >= 0 and i >= args.stop_iter:
            break

        total_loss, all_vectors = _train_step(inputs, labels, 0 == i)
        print("[INFO]: Iteration: {}, loss={}".format(i, total_loss))

        sok_results.append(all_vectors)
    return sok_results