Beispiel #1
0
 def _apply_link(self, parents_values):
     number_samples, number_datapoints = get_number_samples_and_datapoints(parents_values)
     cont_values, discrete_values = split_dict(parents_values,
                                               condition=lambda key, val: not is_discrete(val) or contains_tensors(val))
     reshaped_dict = discrete_values
     if cont_values:
         reshaped_dict.update(map_iterable(lambda x: broadcast_and_reshape_parent_value(x, number_samples, number_datapoints),
                                           cont_values, recursive=True))
     reshaped_output = self.link(reshaped_dict)
     cast_to_new_shape = lambda tensor: tensor.view(size=(number_samples, number_datapoints) + tensor.shape[1:])
     output = {key: cast_to_new_shape(val)
               if is_tensor(val) else map_iterable(cast_to_new_shape, val) if contains_tensors(val) else val
               for key, val in reshaped_output.items()}
     return output
 def _preprocess_parameters_for_sampling(self, **parameters):
     number_samples, number_datapoints = get_number_samples_and_datapoints(
         parameters)
     parameters = map_iterable(
         lambda x: broadcast_and_reshape_parent_value(
             x, number_samples, number_datapoints), parameters)
     reshaped_parameters, tensor_shape = self._preproces_vector_input(
         parameters, self.vector_parameters)
     shape = tuple([number_samples, number_datapoints] + tensor_shape)
     return reshaped_parameters, shape
Beispiel #3
0
 def _preprocess_parameters_for_log_prob(self, x, **parameters):
     parameters_and_data = parameters
     parameters_and_data.update({"x_data": x})
     number_samples, number_datapoints = get_number_samples_and_datapoints(parameters_and_data)
     parameters_and_data = map_iterable(lambda y: broadcast_and_reshape_parent_value(y, number_samples, number_datapoints), parameters_and_data)
     vector_names = self.vector_parameters
     vector_names.add("x_data")
     reshaped_parameters_and_data, _ = self._preproces_vector_input(parameters_and_data, vector_names)
     x = reshaped_parameters_and_data.pop("x_data")
     return x, reshaped_parameters_and_data, number_samples, number_datapoints
def reformat_value(value, index):
    if is_tensor(value):
        if np.prod(value[index, :].shape) == 1:
            return float(value[index, :].cpu().detach().numpy())
        elif value.shape[1] == 1:
            return value[index, :].cpu().detach().numpy()[0, :]
        else:
            return value.cpu().detach().numpy()
    elif isinstance(value, Iterable):
        return map_iterable(lambda x: reformat_value(x, index), value)
    else:
        return value