def _call_alltoall_experimental_columnwise(self, embedding_outputs, bottom_mlp_out): bottom_part_output = tf.concat(embedding_outputs, axis=1) global_batch = tf.shape(bottom_part_output)[0] world_size = hvd.size() local_batch = global_batch // world_size num_tables = len(self.table_sizes) alltoall_input = tf.transpose(bottom_part_output, perm=[0, 2, 1]) alltoall_input = tf.reshape(alltoall_input, shape=[global_batch * self.local_embedding_dim, num_tables]) splits = [tf.shape(alltoall_input)[0] // world_size] * world_size alltoall_output = hvd.alltoall(tensor=alltoall_input, splits=splits, ignore_name_scope=True)[0] alltoall_output = tf.split(alltoall_output, num_or_size_splits=hvd.size(), axis=0) interaction_input = [tf.reshape(x, shape=[local_batch, self.local_embedding_dim, num_tables]) for x in alltoall_output] interaction_input = tf.concat(interaction_input, axis=1) # shape=[local_batch, vector_dim, num_tables] interaction_input = tf.transpose(interaction_input, perm=[0, 2, 1]) # shape=[local_batch, num_tables, vector_dim] if self.running_bottom_mlp: interaction_input = tf.concat([bottom_mlp_out, interaction_input], axis=1) # shape=[local_batch, num_tables + 1, vector_dim] return interaction_input
def _call_alltoall(self, embedding_outputs, bottom_mlp_out=None): num_tables = len(self.table_sizes) if bottom_mlp_out is not None and not self.data_parallel_bottom_mlp: bottom_part_output = tf.concat([bottom_mlp_out] + embedding_outputs, axis=1) num_tables += 1 else: bottom_part_output = tf.concat(embedding_outputs, axis=1) global_batch = tf.shape(bottom_part_output)[0] world_size = hvd.size() local_batch = global_batch // world_size embedding_dim = self.embedding_dim alltoall_input = tf.reshape(bottom_part_output, shape=[global_batch * num_tables, embedding_dim]) splits = [tf.shape(alltoall_input)[0] // world_size] * world_size alltoall_output = hvd.alltoall(tensor=alltoall_input, splits=splits, ignore_name_scope=True)[0] vectors_per_worker = [x * local_batch for x in self.rank_to_feature_count] alltoall_output = tf.split(alltoall_output, num_or_size_splits=vectors_per_worker, axis=0) interaction_input = [tf.reshape(x, shape=[local_batch, -1, embedding_dim]) for x in alltoall_output] if self.data_parallel_bottom_mlp: interaction_input = [bottom_mlp_out] + interaction_input interaction_input = tf.concat(interaction_input, axis=1) # shape=[local_batch, num_vectors, vector_dim] return interaction_input