def __init__(self, num_segments, dyn_a=True, dyn_b=True): super(UnsortedSegmentMaxDynNet, self).__init__() self.unsorted_segment_max = P.UnsortedSegmentMax() self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() self.num_segments = num_segments self.to_dyn_1 = dyn_a self.to_dyn_2 = dyn_b
def __init__(self, vocab_size, embedding_size, field_size, param_init='normal', target='CPU', slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'): super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target, slice_mode, feature_num_list, max_norm, sparse) self.field_size = validator.check_positive_int(field_size, 'field_size') self.operator = operator self.mul = P.Mul() self.inf_mask_mul = P.Mul() self.bias_add = P.Add() self.inf_add = P.Add() self.merge_op = None self.count_op = P.UnsortedSegmentSum() self.abs = P.Abs() self.equal = P.Equal() self.add = P.Add() self.cast = P.Cast() self.div_no_nan = P.DivNoNan() self.expand = P.ExpandDims() self.max_mask_mul = P.Mul() self.max_no_equal = P.NotEqual() if operator == MultiFieldEmbeddingLookup.OPERATOR_SUM: self.merge_op = P.UnsortedSegmentSum() elif operator == MultiFieldEmbeddingLookup.OPERATOR_MAX: self.merge_op = P.UnsortedSegmentMax() elif operator == MultiFieldEmbeddingLookup.OPERATOR_MEAN: self.merge_op = P.UnsortedSegmentSum() else: raise ValueError( "The operator supports ['SUM', 'MAX', 'MEAN'], but found: " + str(operator)) parallel_mode = _get_parallel_mode() is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) if slice_mode in ["table_row_slice", "batch_slice" ] and is_auto_parallel: self.merge_op.shard( ((get_group_size(), 1, 1), (get_group_size(), 1))) self.expand.shard(((get_group_size(), ), )) self.bias_add.shard(((1, 1), (1, 1))) self.mul.shard( ((get_group_size(), 1, 1), (get_group_size(), 1, 1))) self.count_op.shard(((get_group_size(), 1), (get_group_size(), 1))) self.add.shard(((get_group_size(), ), (get_group_size(), ))) self.div_no_nan.shard( ((get_group_size(), 1), (get_group_size(), 1))) self.max_mask_mul.shard( ((get_group_size(), 1), (get_group_size(), 1))) self.max_no_equal.shard(((1, ), ())) if operator == MultiFieldEmbeddingLookup.OPERATOR_MAX: self.equal.shard(((get_group_size(), 1, 1), ())) self.inf_mask_mul.shard(((get_group_size(), 1, 1), ())) self.merge_op.shard( ((get_group_size(), 1), (get_group_size(), ))) self.count_op.shard( ((get_group_size(), ), (get_group_size(), ))) self.inf_add.shard( ((get_group_size(), 1, 1), (get_group_size(), 1, 1))) elif slice_mode == "table_column_slice" and is_auto_parallel: self.merge_op.shard(((1, 1, get_group_size()), (1, 1))) self.div_no_nan.shard(((1, get_group_size()), (1, 1))) self.bias_add.shard(((1, 1), (1, 1))) self.mul.shard(((1, 1, 1), (1, 1, get_group_size()))) self.count_op.shard(((1, 1), (1, 1))) self.add.shard(((1, ), (1, ))) self.max_mask_mul.shard(((1, get_group_size()), (1, 1))) self.expand.shard(((1, ), )) self.max_no_equal.shard(((1, ), ())) if operator == MultiFieldEmbeddingLookup.OPERATOR_MAX: self.equal.shard(((1, 1, 1), ())) self.inf_mask_mul.shard(((1, 1, 1), ())) self.merge_op.shard(((1, get_group_size()), (1, ))) self.count_op.shard(((1, ), (1, ))) self.inf_add.shard(((1, 1, get_group_size()), (1, 1, 1))) else: if is_auto_parallel: raise ValueError( "slice_mode should be ['table_row_slice', 'batch_slice' and \ 'table_column_slice'], but get " + str(slice_mode)) # Min value for fp32 self.negative_inf_value = -3.402823466E+38
def __init__(self, num_segments): super(UnsortedSegmentMaxNet, self).__init__() self.unsorted_segment_max = P.UnsortedSegmentMax() self.num_segments = num_segments
def test_3d_single_init(): context.set_context(mode=context.GRAPH_MODE, device_target='GPU') input_x = Tensor(np.arange(4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32) segment_ids = Tensor([3, 0, 1, -1], mstype.int32) net = P.UnsortedSegmentMax() num_segments = 4 output = net(input_x, segment_ids, num_segments).asnumpy() expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01], [1.8000000e+01, 1.9000000e+01, 2.0000000e+01], [2.1000000e+01, 2.2000000e+01, 2.3000000e+01], [2.4000000e+01, 2.5000000e+01, 2.6000000e+01], [2.7000000e+01, 2.8000000e+01, 2.9000000e+01]], [[3.0000000e+01, 3.1000000e+01, 3.2000000e+01], [3.3000000e+01, 3.4000000e+01, 3.5000000e+01], [3.6000000e+01, 3.7000000e+01, 3.8000000e+01], [3.9000000e+01, 4.0000000e+01, 4.1000000e+01], [4.2000000e+01, 4.3000000e+01, 4.4000000e+01]], [[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]], [[0.0000000e+00, 1.0000000e+00, 2.0000000e+00], [3.0000000e+00, 4.0000000e+00, 5.0000000e+00], [6.0000000e+00, 7.0000000e+00, 8.0000000e+00], [9.0000000e+00, 1.0000000e+01, 1.1000000e+01], [1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32) np.testing.assert_array_almost_equal(output, expect) num_segments = 6 output = net(input_x, segment_ids, num_segments).asnumpy() expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01], [1.8000000e+01, 1.9000000e+01, 2.0000000e+01], [2.1000000e+01, 2.2000000e+01, 2.3000000e+01], [2.4000000e+01, 2.5000000e+01, 2.6000000e+01], [2.7000000e+01, 2.8000000e+01, 2.9000000e+01]], [[3.0000000e+01, 3.1000000e+01, 3.2000000e+01], [3.3000000e+01, 3.4000000e+01, 3.5000000e+01], [3.6000000e+01, 3.7000000e+01, 3.8000000e+01], [3.9000000e+01, 4.0000000e+01, 4.1000000e+01], [4.2000000e+01, 4.3000000e+01, 4.4000000e+01]], [[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]], [[0.0000000e+00, 1.0000000e+00, 2.0000000e+00], [3.0000000e+00, 4.0000000e+00, 5.0000000e+00], [6.0000000e+00, 7.0000000e+00, 8.0000000e+00], [9.0000000e+00, 1.0000000e+01, 1.1000000e+01], [1.2000000e+01, 1.3000000e+01, 1.4000000e+01]], [[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]], [[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38], [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]]]).astype(np.float32) np.testing.assert_array_almost_equal(output, expect)
def __init__(self, strategy1, strategy2, num_segments): super(Net, self).__init__() self.virtual_dataset = _VirtualDataset() self.merge_op = P.UnsortedSegmentMax().shard((strategy1, strategy2)) self.num_segments = num_segments