def basic_collate_fn(batch, time_steps, args=args, data_type='train'): batch = torch.stack(batch) data_dict = {'data': batch, 'time_steps': time_steps} data_dict = utils.split_and_subsample_batch(data_dict, args, data_type=data_type) return data_dict
def variable_time_collate_fn_activity(batch, args, device=torch.device("cpu"), data_type="train"): """ Expects a batch of time series data in the form of (record_id, tt, vals, mask, labels) where - record_id is a patient id - tt is a 1-dimensional tensor containing T time values of observations. - vals is a (T, D) tensor containing observed values for D variables. - mask is a (T, D) tensor containing 1 where values were observed and 0 otherwise. - labels is a list of labels for the current patient, if labels are available. Otherwise None. Returns: combined_tt: The union of all time observations. combined_vals: (M, T, D) tensor containing the observed values. combined_mask: (M, T, D) tensor containing 1 where values were observed and 0 otherwise. """ D = batch[0][2].shape[1] N = batch[0][-1].shape[1] # number of labels combined_tt, inverse_indices = torch.unique(torch.cat( [ex[1] for ex in batch]), sorted=True, return_inverse=True) combined_tt = combined_tt.to(device) offset = 0 combined_vals = torch.zeros([len(batch), len(combined_tt), D]).to(device) combined_mask = torch.zeros([len(batch), len(combined_tt), D]).to(device) combined_labels = torch.zeros([len(batch), len(combined_tt), N]).to(device) for b, (record_id, tt, vals, mask, labels) in enumerate(batch): tt = tt.to(device) vals = vals.to(device) mask = mask.to(device) labels = labels.to(device) indices = inverse_indices[offset:offset + len(tt)] offset += len(tt) combined_vals[b, indices] = vals combined_mask[b, indices] = mask combined_labels[b, indices] = labels combined_tt = combined_tt.float() if torch.max(combined_tt) != 0.0: combined_tt = combined_tt / torch.max(combined_tt) data_dict = { "data": combined_vals, "time_steps": combined_tt, "mask": combined_mask, "labels": combined_labels, } data_dict = utils.split_and_subsample_batch(data_dict, args, data_type=data_type) return data_dict
def basic_collate_fn(batch, time_steps, args = args, device = device, data_type = "train"): batch = torch.stack(batch) data_dict = { "data": batch, "time_steps": time_steps} data_dict = utils.split_and_subsample_batch(data_dict, args, data_type = data_type) return data_dict
def meld_collate_fn(batch, args, device=torch.device("cpu"), data_type="train"): """ Expects a batch of time series data in the form of (feature, time, label) where - feature is a (T, D) tensor containing observed values for D variables. - label is a list of labels for the current patient, if labels are available. Otherwise None. - time start is a 1-dimensional tensor containing T time values of observations. - time end is a 1-dimensional tensor containing T time values of observations. Returns: combined_feature: (M, T, D) tensor containing the observed values. combined_time: The union of all time observations. combined_mask: (M, T, D) tensor containing 1 where values were observed and 0 otherwise. """ D = batch[0][0].shape[1] N = 1 # number of labels combined_tt, inverse_indices = torch.unique(torch.cat( [ex[3] for ex in batch]), sorted=True, return_inverse=True) combined_tt = combined_tt.to(device) offset = 0 combined_vals = torch.zeros([len(batch), len(combined_tt), D + 1]).to(device) combined_mask = torch.zeros([len(batch), len(combined_tt), 1]).to(device) combined_labels = torch.zeros([len(batch), len(combined_tt), N ]).to(device) + 7 max_val = torch.max(combined_tt) for b, (feature, label, s_time, e_time) in enumerate(batch): s_time = s_time.to(device).reshape(-1, 1) / max_val e_time = e_time.to(device) feature = feature.to(device) #mask = torch.ones_like(feature).to(device) label = label.to(device) indices = inverse_indices[offset:offset + len(e_time)] offset += len(e_time) combined_vals[b, indices.squeeze()] = torch.cat((feature, s_time), 1) combined_mask[b, indices.squeeze()] = 1 #mask combined_labels[b, indices] = label.reshape(-1, 1).float() combined_tt = combined_tt.float() if torch.max(combined_tt) != 0.: combined_tt = combined_tt / torch.max(combined_tt) data_dict = { "data": combined_vals, "time_steps": combined_tt, "mask": combined_mask, "labels": combined_labels } data_dict = utils.split_and_subsample_batch(data_dict, args, data_type=data_type) return data_dict
def __getitem__(self, index): #should accept indices and should output the datasamples, as read from disk if isinstance(index, slice): # do your handling for a slice object: output = [] start = 0 if index.start is None else index.start step = 1 if index.start is None else index.step if self.list_form: #list format as the other datasets for i in range(start, index.stop, step): data = torch.from_numpy(self.hdf5dataloader["data"][i]) time_stamps = torch.from_numpy(self.timestamps) mask = torch.from_numpy(self.hdf5dataloader["mask"][i]) labels = torch.from_numpy(self.hdf5dataloader["labels"][i]) output.append((data, time_stamps, mask, labels)) return output else: #tensor_format (more efficient), #raise Exception('Tensorformat not implemented yet!') data = torch.from_numpy(self.hdf5dataloader["data"] [start:index.stop:step]).float().to( self.device) time_stamps = torch.from_numpy(self.timestamps).to(self.device) mask = torch.from_numpy(self.hdf5dataloader["mask"] [start:index.stop:step]).float().to( self.device) labels = torch.from_numpy(self.hdf5dataloader["labels"] [start:index.stop:step]).float().to( self.device) #make it a dictionary to replace the collate function.... data_dict = { "data": data[:, ::self.step, :self.feature_trunc], "time_steps": time_stamps[::self.step], "mask": mask[:, ::self.step, :self.feature_trunc], "labels": labels } data_dict = utils.split_and_subsample_batch( data_dict, self.args, data_type=self.mode) return data_dict #return (data, time_stamps, mask, labels) else: # Do your handling for a plain index if self.second: raise Exception('Tensorformat not implemented yet!') self.second = True if self.list_form: data = torch.from_numpy(self.hdf5dataloader["data"][index]) time_stamps = torch.from_numpy(self.timestamps) mask = torch.from_numpy(self.hdf5dataloader["mask"][index]) labels = torch.from_numpy(self.hdf5dataloader["labels"][index]) return (data, time_stamps, mask, labels) else: data = torch.from_numpy( self.hdf5dataloader["data"][index]).float().to(self.device) time_stamps = torch.from_numpy(self.timestamps).to(self.device) mask = torch.from_numpy( self.hdf5dataloader["mask"][index]).float().to(self.device) labels = torch.from_numpy( self.hdf5dataloader["labels"][index]).float().to( self.device) data_dict = { "data": data, "time_steps": time_stamps, "mask": mask, "labels": labels } data_dict = utils.split_and_subsample_batch( data_dict, self.args, data_type=self.mode) return data_dict
def variable_time_collate_fn_crop(batch, args, device=torch.device("cpu"), data_type="train", data_min=None, data_max=None, list_form=True): """ Returns: combined_tt: The union of all time observations. combined_vals: (M, T, D) tensor containing the observed values. combined_mask: (M, T, D) tensor containing 1 where values were observed and 0 otherwise. """ if list_form: #list format as the other datasets data, tt, mask, labels = batch[0] nfeatures = data.shape[1] N_labels = labels.shape[0] combined_vals = torch.zeros([len(batch), len(tt), nfeatures]).to(device) combined_mask = torch.zeros([len(batch), len(tt), nfeatures]).to(device) combined_labels = (torch.zeros([len(batch), N_labels]) + torch.tensor(float('nan'))).to(device) #combined_labels = (torch.zeros(len(batch), N_labels) + torch.tensor(float('nan'))).to(device = device) for b, (data, tt, mask, labels) in enumerate(batch): tt = tt.to(device) data = data.to(device) mask = mask.to(device) labels = labels.to(device) combined_vals[b] = data combined_mask[b] = mask combined_labels[b] = labels combined_tt = tt else: #tensor_format (more efficient), must agree with the __getitem__ function # Tensorformat data, tt, mask, labels = batch combined_tt = tt combined_vals = data combined_mask = mask combined_labels = labels #combined_vals, _, _ = utils.normalize_masked_data(combined_vals, combined_mask, # att_min = data_min, att_max = data_max) data_dict = { "data": combined_vals, "time_steps": combined_tt, "mask": combined_mask, "labels": combined_labels } data_dict = utils.split_and_subsample_batch(data_dict, args, data_type=data_type) return data_dict