def rnn(x, y_): ''' RNN model, for MNIST dataset. Parameters: x: Variable(hetu.gpu_ops.Node.Node), shape (N, dims) y_: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes) Return: loss: Variable(hetu.gpu_ops.Node.Node), shape (1,) y: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes) ''' print("Building RNN model...") diminput = 28 dimhidden = 128 dimoutput = 10 nsteps = 28 weight1 = init.random_normal(shape=(diminput, dimhidden), stddev=0.1, name='rnn_weight1') bias1 = init.random_normal(shape=(dimhidden, ), stddev=0.1, name='rnn_bias1') weight2 = init.random_normal(shape=(dimhidden + dimhidden, dimhidden), stddev=0.1, name='rnn_weight2') bias2 = init.random_normal(shape=(dimhidden, ), stddev=0.1, name='rnn_bias2') weight3 = init.random_normal(shape=(dimhidden, dimoutput), stddev=0.1, name='rnn_weight3') bias3 = init.random_normal(shape=(dimoutput, ), stddev=0.1, name='rnn_bias3') last_state = ad.Variable(value=np.zeros((1, )).astype(np.float32), name='initial_state', trainable=False) for i in range(nsteps): cur_x = ad.slice_op(x, (0, i * diminput), (-1, diminput)) h = ad.matmul_op(cur_x, weight1) h = h + ad.broadcastto_op(bias1, h) if i == 0: last_state = ad.broadcastto_op(last_state, h) s = ad.concat_op(h, last_state, axis=1) s = ad.matmul_op(s, weight2) s = s + ad.broadcastto_op(bias2, s) last_state = ad.relu_op(s) final_state = last_state x = ad.matmul_op(final_state, weight3) y = x + ad.broadcastto_op(bias3, x) loss = ad.softmaxcrossentropy_op(y, y_) loss = ad.reduce_mean_op(loss, [0]) return loss, y
def wdl_adult(X_deep, X_wide, y_): lr = 5 / 128 dim_wide = 809 dim_deep = 68 W = init.random_normal([dim_wide+20, 2], stddev=0.1, name="W") W1 = init.random_normal([dim_deep, 50], stddev=0.1, name="W1") b1 = init.random_normal([50], stddev=0.1, name="b1") W2 = init.random_normal([50, 20], stddev=0.1, name="W2") b2 = init.random_normal([20], stddev=0.1, name="b2") #deep Embedding = [] X_deep_input = None for i in range(8): Embedding_name = "Embedding_deep_" + str(i) Embedding.append(init.random_normal([50, 8], stddev=0.1, name=Embedding_name)) now = ad.embedding_lookup_op(Embedding[i], X_deep[i]) now = ad.array_reshape_op(now, (-1, 8)) if X_deep_input is None: X_deep_input = now else: X_deep_input = ad.concat_op(X_deep_input, now, 1) for i in range(4): now = ad.array_reshape_op(X_deep[i + 8], (-1, 1)) X_deep_input = ad.concat_op(X_deep_input, now, 1) mat1 = ad.matmul_op(X_deep_input, W1) add1 = mat1 + ad.broadcastto_op(b1, mat1) relu1= ad.relu_op(add1) dropout1 = relu1 #ad.dropout_op(relu1, 0.5) mat2 = ad.matmul_op(dropout1, W2) add2 = mat2 + ad.broadcastto_op(b2, mat2) relu2= ad.relu_op(add2) dropout2 = relu2 #ad.dropout_op(relu2, 0.5) dmodel=dropout2 # wide wmodel = ad.concat_op(X_wide, dmodel, 1) wmodel = ad.matmul_op(wmodel, W) prediction = wmodel loss = ad.softmaxcrossentropy_op(prediction, y_) loss = ad.reduce_mean_op(loss, [0]) opt = optimizer.SGDOptimizer(learning_rate=lr) train_op = opt.minimize(loss) return loss, prediction, y_, train_op
def cross_layer(x0, x1): # x0: input embedding feature (batch_size, 26 * embedding_size + 13) # x1: the output of last layer (batch_size, 26 * embedding_size + 13) embedding_len = 26 * 128 + 13 weight = init.random_normal(shape=(embedding_len, 1), stddev=0.01, name='weight') bias = init.random_normal(shape=(embedding_len, ), stddev=0.01, name='bias') x1w = ad.matmul_op(x1, weight) #(batch_size, 1) y = ad.mul_op(x0, ad.broadcastto_op(x1w, x0)) y = y + x1 + ad.broadcastto_op(bias, y) return y
def __call__(self, input, subgraph_size: list, use_sparse: list): """ Build the computation graph, return the output node split , in-graph message-passing, inter-graph message-passing , concat """ x = ad.matmul_op(input, self.weight) msg = x + ad.broadcastto_op(self.bias, x) output_nodes = [] msgs = [] split_at = 0 # message passing for each subgraph for i in range(self.npart): sliced_msg = ad.slice_op(node=msg, begin=(split_at, 0), size=(subgraph_size[i], self.out_features)) split_at += subgraph_size[i] msgs.append(sliced_msg) if use_sparse[i]: output = ad.csrmm_op(self.mp[i][i], sliced_msg) else: output = ad.matmul_op(self.mp[i][i], sliced_msg) output_nodes.append(output) # message passing between subgraphs for i in range(self.npart): for j in range(self.npart): if i == j: continue output_nodes[j] = output_nodes[j] + ad.csrmm_op( self.mp[i][j], msgs[i]) # concat all the remaining nodes result = output_nodes[0] for i in range(1, self.npart): result = ad.concat_op(result, output_nodes[i]) return result
def fc(x, shape): weight = init.random_normal(shape=shape, stddev=0.1) bias = init.random_normal(shape=shape[-1:], stddev=0.1) x = ad.array_reshape_op(x, (-1, shape[0])) x = ad.matmul_op(x, weight) y = x + ad.broadcastto_op(bias, x) return y
def fc(x, shape, name): weight = init.random_normal(shape=shape, stddev=0.1, name=name + '_weight') bias = init.random_normal(shape=shape[-1:], stddev=0.1, name=name + '_bias') x = ad.matmul_op(x, weight) x = x + ad.broadcastto_op(bias, x) return x
def train_hetu(args): with open(os.path.join(args.path, "meta.yml"), 'rb') as f: meta = yaml.load(f.read(), Loader=yaml.FullLoader) hidden_layer_size = args.hidden_size num_epoch = args.num_epoch rank = int(os.environ["WORKER_ID"]) nrank = int(os.environ["DMLC_NUM_WORKER"]) ctx = ndarray.gpu(rank) x_ = ad.Variable(name="x_") y_ = ad.Variable(name="y_") mask_ = ad.Variable(name="mask_") gcn1 = GraphSage(meta["feature"], hidden_layer_size, activation="relu", dropout=0.1) gcn2 = GraphSage(2*hidden_layer_size, hidden_layer_size, activation="relu", dropout=0.1) x = gcn1(x_) x = gcn2(x) W = initializers.xavier_uniform(shape=(2*hidden_layer_size, meta["class"])) B = initializers.zeros(shape=(meta["class"],)) x = ad.matmul_op(x, W) y = x + ad.broadcastto_op(B, x) loss = ad.softmaxcrossentropy_op(y, y_) loss = ad.mul_op(loss, mask_) loss = ad.reduce_mean_op(loss, [0]) opt = optimizer.SGDOptimizer(0.1) train_op = opt.minimize(loss) executor = ad.Executor([loss, y, train_op], ctx=ctx, comm_mode='PS') distributed.ps_init(rank, nrank) batch_size = 4000 with DistributedGraphSageSampler(args.path, batch_size, 2, 2, rank=rank, nrank=nrank) as sampler: epoch = 0 nnodes = 0 start = time.time() while True: g_sample, mask = sampler.sample() mp_val = mp_matrix(g_sample, ndarray.gpu(rank)) feed_dict = { gcn1.mp : mp_val, gcn2.mp : mp_val, mask_ : ndarray.array(mask, ctx=ctx), x_ : ndarray.array(g_sample.x, ctx=ctx), y_ : ndarray.array(convert_to_one_hot(g_sample.y, max_val=g_sample.num_classes), ctx=ctx) } loss_val, y_predicted, _ = executor.run(feed_dict = feed_dict) y_predicted = y_predicted.asnumpy().argmax(axis=1) acc = ((y_predicted == g_sample.y) * mask).sum() distributed.ps_get_worker_communicator().BarrierWorker() nnodes += batch_size if nnodes > meta["partition"]["nodes"][rank]: nnodes = 0 epoch += 1 print("Epoch :", epoch, time.time() - start) print("Train accuracy:", acc/mask.sum()) start = time.time() if epoch >= num_epoch: break
def vgg_fc(x, in_feat, out_feat, name): weight = init.random_normal(shape=(in_feat, out_feat), stddev=0.1, name=name + '_weight') bias = init.random_normal(shape=(out_feat, ), stddev=0.1, name=name + '_bias') x = ad.matmul_op(x, weight) x = x + ad.broadcastto_op(bias, x) return x
def dense(input_tensor, fan_in, fan_out, activation=None, kernel_initializer=init.xavier_normal, bias_initializer=init.zeros): weights = kernel_initializer(name='dense_weights', shape=(fan_in, fan_out)) bias = bias_initializer(name='dense_bias', shape=(fan_out, )) outputs = ad.matmul_op(input_tensor, weights) outputs = outputs + ad.broadcastto_op(bias, outputs) if activation is not None: outputs = activation(outputs) return outputs
def residual_layer(x0, input_dim, hidden_dim): embedding_len = input_dim weight_1 = init.random_normal(shape=(input_dim, hidden_dim), stddev=0.1, name='weight_1') bias_1 = init.random_normal(shape=(hidden_dim, ), stddev=0.1, name='bias_1') weight_2 = init.random_normal(shape=(hidden_dim, input_dim), stddev=0.1, name='weight_2') bias_2 = init.random_normal(shape=(input_dim, ), stddev=0.1, name='bias_2') x0w = ad.matmul_op(x0, weight_1) #(batch, hidden_dim) x0w_b = x0w + ad.broadcastto_op(bias_1, x0w) relu1 = ad.relu_op(x0w_b) x1w = ad.matmul_op(relu1, weight_2) #(batch, input_dim) x1w_b = x1w + ad.broadcastto_op(bias_2, x1w) residual = x1w_b + x0 y = ad.relu_op(residual) return y
def __call__(self, x): """ Build the computation graph, return the output node """ if self.dropout > 0: x = ad.dropout_op(x, 1 - self.dropout) x = ad.matmul_op(x, self.weight) msg = x + ad.broadcastto_op(self.bias, x) x = ad.CuSparse.csrmm_op(self.mp, msg) if self.activation == "relu": x = ad.relu_op(x) elif self.activation is not None: raise NotImplementedError return x
def logreg(x, y_): ''' Logistic Regression model, for MNIST dataset. Parameters: x: Variable(hetu.gpu_ops.Node.Node), shape (N, dims) y_: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes) Return: loss: Variable(hetu.gpu_ops.Node.Node), shape (1,) y: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes) ''' print("Build logistic regression model...") weight = init.zeros((784, 10), name='logreg_weight') bias = init.zeros((10, ), name='logreg_bias') x = ad.matmul_op(x, weight) y = x + ad.broadcastto_op(bias, x) loss = ad.softmaxcrossentropy_op(y, y_) loss = ad.reduce_mean_op(loss, [0]) return loss, y
def test_broadcast(shape1=(3, 1), shape2=(2, 3, 4)): ctx = ndarray.gpu(1) x = np.random.random(shape1).astype(np.float32) y = np.random.random(shape2).astype(np.float32) ath_x = ad.Variable(name='x', value=x) ath_z = ad.Variable(name='y', value=y) ath_y = ad.broadcastto_op(ath_x, ath_z) ath_grad = ad.gradients(ath_y, [ath_x])[0] executor = ad.Executor([ath_y, ath_grad], ctx=ctx, enable_lazy=False) ath_results = [var.asnumpy() for var in executor.run()] import tensorflow as tf tf_x = tf.convert_to_tensor(x) tf_y = tf.broadcast_to(tf_x, shape2) tf_grad = tf.gradients(tf_y, tf_x) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) tf_results = sess.run([tf_y, tf_grad]) np.testing.assert_allclose(ath_results[0], tf_results[0]) np.testing.assert_allclose(ath_results[1], np.reshape(tf_results[1], ath_results[1].shape)) print('Passed broadcast shape op test with shape ', shape1, shape2)
def train_hetu(num_epoch): ctx = ndarray.gpu(0) x_ = ad.Variable(name="x_") y_ = ad.Variable(name="y_") mask_ = ad.Variable(name="mask_") gcn1 = GraphSage(graph.num_features, hidden_layer_size, activation="relu", dropout=0.1) gcn2 = GraphSage(2*hidden_layer_size, hidden_layer_size, activation="relu", dropout=0.1) x = gcn1(x_) x = gcn2(x) W = initializers.xavier_uniform(shape=(2*hidden_layer_size, graph.num_classes)) B = initializers.zeros(shape=(graph.num_classes,)) x = ad.matmul_op(x, W) y = x + ad.broadcastto_op(B, x) loss = ad.softmaxcrossentropy_op(y, y_) loss = ad.mul_op(loss, mask_) opt = optimizer.AdamOptimizer(0.01) train_op = opt.minimize(loss) executor = ad.Executor([loss, y, train_op], ctx=ctx) def eval(): start = time.time() ad.Dropout.DropoutOp.phase = "eval" mp_val = mp_matrix(graph_full, ctx) feed_dict = { gcn1.mp : mp_val, gcn2.mp : mp_val, x_ : ndarray.array(graph_full.x, ctx=ctx), } executor_eval = ad.Executor([y], ctx=ctx) y_predicted, = executor_eval.run(feed_dict=feed_dict) y_predicted = y_predicted.asnumpy().argmax(axis=1) acc = (y_predicted == graph_full.y)[train_split:].sum() print("Test accuracy:", acc/len(y_predicted[train_split:])) ad.Dropout.DropoutOp.phase = "training" epoch = 0 nnodes = 0 batch_size = 1000 with GraphSageSampler(graph, batch_size, depth=2, num_sample_thread=4) as sampler: start = time.time() while True: g_sample, mask = sampler.sample() mp_val = mp_matrix(g_sample, ctx) #print(time.time() - start) feed_dict = { gcn1.mp : mp_val, gcn2.mp : mp_val, mask_ : ndarray.array(mask,ctx=ctx), x_ : ndarray.array(g_sample.x, ctx=ctx), y_ : ndarray.array(convert_to_one_hot(g_sample.y, max_val=graph.num_classes), ctx=ctx) } loss_val, y_predicted, _ = executor.run(feed_dict = feed_dict) y_predicted = y_predicted.asnumpy().argmax(axis=1) acc = ((y_predicted == g_sample.y) * mask).sum() # print(i, "Train loss :", loss_val.asnumpy().mean()) # print(i, "Train accuracy:", acc/len(y_predicted)) nnodes += batch_size if nnodes > graph_full.num_nodes: nnodes = 0 epoch += 1 print("Epoch :", epoch, time.time() - start) print("Train accuracy:", acc/mask.sum()) eval() start = time.time() if epoch >= num_epoch: break
def train_hetu(num_epoch): ctx = ndarray.gpu(0) x_ = ad.Variable(name="x_") y_ = ad.Variable(name="y_") gcn1 = GraphSage(graph.num_features, hidden_layer_size, activation="relu", dropout=0.1) gcn2 = GraphSage(2 * hidden_layer_size, hidden_layer_size, activation="relu", dropout=0.1) x = gcn1(x_) x = gcn2(x) W = initializers.xavier_uniform(shape=(2 * hidden_layer_size, graph.num_classes)) B = initializers.zeros(shape=(graph.num_classes, )) x = ad.matmul_op(x, W) y = x + ad.broadcastto_op(B, x) loss = ad.softmaxcrossentropy_op(y, y_) opt = optimizer.AdamOptimizer(0.01) train_op = opt.minimize(loss) executor = ad.Executor([loss, y, train_op], ctx=ctx) def eval(): start = time.time() ad.Dropout.DropoutOp.phase = "eval" mp_val = mp_matrix(graph_full, ctx) feed_dict = { gcn1.mp: mp_val, gcn2.mp: mp_val, x_: ndarray.array(graph_full.x, ctx=ctx), } executor_eval = ad.Executor([y], ctx=ctx) y_predicted, = executor_eval.run(feed_dict=feed_dict) y_predicted = y_predicted.asnumpy().argmax(axis=1) acc = (y_predicted == graph_full.y)[train_split:].sum() print("Test accuracy:", acc / len(y_predicted[train_split:])) ad.Dropout.DropoutOp.phase = "training" with RandomWalkSampler(graph, 4000, 2, transformer=transform, num_sample_thread=3) as sampler: for i in range(num_epoch): start = time.time() g_sample, mp_val = sampler.sample() #mp_val = mp_matrix(g_sample, ctx) #print(time.time() - start) feed_dict = { gcn1.mp: mp_val, gcn2.mp: mp_val, x_: ndarray.array(g_sample.x, ctx=ctx), y_: ndarray.array(convert_to_one_hot(g_sample.y, max_val=graph.num_classes), ctx=ctx) } loss_val, y_predicted, _ = executor.run(feed_dict=feed_dict) y_predicted = y_predicted.asnumpy().argmax(axis=1) acc = (y_predicted == g_sample.y).sum() print(i, "Train loss :", loss_val.asnumpy().mean()) print(i, "Train accuracy:", acc / len(y_predicted)) if (i + 1) % 100 == 0: eval() print(time.time() - start)
def lstm(x, y_): ''' LSTM model, for MNIST dataset. Parameters: x: Variable(hetu.gpu_ops.Node.Node), shape (N, dims) y_: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes) Return: loss: Variable(hetu.gpu_ops.Node.Node), shape (1,) y: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes) ''' print("Building LSTM model...") diminput = 28 dimhidden = 128 dimoutput = 10 nsteps = 28 forget_gate_w = init.random_normal(shape=(diminput, dimhidden), stddev=0.1, name="lstm_forget_gate_w") forget_gate_u = init.random_normal(shape=(dimhidden, dimhidden), stddev=0.1, name="lstm_forget_gate_u") forget_gate_b = init.random_normal(shape=(dimhidden, ), stddev=0.1, name="lstm_forget_gate_b") input_gate_w = init.random_normal(shape=(diminput, dimhidden), stddev=0.1, name="lstm_input_gate_w") input_gate_u = init.random_normal(shape=(dimhidden, dimhidden), stddev=0.1, name="lstm_input_gate_u") input_gate_b = init.random_normal(shape=(dimhidden, ), stddev=0.1, name="lstm_input_gate_b") output_gate_w = init.random_normal(shape=(diminput, dimhidden), stddev=0.1, name="lstm_output_gate_w") output_gate_u = init.random_normal(shape=(dimhidden, dimhidden), stddev=0.1, name="lstm_output_gate_u") output_gate_b = init.random_normal(shape=(dimhidden, ), stddev=0.1, name="lstm_output_gate_b") tanh_w = init.random_normal(shape=(diminput, dimhidden), stddev=0.1, name="lstm_tanh_w") tanh_u = init.random_normal(shape=(dimhidden, dimhidden), stddev=0.1, name="lstm_tanh_u") tanh_b = init.random_normal(shape=(dimhidden, ), stddev=0.1, name="lstm_tanh_b") out_weights = init.random_normal(shape=(dimhidden, dimoutput), stddev=0.1, name="lstm_out_weight") out_bias = init.random_normal(shape=(dimoutput, ), stddev=0.1, name="lstm_out_bias") initial_state = ad.Variable(value=np.zeros((1, )).astype(np.float32), name='initial_state', trainable=False) for i in range(nsteps): cur_x = ad.slice_op(x, (0, i * diminput), (-1, diminput)) # forget gate if i == 0: temp = ad.matmul_op(cur_x, forget_gate_w) last_c_state = ad.broadcastto_op(initial_state, temp) last_h_state = ad.broadcastto_op(initial_state, temp) cur_forget = ad.matmul_op(last_h_state, forget_gate_u) + temp else: cur_forget = ad.matmul_op(last_h_state, forget_gate_u) + ad.matmul_op( cur_x, forget_gate_w) cur_forget = cur_forget + ad.broadcastto_op(forget_gate_b, cur_forget) cur_forget = ad.sigmoid_op(cur_forget) # input gate cur_input = ad.matmul_op(last_h_state, input_gate_u) + ad.matmul_op( cur_x, input_gate_w) cur_input = cur_input + ad.broadcastto_op(input_gate_b, cur_input) cur_input = ad.sigmoid_op(cur_input) # output gate cur_output = ad.matmul_op(last_h_state, output_gate_u) + ad.matmul_op( cur_x, output_gate_w) cur_output = cur_output + ad.broadcastto_op(output_gate_b, cur_output) cur_output = ad.sigmoid_op(cur_output) # tanh cur_tanh = ad.matmul_op(last_h_state, tanh_u) + ad.matmul_op( cur_x, tanh_w) cur_tanh = cur_tanh + ad.broadcastto_op(tanh_b, cur_tanh) cur_tanh = ad.tanh_op(cur_tanh) last_c_state = ad.mul_op(last_c_state, cur_forget) + ad.mul_op( cur_input, cur_tanh) last_h_state = ad.tanh_op(last_c_state) * cur_output x = ad.matmul_op(last_h_state, out_weights) y = x + ad.broadcastto_op(out_bias, x) loss = ad.softmaxcrossentropy_op(y, y_) loss = ad.reduce_mean_op(loss, [0]) return loss, y
def wdl_adult(whatever): batch_size = 128 lr=5 dim_wide = 809 lr_ = lr / batch_size dim_deep = 68 from .load_data import load_adult_data x_train_deep, x_train_wide, y_train, x_test_deep, x_test_wide, y_test = load_adult_data() W = init.random_normal([dim_wide+20, 2], stddev=0.1, name="W") W1 = init.random_normal([dim_deep, 50], stddev=0.1, name="W1") b1 = init.random_normal([50], stddev=0.1, name="b1") W2 = init.random_normal([50, 20], stddev=0.1, name="W2") b2 = init.random_normal([20], stddev=0.1, name="b2") X_wide = dl.dataloader_op([ [x_train_wide, batch_size, 'train'], [x_test_wide, batch_size, 'validate'], ]) y_ = dl.dataloader_op([ [y_train, batch_size, 'train'], [y_test, batch_size, 'validate'], ]) #deep Embedding = [] X_deep = [] X_deep_input = None for i in range(8): X_deep_name = "x_deep_" + str(i) Embedding_name = "Embedding_deep_" + str(i) X_deep.append(dl.dataloader_op([ [x_train_deep[:,i], batch_size, 'train'], [x_test_deep[:,i], batch_size, 'validate'], ])) Embedding.append(init.random_normal([50, 8], stddev=0.1, name=Embedding_name)) now = ad.embedding_lookup_op(Embedding[i], X_deep[i]) now = ad.array_reshape_op(now, (-1, 8)) if X_deep_input is None: X_deep_input = now else: X_deep_input = ad.concat_op(X_deep_input, now, 1) for i in range(4): X_deep_name = "x_deep_" + str(8+i) X_deep.append(dl.dataloader_op([ [x_train_deep[:,8+i], batch_size, 'train'], [x_test_deep[:,8+i], batch_size, 'validate'], ])) now = ad.array_reshape_op(X_deep[i + 8], (batch_size, 1)) X_deep_input = ad.concat_op(X_deep_input, now, 1) mat1 = ad.matmul_op(X_deep_input, W1) add1 = mat1 + ad.broadcastto_op(b1, mat1) relu1= ad.relu_op(add1) dropout1 = relu1 #ad.dropout_op(relu1, 0.5) mat2 = ad.matmul_op(dropout1, W2) add2 = mat2 + ad.broadcastto_op(b2, mat2) relu2= ad.relu_op(add2) dropout2 = relu2 #ad.dropout_op(relu2, 0.5) dmodel=dropout2 # wide wmodel = ad.concat_op(X_wide, dmodel, 1) wmodel = ad.matmul_op(wmodel, W) prediction = wmodel loss = ad.softmaxcrossentropy_op(prediction, y_) loss = ad.reduce_mean_op(loss, [0]) opt = optimizer.SGDOptimizer(learning_rate=lr_) train_op = opt.minimize(loss) return loss, prediction, y_, train_op
def multihead_attention(queries, keys, values, config, query_act=None, key_act=None, value_act=None, attention_mask=None, causality=False): def transpose_for_scores(input_tensor): output_tensor = ad.array_reshape_op(input_tensor, [ config.batch_size, -1, config.num_heads, config.d_model // config.num_heads ]) output_tensor = ad.transpose_op(output_tensor, [0, 2, 1, 3]) return output_tensor batch_size = config.batch_size hidden_size = config.d_model num_attention_heads = config.num_heads caus_len = config.maxlen2 - 1 attention_probs_dropout_prob = config.dropout_rate size_per_head = hidden_size // num_attention_heads # reshape to 2d queries2d = ad.array_reshape_op(queries, [-1, hidden_size]) # (N * T_q, d_model) keys2d = ad.array_reshape_op(keys, [-1, hidden_size]) # (N * T_k, d_model) values2d = ad.array_reshape_op(values, [-1, hidden_size]) # (N * T_k, d_model) # linear transformation query_layer = dense(queries2d, hidden_size, hidden_size, query_act) # (N * T_k, d_model) key_layer = dense(keys2d, hidden_size, hidden_size, key_act) # (N * T_k, d_model) value_layer = dense(values2d, hidden_size, hidden_size, value_act) # (N * T_k, d_model) # transpose query_layer = transpose_for_scores(query_layer) # (N, h, T_q, d_model/h) key_layer = transpose_for_scores(key_layer) # (N, h, T_k, d_model/h) value_layer = transpose_for_scores(value_layer) # (N, h, T_k, d_model/h) # score attention_scores = ad.batch_matmul_op(query_layer, key_layer, trans_B=True) # (N, h, T_q, T_k) attention_scores = attention_scores * (1.0 / np.sqrt(float(size_per_head))) # mask if attention_mask is not None: zeros = ad.Variable('no_mask', value=np.array((0, ), dtype=np.float32), trainable=False) adder = ad.Variable('attention_mask', value=np.array((-2**32 + 1, ), dtype=np.float32), trainable=False) zeros = ad.broadcastto_op(zeros, attention_mask) adder = ad.broadcastto_op(adder, attention_mask) attention_mask = ad.where_op(attention_mask, zeros, adder) # (N, T) attention_mask = ad.array_reshape_op(attention_mask, [batch_size, 1, 1, -1]) attention_scores = attention_scores + ad.broadcastto_op( attention_mask, attention_scores) if causality: tril = ad.Variable(name='tril', value=np.tril(np.ones((caus_len, caus_len))), trainable=False) # (T, T) future_masks = ad.broadcast_shape_op( tril, [batch_size, num_attention_heads, caus_len, caus_len]) adder = ad.Variable('future_mask', value=np.array((-2**32 + 1, ), dtype=np.float32), trainable=False) adder = ad.broadcastto_op(adder, future_masks) attention_scores = ad.where_op(future_masks, attention_scores, adder) # (N, h, T, T) # probs attention_probs = ad.softmax_op(attention_scores) attention_probs = dropout(attention_probs, attention_probs_dropout_prob) context_layer = ad.batch_matmul_op(attention_probs, value_layer) context_layer = ad.transpose_op(context_layer, [0, 2, 1, 3]) outputs = ad.array_reshape_op( context_layer, [batch_size, -1, num_attention_heads * size_per_head]) # Residual connection outputs = outputs + queries # (N, T_q, d_model) # Normalize outputs = layer_norm(outputs, hidden_size) # (N, T_q, d_model) return outputs