コード例 #1
0
    def __init__(self, gpus, batch_size, embedding_type, vocabulary_size,
                 slot_num, embedding_vec_size, embedding_type, opt_hparam,
                 update_type, atomic_update, max_feature_num, max_nnz,
                 combiner, gpu_count):
        super(PluginSparseModel, self).__init__()

        self.vocabulary_size_each_gpu = (vocabulary_size // gpu_count) + 1
        self.slot_num = slot_num
        self.embedding_vec_size = embedding_vec_size
        self.embedding_type = embedding_type
        self.optimizer_type = optimizer
        self.opt_hparam = opt_hparam
        self.update_type = update_type
        self.atomic_update = atomic_update
        self.max_feature_num = max_feature_num
        self.max_nnz = max_nnz
        self.combiner = combiner
        self.gpu_count = gpu_count

        # Make use init() only be called once. It will create resource manager for embedding_plugin.
        hugectr_tf_ops.init(visiable_gpus=gpus,
                            seed=123,
                            key_type='int64',
                            value_type='float',
                            batch_size=batch_size,
                            batch_size_eval=len(gpus))

        # create one embedding layer, and its embedding_name will be unique if there are more than one embedding layer.
        self.embedding_name = hugectr_tf_ops.create_embedding(
            initializer,
            name_=name,
            embedding_type=self.embedding_type,
            optimizer_type=self.optimizer_type,
            max_vocabulary_size_per_gpu=self.vocabulary_size_each_gpu,
            opt_hparams=self.opt_hparam,
            update_type=self.update_type,
            atomic_update=self.atomic_update,
            slot_num=self.slot_num,
            max_nnz=self.max_nnz,
            max_feature_num=self.max_feature_num,
            embedding_vec_size=self.embedding_vec_size,
            combiner=self.combiner)
コード例 #2
0
    def __init__(self, gpus, batch_size, embedding_type='distributed', fprop_version='v3'):
        super(PluginSparseModel, self).__init__()

        hugectr_tf_ops.init(visiable_gpus=gpus, seed=123, key_type='int64', value_type='float',
                     batch_size=batch_size, batch_size_eval=len(gpus))

        self.embedding_layer = PluginEmbedding(vocabulary_size=1737710,
                                               slot_num=26,
                                               embedding_vec_size=32,
                                               gpu_count=len(gpus),
                                               initializer=False,
                                               name='plugin_embedding',
                                               embedding_type=embedding_type,
                                               optimizer='Adam',
                                               opt_hparam=[0.1, 0.9, 0.99, 1e-3],
                                               update_type='LazyGlobal',
                                               combiner='sum',
                                               fprop_version=fprop_version)

        if (fprop_version == 'v3'):
            self.call_func = self.call_v3
        elif (fprop_version == 'v4'):
            self.call_func = self.call_v4
コード例 #3
0
    def _fprop_v4_VS_tf():
        print("[INFO]: Testing fprop_v4 vs tf...")
        if vocabulary_size < slot_num:
            raise RuntimeError("vocabulary_size must > slot.")
        with tf.GradientTape(persistent=True) as tape:
            # initial embedding table
            init_value = np.float32(
                np.random.normal(loc=0,
                                 scale=1,
                                 size=(vocabulary_size, embedding_vec_size)))
            # input keys
            # TODO: Keys in different slots should be unique.
            input_keys = np.ones(shape=(batch_size, slot_num, max_nnz),
                                 dtype=np.int64) * -1
            each_slot = vocabulary_size // slot_num
            nnz_0_num = 0
            for batch_id in range(batch_size):
                for slot_id in range(slot_num):
                    nnz = np.random.randint(
                        low=nnz_0_num, high=max_nnz + 1,
                        size=1)[0]  # how many keys in this slot
                    if nnz == 0:
                        nnz_0_num = 1
                    if (embedding_type == 'distributed'):
                        keys = np.random.randint(low=slot_id * each_slot,
                                                 high=(slot_id + 1) *
                                                 each_slot,
                                                 size=nnz)
                    elif (embedding_type == "localized"):
                        # TODO: key should belong to that slot.
                        keys = []
                        while len(keys) < nnz:
                            key = np.random.randint(low=slot_id * each_slot,
                                                    high=(slot_id + 1) *
                                                    each_slot,
                                                    size=1)
                            if key % slot_num == slot_id:
                                keys.append(key)

                    input_keys[batch_id, slot_id, 0:nnz] = keys

            # hugectr ops
            hugectr_tf_ops.init(visiable_gpus=gpus,
                                key_type='int64',
                                value_type='float',
                                batch_size=batch_size,
                                batch_size_eval=len(gpus))
            embedding_name = hugectr_tf_ops.create_embedding(
                init_value=init_value,
                opt_hparams=[0.1, 0.9, 0.99, 1e-5],
                name_='hugectr_embedding',
                max_vocabulary_size_per_gpu=(vocabulary_size // len(gpus)) * 2
                + 1,
                slot_num=slot_num,
                embedding_vec_size=embedding_vec_size,
                max_feature_num=slot_num * max_nnz,
                embedding_type=embedding_type,
                max_nnz=max_nnz,
                update_type='Global')
            reshape_input_keys = np.reshape(input_keys, [-1, max_nnz])
            indices = tf.where(reshape_input_keys != -1)
            values = tf.gather_nd(reshape_input_keys, indices)
            row_indices = tf.transpose(indices, perm=[1, 0])[0]

            bp_trigger = tf.Variable(initial_value=1.0,
                                     trainable=True,
                                     dtype=tf.float32)

            hugectr_forward = hugectr_tf_ops.fprop_v4(
                embedding_name=embedding_name,
                row_indices=row_indices,
                values=values,
                bp_trigger=bp_trigger,
                is_training=True,
                output_shape=[batch_size, slot_num, max_nnz])
            # print("hugectr_results=\n", hugectr_forward)

            # tf ops
            reshape_input_keys = np.reshape(input_keys, [-1, max_nnz])
            tf_indices = tf.where(reshape_input_keys != -1)
            tf_values = tf.gather_nd(reshape_input_keys, tf_indices)
            sparse_tensor = tf.sparse.SparseTensor(tf_indices, tf_values,
                                                   reshape_input_keys.shape)

            # FIXME: if there are too more nnz=0 slots, tf.nn.embedding_lookup_sparse may get wrong results?
            tf_embedding_layer = OriginalEmbedding(
                vocabulary_size=vocabulary_size,
                embedding_vec_size=embedding_vec_size,
                initializer=init_value,
                combiner='sum',
                gpus=gpus)

            tf_forward = tf_embedding_layer(
                sparse_tensor,
                output_shape=[batch_size, slot_num, embedding_vec_size])
            # print("tf_results=\n", tf_forward)

            # compare first forward result
            try:
                tf.debugging.assert_near(hugectr_forward, tf_forward)
            except tf.errors.InvalidArgumentError as error:
                raise error

            print(
                "[INFO]: The results from HugeCTR and tf in the first forward propagation are the same."
            )

        # backward
        hugectr_grads = tape.gradient(hugectr_forward, bp_trigger)

        tf_opt = tf.keras.optimizers.Adam(learning_rate=0.1,
                                          beta_1=0.9,
                                          beta_2=0.99,
                                          epsilon=1e-5)
        tf_grads = tape.gradient(tf_forward,
                                 tf_embedding_layer.trainable_weights)
        tf_opt.apply_gradients(
            zip(tf_grads, tf_embedding_layer.trainable_weights))

        # compare second forward result
        hugectr_forward_2 = hugectr_tf_ops.fprop_v4(
            embedding_name=embedding_name,
            row_indices=row_indices,
            values=values,
            bp_trigger=bp_trigger,
            is_training=True,
            output_shape=[batch_size, slot_num, max_nnz])

        tf_forward_2 = tf_embedding_layer(
            sparse_tensor,
            output_shape=[batch_size, slot_num, embedding_vec_size])

        # print("hugectr 2:\n", hugectr_forward_2)
        # print("tf 2:\n", tf_forward_2)
        try:
            tf.debugging.assert_near(hugectr_forward_2,
                                     tf_forward_2,
                                     rtol=1e-4,
                                     atol=1e-5)
        except tf.errors.InvalidArgumentError as error:
            raise error

        print(
            "[INFO]: The results from HugeCTR and tf in the second forward propagation are the same."
        )
        hugectr_tf_ops.reset()
コード例 #4
0
ファイル: model.py プロジェクト: byshiue/HugeCTR
    def __init__(
            self,
            vocabulary_size,
            embedding_vec_size,
            which_embedding,
            dropout_rate,  # list of float
            deep_layers,  # list of int
            initializer,
            gpus,
            batch_size,
            batch_size_eval,
            embedding_type='localized',
            slot_num=1,
            seed=123):
        super(DeepFM_PluginEmbedding, self).__init__()
        tf.keras.backend.clear_session()
        tf.compat.v1.set_random_seed(seed)

        self.vocabulary_size = vocabulary_size
        self.embedding_vec_size = embedding_vec_size
        self.which_embedding = which_embedding
        self.dropout_rate = dropout_rate
        self.deep_layers = deep_layers
        self.gpus = gpus
        self.batch_size = batch_size
        self.batch_size_eval = batch_size_eval
        self.slot_num = slot_num
        self.embedding_type = embedding_type

        if isinstance(initializer, str):
            initializer = False
        hugectr_tf_ops.init(visiable_gpus=gpus,
                            seed=seed,
                            key_type='int64',
                            value_type='float',
                            batch_size=batch_size,
                            batch_size_eval=batch_size_eval)
        self.plugin_embedding_layer = PluginEmbedding(
            vocabulary_size=vocabulary_size,
            slot_num=slot_num,
            embedding_vec_size=embedding_vec_size + 1,
            embedding_type=embedding_type,
            gpu_count=len(gpus),
            initializer=initializer)
        self.deep_dense = []
        for i, deep_units in enumerate(self.deep_layers):
            self.deep_dense.append(
                tf.keras.layers.Dense(units=deep_units,
                                      activation=None,
                                      use_bias=True,
                                      kernel_initializer='glorot_normal',
                                      bias_initializer='glorot_normal'))
            self.deep_dense.append(tf.keras.layers.Dropout(dropout_rate[i]))
        self.deep_dense.append(
            tf.keras.layers.Dense(
                units=1,
                activation=None,
                use_bias=True,
                kernel_initializer='glorot_normal',
                bias_initializer=tf.constant_initializer(0.01)))
        self.add_layer = tf.keras.layers.Add()
        self.y_act = tf.keras.layers.Activation(activation='sigmoid')

        self.dense_multi = Multiply(1)
        self.dense_embedding = Multiply(self.embedding_vec_size)

        self.concat_1 = tf.keras.layers.Concatenate()
        self.concat_2 = tf.keras.layers.Concatenate()
コード例 #5
0
def create_dataset(dataset_names,
                   feature_desc,
                   batch_size,
                   n_epochs=-1,
                   distribute_keys=False,
                   gpu_count=1,
                   embedding_type='localized',
                   use_which_device='cpu'):
    """
    This function is used to get batch of data from tfrecords file.
    #arguments:
        dataset_names: list of strings
        feature_des: feature description of the features in one sample.
    """
    # num_threads = tf.data.experimental.AUTOTUNE
    if use_which_device == 'gpu':
        global embedding_name
        hugectr_tf_ops.init(visiable_gpus=[i for i in range(gpu_count)],
                            key_type='int64',
                            value_type='float',
                            batch_size=batch_size,
                            batch_size_eval=gpu_count)
        embedding_name = hugectr_tf_ops.create_embedding(
            init_value=False,
            embedding_type=embedding_type,
            opt_hparams=[1.0] * 4,
            slot_num=26,
            max_nnz=1,
            max_feature_num=26 * 1,
            name_='hugectr_embedding')

    num_threads = 32
    dataset = tf.data.TFRecordDataset(filenames=dataset_names,
                                      compression_type=None,
                                      buffer_size=100 * 1024 * 1024,
                                      num_parallel_reads=1)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.repeat(n_epochs)

    # data preprocessing
    @tf.function
    def _parse_fn(serialized, feature_desc, distribute_keys=False):
        with tf.name_scope("datareader_map"):
            features = tf.io.parse_example(serialized, feature_desc)
            # split into label + dense + cate
            label = features['label']
            dense = tf.TensorArray(dtype=tf.int64,
                                   size=utils.NUM_INTEGER_COLUMNS,
                                   dynamic_size=False,
                                   element_shape=(batch_size, ))
            cate = tf.TensorArray(dtype=tf.int64,
                                  size=utils.NUM_CATEGORICAL_COLUMNS,
                                  dynamic_size=False,
                                  element_shape=(batch_size, 1))

            for idx in range(utils.NUM_INTEGER_COLUMNS):
                dense = dense.write(idx, features["I" + str(idx + 1)])

            for idx in range(utils.NUM_CATEGORICAL_COLUMNS):
                cate = cate.write(idx, features["C" + str(idx + 1)])

            dense = tf.transpose(dense.stack(),
                                 perm=[1, 0])  # [batchsize, dense_dim]
            cate = tf.transpose(cate.stack(),
                                perm=[1, 0, 2])  # [batchsize, slot_num, nnz]

            # distribute cate-keys to each GPU
            if distribute_keys:

                def _distribute_keys_cpu(cate):
                    # --------------------- convert cate to CSR on CPU ---------------------------- #
                    indices = tf.where(cate != -1)
                    values = tf.gather_nd(cate, indices)
                    row_offsets, value_tensors, nnz_array = hugectr_tf_ops.distribute_keys(
                        indices,
                        values,
                        cate.shape,
                        gpu_count=gpu_count,
                        embedding_type=embedding_type,
                        max_nnz=1)

                    place_holder = tf.sparse.SparseTensor(
                        [[0, 0]], tf.constant([0], dtype=tf.int64),
                        [batch_size * utils.NUM_CATEGORICAL_COLUMNS, 1])
                    return label, dense, tf.stack(row_offsets), tf.stack(
                        value_tensors), nnz_array, place_holder

                def _distribute_keys_gpu(cate):
                    # ---------------------- convert cate to CSR On GPU ------------------------ #
                    cate = tf.reshape(cate, [-1, 1])
                    indices = tf.where(cate != -1)
                    row_indices = tf.transpose(indices, perm=[1, 0])[0]
                    values = tf.gather_nd(cate, indices)

                    nnz_array = tf.constant(0, dtype=tf.int64)

                    place_holder = tf.sparse.SparseTensor(
                        [[0, 0]], tf.constant([0], dtype=tf.int64),
                        [batch_size * utils.NUM_CATEGORICAL_COLUMNS, 1])
                    return label, dense, row_indices, values, nnz_array, place_holder

                if "cpu" == use_which_device:
                    return _distribute_keys_cpu(cate)
                else:
                    return _distribute_keys_gpu(cate)

            else:
                reshape_cate = tf.reshape(cate, [-1, cate.shape[-1]])
                indices = tf.where(reshape_cate != -1)
                values = tf.gather_nd(reshape_cate, indices)
                sparse_tensor = tf.sparse.SparseTensor(
                    indices=indices,
                    values=values,
                    dense_shape=reshape_cate.shape)
                place_holder = tf.constant(1, dtype=tf.int64)
                return label, dense, place_holder, place_holder, place_holder, sparse_tensor

    dataset = dataset.map(
        lambda serialized: _parse_fn(
            serialized, feature_desc,
            tf.convert_to_tensor(distribute_keys, dtype=tf.bool)),
        num_parallel_calls=1,  # tf.data.experimental.AUTOTUNE
        deterministic=False)
    # dataset = dataset.prefetch(buffer_size=16) # tf.data.experimental.AUTOTUNE

    return dataset
コード例 #6
0
def tf_distribute_keys_fprop_v3(embedding_type):
    with tf.GradientTape() as tape:
        with tf.device("/gpu:0"):

            vocabulary_size = 8
            slot_num = 3
            embedding_vec_size = 4

            init_value = np.float32([
                i for i in range(1, vocabulary_size * embedding_vec_size + 1)
            ]).reshape(vocabulary_size, embedding_vec_size)
            # init_value = False
            # print(init_value)

            hugectr_tf_ops.init(visiable_gpus=[0, 1, 3, 4],
                                seed=123,
                                key_type='int64',
                                value_type='float',
                                batch_size=4,
                                batch_size_eval=4)
            embedding_name = hugectr_tf_ops.create_embedding(
                init_value=init_value,
                opt_hparams=[1.0, 0.9, 0.99, 1e-3],
                name_='test_embedding',
                max_vocabulary_size_per_gpu=1737710,
                slot_num=slot_num,
                embedding_vec_size=embedding_vec_size,
                max_feature_num=4,
                embedding_type=embedding_type,
                max_nnz=2)

            keys = np.array(
                [[[0, -1], [1, -1], [2, 6]], [[0, -1], [1, -1], [-1, -1]],
                 [[0, -1], [1, -1], [6, -1]], [[0, -1], [1, -1], [2, -1]]],
                dtype=np.int64)

            row_offsets, value_tensors, nnz_array = _distribute_kyes(
                tf.convert_to_tensor(keys),
                gpu_count=4,
                embedding_type=embedding_type)
            print("row_ptrs", row_offsets)
            print("\nvalues", value_tensors)
            print("\n", nnz_array)

            row_offsets, value_tensors, nnz_array = _distribute_kyes(
                tf.convert_to_tensor(keys),
                gpu_count=4,
                embedding_type=embedding_type)
            print("\nrow_ptrs", row_offsets)
            print("\nvalues", value_tensors)
            print("\n", nnz_array)
            # print("\n", _distribute_kyes.pretty_printed_concrete_signatures(), "\n")

            bp_trigger = tf.Variable(
                initial_value=[1.0, 2.0],
                trainable=True,
                dtype=tf.float32,
                name='embedding_plugin_bprop_trigger')  # must be trainable

            forward_result = hugectr_tf_ops.fprop_v3(
                embedding_name=embedding_name,
                row_offsets=row_offsets,
                nnz_array=nnz_array,
                value_tensors=value_tensors,
                is_training=True,
                bp_trigger=bp_trigger,
                output_shape=[4, slot_num, embedding_vec_size])
            print("first step: \n", forward_result)

            grads = tape.gradient(forward_result, bp_trigger)

            forward_result = hugectr_tf_ops.fprop_v3(
                embedding_name=embedding_name,
                row_offsets=row_offsets,
                nnz_array=nnz_array,
                value_tensors=value_tensors,
                is_training=False,
                bp_trigger=bp_trigger,
                output_shape=[4, slot_num, embedding_vec_size])
            print("second step: \n", forward_result)
コード例 #7
0
def test_forward_distribute_keys_v4(embedding_type):
    with tf.GradientTape() as tape:
        with tf.device("/gpu:0"):

            vocabulary_size = 8
            slot_num = 3
            embedding_vec_size = 4

            init_value = np.float32([
                i for i in range(1, vocabulary_size * embedding_vec_size + 1)
            ]).reshape(vocabulary_size, embedding_vec_size)
            # init_value = False
            # print(init_value)

            hugectr_tf_ops.init(visiable_gpus=[0, 1, 3, 4],
                                seed=123,
                                key_type='int64',
                                value_type='float',
                                batch_size=4,
                                batch_size_eval=4)
            embedding_name = hugectr_tf_ops.create_embedding(
                init_value=init_value,
                opt_hparams=[1.0, 0.9, 0.99, 1e-3],
                name_='test_embedding',
                max_vocabulary_size_per_gpu=1737710,
                slot_num=slot_num,
                embedding_vec_size=embedding_vec_size,
                max_feature_num=4,
                embedding_type=embedding_type,
                max_nnz=2)

            keys = np.array(
                [[[0, -1], [1, -1], [2, 6]], [[0, -1], [1, -1], [-1, -1]],
                 [[0, -1], [1, -1], [6, -1]], [[0, -1], [1, -1], [2, -1]]],
                dtype=np.int64)

            sparse_indices = tf.where(keys != -1)  #[N, ndims]
            values = tf.gather_nd(keys, sparse_indices)  # [N]

            row_offsets, value_tensors, nnz_array = hugectr_tf_ops.distribute_keys_v4(
                all_keys=keys,
                gpu_count=4,
                embedding_type=embedding_type,
                max_nnz=2,
                batch_size=4,
                slot_num=3)
            print("row_offsets = ", row_offsets, "\n")
            print("value_tensors = ", value_tensors, "\n")
            print("nnz_array = ", nnz_array, "\n")

            bp_trigger = tf.Variable(
                initial_value=[1.0, 2.0],
                trainable=True,
                dtype=tf.float32,
                name='embedding_plugin_bprop_trigger')  # must be trainable

            forward_result = hugectr_tf_ops.fprop_v2(
                embedding_name=embedding_name,
                row_offsets=row_offsets,
                nnz_array=nnz_array,
                value_tensors=value_tensors,
                is_training=True,
                bp_trigger=bp_trigger,
                output_shape=[4, slot_num, embedding_vec_size])
            print("first step: \n", forward_result)

            grads = tape.gradient(forward_result, bp_trigger)

            forward_result = hugectr_tf_ops.fprop_v2(
                embedding_name=embedding_name,
                row_offsets=row_offsets,
                nnz_array=nnz_array,
                value_tensors=value_tensors,
                is_training=False,
                bp_trigger=bp_trigger,
                output_shape=[4, slot_num, embedding_vec_size])
            print("second step: \n", forward_result)
コード例 #8
0
def test(embedding_type):
    with tf.GradientTape() as tape:
        with tf.device("/gpu:0"):

            vocabulary_size = 8
            slot_num = 3
            embedding_vec_size = 4

            init_value = np.float32([
                i for i in range(1, vocabulary_size * embedding_vec_size + 1)
            ]).reshape(vocabulary_size, embedding_vec_size)
            # init_value = False
            # print(init_value)

            hugectr_tf_ops.init(visiable_gpus=[0, 1, 3, 4],
                                seed=123,
                                key_type='uint32',
                                value_type='float',
                                batch_size=4,
                                batch_size_eval=4)
            embedding_name = hugectr_tf_ops.create_embedding(
                init_value=init_value,
                opt_hparams=[0.1, 0.9, 0.99, 1e-3],
                name_='test_embedding',
                max_vocabulary_size_per_gpu=5,
                slot_num=slot_num,
                embedding_vec_size=embedding_vec_size,
                max_feature_num=4,
                embedding_type=embedding_type,
                max_nnz=2)
            # print(embedding_name)
            # embedding_name = hugectr_tf_ops.create_embedding(init_value=init_value, opt_hparams=[0.001, 0.9, 0.99, 1e-3], name_='test_embedding',
            #                                           max_vocabulary_size_per_gpu=5, slot_num=slot_num, embedding_vec_size=embedding_vec_size,
            #                                           max_feature_num=4)
            # print(embedding_name)
            # embedding_name = hugectr_tf_ops.create_embedding(init_value=init_value, opt_hparams=[0.001, 0.9, 0.99, 1e-3], name_='test_embedding',
            #                                           max_vocabulary_size_per_gpu=5, slot_num=slot_num, embedding_vec_size=embedding_vec_size,
            #                                           max_feature_num=4)
            # print(embedding_name)

            keys = np.array(
                [[[0, -1, -1, -1], [1, -1, -1, -1], [2, 6, -1, -1]],
                 [[0, -1, -1, -1], [1, -1, -1, -1], [-1, -1, -1, -1]],
                 [[0, -1, -1, -1], [1, -1, -1, -1], [6, -1, -1, -1]],
                 [[0, -1, -1, -1], [1, -1, -1, -1], [2, -1, -1, -1]]],
                dtype=np.int64)

            sparse_indices = tf.where(keys != -1)  #[N, ndims]
            values = tf.gather_nd(keys, sparse_indices)  # [N]
            # print("sparse_indices = ", sparse_indices)
            # print("values = ", values)

            bp_trigger = tf.Variable(
                initial_value=[1.0, 2.0],
                trainable=True,
                dtype=tf.float32,
                name='embedding_plugin_bprop_trigger')  # must be trainable

            forward_result = hugectr_tf_ops.fprop(
                embedding_name=embedding_name,
                sparse_indices=sparse_indices,
                values=values,
                dense_shape=keys.shape,
                output_type=tf.float32,
                is_training=True,
                bp_trigger=bp_trigger)
            print("first step: \n", forward_result)

            grads = tape.gradient(forward_result, bp_trigger)

            forward_result = hugectr_tf_ops.fprop(
                embedding_name=embedding_name,
                sparse_indices=sparse_indices,
                values=values,
                dense_shape=keys.shape,
                output_type=tf.float32,
                is_training=False,
                bp_trigger=bp_trigger)
            print("second step: \n", forward_result)

            # tf embedding lookup op
            new_keys = np.reshape(keys, newshape=(-1, keys.shape[-1]))

            indices = tf.where(new_keys != -1)
            values = tf.gather_nd(new_keys, indices)
            sparse_tensor = tf.sparse.SparseTensor(indices, values,
                                                   new_keys.shape)

            tf_forward = tf.nn.embedding_lookup_sparse(init_value,
                                                       sparse_tensor,
                                                       sp_weights=None,
                                                       combiner="sum")
            print("tf: \n", tf_forward)