Esempio n. 1
0
 def __init__(self, num_segments, dyn_a=True, dyn_b=True):
     super(UnsortedSegmentMaxDynNet, self).__init__()
     self.unsorted_segment_max = P.UnsortedSegmentMax()
     self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
     self.num_segments = num_segments
     self.to_dyn_1 = dyn_a
     self.to_dyn_2 = dyn_b
Esempio n. 2
0
    def __init__(self,
                 vocab_size,
                 embedding_size,
                 field_size,
                 param_init='normal',
                 target='CPU',
                 slice_mode='batch_slice',
                 feature_num_list=None,
                 max_norm=None,
                 sparse=True,
                 operator='SUM'):
        super(MultiFieldEmbeddingLookup,
              self).__init__(vocab_size, embedding_size, param_init, target,
                             slice_mode, feature_num_list, max_norm, sparse)
        self.field_size = validator.check_positive_int(field_size,
                                                       'field_size')
        self.operator = operator

        self.mul = P.Mul()
        self.inf_mask_mul = P.Mul()
        self.bias_add = P.Add()
        self.inf_add = P.Add()
        self.merge_op = None
        self.count_op = P.UnsortedSegmentSum()
        self.abs = P.Abs()
        self.equal = P.Equal()
        self.add = P.Add()
        self.cast = P.Cast()
        self.div_no_nan = P.DivNoNan()
        self.expand = P.ExpandDims()
        self.max_mask_mul = P.Mul()
        self.max_no_equal = P.NotEqual()

        if operator == MultiFieldEmbeddingLookup.OPERATOR_SUM:
            self.merge_op = P.UnsortedSegmentSum()
        elif operator == MultiFieldEmbeddingLookup.OPERATOR_MAX:
            self.merge_op = P.UnsortedSegmentMax()
        elif operator == MultiFieldEmbeddingLookup.OPERATOR_MEAN:
            self.merge_op = P.UnsortedSegmentSum()
        else:
            raise ValueError(
                "The operator supports ['SUM', 'MAX', 'MEAN'], but found: " +
                str(operator))

        parallel_mode = _get_parallel_mode()
        is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL,
                                             ParallelMode.AUTO_PARALLEL)
        if slice_mode in ["table_row_slice", "batch_slice"
                          ] and is_auto_parallel:
            self.merge_op.shard(
                ((get_group_size(), 1, 1), (get_group_size(), 1)))
            self.expand.shard(((get_group_size(), ), ))
            self.bias_add.shard(((1, 1), (1, 1)))
            self.mul.shard(
                ((get_group_size(), 1, 1), (get_group_size(), 1, 1)))
            self.count_op.shard(((get_group_size(), 1), (get_group_size(), 1)))
            self.add.shard(((get_group_size(), ), (get_group_size(), )))
            self.div_no_nan.shard(
                ((get_group_size(), 1), (get_group_size(), 1)))
            self.max_mask_mul.shard(
                ((get_group_size(), 1), (get_group_size(), 1)))
            self.max_no_equal.shard(((1, ), ()))
            if operator == MultiFieldEmbeddingLookup.OPERATOR_MAX:
                self.equal.shard(((get_group_size(), 1, 1), ()))
                self.inf_mask_mul.shard(((get_group_size(), 1, 1), ()))
                self.merge_op.shard(
                    ((get_group_size(), 1), (get_group_size(), )))
                self.count_op.shard(
                    ((get_group_size(), ), (get_group_size(), )))
                self.inf_add.shard(
                    ((get_group_size(), 1, 1), (get_group_size(), 1, 1)))
        elif slice_mode == "table_column_slice" and is_auto_parallel:
            self.merge_op.shard(((1, 1, get_group_size()), (1, 1)))
            self.div_no_nan.shard(((1, get_group_size()), (1, 1)))
            self.bias_add.shard(((1, 1), (1, 1)))
            self.mul.shard(((1, 1, 1), (1, 1, get_group_size())))
            self.count_op.shard(((1, 1), (1, 1)))
            self.add.shard(((1, ), (1, )))
            self.max_mask_mul.shard(((1, get_group_size()), (1, 1)))
            self.expand.shard(((1, ), ))
            self.max_no_equal.shard(((1, ), ()))
            if operator == MultiFieldEmbeddingLookup.OPERATOR_MAX:
                self.equal.shard(((1, 1, 1), ()))
                self.inf_mask_mul.shard(((1, 1, 1), ()))
                self.merge_op.shard(((1, get_group_size()), (1, )))
                self.count_op.shard(((1, ), (1, )))
                self.inf_add.shard(((1, 1, get_group_size()), (1, 1, 1)))
        else:
            if is_auto_parallel:
                raise ValueError(
                    "slice_mode should be  ['table_row_slice', 'batch_slice' and \
                       'table_column_slice'], but get " + str(slice_mode))

        # Min value for fp32
        self.negative_inf_value = -3.402823466E+38
Esempio n. 3
0
 def __init__(self, num_segments):
     super(UnsortedSegmentMaxNet, self).__init__()
     self.unsorted_segment_max = P.UnsortedSegmentMax()
     self.num_segments = num_segments
Esempio n. 4
0
def test_3d_single_init():
    context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
    input_x = Tensor(np.arange(4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3),
                     dtype=mindspore.float32)
    segment_ids = Tensor([3, 0, 1, -1], mstype.int32)
    net = P.UnsortedSegmentMax()

    num_segments = 4
    output = net(input_x, segment_ids, num_segments).asnumpy()
    expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01],
                        [1.8000000e+01, 1.9000000e+01, 2.0000000e+01],
                        [2.1000000e+01, 2.2000000e+01, 2.3000000e+01],
                        [2.4000000e+01, 2.5000000e+01, 2.6000000e+01],
                        [2.7000000e+01, 2.8000000e+01, 2.9000000e+01]],
                       [[3.0000000e+01, 3.1000000e+01, 3.2000000e+01],
                        [3.3000000e+01, 3.4000000e+01, 3.5000000e+01],
                        [3.6000000e+01, 3.7000000e+01, 3.8000000e+01],
                        [3.9000000e+01, 4.0000000e+01, 4.1000000e+01],
                        [4.2000000e+01, 4.3000000e+01, 4.4000000e+01]],
                       [[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
                        [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
                        [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
                        [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
                        [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]],
                       [[0.0000000e+00, 1.0000000e+00, 2.0000000e+00],
                        [3.0000000e+00, 4.0000000e+00, 5.0000000e+00],
                        [6.0000000e+00, 7.0000000e+00, 8.0000000e+00],
                        [9.0000000e+00, 1.0000000e+01, 1.1000000e+01],
                        [1.2000000e+01, 1.3000000e+01,
                         1.4000000e+01]]]).astype(np.float32)
    np.testing.assert_array_almost_equal(output, expect)

    num_segments = 6
    output = net(input_x, segment_ids, num_segments).asnumpy()
    expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01],
                        [1.8000000e+01, 1.9000000e+01, 2.0000000e+01],
                        [2.1000000e+01, 2.2000000e+01, 2.3000000e+01],
                        [2.4000000e+01, 2.5000000e+01, 2.6000000e+01],
                        [2.7000000e+01, 2.8000000e+01, 2.9000000e+01]],
                       [[3.0000000e+01, 3.1000000e+01, 3.2000000e+01],
                        [3.3000000e+01, 3.4000000e+01, 3.5000000e+01],
                        [3.6000000e+01, 3.7000000e+01, 3.8000000e+01],
                        [3.9000000e+01, 4.0000000e+01, 4.1000000e+01],
                        [4.2000000e+01, 4.3000000e+01, 4.4000000e+01]],
                       [[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
                        [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
                        [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
                        [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
                        [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]],
                       [[0.0000000e+00, 1.0000000e+00, 2.0000000e+00],
                        [3.0000000e+00, 4.0000000e+00, 5.0000000e+00],
                        [6.0000000e+00, 7.0000000e+00, 8.0000000e+00],
                        [9.0000000e+00, 1.0000000e+01, 1.1000000e+01],
                        [1.2000000e+01, 1.3000000e+01, 1.4000000e+01]],
                       [[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
                        [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
                        [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
                        [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
                        [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]],
                       [[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
                        [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
                        [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
                        [-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
                        [-3.4028235e+38, -3.4028235e+38,
                         -3.4028235e+38]]]).astype(np.float32)
    np.testing.assert_array_almost_equal(output, expect)
 def __init__(self, strategy1, strategy2, num_segments):
     super(Net, self).__init__()
     self.virtual_dataset = _VirtualDataset()
     self.merge_op = P.UnsortedSegmentMax().shard((strategy1, strategy2))
     self.num_segments = num_segments