def _hybrid_embedding(name, ids, embedding_size, vocab_size, hf_vocab_size): b, s = ids.shape ids = flow.flatten(ids) unique_ids, unique_ids_idx, _, _ = flow.experimental.unique_with_counts(ids) hf_vocab_size_constant = flow.constant(hf_vocab_size, dtype=flow.int32) hf_indices = flow.argwhere(flow.math.less(unique_ids, hf_vocab_size_constant)) lf_indices = flow.argwhere(flow.math.greater_equal(unique_ids, hf_vocab_size_constant)) hf_ids = flow.gather_nd(params=unique_ids, indices=hf_indices) lf_ids = flow.gather_nd(params=unique_ids, indices=lf_indices) hf_embedding_table = flow.get_variable( name=f'hf_{name}', shape=(hf_vocab_size, embedding_size), dtype=flow.float, initializer=flow.random_uniform_initializer(minval=-0.05, maxval=0.05), ) hf_embedding = flow.gather(params=hf_embedding_table, indices=hf_ids) lf_ids = lf_ids - hf_vocab_size_constant with flow.scope.placement('cpu', '0:0'): lf_embedding_table = flow.get_variable( name=f'lf_{name}', shape=(vocab_size - hf_vocab_size, embedding_size), dtype=flow.float, initializer=flow.random_uniform_initializer(minval=-0.05, maxval=0.05), ) lf_embedding = flow.gather(params=lf_embedding_table, indices=lf_ids) unique_embedding = flow.reshape(flow.zeros_like(unique_ids, dtype=flow.float), (-1, 1)) * flow.constant(0.0, dtype=flow.float, shape=(1,embedding_size)) unique_embedding = flow.tensor_scatter_nd_update(params=unique_embedding, updates=hf_embedding, indices=hf_indices) unique_embedding = flow.tensor_scatter_nd_update(params=unique_embedding, updates=lf_embedding, indices=lf_indices) unique_embedding = flow.gather(params=unique_embedding, indices=unique_ids_idx) unique_embedding = flow.cast_to_static_shape(unique_embedding) unique_embedding = flow.reshape(unique_embedding, shape=(b, s*embedding_size)) return unique_embedding
def _model(dense_fields, wide_sparse_fields, deep_sparse_fields): wide_sparse_fields = flow.parallel_cast( wide_sparse_fields, distribute=flow.distribute.broadcast()) wide_embedding_table = flow.get_variable( name='wide_embedding', shape=(FLAGS.wide_vocab_size, 1), initializer=flow.random_uniform_initializer(minval=-0.05, maxval=0.05), distribute=flow.distribute.split(0), ) wide_embedding = flow.gather(params=wide_embedding_table, indices=wide_sparse_fields) wide_embedding = flow.reshape(wide_embedding, shape=(-1, wide_embedding.shape[-1] * wide_embedding.shape[-2])) wide_scores = flow.math.reduce_sum(wide_embedding, axis=[1], keepdims=True) wide_scores = flow.parallel_cast( wide_scores, distribute=flow.distribute.split(0), gradient_distribute=flow.distribute.broadcast()) deep_sparse_fields = flow.parallel_cast( deep_sparse_fields, distribute=flow.distribute.broadcast()) deep_embedding_table = flow.get_variable( name='deep_embedding', shape=(FLAGS.deep_vocab_size, FLAGS.deep_embedding_vec_size), initializer=flow.random_uniform_initializer(minval=-0.05, maxval=0.05), distribute=flow.distribute.split(1), ) deep_embedding = flow.gather(params=deep_embedding_table, indices=deep_sparse_fields) deep_embedding = flow.parallel_cast( deep_embedding, distribute=flow.distribute.split(0), gradient_distribute=flow.distribute.split(2)) deep_embedding = flow.reshape(deep_embedding, shape=(-1, deep_embedding.shape[-1] * deep_embedding.shape[-2])) deep_features = flow.concat([deep_embedding, dense_fields], axis=1) for idx, units in enumerate(DEEP_HIDDEN_UNITS): deep_features = flow.layers.dense( deep_features, units=units, kernel_initializer=flow.glorot_uniform_initializer(), bias_initializer=flow.constant_initializer(0.0), activation=flow.math.relu, name='fc' + str(idx + 1)) deep_features = flow.nn.dropout(deep_features, rate=FLAGS.deep_dropout_rate) deep_scores = flow.layers.dense( deep_features, units=1, kernel_initializer=flow.glorot_uniform_initializer(), bias_initializer=flow.constant_initializer(0.0), name='fc' + str(len(DEEP_HIDDEN_UNITS) + 1)) scores = wide_scores + deep_scores return scores
def _GatherIndexes(sequence_blob, positions_blob, seq_length, hidden_size): output = flow.gather(params=sequence_blob, indices=positions_blob, axis=2, batch_dims=2) output = flow.reshape(output, [-1, hidden_size]) return output
def test_fn( x: flow.typing.Numpy.Placeholder((1024, 4)), indices: flow.typing.Numpy.Placeholder(shape=(12, ), dtype=flow.int32), ) -> flow.typing.Numpy: with flow.scope.placement("gpu", "0:0-3", (2, 2)): x = flow.hierarchical_parallel_cast(x, nd_sbp=["S(0)", "S(0)"]) indices = flow.hierarchical_parallel_cast(indices, nd_sbp=["B", "B"]) x = flow.hierarchical_parallel_cast(x, nd_sbp=["S(0)", "B"]) v = flow.get_variable( name="v", shape=(1024, 4), nd_sbp=["S(0)", "B"], initializer=flow.zeros_initializer(), ) x = x + v indices = flow.hierarchical_parallel_cast(indices, nd_sbp=["B", "S(0)"]) x = flow.gather(x, indices) x = flow.hierarchical_parallel_cast( x, nd_sbp=["B", "S(0)"], grad_mode="manual", grad_nd_sbp=["B", "S(0)"], ) x = flow.math.relu(x) x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"]) x = flow.hierarchical_parallel_cast(x, nd_sbp=["B"]) flow.optimizer.SGD(flow.optimizer.PiecewiseConstantScheduler([], [0.001]), momentum=0).minimize(x) return x
def margin_loss(loss_m1, loss_m2, loss_m3, s, inputs, labels): inputs = inputs * s class_num = inputs.shape[1] if loss_m1 != 1.0 or loss_m2 != 0.0 or loss_m3 != 0.0: if loss_m1 == 1.0 and loss_m2 == 0.0: s_m = s * loss_m3 gt_one_hot = flow.one_hot( labels, depth=class_num, on_value=s_m, off_value=0.0, dtype=flow.float ) inputs = inputs - gt_one_hot else: labels_expand = flow.reshape(labels, (labels.shape[0], 1)) zy = flow.gather(inputs, labels_expand, batch_dims=1) cos_t = zy * (1 / s) t = flow.math.acos(cos_t) if loss_m1 != 1.0: t = t * loss_m1 if loss_m2 > 0.0: t = t + loss_m2 body = flow.math.cos(t) if loss_m3 > 0.0: body = body - loss_m3 new_zy = body * s diff = new_zy - zy gt_one_hot = flow.one_hot( labels, depth=class_num, on_value=1.0, off_value=0.0, dtype=flow.float ) body = gt_one_hot * diff inputs = inputs + body return inputs
def gather_model_parallel_fw_job( params: oft.Numpy.Placeholder(params_shape, dtype=flow.float), indices: oft.Numpy.Placeholder(indices_shape, dtype=flow.int32), ): with flow.scope.placement(device_type, "0:0-3"): params = params.with_distribute(flow.distribute.split(split_axis)) indices = indices.with_distribute(flow.distribute.broadcast()) return flow.gather(params=params, indices=indices, axis=axis)
def _embedding(name, ids, embedding_size, vocab_size, split_axis=0): ids = flow.parallel_cast(ids, distribute=flow.distribute.broadcast()) params = flow.get_variable( name=name, shape=(vocab_size, embedding_size), initializer=flow.random_uniform_initializer(minval=-0.05, maxval=0.05), distribute=flow.distribute.split(split_axis), ) embedding = flow.gather(params=params, indices=ids) embedding = flow.reshape(embedding, shape=(-1, embedding.shape[-1] * embedding.shape[-2])) return embedding
def __call__(self, tokens): """ tokens shape: (batch_size, seq_length) dp sbp: S(0) 2d sbp: [S(0), B] """ assert len(tokens.shape) == 2 assert tokens.shape[0] == self.batch_size assert tokens.shape[1] == self.seq_length with distribute.layer_placement_scope(0): wpe = flow.get_variable( "wpe", shape=(self.seq_length, self.hidden_size), initializer=self.wpe_initializer, nd_sbp=distribute.get_wpe_parallel_dist(), ) wte = flow.get_variable( "wte", shape=(self.vocab_size, self.hidden_size), initializer=self.wte_initializer, nd_sbp=distribute.get_wte_parallel_dist(), ) # 2d sbp sig: [B, S(0)] x [S(0), B] -> [S(0), P] -> [S(0), B] # grad 2d sbp sig: [S(0), B](dy) x [S(0), B](index) x [B, S(0)](x) # -> [P, S(0)](dx) -> [B, S(0)](wte_grad) if self.use_fp16: h = flow.gather(flow.amp_white_identity(wte), tokens) wpe = flow.amp_white_identity(wpe) else: h = flow.gather(wte, tokens) h = distribute.forward_p2b_parallel_cast(h) + wpe h = flow.nn.dropout(h, rate=self.embedding_dropout_rate, name="embd_dropout") return h, wte
def _EmbeddingLookup(input_ids_blob, vocab_size, embedding_size=128, initializer_range=0.02, word_embedding_name="word_embeddings"): embedding_table = flow.get_variable( name=word_embedding_name, shape=[vocab_size, embedding_size], dtype=flow.float, initializer=CreateInitializer(initializer_range)) output = flow.gather(params=embedding_table, indices=input_ids_blob, axis=0) return output, embedding_table
def do_gather(x_blob, i_blob): with flow.scope.placement(device_type, "0:0"): x = flow.get_variable( "params", shape=params.shape, dtype=flow.float32, initializer=flow.constant_initializer(0), ) x = x + x_blob y = flow.gather(x, i_blob, axis=axis, batch_dims=batch_dims) lr_scheduler = flow.optimizer.PiecewiseConstantScheduler([], [0.001]) flow.optimizer.SGD(lr_scheduler, momentum=0).minimize(y) flow.watch_diff(x, compare_fn) return y
def _EmbeddingPostprocessor( input_blob, seq_length, embedding_size, use_token_type=False, token_type_ids_blob=None, token_type_vocab_size=16, token_type_embedding_name="token_type_embeddings", use_position_embeddings=True, position_embedding_name="position_embeddings", initializer_range=0.02, max_position_embeddings=512, dropout_prob=0.1, ): output = input_blob if use_token_type: assert token_type_ids_blob is not None token_type_table = flow.get_variable( name=token_type_embedding_name, shape=[token_type_vocab_size, embedding_size], dtype=input_blob.dtype, initializer=CreateInitializer(initializer_range), ) token_type_embeddings = flow.gather(params=token_type_table, indices=token_type_ids_blob, axis=0) output = output + token_type_embeddings if use_position_embeddings: position_table = flow.get_variable( name=position_embedding_name, shape=[1, max_position_embeddings, embedding_size], dtype=input_blob.dtype, initializer=CreateInitializer(initializer_range), ) assert seq_length <= max_position_embeddings if seq_length != max_position_embeddings: position_table = flow.slice(position_table, begin=[None, 0, 0], size=[None, seq_length, -1]) output = output + position_table output = _LayerNorm(output, embedding_size) output = _Dropout(output, dropout_prob) return output
def test_fn( x: flow.typing.Numpy.Placeholder((1024, 1024)), indices: flow.typing.Numpy.Placeholder(shape=(64, ), dtype=flow.int32), ) -> flow.typing.Numpy: with flow.scope.placement("gpu", "0:0-3", (2, 2)): if src[0] == "S(0)": x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["S(0)", "S(0)"]) if src[1] == "S(0)": x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["S(0)", "S(0)"]) elif src[1] == "S(1)": x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "S(1)"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["S(0)", "B"]) elif src[1] == "P": x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "S(0)"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["S(0)", "B"]) elif src[1] == "B": x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["S(0)", "B"]) elif src[0] == "P": x = flow.hierarchical_parallel_cast(x, nd_sbp=["S(0)", "S(0)"]) indices = flow.hierarchical_parallel_cast(indices, nd_sbp=["B", "B"]) if src[1] == "S(0)": x = flow.hierarchical_parallel_cast(x, nd_sbp=["S(0)", "B"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["B", "S(0)"]) elif src[1] == "S(1)": x = flow.hierarchical_parallel_cast( x, nd_sbp=["S(0)", "S(1)"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["B", "B"]) elif src[1] == "P": x = flow.hierarchical_parallel_cast( x, nd_sbp=["S(0)", "S(0)"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["B", "B"]) elif src[1] == "B": x = flow.hierarchical_parallel_cast(x, nd_sbp=["S(0)", "B"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["B", "B"]) elif src[0] == "B": x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"]) indices = flow.hierarchical_parallel_cast(indices, nd_sbp=["B", "B"]) if src[1] == "S(0)": x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["B", "S(0)"]) elif src == "S(1)": x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "S(1)"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["B", "B"]) elif src == "P": x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "S(0)"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["B", "B"]) elif src == "B": x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"]) indices = flow.hierarchical_parallel_cast( indices, nd_sbp=["B", "B"]) else: raise NotImplementedError x = flow.gather(x, indices) x = flow.hierarchical_parallel_cast(x, nd_sbp=dst, name="gather_cast") if dst[0] == "S(0)": x = flow.hierarchical_parallel_cast(x, nd_sbp=["S(0)", "S(0)"]) elif dst[0] == "B": x = flow.hierarchical_parallel_cast(x, nd_sbp=["B", "B"]) elif dst[0] == "S(1)": x = flow.hierarchical_parallel_cast(x, nd_sbp=["S(1)", "S(1)"]) else: raise NotImplementedError x = flow.hierarchical_parallel_cast(x, nd_sbp=["B"]) return x
def batch_gather_job( x=flow.FixedTensorDef(input_shape, dtype=dtype), indices=flow.FixedTensorDef(indices_shape, dtype=flow.int32), ): return flow.gather(x, indices, batch_dims=axis)