def test_sok_dense_demo(args, init_tensors, *random_samples):
    port = 12345
    os.environ["TF_CONFIG"] = json.dumps({
        "cluster": {"worker": [args.ips[i] + ":" + str(port + i) for i in range(args.worker_num)]},
        "task": {"type": "worker", "index": args.task_id}
    })
    strategy = tf.distribute.MultiWorkerMirroredStrategy()
    with strategy.scope():
        sok.Init(global_batch_size=args.global_batch_size)

        sok_dense_demo = SOKDenseDemo(max_vocabulary_size_per_gpu=args.max_vocabulary_size_per_gpu,
                                      embedding_vec_size=args.embedding_vec_size,
                                      slot_num=args.slot_num,
                                      nnz_per_slot=args.nnz_per_slot,
                                      use_hashtable=args.use_hashtable)

        emb_opt = utils.get_embedding_optimizer(args.optimizer)(learning_rate=0.1)
        dense_opt = utils.get_dense_optimizer(args.optimizer)(learning_rate=0.1)
        if args.mixed_precision:
            emb_opt = tf.keras.mixed_precision.LossScaleOptimizer(emb_opt, initial_scale=1024)

    sok_saver = sok.Saver()
    if 1 == args.restore_params:
        filepath = r"./embedding_variables"
        sok_saver.restore_from_file(sok_dense_demo.embedding_layer.embedding_variable, filepath)
    else:
        sok_saver.load_embedding_values(sok_dense_demo.embedding_layer.embedding_variable, init_tensors)

    loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
    def _replica_loss(labels, logits):
        loss = loss_fn(labels, logits)
        _dtype = loss.dtype
        loss = tf.cast(loss, tf.float32)
        loss = tf.nn.compute_average_loss(loss, global_batch_size=args.global_batch_size)
        return tf.cast(loss, _dtype)

    @tf.function
    def _train_step(inputs, labels):
        with tf.GradientTape() as tape:
            logit, embedding_vector = sok_dense_demo(inputs, training=True)
            loss = _replica_loss(labels, logit)
            if args.mixed_precision:
                _loss = emb_opt.get_scaled_loss(loss)
            else:
                _loss = loss

        embedding_variables, other_variable = sok.split_embedding_variable_from_others(sok_dense_demo.trainable_variables)
        grads, emb_grads = tape.gradient(_loss, [other_variable, embedding_variables])
        if args.mixed_precision:
            grads = emb_opt.get_unscaled_gradients(grads)
            emb_grads = emb_opt.get_unscaled_gradients(emb_grads)

        if "plugin" not in args.optimizer:
            with sok.OptimizerScope(embedding_variables):
                emb_opt.apply_gradients(zip(emb_grads, embedding_variables),
                                        experimental_aggregate_gradients=False)
        else:
            emb_opt.apply_gradients(zip(emb_grads, embedding_variables),
                                    experimental_aggregate_gradients=False)
        dense_opt.apply_gradients(zip(grads, other_variable))
        return loss, embedding_vector

    sok_results = list()

    def _dataset_fn(input_context):
        replica_batch_size = input_context.get_per_replica_batch_size(args.global_batch_size)
        dataset = utils.tf_dataset(*random_samples, batchsize=replica_batch_size, 
                                   to_sparse_tensor=False, repeat=1)
        return dataset

    dataset = strategy.distribute_datasets_from_function(_dataset_fn)

    for i, (input_tensors, replica_labels) in enumerate(dataset):
        print("-"*30, "step ", str(i), "-"*30)
        loss, embedding_vector = strategy.run(_train_step, args=(input_tensors, replica_labels))
        loss = strategy.reduce("sum", loss, axis=None)
        print("[INFO]: iteration {}, loss {}".format(i, loss))
        sok_results.append(embedding_vector)


    # save params to file.
    if 1 == args.save_params:
        filepath = r"./embedding_variables"
        utils.try_make_dirs(filepath, chief=(True if args.task_id == 0 else False))

        sok_saver.dump_to_file(sok_dense_demo.embedding_layer.embedding_variable, filepath)

    return sok_results, sok_dense_demo.embedding_layer.embedding_variable.values[0].m_var_name
def run_sok_model(args, dense_variables, vocabulary_tensors, samples, labels):
    # split sample and labels
    assert (args.global_batch_size % hvd.size() == 0)
    local_batch_size = args.global_batch_size // hvd.size()
    local_id = hvd.local_rank()
    samples = samples[local_id * local_batch_size:(local_id + 1) *
                      local_batch_size]
    labels = labels[local_id * local_batch_size:(local_id + 1) *
                    local_batch_size]

    sok.Init(global_batch_size=args.global_batch_size)

    model = SOKDenseDemo(
        max_vocabulary_size_per_gpu=args.max_vocabulary_size_per_gpu,
        embedding_vec_size=args.embedding_vec_size,
        slot_num=args.slot_num,
        nnz_per_slot=args.nnz_per_slot,
        num_dense_layers=args.num_dense_layers,
        num_dense_units=args.num_dense_units)

    #model.build(input_shape=(local_batch_size, args.slot_num * args.nnz_per_slot * args.embedding_vec_size))
    model(samples, training=False)
    for i in range(args.num_dense_layers):
        model.dense_layers[i].trainable_variables[0].assign(
            dense_variables[0][i])
        model.dense_layers[i].trainable_variables[1].assign(
            dense_variables[1][i])

    sok_saver = sok.Saver()
    init_tensors = [tensor.numpy() for tensor in vocabulary_tensors]
    sok_saver.load_embedding_values(model.embedding_layer.embedding_variable,
                                    init_tensors)

    embedding_optimizer = utils.get_embedding_optimizer(
        args.optimizer)(learning_rate=0.1)
    dense_optimizer = utils.get_dense_optimizer(
        args.optimizer)(learning_rate=0.1)
    if args.mixed_precision:
        embedding_optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
            embedding_optimizer, initial_scale=1024)

    loss_fn = tf.keras.losses.BinaryCrossentropy(
        from_logits=True, reduction=tf.keras.losses.Reduction.NONE)

    def _replica_loss(labels, logits):
        loss = loss_fn(labels, logits)
        _dtype = loss.dtype
        loss = tf.cast(loss, tf.float32)
        loss = tf.nn.compute_average_loss(
            loss, global_batch_size=args.global_batch_size)
        return tf.cast(loss, dtype=_dtype)

    @tf.function
    def _train_step(inputs, labels, first_batch):
        with tf.GradientTape() as tape, tf.GradientTape() as emb_tape:
            logit = model(inputs, training=True)
            replica_loss = _replica_loss(labels, logit)
            if args.mixed_precision:
                _loss = embedding_optimizer.get_scaled_loss(replica_loss)
            else:
                _loss = replica_loss

        tape = hvd.DistributedGradientTape(tape)

        emb_variable, other_variable = sok.split_embedding_variable_from_others(
            model.trainable_variables)
        emb_grads = emb_tape.gradient(_loss, emb_variable)
        grads = tape.gradient(_loss, other_variable)
        if args.mixed_precision:
            emb_grads = embedding_optimizer.get_unscaled_gradients(emb_grads)
            grads = embedding_optimizer.get_unscaled_gradients(grads)

        if 'plugin' not in args.optimizer:
            with sok.OptimizerScope(emb_variable):
                embedding_optimizer.apply_gradients(
                    zip(emb_grads, emb_variable),
                    experimental_aggregate_gradients=False)
        else:
            embedding_optimizer.apply_gradients(
                zip(emb_grads, emb_variable),
                experimental_aggregate_gradients=False)
        dense_optimizer.apply_gradients(zip(grads, other_variable))

        # Note: broadcast should be done after the first gradient step to ensure optimizer initialization.
        if first_batch:
            hvd.broadcast_variables(other_variable, root_rank=0)
            hvd.broadcast_variables(dense_optimizer.variables(), root_rank=0)

        return replica_loss

    loss_list = []
    for i in range(args.iter_num):
        loss = _train_step(samples, labels, i == 0)
        loss_list.append(loss)
        print("[INFO]: Iteration: {}, loss={}".format(i, loss))
    return loss_list
Exemplo n.º 3
0
def test_sok_multi_dense_emb(args):
    comm_options = tf.distribute.experimental.CommunicationOptions(
        bytes_per_pack=0,
        timeout_seconds=None,
        implementation=tf.distribute.experimental.CommunicationImplementation.
        NCCL)

    if args.worker_num == 1:
        strategy = tf.distribute.MirroredStrategy()
    else:
        port = 12345
        os.environ["TF_CONFIG"] = json.dumps({
            "cluster": {
                "worker": [
                    "localhost" + ":" + str(port + i)
                    for i in range(args.worker_num)
                ]
            },
            "task": {
                "type": "worker",
                "index": args.task_id
            }
        })
        strategy = tf.distribute.MultiWorkerMirroredStrategy(
            communication_options=comm_options)

    replica_batch_size = args.global_batch_size // (args.worker_num * 1)

    dataset = utility.TFDataset(filename=args.file_prefix + str(args.task_id) +
                                ".file",
                                batchsize=replica_batch_size,
                                as_sparse_tensor=False,
                                repeat=1)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    dynamic_input = True if args.dynamic_input == 1 else False

    with strategy.scope():
        sok.Init(global_batch_size=args.global_batch_size)

        model = SOKDenseModel(
            max_vocabulary_size_per_gpu=args.max_vocabulary_size_per_gpu,
            embedding_vec_size_list=args.embedding_vec_size_list,
            slot_num_list=args.slot_num_list,
            nnz_per_slot_list=[
                args.nnz_per_slot for _ in range(len(args.slot_num_list))
            ],
            num_dense_layers=args.num_dense_layers,
            dynamic_input=dynamic_input)

        emb_opt = utils.get_embedding_optimizer(
            args.optimizer)(learning_rate=0.1)
        dense_opt = utils.get_dense_optimizer(
            args.optimizer)(learning_rate=0.1)
        if args.mixed_precision:
            emb_opt = tf.keras.mixed_precision.LossScaleOptimizer(
                emb_opt, initial_scale=1024)

    # set initial value to embedding variables.
    sok_saver = sok.Saver()
    for i, layer in enumerate(model.embedding_layers):
        init_tensors = utils.get_ones_tensor(
            max_vocab_size_per_gpu=args.max_vocabulary_size_per_gpu,
            embedding_vec_size=args.embedding_vec_size_list[i],
            num=args.worker_num)
        sok_saver.load_embedding_values(layer.embedding_variable, init_tensors)

    loss_fn = tf.keras.losses.BinaryCrossentropy(
        from_logits=True, reduction=tf.keras.losses.Reduction.NONE)

    def _replica_loss(labels, logits):
        loss = loss_fn(labels, logits)
        _dtype = loss.dtype
        loss = tf.cast(loss, tf.float32)
        loss = tf.nn.compute_average_loss(
            loss, global_batch_size=args.global_batch_size)
        return tf.cast(loss, _dtype)

    @tf.function
    def _train_step(inputs, labels):
        with tf.GradientTape() as tape:
            logit, all_vectors = model(inputs, training=True)
            loss = _replica_loss(labels, logit)
            if args.mixed_precision:
                _loss = emb_opt.get_scaled_loss(loss)
            else:
                _loss = loss
        emb_variable, other_variable = sok.split_embedding_variable_from_others(
            model.trainable_variables)
        grads, emb_grads = tape.gradient(_loss, [other_variable, emb_variable])
        if args.mixed_precision:
            grads = emb_opt.get_unscaled_gradients(grads)
            emb_grads = emb_opt.get_unscaled_gradients(emb_grads)

        if "plugin" not in args.optimizer:
            with sok.OptimizerScope(emb_variable):
                emb_opt.apply_gradients(zip(emb_grads, emb_variable),
                                        experimental_aggregate_gradients=False)
        else:
            emb_opt.apply_gradients(zip(emb_grads, emb_variable),
                                    experimental_aggregate_gradients=False)

        with tf.control_dependencies(emb_grads):
            # mannually all-reduce dense gradients
            replica_context = tf.distribute.get_replica_context()
            grads = replica_context.all_reduce("sum",
                                               grads,
                                               options=comm_options)
            dense_opt.apply_gradients(zip(grads, other_variable),
                                      experimental_aggregate_gradients=False)

            # manually all-reduce loss, it is ok, because replica_loss has already been used to
            # update local variables.
            loss = replica_context.all_reduce(tf.distribute.ReduceOp.SUM,
                                              loss,
                                              options=comm_options)
        return loss, all_vectors, logit

    # save its results
    sok_results = list()
    for i, (inputs, labels) in enumerate(dataset):
        if args.stop_iter >= 0 and i >= args.stop_iter:
            break

        total_loss, all_vectors, logit = strategy.run(_train_step,
                                                      args=(inputs, labels))
        print("[INFO]: Iteration: {}, loss={}".format(i, total_loss))

        with tf.device("CPU:0"):
            sok_results.append(all_vectors)

    return sok_results
Exemplo n.º 4
0
def test_sok_demo(args, init_tensors, *random_samples):
    strategy = tf.distribute.MirroredStrategy()
    with strategy.scope():
        result = sok.Init(global_batch_size=args.global_batch_size)

        embedding_initializer = tf.keras.initializers.Ones(
        ) if args.use_tf_initializer else None

        plugin_demo = SOKDemo(
            combiner=args.combiner,
            max_vocabulary_size_per_gpu=args.max_vocabulary_size_per_gpu,
            slot_num=args.slot_num,
            max_nnz=args.max_nnz,
            embedding_vec_size=args.embedding_vec_size,
            use_hashtable=args.use_hashtable,
            key_dtype=args.key_dtype,
            embedding_initializer=embedding_initializer)

        emb_opt = utils.get_embedding_optimizer(
            args.optimizer)(learning_rate=0.1)
        dense_opt = utils.get_dense_optimizer(
            args.optimizer)(learning_rate=0.1)
        if args.mixed_precision:
            emb_opt = tf.keras.mixed_precision.LossScaleOptimizer(
                emb_opt, initial_scale=1024)

    plugin_saver = sok.Saver()

    if (1 == args.restore_params):  # restore from trained parameters
        filepath = r"./embedding_variables"
        plugin_saver.restore_from_file(
            plugin_demo.embedding_layer.embedding_variable, filepath)
    else:  # initialize using randomized initial value
        if not args.use_tf_initializer and init_tensors:
            status = plugin_saver.load_embedding_values(
                plugin_demo.embedding_layer.embedding_variable, init_tensors)

    loss_fn = tf.keras.losses.BinaryCrossentropy(
        from_logits=True, reduction=tf.keras.losses.Reduction.NONE)

    def _replica_loss(labels, logits):
        loss = loss_fn(labels, logits)
        _dtype = loss.dtype
        loss = tf.cast(loss, tf.float32)
        loss = tf.nn.compute_average_loss(
            loss, global_batch_size=args.global_batch_size)
        return tf.cast(loss, _dtype)

    @tf.function
    def _train_step(inputs, labels):
        with tf.GradientTape() as tape:
            logit, embedding_vector = plugin_demo(inputs, training=True)
            loss = _replica_loss(labels, logit)
            if args.mixed_precision:
                _loss = emb_opt.get_scaled_loss(loss)
            else:
                _loss = loss
        embedding_variables, other_variable = sok.split_embedding_variable_from_others(
            plugin_demo.trainable_variables)
        grads, emb_grads = tape.gradient(_loss,
                                         [other_variable, embedding_variables])
        if args.mixed_precision:
            grads = emb_opt.get_unscaled_gradients(grads)
            emb_grads = emb_opt.get_unscaled_gradients(emb_grads)

        with tf.control_dependencies([*emb_grads]):
            # in case NCCL runs concurrently via SOK and TF
            if 'plugin' not in args.optimizer:
                with sok.OptimizerScope(embedding_variables):
                    emb_opt.apply_gradients(
                        zip(emb_grads, embedding_variables),
                        experimental_aggregate_gradients=False)
            else:
                emb_opt.apply_gradients(zip(emb_grads, embedding_variables),
                                        experimental_aggregate_gradients=False)
            dense_opt.apply_gradients(zip(grads, other_variable))
            return loss, embedding_vector

    sok_results = list()

    def _dataset_fn(input_context):
        replica_batch_size = input_context.get_per_replica_batch_size(
            args.global_batch_size)
        dataset = utils.tf_dataset(*random_samples,
                                   batchsize=replica_batch_size,
                                   to_sparse_tensor=True,
                                   repeat=1,
                                   args=args)
        dataset = dataset.shard(input_context.num_input_pipelines,
                                input_context.input_pipeline_id)
        return dataset

    dataset = strategy.distribute_datasets_from_function(_dataset_fn)

    for i, (sparse_tensors, replica_labels) in enumerate(dataset):
        print("-" * 30, "step ", str(i), "-" * 30)
        loss, embedding_vector = strategy.run(_train_step,
                                              args=(sparse_tensors,
                                                    replica_labels))
        loss = strategy.reduce("sum", loss, axis=None)
        print("[INFO]: iteration {}, loss {}".format(i, loss))
        sok_results.append(embedding_vector)

    # save params to file.
    if 1 == args.save_params:
        filepath = r"./embedding_variables/"
        utils.try_make_dirs(filepath)

        plugin_saver.dump_to_file(
            plugin_demo.embedding_layer.embedding_variable, filepath)

    return sok_results, plugin_demo.embedding_layer.embedding_variable.values[
        0].m_var_name
Exemplo n.º 5
0
def get_sok_results(args, init_tensors, *random_samples):
    if args.distributed_tool == "onedevice":
        strategy = strategy_wrapper.OneDeviceStrategy()
    elif args.distributed_tool == "horovod":
        import horovod.tensorflow as hvd
        hvd.init()
        strategy = strategy_wrapper.HorovodStrategy()
    else:
        raise ValueError(f"{args.distributed_tool} is not supported.")

    with strategy.scope():
        sok_init_op = sok.Init(global_batch_size=args.global_batch_size)

        embedding_initializer = tf.keras.initializers.Ones(
        ) if args.use_tf_initializer else None

        sok_dense_demo = SOKDemo(
            max_vocabulary_size_per_gpu=args.max_vocabulary_size_per_gpu,
            embedding_vec_size=args.embedding_vec_size,
            slot_num=args.slot_num,
            nnz_per_slot=args.nnz_per_slot,
            use_hashtable=args.use_hashtable,
            dynamic_input=args.dynamic_input,
            num_of_dense_layers=0,
            key_dtype=args.key_dtype,
            embedding_initializer=embedding_initializer)

        emb_opt = utils.get_embedding_optimizer(
            args.optimizer)(learning_rate=0.1)
        dense_opt = utils.get_dense_optimizer(
            args.optimizer)(learning_rate=0.1)
        if args.mixed_precision:
            emb_opt = sok.tf.keras.mixed_precision.LossScaleOptimizer(
                emb_opt, 1024)

    sok_saver = sok.Saver()
    restore_op = list()
    for i, embedding_layer in enumerate(sok_dense_demo.embedding_layers):
        control_inputs = [restore_op[-1]] if restore_op else None
        with tf.control_dependencies(control_inputs):
            if args.restore_params:
                filepath = r"./embedding_variables"
                op = sok_saver.restore_from_file(
                    embedding_layer.embedding_variable, filepath)
            else:
                if not args.use_tf_initializer:
                    op = sok_saver.load_embedding_values(
                        embedding_layer.embedding_variable, init_tensors[i])
                else:
                    op = tf.constant(1.0)
            restore_op.append(op)

    loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True,
                                                 reduction='none')

    def _replica_loss(labels, logits):
        loss = loss_fn(labels, logits)
        _dtype = loss.dtype
        loss = tf.cast(loss, tf.float32)
        loss = tf.nn.compute_average_loss(
            loss, global_batch_size=args.global_batch_size)
        return tf.cast(loss, _dtype)

    def _train_step(inputs, labels, training):
        def _step_fn(inputs, labels):
            logit, embedding_vector = sok_dense_demo(inputs, training=training)
            loss = _replica_loss(labels, logit)
            if args.mixed_precision:
                _loss = emb_opt.get_scaled_loss(loss)
            else:
                _loss = loss
            emb_var, other_var = sok.split_embedding_variable_from_others(
                sok_dense_demo.trainable_variables)
            grads = tf.gradients(
                _loss,
                emb_var + other_var,
                colocate_gradients_with_ops=True,
                unconnected_gradients=tf.UnconnectedGradients.NONE)
            emb_grads, other_grads = grads[:len(emb_var)], grads[len(emb_var):]
            if args.mixed_precision:
                other_grads = emb_opt.get_unscaled_gradients(other_grads)
                emb_grads = emb_opt.get_unscaled_gradients(emb_grads)

            if "plugin" in args.optimizer:
                emb_train_op = emb_opt.apply_gradients(zip(emb_grads, emb_var))
            else:
                with sok.OptimizerScope(emb_var):
                    emb_train_op = emb_opt.apply_gradients(
                        zip(emb_grads, emb_var))
            with tf.control_dependencies([*emb_grads]):
                # in case NCCL runs concurrently via SOK and horovod
                other_grads = strategy.reduce("sum", other_grads)
            other_train_op = dense_opt.apply_gradients(
                zip(other_grads, other_var))

            with tf.control_dependencies([emb_train_op, other_train_op]):
                total_loss = strategy.reduce("sum", loss)
                total_loss = tf.identity(total_loss)
                return total_loss, embedding_vector

        return strategy.run(_step_fn, inputs, labels)

    replica_batch_size = args.global_batch_size // args.gpu_num
    dataset = utils.tf_dataset(*random_samples,
                               batchsize=replica_batch_size,
                               to_sparse_tensor=False,
                               repeat=1,
                               args=args)
    train_iterator = dataset.make_initializable_iterator()
    iterator_init = train_iterator.initializer

    inputs, labels = train_iterator.get_next()
    graph_results = _train_step(inputs, labels, training=True)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    if "plugin" in args.optimizer:
        init_op = tf.group(init_op, emb_opt.initializer)

    save_op = list()
    for i, embedding_layer in enumerate(sok_dense_demo.embedding_layers):
        control_inputs = [save_op[-1]] if save_op else None
        with tf.control_dependencies(control_inputs):
            if args.save_params:
                filepath = r"./embedding_variables/"
                utils.try_make_dirs(filepath)
                op = sok_saver.dump_to_file(embedding_layer.embedding_variable,
                                            filepath)
            else:
                op = tf.constant(1.0)
        save_op.append(op)

    sok_results = list()

    config = tf.ConfigProto()
    config.log_device_placement = False
    with tf.Session(config=config) as sess:
        sess.run(sok_init_op)
        sess.run([init_op, iterator_init])
        sess.run(restore_op)
        sess.graph.finalize()

        for step in range(args.iter_num):
            loss_v, emb_vector_v = sess.run([*graph_results])
            print("*" * 80)
            print(f"Step: {step}, loss: {loss_v}"
                  )  #", embedding_vector:\n{emb_vector_v}")
            sok_results.append(emb_vector_v)

        sess.run(save_op)

    name = list()
    for embedding_layer in sok_dense_demo.embedding_layers:
        name.append(embedding_layer.embedding_variable.m_var_name)

    return sok_results, name
Exemplo n.º 6
0
def test_sok_multi_dense_emb(args):
    assert (args.global_batch_size % args.worker_num == 0)
    replica_batch_size = args.global_batch_size // (args.worker_num)

    dataset = utility.TFDataset(filename=args.file_prefix + str(args.task_id) +
                                ".file",
                                batchsize=replica_batch_size,
                                as_sparse_tensor=False,
                                repeat=1)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    dynamic_input = True if args.dynamic_input == 1 else False

    # SOK initialize
    sok.Init(global_batch_size=args.global_batch_size)

    model = SOKDenseModel(
        max_vocabulary_size_per_gpu=args.max_vocabulary_size_per_gpu,
        embedding_vec_size_list=args.embedding_vec_size_list,
        slot_num_list=args.slot_num_list,
        nnz_per_slot_list=[
            args.nnz_per_slot for _ in range(len(args.slot_num_list))
        ],
        num_dense_layers=args.num_dense_layers,
        dynamic_input=dynamic_input,
        use_hashtable=args.use_hashtable)

    emb_opt = utils.get_embedding_optimizer(args.optimizer)(learning_rate=0.1)
    dense_opt = utils.get_dense_optimizer(args.optimizer)(learning_rate=0.1)
    if args.mixed_precision:
        emb_opt = tf.keras.mixed_precision.LossScaleOptimizer(
            emb_opt, initial_scale=1024)

    sok_saver = sok.Saver()
    for i, layer in enumerate(model.embedding_layers):
        init_tensors = utils.get_ones_tensor(
            max_vocab_size_per_gpu=args.max_vocabulary_size_per_gpu,
            embedding_vec_size=args.embedding_vec_size_list[i],
            num=args.worker_num)

        sok_saver.load_embedding_values(layer.embedding_variable, init_tensors)

    loss_fn = tf.keras.losses.BinaryCrossentropy(
        from_logits=True, reduction=tf.keras.losses.Reduction.NONE)

    def _replica_loss(labels, logits):
        loss = loss_fn(labels, logits)
        _dtype = loss.dtype
        loss = tf.cast(loss, tf.float32)
        loss = tf.nn.compute_average_loss(
            loss, global_batch_size=args.global_batch_size)
        return tf.cast(loss, _dtype)

    @tf.function
    def _train_step(inputs, labels, first_batch):
        with tf.GradientTape() as tape:
            logit, all_vectors = model(inputs, training=True)
            replica_loss = _replica_loss(labels, logit)
            if args.mixed_precision:
                _loss = emb_opt.get_scaled_loss(replica_loss)
            else:
                _loss = replica_loss

        emb_var, other_var = sok.split_embedding_variable_from_others(
            model.trainable_variables)
        emb_grads, grads = tape.gradient(_loss, [emb_var, other_var])
        if args.mixed_precision:
            emb_grads = emb_opt.get_unscaled_gradients(emb_grads)
            grads = emb_opt.get_unscaled_gradients(grads)

        if "plugin" not in args.optimizer:
            with sok.OptimizerScope(emb_var):
                emb_opt.apply_gradients(zip(emb_grads, emb_var),
                                        experimental_aggregate_gradients=False)
        else:
            emb_opt.apply_gradients(zip(emb_grads, emb_var),
                                    experimental_aggregate_gradients=False)

        with tf.control_dependencies(emb_grads):

            grads = [hvd.allreduce(grad) for grad in grads]
            dense_opt.apply_gradients(zip(grads, other_var))

            if first_batch:
                hvd.broadcast_variables(other_var, root_rank=0)
                hvd.broadcast_variables(dense_opt.variables(), root_rank=0)

            total_loss = hvd.allreduce(replica_loss)
        return total_loss, all_vectors

    sok_results = list()
    for i, (inputs, labels) in enumerate(dataset):
        if args.stop_iter >= 0 and i >= args.stop_iter:
            break

        total_loss, all_vectors = _train_step(inputs, labels, 0 == i)
        print("[INFO]: Iteration: {}, loss={}".format(i, total_loss))

        sok_results.append(all_vectors)
    return sok_results