def __init__(self): super(SoftmaxHeteroscedasticLoss, self).__init__() min_variance = 1e-3 keep_variance_fn = lambda x: keep_variance(x, min_variance=min_variance) self.adf_softmax = adf.Softmax(dim=1, keep_variance_fn=keep_variance_fn)
def compute_preds(net, inputs, use_adf=False, use_mcdo=False): model_variance = None data_variance = None def keep_variance(x, min_variance): return x + min_variance keep_variance_fn = lambda x: keep_variance(x, min_variance=args.min_variance) softmax = nn.Softmax(dim=1) adf_softmax = adf.Softmax(dim=1, keep_variance_fn=keep_variance_fn) net.eval() if use_mcdo: net = set_training_mode_for_dropout(net, True) outputs = [net(inputs) for i in range(args.num_samples)] if use_adf: outputs = [adf_softmax(*outs) for outs in outputs] outputs_mean = [mean for (mean, var) in outputs] data_variance = [var for (mean, var) in outputs] data_variance = torch.stack(data_variance) data_variance = torch.mean(data_variance, dim=0) else: outputs_mean = [softmax(outs) for outs in outputs] outputs_mean = torch.stack(outputs_mean) model_variance = torch.var(outputs_mean, dim=0) # Compute MCDO prediction outputs_mean = torch.mean(outputs_mean, dim=0) else: outputs = net(inputs) if adf: outputs_mean, data_variance = adf_softmax(*outputs) else: outputs_mean = outputs net = set_training_mode_for_dropout(net, False) return outputs_mean, data_variance, model_variance