def __init__(self): super(NetWrapper, self).__init__() self.unq = P.Unique() self.add = P.TensorAdd() self.expand_dims = P.ExpandDims() self.cast = P.Cast() self.net = Net()
def __init__(self): super(NetWrapper, self).__init__() self.unq = P.Unique() self.add = P.Add() self.expand_dims = P.ExpandDims() self.cast = P.Cast() self.net = SparseApplyFtrlNet()
def rowtensor_deduplicate_indices_slices(grad): """Unique the indices and sums the 'values' corresponding to the duplicate indices.""" indices = grad.indices values = grad.values unique_indices, index_position = P.Unique()(indices) summed_values = P.UnsortedSegmentSum()(values, index_position, P.DynamicShape()(unique_indices)[0]) return RowTensor(unique_indices, summed_values, grad.dense_shape)
def __init__(self): super().__init__() self.unique = P.Unique().shard(((1, ), )) self.relu = P.ReLU() self.mul = P.Mul() self.embedding_lookp = P.Gather().shard(((8, 1), (1, ))) self.embedding_table = Parameter(initializer( 'normal', [2000, 128]), name='embedding_table') self.gatherv2 = P.Gather().shard(((1, 1), (1, ))) self.reshape = P.Reshape() self.matmul = P.MatMul() self.mul_weight = Parameter(Tensor( np.full([32, 64, 1], 0.5, dtype=np.float32)), name="mul_weight")
def _set_cache_enable(self): """EmbeddingLookup cache check for not ps env, which is only support 'ascend'.""" if self.target != 'DEVICE': raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target.") if not self.sparse: raise ValueError("The configuration of 'vocab_cache_size' is valid only 'sparse' is true.") if get_context("device_target") != 'Ascend': raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'ascend'.") logger.info("EmbeddingLookup cache enable takes effect.") self.forward_unique = True self.unique = P.Unique().add_prim_attr('primitive_target', 'CPU') self.unique.add_prim_attr('cache_enable', True) self.embedding_table.cache_enable = self.cache_enable self.embedding_table.cache_shape = (self.vocab_cache_size, self.embedding_size) self.reshape_first = P.Reshape().add_prim_attr('primitive_target', 'CPU')
def _set_cache_enable(self): """EmbeddingLookup cache check for not ps env.""" if self.target != 'DEVICE': logger.warning( "The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, " "so it will be ignored.") return if not self.sparse: logger.warning( "The configuration of 'vocab_cache_size' is valid only 'sparse' is true, " "so it will be ignored.") return logger.info("EmbeddingLookup cache enable takes effect.") self.forward_unique = True self.unique = P.Unique().add_prim_attr('primitive_target', 'CPU') self.unique.add_prim_attr('cache_enable', True) self.embedding_table.cache_enable = self.cache_enable self.embedding_table.cache_shape = (self.vocab_cache_size, self.embedding_size) self.reshape_first = P.Reshape().add_prim_attr('primitive_target', 'CPU')
def __init__(self): super(UniqueSquare, self).__init__() self.unique = P.Unique() self.square = P.Square()
def __init__(self): super(Net, self).__init__() self.unique = P.Unique()
def __init__(self, config): super(WideDeepModel, self).__init__() self.batch_size = config.batch_size host_device_mix = bool(config.host_device_mix) parameter_server = bool(config.parameter_server) parallel_mode = context.get_auto_parallel_context("parallel_mode") is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) if is_auto_parallel: self.batch_size = self.batch_size * get_group_size() is_field_slice = config.field_slice sparse = config.sparse self.field_size = config.field_size self.vocab_size = config.vocab_size self.vocab_cache_size = config.vocab_cache_size self.emb_dim = config.emb_dim self.deep_layer_dims_list = config.deep_layer_dim self.deep_layer_act = config.deep_layer_act self.init_args = config.init_args self.weight_init, self.bias_init = config.weight_bias_init self.weight_bias_init = config.weight_bias_init self.emb_init = config.emb_init self.drop_out = config.dropout_flag self.keep_prob = config.keep_prob self.deep_input_dims = self.field_size * self.emb_dim self.layer_dims = self.deep_layer_dims_list + [1] self.all_dim_list = [self.deep_input_dims] + self.layer_dims init_acts = [('Wide_b', [1], self.emb_init)] var_map = init_var_dict(self.init_args, init_acts) self.wide_b = var_map["Wide_b"] self.dense_layer_1 = DenseLayer(self.all_dim_list[0], self.all_dim_list[1], self.weight_bias_init, self.deep_layer_act, convert_dtype=True, drop_out=config.dropout_flag) self.dense_layer_2 = DenseLayer(self.all_dim_list[1], self.all_dim_list[2], self.weight_bias_init, self.deep_layer_act, convert_dtype=True, drop_out=config.dropout_flag) self.dense_layer_3 = DenseLayer(self.all_dim_list[2], self.all_dim_list[3], self.weight_bias_init, self.deep_layer_act, convert_dtype=True, drop_out=config.dropout_flag) self.dense_layer_4 = DenseLayer(self.all_dim_list[3], self.all_dim_list[4], self.weight_bias_init, self.deep_layer_act, convert_dtype=True, drop_out=config.dropout_flag) self.dense_layer_5 = DenseLayer(self.all_dim_list[4], self.all_dim_list[5], self.weight_bias_init, self.deep_layer_act, use_activation=False, convert_dtype=True, drop_out=config.dropout_flag) self.wide_mul = P.Mul() self.deep_mul = P.Mul() self.reduce_sum = P.ReduceSum(keep_dims=False) self.reshape = P.Reshape() self.deep_reshape = P.Reshape() self.square = P.Square() self.shape = P.Shape() self.tile = P.Tile() self.concat = P.Concat(axis=1) self.cast = P.Cast() self.unique = P.Unique().shard(((1, ), )) self.wide_gatherv2 = P.GatherV2() self.deep_gatherv2 = P.GatherV2() if is_auto_parallel and sparse and not is_field_slice: target = 'DEVICE' if host_device_mix: target = 'CPU' self.wide_embeddinglookup = nn.EmbeddingLookup( self.vocab_size, 1, target=target, slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE) if config.deep_table_slice_mode == "column_slice": self.deep_embeddinglookup = nn.EmbeddingLookup( self.vocab_size, self.emb_dim, target=target, slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE) self.dense_layer_1.dropout.dropout.shard( ((1, get_group_size()), )) self.dense_layer_1.matmul.shard( ((1, get_group_size()), (get_group_size(), 1))) self.dense_layer_1.matmul.add_prim_attr( "field_size", self.field_size) self.deep_mul.shard(((1, 1, get_group_size()), (1, 1, 1))) self.deep_reshape.add_prim_attr("skip_redistribution", True) else: self.deep_embeddinglookup = nn.EmbeddingLookup( self.vocab_size, self.emb_dim, target=target, slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE) self.reduce_sum.add_prim_attr("cross_batch", True) self.embedding_table = self.deep_embeddinglookup.embedding_table elif is_auto_parallel and host_device_mix and is_field_slice and config.full_batch and config.manual_shape: manual_shapes = tuple((s[0] for s in config.manual_shape)) self.deep_embeddinglookup = nn.EmbeddingLookup( self.vocab_size, self.emb_dim, slice_mode=nn.EmbeddingLookup.FIELD_SLICE, manual_shapes=manual_shapes) self.wide_embeddinglookup = nn.EmbeddingLookup( self.vocab_size, 1, slice_mode=nn.EmbeddingLookup.FIELD_SLICE, manual_shapes=manual_shapes) self.deep_mul.shard( ((1, get_group_size(), 1), (1, get_group_size(), 1))) self.wide_mul.shard( ((1, get_group_size(), 1), (1, get_group_size(), 1))) self.reduce_sum.shard(((1, get_group_size(), 1), )) self.dense_layer_1.dropout.dropout.shard(((1, get_group_size()), )) self.dense_layer_1.matmul.shard( ((1, get_group_size()), (get_group_size(), 1))) self.embedding_table = self.deep_embeddinglookup.embedding_table elif parameter_server: cache_enable = self.vocab_cache_size > 0 target = 'DEVICE' if cache_enable else 'CPU' if not cache_enable: sparse = True if is_auto_parallel and config.full_batch and cache_enable: self.deep_embeddinglookup = nn.EmbeddingLookup( self.vocab_size, self.emb_dim, target=target, slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, sparse=sparse, vocab_cache_size=self.vocab_cache_size) self.wide_embeddinglookup = nn.EmbeddingLookup( self.vocab_size, 1, target=target, slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, sparse=sparse, vocab_cache_size=self.vocab_cache_size) else: self.deep_embeddinglookup = nn.EmbeddingLookup( self.vocab_size, self.emb_dim, target=target, sparse=sparse, vocab_cache_size=self.vocab_cache_size) self.wide_embeddinglookup = nn.EmbeddingLookup( self.vocab_size, 1, target=target, sparse=sparse, vocab_cache_size=self.vocab_cache_size) self.embedding_table = self.deep_embeddinglookup.embedding_table self.deep_embeddinglookup.embedding_table.set_param_ps() self.wide_embeddinglookup.embedding_table.set_param_ps() else: self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target='DEVICE', sparse=sparse) self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target='DEVICE', sparse=sparse) self.embedding_table = self.deep_embeddinglookup.embedding_table
def __init__(self, vocab_size, embedding_size, param_init='normal', target='CPU', slice_mode='batch_slice', manual_shapes=None, max_norm=None, sparse=True, vocab_cache_size=0): super(EmbeddingLookup, self).__init__() validator.check_value_type('sparse', sparse, [bool], self.cls_name) self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size') self.vocab_cache_size = validator.check_non_negative_int( vocab_cache_size, 'vocab_cache_size') self.target = target self.sparse = sparse self.cache_enable = self.vocab_cache_size > 0 self.forward_unique = False if target not in ('CPU', 'DEVICE'): raise ValueError( 'Attr \'target\' of \'EmbeddingLookup\' Op passed ' + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.') if not sparse and target == 'CPU': raise ValueError( 'When target is CPU, embedding_lookup must be sparse.') if sparse: self.gatherv2 = P.SparseGatherV2() else: self.gatherv2 = P.Gather() self.embeddinglookup = P.EmbeddingLookup().add_prim_attr( 'primitive_target', 'CPU') enable_ps = _get_ps_context("enable_ps") if enable_ps: self._process_vocab_cache(slice_mode) self.embedding_size = validator.check_positive_int( embedding_size, 'embedding_size') self.embedding_table = Parameter(initializer( param_init, [self.vocab_size, self.embedding_size]), name='embedding_table') parallel_mode = _get_parallel_mode() is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) self.gather_revert = P.Gather() self.reshape_first = P.Reshape() self.reshape = P.Reshape() self.unique = P.Unique() self.shape = P.Shape() if is_auto_parallel: self.unique = P.Unique().shard(((1, ), )) if self.cache_enable and enable_ps: self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size) if is_auto_parallel: self.unique.add_prim_attr('cache_enable', True) indices_shape_size = 2 if slice_mode == "field_slice" and is_auto_parallel: if not manual_shapes: raise ValueError( "in slice field mode, the manual_shapes should not be none" ) if not isinstance(manual_shapes, tuple): raise TypeError( "manual_shapes type must be tuple(int) cannot be {}!". format(type(manual_shapes))) for dim in manual_shapes: validator.check_positive_int(dim, 'manual shape dim', self.cls_name) self.gatherv2.add_prim_attr("manual_split", manual_shapes) self.embeddinglookup.add_prim_attr("manual_split", manual_shapes) self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size()))) self.embeddinglookup.shard( ((get_group_size(), 1), (1, get_group_size()))) elif slice_mode == "table_row_slice" and is_auto_parallel: full_batch = _get_full_batch() if (target == 'DEVICE' and not full_batch) or (self.cache_enable and enable_ps and sparse): indices_shape_size = 1 self.gather_revert.shard(((1, 1), (get_group_size(), ))) self.forward_unique = True indices_strategy = (1, ) * indices_shape_size self.gatherv2.shard(((get_group_size(), 1), indices_strategy)) self.embeddinglookup.shard( ((get_group_size(), 1), indices_strategy)) elif slice_mode == "table_column_slice" and is_auto_parallel: if target == 'DEVICE': indices_shape_size = 1 self.gather_revert.shard(((1, get_group_size()), (1, ))) self.forward_unique = True indices_strategy = (1, ) * indices_shape_size self.gatherv2.shard(((1, get_group_size()), indices_strategy)) self.embeddinglookup.shard( ((1, get_group_size()), indices_strategy)) elif slice_mode == "batch_slice" and is_auto_parallel: indices_strategy = [get_group_size()] indices_strategy.extend([1] * (indices_shape_size - 1)) indices_strategy = tuple(indices_strategy) self.gatherv2.shard(((1, 1), indices_strategy)) self.embeddinglookup.shard(((1, 1), indices_strategy)) else: if is_auto_parallel: raise ValueError( "slice_mode should support mode in nn.EmbeddingLookup, but get " + str(slice_mode)) if self.cache_enable and not enable_ps: if parallel_mode != ParallelMode.STAND_ALONE: raise ValueError( "parallel mode haven't supported cache enable yet.") self._set_cache_enable() self.embedding_table.unique = self.forward_unique self.max_norm = max_norm if self.max_norm is not None: self.max_norm = validator.check_positive_float( self.max_norm, 'max_norm', self.cls_name) self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
def __init__(self, axis=0): super(Net, self).__init__() self.unique = P.Unique() self.reshape = P.Reshape() self.concat = P.Concat(axis=axis)
def __init__(self): super(Net, self).__init__() self.unique = P.Unique() self.dynamic_assign = P.DynamicAssign() self.param = Parameter(Tensor(np.zeros((5, ), np.int32)), name="assign_x")
def __init__(self): super(Net, self).__init__() self.unique = P.Unique().add_prim_attr("primitive_target", "CPU")
def __init__(self): super(UniqueSquare, self).__init__() self.unique = P.Unique().add_prim_attr("primitive_target", "CPU") self.square = P.Square()
def __init__(self): super(Net, self).__init__() self.unq = P.Unique() self.gather = P.Gather() self.yy = Tensor(np.ones([8], dtype=np.int32))
def __init__(self): super(Net, self).__init__() self.unq = P.Unique() self.gather = P.GatherV2()
def __init__(self): super(Net, self).__init__() self.unq = P.Unique() self.addn = P.AddN()
def __init__(self): super(Net, self).__init__() self.unq = P.Unique() self.segment_ids = Tensor([0, 0, 1, 2, 1, 1, 1, 1], mstype.int32) self.sum = P.UnsortedSegmentSum()
def __init__(self): super(NetUniqueDynamic, self).__init__() self.convert = inner.GpuConvertToDynamicShape() self.unique = P.Unique() self.split = P.Split(0, 2)