Exemplo n.º 1
0
    def collate(cls, batch: Sequence['TimeSeriesDataset']) -> 'TimeSeriesDataset':

        to_concat = {
            'tensors': [batch[0].tensors],
            'group_names': [batch[0].group_names],
            'start_times': [batch[0].start_times]
        }
        fixed = {'dt_unit': batch[0].dt_unit, 'measures': batch[0].measures}
        for i, ts_dataset in enumerate(batch[1:], 1):
            for attr, appendlist in to_concat.items():
                to_concat[attr].append(getattr(ts_dataset, attr))
            for attr, required_val in fixed.items():
                new_val = getattr(ts_dataset, attr)
                if new_val != required_val:
                    raise ValueError(f"Element {i} has `{attr}` = {new_val}, but for element 0 it's {required_val}.")

        tensors = tuple(ragged_cat(t, ragged_dim=1) for t in zip(*to_concat['tensors']))

        return cls(
            *tensors,
            group_names=np.concatenate(to_concat['group_names']),
            start_times=np.concatenate(to_concat['start_times']),
            measures=fixed['measures'],
            dt_unit=fixed['dt_unit']
        )
Exemplo n.º 2
0
    def with_new_start_times(self, start_times: Union[np.ndarray, Sequence]) -> 'TimeSeriesDataset':
        """
        Subset a TimeSeriesDataset so that some/all of the groups have later start times.

        :param start_times: An array/list of new datetimes.
        :return: A new TimeSeriesDataset.
        """
        new_tensors = []
        for i, tens in enumerate(self.tensors):
            times = self.times(i)
            new_tens = []
            for g, (new_time, old_times) in enumerate(zip(start_times, times)):
                if (old_times <= new_time).all():
                    raise ValueError(f"{new_time} is later than all the times for group {self.group_names[g]}")
                elif (old_times > new_time).all():
                    raise ValueError(f"{new_time} is earlier than all the times for group {self.group_names[g]}")
                new_tens.append(tens[g, true1d_idx(old_times >= new_time), :].unsqueeze(0))
            new_tens = ragged_cat(new_tens, ragged_dim=1, cat_dim=0)
            new_tensors.append(new_tens)
        return type(self)(
            *new_tensors,
            group_names=self.group_names,
            start_times=start_times,
            measures=self.measures,
            dt_unit=self.dt_unit
        )
Exemplo n.º 3
0
    def to_dataframe(self,
                     group_colname: str = 'group',
                     time_colname: str = 'time') -> 'DataFrame':

        return self.tensor_to_dataframe(tensor=ragged_cat(self.tensors,
                                                          ragged_dim=1,
                                                          cat_dim=2),
                                        times=self.times(),
                                        group_names=self.group_names,
                                        group_colname=group_colname,
                                        time_colname=time_colname,
                                        measures=self.all_measures)
Exemplo n.º 4
0
    def with_new_start_times(
            self, start_times: Union[np.ndarray,
                                     Sequence]) -> 'TimeSeriesDataset':
        """
        Subset a TimeSeriesDataset so that some/all of the groups have later start times.

        :param start_times: An array/list of new datetimes.
        :return: A new TimeSeriesDataset.
        """
        new_tensors = []
        for i, tens in enumerate(self.tensors):
            times = self.times(i)
            new_tens = []
            for g, (new_time, old_times) in enumerate(zip(start_times, times)):
                if (old_times <= new_time).all():
                    warn(
                        f"{new_time} is later than all the times for group {self.group_names[g]}"
                    )
                    new_tens.append(tens[[g], 0:0])
                    continue
                elif (old_times > new_time).all():
                    warn(
                        f"{new_time} is earlier than all the times for group {self.group_names[g]}"
                    )
                    new_tens.append(tens[[g], 0:0])
                    continue
                # drop if before new_time:
                g_tens = tens[g, true1d_idx(old_times >= new_time)]
                # drop if after last nan:
                all_nan, _ = torch.min(torch.isnan(g_tens), 1)
                if all_nan.all():
                    warn(
                        f"Group '{self.group_names[g]}' (tensor {i}) has only `nans` after {new_time}"
                    )
                    end_idx = 0
                else:
                    end_idx = true1d_idx(~all_nan).max() + 1
                new_tens.append(g_tens[:end_idx].unsqueeze(0))
            new_tens = ragged_cat(new_tens, ragged_dim=1, cat_dim=0)
            new_tensors.append(new_tens)
        return type(self)(*new_tensors,
                          group_names=self.group_names,
                          start_times=start_times,
                          measures=self.measures,
                          dt_unit=self.dt_unit)