def __init__(self, inputs: List[TensorType], model: TorchModelV2): # If inputs are not a torch Tensor, make them one and make sure they # are on the correct device. if not isinstance(inputs, torch.Tensor): inputs = torch.from_numpy(inputs) if isinstance(model, TorchModelV2): inputs = inputs.to(next(model.parameters()).device) super().__init__(inputs, model) # Store the last sample here. self.last_sample = None
def __init__(self, inputs: List[TensorType], model: TorchModelV2, low: float = -1.0, high: float = 1.0): """Parameterizes the distribution via `inputs`. Args: low (float): The lowest possible sampling value (excluding this value). high (float): The highest possible sampling value (excluding this value). """ super().__init__(inputs, model) assert low < high # Make sure high and low are torch tensors. self.low = torch.from_numpy(np.array(low)) self.high = torch.from_numpy(np.array(high)) # Place on correct device. if isinstance(model, TorchModelV2): device = next(model.parameters()).device self.low = self.low.to(device) self.high = self.high.to(device) mean, log_std = torch.chunk(self.inputs, 2, dim=-1) self._num_vars = mean.shape[1] assert log_std.shape[1] == self._num_vars # Clip `std` values (coming from NN) to reasonable values. self.log_std = torch.clamp(log_std, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT) # Clip loc too, for numerical stability reasons. mean = torch.clamp(mean, -3, 3) std = torch.exp(self.log_std) self.distr = torch.distributions.normal.Normal(mean, std) assert len(self.distr.loc.shape) == 2 assert len(self.distr.scale.shape) == 2