Пример #1
0
def predict(encoded_text: torch.Tensor,
            model: nn.Module,
            k: int = 1,
            device: torch.device = "cpu") -> torch.Tensor:
    model.eval()

    (out) = model(encoded_text.to(device))

    logits = out[0]

    # TODO why?
    logits = logits[:, -1]
    sample = Multinomial(k, logits=logits).sample()
    prediction = sample.argmax().reshape((encoded_text.shape[0], ))

    return prediction, out[1]
Пример #2
0
    def update_environment(self, block, trial, responses):
        """Generate stimuli for the current block and trial and update the state
        """

        # offers in the current trial
        offers = self.offers[block][trial]

        # selected arm types
        arm_types = self.arm_types[offers, responses]

        # each selected arm is associated with specific set of reward probabilities
        probs = self.states['probs'][block, trial, range(self.nsub), arm_types]
        out1 = Multinomial(probs=probs).sample()

        out = {'locations': responses, 'features': out1.argmax(-1)}

        out2 = self.update_states(block,
                                  trial + 1,
                                  responses=responses,
                                  outcomes=out1)

        return [responses, (out, out2)]