def for_batch(self, num_groups: int, num_timesteps: int, predictors: Tensor, allow_extra_timesteps: bool = True) -> 'LinearModel': for_batch = super().for_batch(num_groups=num_groups, num_timesteps=num_timesteps) self._validate_predictor_mat( num_groups=num_groups, num_timesteps=num_timesteps, predictor_mat=predictors, expected_num_predictors=len(self.state_elements), allow_extra_timesteps=allow_extra_timesteps) if predictors.shape[1] > num_timesteps: predictors = predictors[:, 0:num_timesteps, :] for measure in self.measures: for i, cov in enumerate(self.state_elements): for_batch._adjust_measure(measure=measure, state_element=cov, adjustment=split_flat( predictors[:, :, i], dim=1)) return for_batch
def for_batch(self, num_groups: int, num_timesteps: int, start_datetimes: Optional[np.ndarray] = None): for_batch = super().for_batch(num_groups=num_groups, num_timesteps=num_timesteps) # determine the delta (integer time accounting for different groups having different start datetimes) if start_datetimes is None: if self._dt_helper.start_datetime: raise TypeError("Missing argument `start_datetimes`.") start_datetimes = np.zeros(num_groups) delta = self._dt_helper.make_delta_grid(start_datetimes, num_timesteps) # determine season: season = delta % self.seasonal_period # generate the fourier tensor: fourier_tens = fourier_tensor(time=Tensor(season), seasonal_period=self.seasonal_period, K=self.K) for measure in self.measures: for state_element in self.state_elements: r, c = (int(x) for x in state_element.split(sep=",")) for_batch._adjust_measure(measure=measure, state_element=state_element, adjustment=split_flat( fourier_tens[:, :, r, c], dim=1)) return for_batch
def for_batch(self, num_groups: int, num_timesteps: int, start_datetimes: Optional[np.ndarray] = None): for_batch = super().for_batch(num_groups=num_groups, num_timesteps=num_timesteps) # determine the delta (integer time accounting for different groups having different start datetimes) if start_datetimes is None: start_datetimes = np.zeros(num_groups) delta = self._dt_helper.make_delta_grid(start_datetimes, num_timesteps) # determine season: season = delta % self.seasonal_period # generate the fourier tensor: fourier_tens = fourier_tensor(time=Tensor(season), seasonal_period=self.seasonal_period, K=self.K) for state_element in self.state_elements: if state_element == 'position': continue r, c = (int(x) for x in state_element.split(sep=",")) for_batch._adjust_transition(from_element=state_element, to_element='position', adjustment=split_flat( fourier_tens[:, :, r, c], dim=1)) return for_batch
def for_batch(self, num_groups: int, num_timesteps: int, start_datetimes: Optional[np.ndarray] = None): if start_datetimes is not None: if len(start_datetimes) != num_groups or len(start_datetimes.shape) != 1: raise ValueError(f"Expected `start_datetimes` to be 1D array of length {num_groups}.") for_batch = super().for_batch(num_groups=num_groups, num_timesteps=num_timesteps) if start_datetimes is None: if self._dt_helper.dt_unit: raise TypeError("Missing argument `start_datetimes`.") start_datetimes = np.zeros(num_groups) delta = self._dt_helper.make_delta_grid(start_datetimes, num_timesteps) in_transition = (delta % self.season_duration) == (self.season_duration - 1) transitions = { 'to_next_state': torch.from_numpy(in_transition.astype('float32')), 'from_measured_to_measured': torch.from_numpy(np.where(in_transition, -1., 1.).astype('float32')) } transitions['to_self'] = 1 - transitions['to_next_state'] transitions['to_measured'] = -transitions['to_next_state'] for k in transitions.keys(): transitions[k] = split_flat(transitions[k], dim=1, clone=True) if self.decay is not None: decay_value = self.decay.get_value() transitions[k] = [x * decay_value for x in transitions[k]] # this is convoluted, but the idea is to manipulate the transitions so that we use one less degree of freedom # than the number of seasons, by having the 'measured' state be equal to -sum(all others) for i in range(1, len(self.state_elements)): current = self.state_elements[i] prev = self.state_elements[i - 1] if prev == self.measured_name: # measured requires special-case to_measured = transitions['from_measured_to_measured'] else: to_measured = transitions['to_measured'] for_batch._adjust_transition(from_element=prev, to_element=current, adjustment=transitions['to_next_state']) for_batch._adjust_transition(from_element=prev, to_element=self.measured_name, adjustment=to_measured) # from state to itself: for_batch._adjust_transition(from_element=current, to_element=current, adjustment=transitions['to_self']) return for_batch