def for_batch(self, num_groups: int, num_timesteps: int, start_datetimes: Optional[np.ndarray] = None): for_batch = super().for_batch(num_groups=num_groups, num_timesteps=num_timesteps) # determine the delta (integer time accounting for different groups having different start datetimes) delta = self._get_delta(for_batch.num_groups, for_batch.num_timesteps, start_datetimes=start_datetimes) # determine season: season = delta % self.seasonal_period # generate the fourier tensor: fourier_tens = fourier_tensor(time=Tensor(season), seasonal_period=self.seasonal_period, K=self.K) for state_element in self.state_elements: if state_element == 'position': continue r, c = (int(x) for x in state_element.split(sep=",")) for_batch._adjust_transition( from_element=state_element, to_element='position', adjustment=split_flat(fourier_tens[:, :, r, c], dim=1) ) return for_batch
def for_batch( self, num_groups: int, num_timesteps: int, start_datetimes: Optional[np.ndarray] = None) -> ProcessForBatch: for_batch = super().for_batch(num_groups=num_groups, num_timesteps=num_timesteps) delta = self.dt_tracker.get_delta(for_batch.num_groups, for_batch.num_timesteps, start_datetimes=start_datetimes) in_transition = (delta % self.season_duration) == (self.season_duration - 1) transitions = dict() transitions['to_next_state'] = torch.from_numpy( in_transition.astype('float32')) transitions['to_self'] = 1 - transitions['to_next_state'] transitions['to_measured'] = -transitions['to_next_state'] transitions['from_measured_to_measured'] = torch.from_numpy( np.where(in_transition, -1., 1.).astype('float32')) for k in transitions.keys(): transitions[k] = split_flat(transitions[k], dim=1, clone=True) if self.decay is not None: decay_value = self.decay.get_value() transitions[k] = [x * decay_value for x in transitions[k]] # this is convoluted, but the idea is to manipulate the transitions so that we use one less degree of freedom than # the number of seasons, by having the 'measured' state be equal to -sum(all others) for i in range(1, len(self.state_elements)): current = self.state_elements[i] prev = self.state_elements[i - 1] if prev == self.measured_name: # measured requires special-case to_measured = transitions['from_measured_to_measured'] else: to_measured = transitions['to_measured'] for_batch.adjust_transition( from_element=prev, to_element=current, adjustment=transitions['to_next_state']) for_batch.adjust_transition(from_element=prev, to_element=self.measured_name, adjustment=to_measured) # from state to itself: for_batch.adjust_transition(from_element=current, to_element=current, adjustment=transitions['to_self']) return for_batch
def for_batch(self, num_groups: int, num_timesteps: int, predictors: Tensor, allow_extra_timesteps: bool = False) -> 'LinearModel': for_batch = super().for_batch( num_groups, num_timesteps, expected_num_predictors=len(self.state_elements), allow_extra_timesteps=allow_extra_timesteps) if predictors.shape[1] > num_timesteps: predictors = predictors[:, 0:num_timesteps, :] for measure in self.measures: for i, cov in enumerate(self.state_elements): for_batch.adjust_measure(measure=measure, state_element=cov, adjustment=split_flat(predictors[:, :, i], dim=1)) return for_batch
def _validate_assignment( self, design_mat_assignment: DesignMatAdjustment, check_slow_grad: bool = True) -> DesignMatAdjustment: if isinstance(design_mat_assignment, Tensor): if list(design_mat_assignment.shape) == [ self.num_groups, self.num_timesteps ]: if design_mat_assignment.requires_grad: raise ValueError( "Cannot use group X time tensor as adjustment, unless it does not `require_grad`. " "To make adjustments that are group and time specific, pass a list of len(times), each " "containing a 1D tensor w/ len(groups) (or each containing a scalar tensor)." ) design_mat_assignment = split_flat(design_mat_assignment, dim=1, clone=True) else: self._check_tens(design_mat_assignment, in_list=False, check_slow_grad=check_slow_grad) if isinstance(design_mat_assignment, (list, tuple)): assert len(design_mat_assignment) == self.num_timesteps [ self._check_tens(tens, in_list=True, check_slow_grad=check_slow_grad) for tens in design_mat_assignment ] else: raise ValueError( "Expected `design_mat_assignment` be list/tuple or tensor") return design_mat_assignment