def __init__(self,
              strategy1=None,
              strategy2=None,
              strategy3=None,
              axis=0,
              init_flag=True,
              split_tuple=(4, 4),
              split_string="manual_split",
              param_shape=(8, 8)):
     super().__init__()
     self.gatherv2 = P.EmbeddingLookup().shard(strategy1)
     self.gatherv2.add_prim_attr(split_string, split_tuple)
     self.gatherv2.add_prim_attr("primitive_target", "CPU")
     self.mul = P.Mul().shard(strategy2)
     self.reshape = P.Reshape()
     self.matmul = P.MatMul().shard(strategy3)
     self.matmul.add_prim_attr("forward_reduce_scatter", True)
     if init_flag:
         self.param = Parameter(initializer("ones", param_shape,
                                            ms.float32),
                                name="gatherv2_param")
     else:
         self.param = Parameter(Tensor(np.ones(param_shape),
                                       dtype=ms.float32),
                                name="gatherv2_param")
     self.mul_weight = Parameter(initializer("ones", (8, 8, 8), ms.float32),
                                 name="mul_weight")
     self.matmul_weight = Parameter(initializer("ones", (64, 16),
                                                ms.float32),
                                    name="matmul_weight")
     self.axis = axis
Example #2
0
 def __init__(self, shape, offset, reduce_scatter_flag, split_num):
     super().__init__()
     self.index = Tensor(np.ones(shape), dtype=ms.int32)
     self.offset = offset
     self.reduce_scatter_flag = reduce_scatter_flag
     self.split_num = split_num
     self.elu = P.EmbeddingLookup()
     self.mm = P.BatchMatMul()
Example #3
0
 def __init__(self,
              vocab_size,
              embedding_size,
              param_init='normal',
              target='CPU',
              slice_mode='batch_slice',
              manual_shapes=None):
     super(EmbeddingLookup, self).__init__()
     self.target = target
     if target not in ('CPU', 'DEVICE'):
         raise ValueError(
             'Attr \'target\' of \'EmbeddingLookup\' Op passed ' +
             str(target) +
             ', should be one of values in \'CPU\', \'DEVICE\'.')
     self.gatherv2 = P.GatherV2()
     self.embeddinglookup = P.EmbeddingLookup().add_prim_attr(
         'primitive_target', 'CPU')
     self.embedding_table = Parameter(initializer(
         param_init, [vocab_size, embedding_size]),
                                      name='embedding_table')
     parallel_mode = _get_parallel_mode()
     is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL,
                                          ParallelMode.AUTO_PARALLEL)
     if slice_mode == EmbeddingLookUpSplitMode.FIELD_SLICE and is_auto_parallel:
         if not manual_shapes:
             raise ValueError(
                 "in slice field mode, the manual_shapes should not be none"
             )
         if not isinstance(manual_shapes, tuple):
             raise TypeError(
                 "manual_shapes type must be tuple(int) cannot be {}!".
                 format(type(manual_shapes)))
         for dim in manual_shapes:
             Validator.check_integer('manul shape dim', dim, 0, Rel.GT,
                                     self.cls_name)
         self.gatherv2.add_prim_attr("manual_split", manual_shapes)
         self.embeddinglookup.add_prim_attr("manual_split", manual_shapes)
         self.gatherv2.set_strategy(
             ((get_group_size(), 1), (1, get_group_size())))
         self.embeddinglookup.set_strategy(
             ((get_group_size(), 1), (1, get_group_size())))
     elif slice_mode == EmbeddingLookUpSplitMode.TABLE_ROW_SLICE and is_auto_parallel:
         self.gatherv2.set_strategy(((get_group_size(), 1), (1, 1)))
         self.embeddinglookup.set_strategy(((get_group_size(), 1), (1, 1)))
     elif slice_mode == EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE and is_auto_parallel:
         self.gatherv2.set_strategy(((1, get_group_size()), (1, 1)))
         self.embeddinglookup.set_strategy(((1, get_group_size()), (1, 1)))
     elif slice_mode == EmbeddingLookUpSplitMode.BATCH_SLICE and is_auto_parallel:
         self.gatherv2.set_strategy(((1, 1), (get_group_size(), 1)))
         self.embeddinglookup.set_strategy(((1, 1), (get_group_size(), 1)))
     else:
         if is_auto_parallel:
             raise ValueError(
                 "slice_mode should support mode in nn.EmbeddingLookUpSplitMode, but get "
                 + str(slice_mode))
Example #4
0
 def __init__(self, target='CPU'):
     super(EmbeddingLookup, self).__init__()
     self.target = target
     if target not in ('CPU', 'DEVICE'):
         raise ValueError(
             'Attr \'target\' of \'EmbeddingLookup\' Op passed ' +
             str(target) +
             ', should be one of values in \'CPU\', \'DEVICE\'.')
     self.gatherv2 = P.GatherV2()
     self.embeddinglookup = P.EmbeddingLookup().add_prim_attr(
         'primitive_target', 'CPU')
 def __init__(self,
              shape,
              offset,
              strategy1=None,
              strategy2=None,
              target="Device"):
     super().__init__()
     self.index = Tensor(np.ones(shape), dtype=ms.int32)
     self.offset = offset
     self.elu = P.EmbeddingLookup().set_strategy(strategy1).add_prim_attr(
         "primitive_target", target)
     self.mm = P.BatchMatMul().set_strategy(strategy2)
Example #6
0
 def __init__(self,
              vocab_size,
              embedding_size,
              param_init='normal',
              target='CPU',
              slice_mode='batch_slice',
              manual_shapes=None,
              max_norm=None,
              sparse=True,
              vocab_cache_size=0):
     super(EmbeddingLookup, self).__init__()
     validator.check_value_type('sparse', sparse, [bool], self.cls_name)
     self.vocab_size = validator.check_positive_int(vocab_size,
                                                    'vocab_size')
     self.vocab_cache_size = validator.check_non_negative_int(
         vocab_cache_size, 'vocab_cache_size')
     self.target = target
     self.sparse = sparse
     self.cache_enable = self.vocab_cache_size > 0
     self.forward_unique = False
     if target not in ('CPU', 'DEVICE'):
         raise ValueError(
             'Attr \'target\' of \'EmbeddingLookup\' Op passed ' +
             str(target) +
             ', should be one of values in \'CPU\', \'DEVICE\'.')
     if not sparse and target == 'CPU':
         raise ValueError(
             'When target is CPU, embedding_lookup must be sparse.')
     if sparse:
         self.gatherv2 = P.SparseGatherV2()
     else:
         self.gatherv2 = P.Gather()
     self.embeddinglookup = P.EmbeddingLookup().add_prim_attr(
         'primitive_target', 'CPU')
     enable_ps = _get_ps_context("enable_ps")
     if enable_ps:
         self._process_vocab_cache(slice_mode)
     self.embedding_size = validator.check_positive_int(
         embedding_size, 'embedding_size')
     self.embedding_table = Parameter(initializer(
         param_init, [self.vocab_size, self.embedding_size]),
                                      name='embedding_table')
     parallel_mode = _get_parallel_mode()
     is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL,
                                          ParallelMode.AUTO_PARALLEL)
     self.gather_revert = P.Gather()
     self.reshape_first = P.Reshape()
     self.reshape = P.Reshape()
     self.unique = P.Unique()
     self.shape = P.Shape()
     if is_auto_parallel:
         self.unique = P.Unique().shard(((1, ), ))
     if self.cache_enable and enable_ps:
         self._set_voacb_cache_enable_for_ps(vocab_cache_size,
                                             embedding_size, vocab_size)
         if is_auto_parallel:
             self.unique.add_prim_attr('cache_enable', True)
     indices_shape_size = 2
     if slice_mode == "field_slice" and is_auto_parallel:
         if not manual_shapes:
             raise ValueError(
                 "in slice field mode, the manual_shapes should not be none"
             )
         if not isinstance(manual_shapes, tuple):
             raise TypeError(
                 "manual_shapes type must be tuple(int) cannot be {}!".
                 format(type(manual_shapes)))
         for dim in manual_shapes:
             validator.check_positive_int(dim, 'manual shape dim',
                                          self.cls_name)
         self.gatherv2.add_prim_attr("manual_split", manual_shapes)
         self.embeddinglookup.add_prim_attr("manual_split", manual_shapes)
         self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
         self.embeddinglookup.shard(
             ((get_group_size(), 1), (1, get_group_size())))
     elif slice_mode == "table_row_slice" and is_auto_parallel:
         full_batch = _get_full_batch()
         if (target == 'DEVICE'
                 and not full_batch) or (self.cache_enable and enable_ps
                                         and sparse):
             indices_shape_size = 1
             self.gather_revert.shard(((1, 1), (get_group_size(), )))
             self.forward_unique = True
         indices_strategy = (1, ) * indices_shape_size
         self.gatherv2.shard(((get_group_size(), 1), indices_strategy))
         self.embeddinglookup.shard(
             ((get_group_size(), 1), indices_strategy))
     elif slice_mode == "table_column_slice" and is_auto_parallel:
         if target == 'DEVICE':
             indices_shape_size = 1
             self.gather_revert.shard(((1, get_group_size()), (1, )))
             self.forward_unique = True
         indices_strategy = (1, ) * indices_shape_size
         self.gatherv2.shard(((1, get_group_size()), indices_strategy))
         self.embeddinglookup.shard(
             ((1, get_group_size()), indices_strategy))
     elif slice_mode == "batch_slice" and is_auto_parallel:
         indices_strategy = [get_group_size()]
         indices_strategy.extend([1] * (indices_shape_size - 1))
         indices_strategy = tuple(indices_strategy)
         self.gatherv2.shard(((1, 1), indices_strategy))
         self.embeddinglookup.shard(((1, 1), indices_strategy))
     else:
         if is_auto_parallel:
             raise ValueError(
                 "slice_mode should support mode in nn.EmbeddingLookup, but get "
                 + str(slice_mode))
     if self.cache_enable and not enable_ps:
         if parallel_mode != ParallelMode.STAND_ALONE:
             raise ValueError(
                 "parallel mode haven't supported cache enable yet.")
         self._set_cache_enable()
     self.embedding_table.unique = self.forward_unique
     self.max_norm = max_norm
     if self.max_norm is not None:
         self.max_norm = validator.check_positive_float(
             self.max_norm, 'max_norm', self.cls_name)
         self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
Example #7
0
 def __init__(self,
              vocab_size,
              embedding_size,
              param_init='normal',
              target='CPU',
              slice_mode='batch_slice',
              manual_shapes=None,
              max_norm=None):
     super(EmbeddingLookup, self).__init__()
     self.target = target
     if target not in ('CPU', 'DEVICE'):
         raise ValueError(
             'Attr \'target\' of \'EmbeddingLookup\' Op passed ' +
             str(target) +
             ', should be one of values in \'CPU\', \'DEVICE\'.')
     self.gatherv2 = P.GatherV2()
     self.embeddinglookup = P.EmbeddingLookup().add_prim_attr(
         'primitive_target', 'CPU')
     self.vocab_size = validator.check_value_type('vocab_size', vocab_size,
                                                  [int], self.cls_name)
     self.embedding_size = validator.check_value_type(
         'embedding_size', embedding_size, [int], self.cls_name)
     self.embedding_table = Parameter(initializer(
         param_init, [self.vocab_size, self.embedding_size]),
                                      name='embedding_table')
     parallel_mode = _get_parallel_mode()
     is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL,
                                          ParallelMode.AUTO_PARALLEL)
     if slice_mode == "field_slice" and is_auto_parallel:
         if not manual_shapes:
             raise ValueError(
                 "in slice field mode, the manual_shapes should not be none"
             )
         if not isinstance(manual_shapes, tuple):
             raise TypeError(
                 "manual_shapes type must be tuple(int) cannot be {}!".
                 format(type(manual_shapes)))
         for dim in manual_shapes:
             validator.check_positive_int(dim, 'manual shape dim',
                                          self.cls_name)
         self.gatherv2.add_prim_attr("manual_split", manual_shapes)
         self.embeddinglookup.add_prim_attr("manual_split", manual_shapes)
         self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
         self.embeddinglookup.shard(
             ((get_group_size(), 1), (1, get_group_size())))
     elif slice_mode == "table_row_slice" and is_auto_parallel:
         self.gatherv2.shard(((get_group_size(), 1), (1, 1)))
         self.embeddinglookup.shard(((get_group_size(), 1), (1, 1)))
     elif slice_mode == "table_column_slice" and is_auto_parallel:
         self.gatherv2.shard(((1, get_group_size()), (1, 1)))
         self.embeddinglookup.shard(((1, get_group_size()), (1, 1)))
     elif slice_mode == "batch_slice" and is_auto_parallel:
         self.gatherv2.shard(((1, 1), (get_group_size(), 1)))
         self.embeddinglookup.shard(((1, 1), (get_group_size(), 1)))
     else:
         if is_auto_parallel:
             raise ValueError(
                 "slice_mode should support mode in nn.EmbeddingLookup, but get "
                 + str(slice_mode))
     self.max_norm = max_norm
     if self.max_norm is not None:
         self.max_norm = validator.check_positive_float(
             self.max_norm, 'max_norm', self.cls_name)
         self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
 def __init__(self, shape, offset):
     super().__init__()
     self.index = Tensor(np.ones(shape), dtype=ms.int32)
     self.offset = offset
     self.elu = P.EmbeddingLookup()
     self.mm = P.BatchMatMul()
 def __init__(self, offset):
     super().__init__()
     self.op = op.EmbeddingLookup()
     self.offset = offset
Example #10
0
 def __init__(self, offset):
     super(Net, self).__init__()
     self.embedding = P.EmbeddingLookup().add_prim_attr("primitive_target", "CPU")
     self.offset = offset
 def __init__(self, offset):
     super(Net, self).__init__()
     self.embedding = P.EmbeddingLookup()
     self.offset = offset