def predict(encoded_text: torch.Tensor, model: nn.Module, k: int = 1, device: torch.device = "cpu") -> torch.Tensor: model.eval() (out) = model(encoded_text.to(device)) logits = out[0] # TODO why? logits = logits[:, -1] sample = Multinomial(k, logits=logits).sample() prediction = sample.argmax().reshape((encoded_text.shape[0], )) return prediction, out[1]
def update_environment(self, block, trial, responses): """Generate stimuli for the current block and trial and update the state """ # offers in the current trial offers = self.offers[block][trial] # selected arm types arm_types = self.arm_types[offers, responses] # each selected arm is associated with specific set of reward probabilities probs = self.states['probs'][block, trial, range(self.nsub), arm_types] out1 = Multinomial(probs=probs).sample() out = {'locations': responses, 'features': out1.argmax(-1)} out2 = self.update_states(block, trial + 1, responses=responses, outcomes=out1) return [responses, (out, out2)]