Пример #1
0
 def comparison_ops(self):
     a = torch.randn(4)
     b = torch.randn(4)
     return (
         torch.allclose(a, b),
         torch.argsort(a),
         torch.eq(a, b),
         torch.equal(a, b),
         torch.ge(a, b),
         torch.greater_equal(a, b),
         torch.gt(a, b),
         torch.greater(a, b),
         torch.isclose(a, b),
         torch.isfinite(a),
         torch.isin(a, b),
         torch.isinf(a),
         torch.isposinf(a),
         torch.isneginf(a),
         torch.isnan(a),
         torch.isreal(a),
         torch.kthvalue(a, 1),
         torch.le(a, b),
         torch.less_equal(a, b),
         torch.lt(a, b),
         torch.less(a, b),
         torch.maximum(a, b),
         torch.minimum(a, b),
         torch.fmax(a, b),
         torch.fmin(a, b),
         torch.ne(a, b),
         torch.not_equal(a, b),
         torch.sort(a),
         torch.topk(a, 1),
         torch.msort(a),
     )
    def test_forced_bos_token_logits_processor(self):
        vocab_size = 20
        batch_size = 4
        bos_token_id = 0

        logits_processor = ForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)

        # check that all scores are -inf except the bos_token_id score
        input_ids = ids_tensor((batch_size, 1), vocab_size=20)
        scores = self._get_uniform_logits(batch_size, vocab_size)
        scores = logits_processor(input_ids, scores)
        self.assertTrue(torch.isneginf(scores[:, bos_token_id + 1 :]).all())
        self.assertListEqual(scores[:, bos_token_id].tolist(), 4 * [0])  # score for bos_token_id shold be zero

        # check that bos_token_id is not forced if current length is greater than 1
        input_ids = ids_tensor((batch_size, 4), vocab_size=20)
        scores = self._get_uniform_logits(batch_size, vocab_size)
        scores = logits_processor(input_ids, scores)
        self.assertFalse(torch.isinf(scores).any())
Пример #3
0
 def forward(self):
     a = torch.tensor(0)
     b = torch.tensor(1)
     return len(
         torch.allclose(a, b),
         torch.argsort(a),
         torch.eq(a, b),
         torch.eq(a, 1),
         torch.equal(a, b),
         torch.ge(a, b),
         torch.ge(a, 1),
         torch.greater_equal(a, b),
         torch.greater_equal(a, 1),
         torch.gt(a, b),
         torch.gt(a, 1),
         torch.greater(a, b),
         torch.isclose(a, b),
         torch.isfinite(a),
         torch.isin(a, b),
         torch.isinf(a),
         torch.isposinf(a),
         torch.isneginf(a),
         torch.isnan(a),
         torch.isreal(a),
         torch.kthvalue(a, 1),
         torch.le(a, b),
         torch.le(a, 1),
         torch.less_equal(a, b),
         torch.lt(a, b),
         torch.lt(a, 1),
         torch.less(a, b),
         torch.maximum(a, b),
         torch.minimum(a, b),
         torch.fmax(a, b),
         torch.fmin(a, b),
         torch.ne(a, b),
         torch.ne(a, 1),
         torch.not_equal(a, b),
         torch.sort(a),
         torch.topk(a, 1),
         torch.msort(a),
     )
    def test_forced_eos_token_logits_processor(self):
        vocab_size = 20
        batch_size = 4
        eos_token_id = 0
        max_length = 5

        logits_processor = ForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)

        # check that all scores are -inf except the eos_token_id when max_length is reached
        input_ids = ids_tensor((batch_size, 4), vocab_size=20)
        scores = self._get_uniform_logits(batch_size, vocab_size)
        scores = logits_processor(input_ids, scores)
        self.assertTrue(torch.isneginf(scores[:, eos_token_id + 1 :]).all())
        self.assertListEqual(scores[:, eos_token_id].tolist(), 4 * [0])  # score for eos_token_id should be zero

        # check that eos_token_id is not forced if max_length is not reached
        input_ids = ids_tensor((batch_size, 3), vocab_size=20)
        scores = self._get_uniform_logits(batch_size, vocab_size)
        scores = logits_processor(input_ids, scores)
        self.assertFalse(torch.isinf(scores).any())
Пример #5
0
    def forward(self):
        builder = flatbuffers.Builder(64)

        # construct MessageBody
        ppx_Run.RunStart(builder)
        message_body = ppx_Run.RunEnd(builder)

        # construct Message
        ppx_Message.MessageStart(builder)
        ppx_Message.MessageAddBodyType(builder,
                                       ppx_MessageBody.MessageBody().Run)
        ppx_Message.MessageAddBody(builder, message_body)
        message = ppx_Message.MessageEnd(builder)
        builder.Finish(message)

        message = builder.Output()
        self._requester.send_request(message)

        while True:
            reply = self._requester.receive_reply()
            message_body = self._get_message_body(reply)

            if isinstance(message_body, ppx_RunResult.RunResult):
                result = self._protocol_tensor_to_variable(
                    message_body.Result())
                return result
            elif isinstance(message_body, ppx_Sample.Sample):
                address = message_body.Address().decode('utf-8')
                name = message_body.Name().decode('utf-8')
                if name == '':
                    name = None
                control = bool(message_body.Control())
                replace = bool(message_body.Replace())
                distribution_type = message_body.DistributionType()
                if distribution_type == ppx_Distribution.Distribution(
                ).Uniform:
                    uniform = ppx_Uniform.Uniform()
                    uniform.Init(message_body.Distribution().Bytes,
                                 message_body.Distribution().Pos)
                    low = self._protocol_tensor_to_variable(uniform.Low())
                    high = self._protocol_tensor_to_variable(uniform.High())
                    dist = Uniform(low, high)
                elif distribution_type == ppx_Distribution.Distribution(
                ).Normal:
                    normal = ppx_Normal.Normal()
                    normal.Init(message_body.Distribution().Bytes,
                                message_body.Distribution().Pos)
                    mean = self._protocol_tensor_to_variable(normal.Mean())
                    stddev = self._protocol_tensor_to_variable(normal.Stddev())
                    dist = Normal(mean, stddev)
                elif distribution_type == ppx_Distribution.Distribution(
                ).Categorical:
                    categorical = ppx_Categorical.Categorical()
                    categorical.Init(message_body.Distribution().Bytes,
                                     message_body.Distribution().Pos)
                    probs = self._protocol_tensor_to_variable(
                        categorical.Probs())
                    dist = Categorical(probs)
                elif distribution_type == ppx_Distribution.Distribution(
                ).Poisson:
                    poisson = ppx_Poisson.Poisson()
                    poisson.Init(message_body.Distribution().Bytes,
                                 message_body.Distribution().Pos)
                    rate = self._protocol_tensor_to_variable(poisson.Rate())
                    dist = Poisson(rate)
                elif distribution_type == ppx_Distribution.Distribution(
                ).Bernoulli:
                    bernoulli = ppx_Bernoulli.Bernoulli()
                    bernoulli.Init(message_body.Distribution().Bytes,
                                   message_body.Distribution().Pos)
                    probs = self._protocol_tensor_to_variable(
                        bernoulli.Probs())
                    dist = Bernoulli(probs)
                elif distribution_type == ppx_Distribution.Distribution().Beta:
                    beta = ppx_Beta.Beta()
                    beta.Init(message_body.Distribution().Bytes,
                              message_body.Distribution().Pos)
                    concentration1 = self._protocol_tensor_to_variable(
                        beta.Concentration1())
                    concentration0 = self._protocol_tensor_to_variable(
                        beta.Concentration0())
                    dist = Beta(concentration1, concentration0)
                elif distribution_type == ppx_Distribution.Distribution(
                ).Exponential:
                    exponential = ppx_Exponential.Exponential()
                    exponential.Init(message_body.Distribution().Bytes,
                                     message_body.Distribution().Pos)
                    rate = self._protocol_tensor_to_variable(
                        exponential.Rate())
                    dist = Exponential(rate)
                elif distribution_type == ppx_Distribution.Distribution(
                ).Gamma:
                    gamma = ppx_Gamma.Gamma()
                    gamma.Init(message_body.Distribution().Bytes,
                               message_body.Distribution().Pos)
                    concentration = self._protocol_tensor_to_variable(
                        gamma.Concentration())
                    rate = self._protocol_tensor_to_variable(gamma.Rate())
                    dist = Gamma(concentration, rate)
                elif distribution_type == ppx_Distribution.Distribution(
                ).LogNormal:
                    log_normal = ppx_LogNormal.LogNormal()
                    log_normal.Init(message_body.Distribution().Bytes,
                                    message_body.Distribution().Pos)
                    loc = self._protocol_tensor_to_variable(log_normal.Loc())
                    scale = self._protocol_tensor_to_variable(
                        log_normal.Scale())
                    dist = LogNormal(loc, scale)
                elif distribution_type == ppx_Distribution.Distribution(
                ).Binomial:
                    binomial = ppx_Binomial.Binomial()
                    binomial.Init(message_body.Distribution().Bytes,
                                  message_body.Distribution().Pos)
                    total_count = self._protocol_tensor_to_variable(
                        binomial.TotalCount())
                    probs = self._protocol_tensor_to_variable(binomial.Probs())
                    dist = Binomial(total_count, probs)
                elif distribution_type == ppx_Distribution.Distribution(
                ).Weibull:
                    weibull = ppx_Weibull.Weibull()
                    weibull.Init(message_body.Distribution().Bytes,
                                 message_body.Distribution().Pos)
                    scale = self._protocol_tensor_to_variable(weibull.Scale())
                    concentration = self._protocol_tensor_to_variable(
                        weibull.Concentration())
                    dist = Weibull(scale, concentration)
                else:
                    raise RuntimeError(
                        'ppx (Python): Sample from an unexpected distribution requested.'
                    )
                result = state.sample(distribution=dist,
                                      control=control,
                                      name=name,
                                      address=address)
                builder = flatbuffers.Builder(64)
                result = self._variable_to_protocol_tensor(builder, result)
                ppx_SampleResult.SampleResultStart(builder)
                ppx_SampleResult.SampleResultAddResult(builder, result)
                message_body = ppx_SampleResult.SampleResultEnd(builder)

                # construct Message
                ppx_Message.MessageStart(builder)
                ppx_Message.MessageAddBodyType(
                    builder,
                    ppx_MessageBody.MessageBody().SampleResult)
                ppx_Message.MessageAddBody(builder, message_body)
                message = ppx_Message.MessageEnd(builder)
                builder.Finish(message)

                message = builder.Output()
                self._requester.send_request(message)
            elif isinstance(message_body, ppx_Observe.Observe):
                address = message_body.Address().decode('utf-8')
                name = message_body.Name().decode('utf-8')
                if name == '':
                    name = None
                value = self._protocol_tensor_to_variable(message_body.Value())
                distribution_type = message_body.DistributionType()
                if distribution_type == ppx_Distribution.Distribution().NONE:
                    dist = None
                elif distribution_type == ppx_Distribution.Distribution(
                ).Uniform:
                    uniform = ppx_Uniform.Uniform()
                    uniform.Init(message_body.Distribution().Bytes,
                                 message_body.Distribution().Pos)
                    low = self._protocol_tensor_to_variable(uniform.Low())
                    high = self._protocol_tensor_to_variable(uniform.High())
                    dist = Uniform(low, high)
                elif distribution_type == ppx_Distribution.Distribution(
                ).Normal:
                    normal = ppx_Normal.Normal()
                    normal.Init(message_body.Distribution().Bytes,
                                message_body.Distribution().Pos)
                    mean = self._protocol_tensor_to_variable(normal.Mean())
                    stddev = self._protocol_tensor_to_variable(normal.Stddev())
                    dist = Normal(mean, stddev)
                elif distribution_type == ppx_Distribution.Distribution(
                ).Categorical:
                    categorical = ppx_Categorical.Categorical()
                    categorical.Init(message_body.Distribution().Bytes,
                                     message_body.Distribution().Pos)
                    probs = self._protocol_tensor_to_variable(
                        categorical.Probs())
                    dist = Categorical(probs)
                elif distribution_type == ppx_Distribution.Distribution(
                ).Poisson:
                    poisson = ppx_Poisson.Poisson()
                    poisson.Init(message_body.Distribution().Bytes,
                                 message_body.Distribution().Pos)
                    rate = self._protocol_tensor_to_variable(poisson.Rate())
                    dist = Poisson(rate)
                elif distribution_type == ppx_Distribution.Distribution(
                ).Bernoulli:
                    bernoulli = ppx_Bernoulli.Bernoulli()
                    bernoulli.Init(message_body.Distribution().Bytes,
                                   message_body.Distribution().Pos)
                    probs = self._protocol_tensor_to_variable(
                        bernoulli.Probs())
                    dist = Bernoulli(probs)
                elif distribution_type == ppx_Distribution.Distribution().Beta:
                    beta = ppx_Beta.Beta()
                    beta.Init(message_body.Distribution().Bytes,
                              message_body.Distribution().Pos)
                    concentration1 = self._protocol_tensor_to_variable(
                        beta.Concentration1())
                    concentration0 = self._protocol_tensor_to_variable(
                        beta.Concentration0())
                    dist = Beta(concentration1, concentration0)
                elif distribution_type == ppx_Distribution.Distribution(
                ).Exponential:
                    exponential = ppx_Exponential.Exponential()
                    exponential.Init(message_body.Distribution().Bytes,
                                     message_body.Distribution().Pos)
                    rate = self._protocol_tensor_to_variable(
                        exponential.Rate())
                    dist = Exponential(rate)
                elif distribution_type == ppx_Distribution.Distribution(
                ).Gamma:
                    gamma = ppx_Gamma.Gamma()
                    gamma.Init(message_body.Distribution().Bytes,
                               message_body.Distribution().Pos)
                    concentration = self._protocol_tensor_to_variable(
                        gamma.Concentration())
                    rate = self._protocol_tensor_to_variable(gamma.Rate())
                    dist = Gamma(concentration, rate)
                elif distribution_type == ppx_Distribution.Distribution(
                ).LogNormal:
                    log_normal = ppx_LogNormal.LogNormal()
                    log_normal.Init(message_body.Distribution().Bytes,
                                    message_body.Distribution().Pos)
                    loc = self._protocol_tensor_to_variable(log_normal.Loc())
                    scale = self._protocol_tensor_to_variable(
                        log_normal.Scale())
                    dist = LogNormal(loc, scale)
                elif distribution_type == ppx_Distribution.Distribution(
                ).Binomial:
                    binomial = ppx_Binomial.Binomial()
                    binomial.Init(message_body.Distribution().Bytes,
                                  message_body.Distribution().Pos)
                    total_count = self._protocol_tensor_to_variable(
                        binomial.TotalCount())
                    probs = self._protocol_tensor_to_variable(binomial.Probs())
                    dist = Binomial(total_count, probs)
                elif distribution_type == ppx_Distribution.Distribution(
                ).Weibull:
                    weibull = ppx_Weibull.Weibull()
                    weibull.Init(message_body.Distribution().Bytes,
                                 message_body.Distribution().Pos)
                    scale = self._protocol_tensor_to_variable(weibull.Scale())
                    concentration = self._protocol_tensor_to_variable(
                        weibull.Concentration())
                    dist = Weibull(scale, concentration)
                else:
                    raise RuntimeError(
                        'ppx (Python): Sample from an unexpected distribution requested: {}'
                        .format(distribution_type))

                result = state.observe(distribution=dist,
                                       value=value,
                                       name=name,
                                       address=address)
                if self._kill_on_zero_likelihood and result is not None and torch.any(
                        torch.isneginf(dist.log_prob(result, sum=True))):
                    # result is None if the observed value is not provided (neither as an observe argument nor in the observation dictionary)
                    raise ZeroLikelihoodException(name)
                builder = flatbuffers.Builder(64)
                ppx_ObserveResult.ObserveResultStart(builder)
                message_body = ppx_ObserveResult.ObserveResultEnd(builder)

                # construct Message
                ppx_Message.MessageStart(builder)
                ppx_Message.MessageAddBodyType(
                    builder,
                    ppx_MessageBody.MessageBody().ObserveResult)
                ppx_Message.MessageAddBody(builder, message_body)
                message = ppx_Message.MessageEnd(builder)
                builder.Finish(message)

                message = builder.Output()
                self._requester.send_request(message)
            elif isinstance(message_body, ppx_Tag.Tag):
                address = message_body.Address().decode('utf-8')
                name = message_body.Name().decode('utf-8')
                if name == '':
                    name = None
                value = self._protocol_tensor_to_variable(message_body.Value())
                state.tag(value=value, name=name, address=address)
                builder = flatbuffers.Builder(64)
                ppx_TagResult.TagResultStart(builder)
                message_body = ppx_TagResult.TagResultEnd(builder)

                # construct Message
                ppx_Message.MessageStart(builder)
                ppx_Message.MessageAddBodyType(
                    builder,
                    ppx_MessageBody.MessageBody().TagResult)
                ppx_Message.MessageAddBody(builder, message_body)
                message = ppx_Message.MessageEnd(builder)
                builder.Finish(message)

                message = builder.Output()
                self._requester.send_request(message)
            elif isinstance(message_body, ppx_Reset.Reset):
                raise RuntimeError(
                    'ppx (Python): Received a reset request. Protocol out of sync.'
                )
            else:
                raise RuntimeError(
                    'ppx (Python): Received unexpected message.')
Пример #6
0
    # print(torch.inner(c, d))
    """flip"""
    x = torch.arange(4).view(2, 2)
    print(torch.flipud(x))
    print(torch.fliplr(x))

    # logical
    print("logical function:")
    print(torch.eq(c, d))
    print(torch.ne(c, d))
    print(torch.gt(c, d))
    print(torch.logical_and(c, d))
    print(torch.logical_or(c, d))
    print(torch.logical_xor(c, d))
    print(torch.logical_not(c))
    print(torch.equal(c, d))  # if all equal
    a = torch.rand(2, 2).bool()
    print(a)
    print(torch.all(a))
    print(torch.all(a, dim=0))  # 按列
    print(torch.all(a, dim=1))  # 按行
    print(torch.any(a))
    print(torch.any(a, dim=0))  # 按列
    print(torch.any(a, dim=1))  # 按行
    to_test = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])
    print(torch.isfinite(to_test))
    print(torch.isinf(to_test))
    print(torch.isposinf(to_test))
    print(torch.isneginf(to_test))
    print(torch.isnan(to_test))