def forward(self, inputs, ret_logvar=False): # Transform inputs inputs = (inputs - self.inputs_mu) / self.inputs_sigma inputs = inputs.matmul(self.lin0_w) + self.lin0_b inputs = swish(inputs) inputs = inputs.matmul(self.lin1_w) + self.lin1_b inputs = swish(inputs) inputs = inputs.matmul(self.lin2_w) + self.lin2_b inputs = swish(inputs) inputs = inputs.matmul(self.lin3_w) + self.lin3_b mean = inputs[:, :, :self.out_features // 2] logvar = inputs[:, :, self.out_features // 2:] logvar = self.max_logvar - F.softplus(self.max_logvar - logvar) logvar = self.min_logvar + F.softplus(logvar - self.min_logvar) if ret_logvar: return mean, logvar return mean, torch.exp(logvar)
def forward(self, inputs, ret_logvar=False): # Transform inputs # NUM_NETS x BATCH_SIZE X INPUT_LENGTH # Normalizing inputs inputs = (inputs - self.inputs_mu) / self.inputs_sigma for i, layer in enumerate( zip(self.linear_layers[::2], self.linear_layers[1::2])): weight, bias = layer inputs = inputs.matmul(weight) + bias if i < self.num_layers - 1: inputs = swish(inputs) mean = inputs[:, :, :self.out_features // 2] logvar = inputs[:, :, self.out_features // 2:-1] logvar = self.max_logvar - F.softplus(self.max_logvar - logvar) logvar = self.min_logvar + F.softplus(logvar - self.min_logvar) catastrophe_pred = inputs[..., -1:] if ret_logvar: return mean, logvar, catastrophe_pred return mean, torch.exp(logvar), catastrophe_pred