def _default_summary_writer(self) -> SummaryWriter: """Return summary writer logging to method- and simulator-specific directory.""" method = self.__class__.__name__ logdir = Path( get_log_root(), method, datetime.now().isoformat().replace(":", "_"), ) return SummaryWriter(logdir)
def __init__( self, simulator: Callable, prior, true_observation: Tensor, simulation_batch_size: int = 1, device: Optional[torch.device] = None, summary_writer: Optional[SummaryWriter] = None, simulator_name: Optional[str] = "simulator", ): """ Args: simulator: a regular function parameter->result Both parameters and result can be multi-dimensional. prior: distribution-like object with `log_prob`and `sample` methods. true_observation: tensor containing the observation x_o. If it has more than one dimension, the leading dimension will be interpreted as a batch dimension but *currently* only the first batch element will be used to condition on. simulation_batch_size: number of parameter sets that the simulator accepts and converts to data x at once. If -1, we simulate all parameter sets at the same time. If >= 1, the simulator has to process data of shape (simulation_batch_size, parameter_dimension). summary_writer: an optional SummaryWriter to control, among others, log file location (default is <current working directory>/logs.) device: torch.device on which to compute (optional). mcmc_method: MCMC method to use for posterior sampling, one of ['slice', 'hmc', 'nuts']. """ self._simulator, self._prior, self._true_observation = prepare_sbi_problem( simulator, prior, true_observation) self._simulation_batch_size = simulation_batch_size self._device = get_default_device() if device is None else device # Initialize roundwise (parameter, observation) storage. self._parameter_bank, self._observation_bank = [], [] # XXX We could instantiate here the Posterior for all children. Two problems: # XXX 1. We must dispatch to right PotentialProvider for mcmc based on name # XXX 2. `alg_family` cannot be resolved only from `self.__class__.__name__`, # XXX since SRE, AALR demand different handling but are both in SRE class. if summary_writer is None: log_dir = os.path.join( get_log_root(), self.__class__.__name__, simulator_name, get_timestamp(), ) self._summary_writer = SummaryWriter(log_dir) else: self._summary_writer = summary_writer # Logging during training (by SummaryWriter). self._summary = dict( mmds=[], median_observation_distances=[], negative_log_probs_true_parameters=[], neural_net_fit_times=[], #XXX unused elsewhere epochs=[], best_validation_log_probs=[], )