def construct(self, input_ids, attn_mask, token_type_ids, context_mask, square_mask, packing_mask, cache_mask, para_start_mapping, sent_end_mapping): """construct function""" state = self.encoder(attn_mask, input_ids, token_type_ids) para_state = self.bmm(ops.Cast()(para_start_mapping, dst_type), ops.Cast()(state, dst_type)) # [B, 2, D] sent_state = self.bmm(ops.Cast()(sent_end_mapping, dst_type), ops.Cast()(state, dst_type)) # [B, max_sent, D] q_type, start, end, para_logit, sent_logit = self.downstream( ops.Cast()(para_state, dst_type2), ops.Cast()(sent_state, dst_type2), state, context_mask) outer = start[:, :, None] + end[:, None] outer_mask = cache_mask outer_mask = square_mask * outer_mask[None] outer = outer - 1e30 * (1 - outer_mask) outer = outer - 1e30 * packing_mask[:, :, None] max_row = ops.ReduceMax()(outer, 2) y1 = ops.Argmax()(max_row) max_col = ops.ReduceMax()(outer, 1) y2 = ops.Argmax()(max_col) return start, end, q_type, para_logit, sent_logit, y1, y2
def __call__(self, x, y): x = self.net(x) loss = self.loss(x, y) x = self.softmax(x) predicts = ops.Argmax(output_type=mstype.int32)(x) acc = np.sum(predicts.asnumpy() == y.asnumpy()) / len(y.asnumpy()) return loss.asnumpy(), acc
def test_time_distributed_argmax_no_reshape_axis(): inputs = np.random.randint(0, 10, [3, 4]) argmax = ops.Argmax(output_type=mindspore.int32, axis=1) output_expect = argmax(Tensor(inputs, mindspore.float32)).asnumpy() inputs = inputs.reshape([3, 1, 4]).repeat(6, axis=1) time_distributed = TestTimeDistributed(argmax, time_axis=1) output = time_distributed(Tensor(inputs, mindspore.float32)).asnumpy() for i in range(output.shape[1]): assert np.all(output[:, i] == output_expect) print("Argmax op with no reshape axis wrapped successful")
def __init__(self, axis=0): super(NetArgmax, self).__init__() self.argmax = ops.Argmax(axis, output_type=mstype.int32)