def collate(cls, batch: Sequence['TimeSeriesDataset']) -> 'TimeSeriesDataset': to_concat = { 'tensors': [batch[0].tensors], 'group_names': [batch[0].group_names], 'start_times': [batch[0].start_times] } fixed = {'dt_unit': batch[0].dt_unit, 'measures': batch[0].measures} for i, ts_dataset in enumerate(batch[1:], 1): for attr, appendlist in to_concat.items(): to_concat[attr].append(getattr(ts_dataset, attr)) for attr, required_val in fixed.items(): new_val = getattr(ts_dataset, attr) if new_val != required_val: raise ValueError(f"Element {i} has `{attr}` = {new_val}, but for element 0 it's {required_val}.") tensors = tuple(ragged_cat(t, ragged_dim=1) for t in zip(*to_concat['tensors'])) return cls( *tensors, group_names=np.concatenate(to_concat['group_names']), start_times=np.concatenate(to_concat['start_times']), measures=fixed['measures'], dt_unit=fixed['dt_unit'] )
def with_new_start_times(self, start_times: Union[np.ndarray, Sequence]) -> 'TimeSeriesDataset': """ Subset a TimeSeriesDataset so that some/all of the groups have later start times. :param start_times: An array/list of new datetimes. :return: A new TimeSeriesDataset. """ new_tensors = [] for i, tens in enumerate(self.tensors): times = self.times(i) new_tens = [] for g, (new_time, old_times) in enumerate(zip(start_times, times)): if (old_times <= new_time).all(): raise ValueError(f"{new_time} is later than all the times for group {self.group_names[g]}") elif (old_times > new_time).all(): raise ValueError(f"{new_time} is earlier than all the times for group {self.group_names[g]}") new_tens.append(tens[g, true1d_idx(old_times >= new_time), :].unsqueeze(0)) new_tens = ragged_cat(new_tens, ragged_dim=1, cat_dim=0) new_tensors.append(new_tens) return type(self)( *new_tensors, group_names=self.group_names, start_times=start_times, measures=self.measures, dt_unit=self.dt_unit )
def to_dataframe(self, group_colname: str = 'group', time_colname: str = 'time') -> 'DataFrame': return self.tensor_to_dataframe(tensor=ragged_cat(self.tensors, ragged_dim=1, cat_dim=2), times=self.times(), group_names=self.group_names, group_colname=group_colname, time_colname=time_colname, measures=self.all_measures)
def with_new_start_times( self, start_times: Union[np.ndarray, Sequence]) -> 'TimeSeriesDataset': """ Subset a TimeSeriesDataset so that some/all of the groups have later start times. :param start_times: An array/list of new datetimes. :return: A new TimeSeriesDataset. """ new_tensors = [] for i, tens in enumerate(self.tensors): times = self.times(i) new_tens = [] for g, (new_time, old_times) in enumerate(zip(start_times, times)): if (old_times <= new_time).all(): warn( f"{new_time} is later than all the times for group {self.group_names[g]}" ) new_tens.append(tens[[g], 0:0]) continue elif (old_times > new_time).all(): warn( f"{new_time} is earlier than all the times for group {self.group_names[g]}" ) new_tens.append(tens[[g], 0:0]) continue # drop if before new_time: g_tens = tens[g, true1d_idx(old_times >= new_time)] # drop if after last nan: all_nan, _ = torch.min(torch.isnan(g_tens), 1) if all_nan.all(): warn( f"Group '{self.group_names[g]}' (tensor {i}) has only `nans` after {new_time}" ) end_idx = 0 else: end_idx = true1d_idx(~all_nan).max() + 1 new_tens.append(g_tens[:end_idx].unsqueeze(0)) new_tens = ragged_cat(new_tens, ragged_dim=1, cat_dim=0) new_tensors.append(new_tens) return type(self)(*new_tensors, group_names=self.group_names, start_times=start_times, measures=self.measures, dt_unit=self.dt_unit)