def test_rc_pynative_fp16_int32():
    context.set_context(mode=context.PYNATIVE_MODE, device_target=TARGET)

    x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float16)
    num_sample = 10
    seed = 5
    dtype = ms.int32
    expect = np.array(
        [[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]],
        dtype=np.int32)

    output = P.RandomCategorical(dtype)(x, num_sample, seed)
    diff = output.asnumpy() - expect
    assert expect.dtype == output.asnumpy().dtype
    assert np.all(diff == 0)
 def __init__(self, num_sample, seed=0, dtype=ms.int64):
     super(RCnet, self).__init__()
     self.rc = P.RandomCategorical(dtype)
     self.num_sample = num_sample
     self.seed = seed
    def __init__(self,
                 decoder:Model,
                 model_config=None,
                 generate_length:int=1,
                 tokenizer:Optional[GPT2Tokenizer]=None,
                 topk_num:int=0,
                 topp_prob:float=1.0,
                 temperature:float=1.0,
                 min_tokens_to_keep:int=1,
                 early_stop:bool=False,
                 demo_mode:bool=False,
                 return_ids:bool=False,
                 return_last_token_logits:bool=False,
                 append_eos:bool=False):

       
        assert model_config is not None, 'Config is a must for sampling.'
        
        self.model_config = model_config
        self.topk_num = topk_num
        self.topp_prob = topp_prob
        self.temperature = temperature
        self.min_tokens_to_keep = min_tokens_to_keep
        
        self.decoder = decoder
        self.tokenizer = tokenizer
        self.reshape = P.Reshape()
        self.cumsum = P.CumSum()
        self.onehot = P.OneHot()
        self.generate_length = generate_length
        self.seq_length = model_config.seq_length
        self.batch_size = model_config.batch_size
        self.vocab_size = model_config.vocab_size
        
        self.on_value = Tensor(1.0, mstype.float32)
        self.off_value = Tensor(0.0, mstype.float32)
        self.cast = P.Cast()
        self.concat = P.Concat()
        self.early_stop = early_stop
        self.demo_mode = demo_mode
        self.return_ids = return_ids
        self.return_last_token_logits = return_last_token_logits
        self.append_eos = append_eos
        self.device_target = get_context("device_target")

        #different choice of sample function for adjusting several device target types
        if self.device_target == "GPU":
            self.sample_function = P.Multinomial(seed=1)
        elif self.device_target == "Ascend":
            self.sample_function = P.RandomCategorical(mstype.int32)
        else:
            raise NotImplementedError("Device Target {} not supported.".format(self.device_target))

        self.filter_distribution = TopKTopP_Filter(self.batch_size,
                                                   self.vocab_size,
                                                   k=self.topk_num,
                                                   p=self.topp_prob,
                                                   temperature=self.temperature,
                                                   min_tokens_to_keep=self.min_tokens_to_keep)

        if self.tokenizer is not None:
            self.eos_id = self.tokenizer.eos_token_id
        else:
            self.eos_id = model_config.vocab_size-1

        if self.tokenizer is not None:
            self.eos_text = self.tokenizer.eos_token
        else:
            self.eos_text = "<|endoftext|>"

        if self.demo_mode is True:
            assert self.batch_size == 1, 'Demo mode requires batchsize euqals to 1, but get batch_size={}'.format(
                self.batch_size)
Exemple #4
0
 def __init__(self, num_sample):
     super(RandomCategoricalNet, self).__init__()
     self.random_categorical = P.RandomCategorical(mstype.int64)
     self.num_sample = num_sample
Exemple #5
0
    def __init__(
        self,
        decoder,
        config=None,
        batch_size=None,
        tokenizer=None,
        generate_length=1,
        topk_num=0,
        topp_prob=1.0,
        temperature=1.0,
        min_tokens_to_keep=1,
        early_stop=False,
        demo_mode=False,
        return_ids=False,
        return_last_token_logits=False,
        append_eos=False,
    ):

        assert config is not None, 'Config is a must for sampling.'

        self.decoder = decoder
        self.config = config
        self.tokenizer = tokenizer
        self.generate_length = generate_length
        self.topk_num = topk_num
        self.topp_prob = topp_prob
        self.temperature = temperature
        self.min_tokens_to_keep = min_tokens_to_keep
        self.early_stop = early_stop
        self.demo_mode = demo_mode
        self.return_ids = return_ids
        self.return_last_token_logits = return_last_token_logits
        self.append_eos = append_eos

        self.seq_length = config.seq_length
        self.batch_size = config.batch_size if batch_size is None else batch_size
        self.vocab_size = config.vocab_size

        self.on_value = Tensor(1.0, mstype.float32)
        self.off_value = Tensor(0.0, mstype.float32)
        self.reshape = P.Reshape()
        self.cumsum = P.CumSum()
        self.onehot = P.OneHot()
        self.cast = P.Cast()
        self.concat = P.Concat()
        self.sample_function = P.RandomCategorical(mstype.int32)
        self.filter_distribution = TopKTopP_Filter(
            batch_size=self.batch_size,
            vocab_size=self.vocab_size,
            k=self.topk_num,
            p=self.topp_prob,
            temperature=self.temperature,
            min_tokens_to_keep=self.min_tokens_to_keep)

        if self.tokenizer is not None:
            self.eos_id = self.tokenizer.eos_token_id
        else:
            self.eos_id = config.vocab_size - 1

        if self.tokenizer is not None:
            self.eos_text = self.tokenizer.eos_token
        else:
            self.eos_text = "<|endoftext|>"

        if self.demo_mode is True:
            assert self.batch_size == 1, 'Demo mode requires batchsize euqals to 1, but get batch_size={}'.format(
                self.batch_size)