Пример #1
0
    def _call_alltoall_experimental_columnwise(self, embedding_outputs, bottom_mlp_out):
        bottom_part_output = tf.concat(embedding_outputs, axis=1)

        global_batch = tf.shape(bottom_part_output)[0]
        world_size = hvd.size()
        local_batch = global_batch // world_size
        num_tables = len(self.table_sizes)

        alltoall_input = tf.transpose(bottom_part_output, perm=[0, 2, 1])
        alltoall_input = tf.reshape(alltoall_input, shape=[global_batch * self.local_embedding_dim,
                                                           num_tables])

        splits = [tf.shape(alltoall_input)[0] // world_size] * world_size

        alltoall_output = hvd.alltoall(tensor=alltoall_input, splits=splits, ignore_name_scope=True)[0]

        alltoall_output = tf.split(alltoall_output,
                                   num_or_size_splits=hvd.size(),
                                   axis=0)
        interaction_input = [tf.reshape(x, shape=[local_batch,
                                                  self.local_embedding_dim, num_tables]) for x in alltoall_output]

        interaction_input = tf.concat(interaction_input, axis=1)  # shape=[local_batch, vector_dim, num_tables]
        interaction_input = tf.transpose(interaction_input,
                                         perm=[0, 2, 1])  # shape=[local_batch, num_tables, vector_dim]

        if self.running_bottom_mlp:
            interaction_input = tf.concat([bottom_mlp_out,
                                           interaction_input],
                                          axis=1)  # shape=[local_batch, num_tables + 1, vector_dim]
        return interaction_input
Пример #2
0
    def _call_alltoall(self, embedding_outputs, bottom_mlp_out=None):
        num_tables = len(self.table_sizes)
        if bottom_mlp_out is not None and not self.data_parallel_bottom_mlp:
            bottom_part_output = tf.concat([bottom_mlp_out] + embedding_outputs,
                                           axis=1)
            num_tables += 1
        else:
            bottom_part_output = tf.concat(embedding_outputs, axis=1)

        global_batch = tf.shape(bottom_part_output)[0]
        world_size = hvd.size()
        local_batch = global_batch // world_size
        embedding_dim = self.embedding_dim

        alltoall_input = tf.reshape(bottom_part_output,
                                    shape=[global_batch * num_tables,
                                           embedding_dim])

        splits = [tf.shape(alltoall_input)[0] // world_size] * world_size

        alltoall_output = hvd.alltoall(tensor=alltoall_input, splits=splits, ignore_name_scope=True)[0]

        vectors_per_worker = [x * local_batch for x in self.rank_to_feature_count]
        alltoall_output = tf.split(alltoall_output,
                                   num_or_size_splits=vectors_per_worker,
                                   axis=0)
        interaction_input = [tf.reshape(x, shape=[local_batch, -1, embedding_dim]) for x in alltoall_output]

        if self.data_parallel_bottom_mlp:
            interaction_input = [bottom_mlp_out] + interaction_input

        interaction_input = tf.concat(interaction_input, axis=1)  # shape=[local_batch, num_vectors, vector_dim]
        return interaction_input