Esempio n. 1
0
 def basic_collate_fn(batch, time_steps, args=args, data_type='train'):
     batch = torch.stack(batch)
     data_dict = {'data': batch, 'time_steps': time_steps}
     data_dict = utils.split_and_subsample_batch(data_dict,
                                                 args,
                                                 data_type=data_type)
     return data_dict
Esempio n. 2
0
def variable_time_collate_fn_activity(batch,
                                      args,
                                      device=torch.device("cpu"),
                                      data_type="train"):
    """
	Expects a batch of time series data in the form of (record_id, tt, vals, mask, labels) where
		- record_id is a patient id
		- tt is a 1-dimensional tensor containing T time values of observations.
		- vals is a (T, D) tensor containing observed values for D variables.
		- mask is a (T, D) tensor containing 1 where values were observed and 0 otherwise.
		- labels is a list of labels for the current patient, if labels are available. Otherwise None.
	Returns:
		combined_tt: The union of all time observations.
		combined_vals: (M, T, D) tensor containing the observed values.
		combined_mask: (M, T, D) tensor containing 1 where values were observed and 0 otherwise.
	"""
    D = batch[0][2].shape[1]
    N = batch[0][-1].shape[1]  # number of labels

    combined_tt, inverse_indices = torch.unique(torch.cat(
        [ex[1] for ex in batch]),
                                                sorted=True,
                                                return_inverse=True)
    combined_tt = combined_tt.to(device)

    offset = 0
    combined_vals = torch.zeros([len(batch), len(combined_tt), D]).to(device)
    combined_mask = torch.zeros([len(batch), len(combined_tt), D]).to(device)
    combined_labels = torch.zeros([len(batch), len(combined_tt), N]).to(device)

    for b, (record_id, tt, vals, mask, labels) in enumerate(batch):
        tt = tt.to(device)
        vals = vals.to(device)
        mask = mask.to(device)
        labels = labels.to(device)

        indices = inverse_indices[offset:offset + len(tt)]
        offset += len(tt)

        combined_vals[b, indices] = vals
        combined_mask[b, indices] = mask
        combined_labels[b, indices] = labels

    combined_tt = combined_tt.float()

    if torch.max(combined_tt) != 0.0:
        combined_tt = combined_tt / torch.max(combined_tt)

    data_dict = {
        "data": combined_vals,
        "time_steps": combined_tt,
        "mask": combined_mask,
        "labels": combined_labels,
    }

    data_dict = utils.split_and_subsample_batch(data_dict,
                                                args,
                                                data_type=data_type)
    return data_dict
    def basic_collate_fn(batch, time_steps, args = args, device = device, data_type = "train"):
        batch = torch.stack(batch)
        data_dict = {
            "data": batch, 
            "time_steps": time_steps}

        data_dict = utils.split_and_subsample_batch(data_dict, args, data_type = data_type)
        return data_dict
Esempio n. 4
0
def meld_collate_fn(batch,
                    args,
                    device=torch.device("cpu"),
                    data_type="train"):
    """
    Expects a batch of time series data in the form of (feature, time, label) where
        - feature is a (T, D) tensor containing observed values for D variables.
        - label is a list of labels for the current patient, if labels are available. Otherwise None.
        - time start is a 1-dimensional tensor containing T time values of observations.
        - time end is a 1-dimensional tensor containing T time values of observations.
    Returns:
        combined_feature: (M, T, D) tensor containing the observed values.
        combined_time: The union of all time observations.
        combined_mask: (M, T, D) tensor containing 1 where values were observed and 0 otherwise.
    """
    D = batch[0][0].shape[1]
    N = 1  # number of labels

    combined_tt, inverse_indices = torch.unique(torch.cat(
        [ex[3] for ex in batch]),
                                                sorted=True,
                                                return_inverse=True)
    combined_tt = combined_tt.to(device)

    offset = 0
    combined_vals = torch.zeros([len(batch),
                                 len(combined_tt), D + 1]).to(device)
    combined_mask = torch.zeros([len(batch), len(combined_tt), 1]).to(device)
    combined_labels = torch.zeros([len(batch), len(combined_tt), N
                                   ]).to(device) + 7

    max_val = torch.max(combined_tt)
    for b, (feature, label, s_time, e_time) in enumerate(batch):
        s_time = s_time.to(device).reshape(-1, 1) / max_val
        e_time = e_time.to(device)
        feature = feature.to(device)
        #mask = torch.ones_like(feature).to(device)
        label = label.to(device)

        indices = inverse_indices[offset:offset + len(e_time)]
        offset += len(e_time)
        combined_vals[b, indices.squeeze()] = torch.cat((feature, s_time), 1)
        combined_mask[b, indices.squeeze()] = 1  #mask
        combined_labels[b, indices] = label.reshape(-1, 1).float()

    combined_tt = combined_tt.float()

    if torch.max(combined_tt) != 0.:
        combined_tt = combined_tt / torch.max(combined_tt)

    data_dict = {
        "data": combined_vals,
        "time_steps": combined_tt,
        "mask": combined_mask,
        "labels": combined_labels
    }

    data_dict = utils.split_and_subsample_batch(data_dict,
                                                args,
                                                data_type=data_type)
    return data_dict
Esempio n. 5
0
    def __getitem__(self, index):

        #should accept indices and should output the datasamples, as read from disk
        if isinstance(index, slice):
            # do your handling for a slice object:
            output = []
            start = 0 if index.start is None else index.start
            step = 1 if index.start is None else index.step

            if self.list_form:  #list format as the other datasets
                for i in range(start, index.stop, step):
                    data = torch.from_numpy(self.hdf5dataloader["data"][i])
                    time_stamps = torch.from_numpy(self.timestamps)
                    mask = torch.from_numpy(self.hdf5dataloader["mask"][i])
                    labels = torch.from_numpy(self.hdf5dataloader["labels"][i])
                    output.append((data, time_stamps, mask, labels))
                return output

            else:  #tensor_format (more efficient),
                #raise Exception('Tensorformat not implemented yet!')

                data = torch.from_numpy(self.hdf5dataloader["data"]
                                        [start:index.stop:step]).float().to(
                                            self.device)
                time_stamps = torch.from_numpy(self.timestamps).to(self.device)
                mask = torch.from_numpy(self.hdf5dataloader["mask"]
                                        [start:index.stop:step]).float().to(
                                            self.device)
                labels = torch.from_numpy(self.hdf5dataloader["labels"]
                                          [start:index.stop:step]).float().to(
                                              self.device)

                #make it a dictionary to replace the collate function....
                data_dict = {
                    "data": data[:, ::self.step, :self.feature_trunc],
                    "time_steps": time_stamps[::self.step],
                    "mask": mask[:, ::self.step, :self.feature_trunc],
                    "labels": labels
                }

                data_dict = utils.split_and_subsample_batch(
                    data_dict, self.args, data_type=self.mode)

                return data_dict
                #return (data, time_stamps, mask, labels)
        else:
            # Do your handling for a plain index

            if self.second:
                raise Exception('Tensorformat not implemented yet!')
                self.second = True

            if self.list_form:
                data = torch.from_numpy(self.hdf5dataloader["data"][index])
                time_stamps = torch.from_numpy(self.timestamps)
                mask = torch.from_numpy(self.hdf5dataloader["mask"][index])
                labels = torch.from_numpy(self.hdf5dataloader["labels"][index])
                return (data, time_stamps, mask, labels)
            else:
                data = torch.from_numpy(
                    self.hdf5dataloader["data"][index]).float().to(self.device)
                time_stamps = torch.from_numpy(self.timestamps).to(self.device)
                mask = torch.from_numpy(
                    self.hdf5dataloader["mask"][index]).float().to(self.device)
                labels = torch.from_numpy(
                    self.hdf5dataloader["labels"][index]).float().to(
                        self.device)

                data_dict = {
                    "data": data,
                    "time_steps": time_stamps,
                    "mask": mask,
                    "labels": labels
                }

                data_dict = utils.split_and_subsample_batch(
                    data_dict, self.args, data_type=self.mode)

                return data_dict
Esempio n. 6
0
def variable_time_collate_fn_crop(batch,
                                  args,
                                  device=torch.device("cpu"),
                                  data_type="train",
                                  data_min=None,
                                  data_max=None,
                                  list_form=True):
    """
	Returns:
		combined_tt: The union of all time observations.
		combined_vals: (M, T, D) tensor containing the observed values.
		combined_mask: (M, T, D) tensor containing 1 where values were observed and 0 otherwise.
	"""

    if list_form:  #list format as the other datasets

        data, tt, mask, labels = batch[0]
        nfeatures = data.shape[1]
        N_labels = labels.shape[0]

        combined_vals = torch.zeros([len(batch),
                                     len(tt), nfeatures]).to(device)
        combined_mask = torch.zeros([len(batch),
                                     len(tt), nfeatures]).to(device)

        combined_labels = (torch.zeros([len(batch), N_labels]) +
                           torch.tensor(float('nan'))).to(device)
        #combined_labels = (torch.zeros(len(batch), N_labels) + torch.tensor(float('nan'))).to(device = device)

        for b, (data, tt, mask, labels) in enumerate(batch):
            tt = tt.to(device)
            data = data.to(device)
            mask = mask.to(device)
            labels = labels.to(device)

            combined_vals[b] = data
            combined_mask[b] = mask

            combined_labels[b] = labels
        combined_tt = tt

    else:  #tensor_format (more efficient), must agree with the __getitem__ function
        # Tensorformat
        data, tt, mask, labels = batch

        combined_tt = tt
        combined_vals = data
        combined_mask = mask
        combined_labels = labels

    #combined_vals, _, _ = utils.normalize_masked_data(combined_vals, combined_mask,
    #		att_min = data_min, att_max = data_max)

    data_dict = {
        "data": combined_vals,
        "time_steps": combined_tt,
        "mask": combined_mask,
        "labels": combined_labels
    }

    data_dict = utils.split_and_subsample_batch(data_dict,
                                                args,
                                                data_type=data_type)

    return data_dict