Beispiel #1
0
 def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size):
     """PS embeddingLookup cache enable set."""
     self.embedding_table.cache_enable = True
     self.embedding_table.is_param_ps = True
     _set_cache_enable(True)
     if _is_role_worker():
         _insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size)
Beispiel #2
0
    def __init__(self,
                 vocab_size,
                 embedding_size,
                 param_init='normal',
                 target='CPU',
                 slice_mode='batch_slice',
                 manual_shapes=None,
                 max_norm=None,
                 sparse=True,
                 vocab_cache_size=0):
        super(EmbeddingLookup, self).__init__()
        validator.check_value_type('sparse', sparse, [bool], self.cls_name)
        self.target = target
        if target not in ('CPU', 'DEVICE'):
            raise ValueError(
                'Attr \'target\' of \'EmbeddingLookup\' Op passed ' +
                str(target) +
                ', should be one of values in \'CPU\', \'DEVICE\'.')
        if not sparse and target == 'CPU':
            raise ValueError(
                'When target is CPU, embedding_lookup must be sparse.')
        enable_ps = context.get_ps_context("enable_ps")
        if not enable_ps and vocab_cache_size > 0:
            logger.warning(
                "The configuration of 'vocab_cache_size' is valid only in parameter server trainning mode, "
                "current mode is not parameter server trainning mode, so it will be ignored."
            )
            vocab_cache_size = 0
        if sparse:
            self.gatherv2 = P.SparseGatherV2()
        else:
            self.gatherv2 = P.GatherV2()
        self.embeddinglookup = P.EmbeddingLookup().add_prim_attr(
            'primitive_target', 'CPU')
        self.vocab_size = validator.check_positive_int(vocab_size,
                                                       'vocab_size')
        self.vocab_cache_size = validator.check_non_negative_int(
            vocab_cache_size, 'vocab_cache_size')
        self.embedding_size = validator.check_positive_int(
            embedding_size, 'embedding_size')
        parallel_mode = _get_parallel_mode()
        is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL,
                                             ParallelMode.AUTO_PARALLEL)
        self.cache_enable = self.vocab_cache_size > 0
        if self.cache_enable:
            if is_auto_parallel:
                self.vocab_cache_size = self.vocab_cache_size * get_group_size(
                )
            self.vocab_size = self.vocab_cache_size

        self.embedding_table = Parameter(initializer(
            param_init, [self.vocab_size, self.embedding_size]),
                                         name='embedding_table')
        if self.cache_enable:
            self.embedding_table.cache_enable = True
            _set_cache_enable(True)
            if _is_role_worker():
                _insert_hash_table_size(self.embedding_table.name,
                                        vocab_cache_size, embedding_size,
                                        vocab_size)
        self.forward_unique = False
        self.gather_revert = P.GatherV2()
        self.unique = P.Unique().shard(((1, ), ))
        self.reshape = P.Reshape()
        self.shape = P.Shape()
        indices_shape_size = 2
        if slice_mode == "field_slice" and is_auto_parallel:
            if not manual_shapes:
                raise ValueError(
                    "in slice field mode, the manual_shapes should not be none"
                )
            if not isinstance(manual_shapes, tuple):
                raise TypeError(
                    "manual_shapes type must be tuple(int) cannot be {}!".
                    format(type(manual_shapes)))
            for dim in manual_shapes:
                validator.check_positive_int(dim, 'manual shape dim',
                                             self.cls_name)
            self.gatherv2.add_prim_attr("manual_split", manual_shapes)
            self.embeddinglookup.add_prim_attr("manual_split", manual_shapes)
            self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
            self.embeddinglookup.shard(
                ((get_group_size(), 1), (1, get_group_size())))
        elif slice_mode == "table_row_slice" and is_auto_parallel:
            if target == 'DEVICE' and not self.cache_enable:
                indices_shape_size = 1
                self.gather_revert.shard(((1, 1), (get_group_size(), )))
                self.forward_unique = True
            indices_strategy = (1, ) * indices_shape_size
            self.gatherv2.shard(((get_group_size(), 1), indices_strategy))
            self.embeddinglookup.shard(
                ((get_group_size(), 1), indices_strategy))
        elif slice_mode == "table_column_slice" and is_auto_parallel:
            if target == 'DEVICE':
                indices_shape_size = 1
                self.gather_revert.shard(((1, get_group_size()), (1, )))
                self.forward_unique = True
            indices_strategy = (1, ) * indices_shape_size
            self.gatherv2.shard(((1, get_group_size()), indices_strategy))
            self.embeddinglookup.shard(
                ((1, get_group_size()), indices_strategy))
        elif slice_mode == "batch_slice" and is_auto_parallel:
            indices_strategy = [get_group_size()]
            indices_strategy.extend([1] * (indices_shape_size - 1))
            indices_strategy = tuple(indices_strategy)
            self.gatherv2.shard(((1, 1), indices_strategy))
            self.embeddinglookup.shard(((1, 1), indices_strategy))
        else:
            if is_auto_parallel:
                raise ValueError(
                    "slice_mode should support mode in nn.EmbeddingLookup, but get "
                    + str(slice_mode))
        self.embedding_table.unique = self.forward_unique
        self.max_norm = max_norm
        if self.max_norm is not None:
            self.max_norm = validator.check_positive_float(
                self.max_norm, 'max_norm', self.cls_name)
            self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)