Пример #1
0
    def load_data(self,sample_percent,batch_size,data_type="train"):
        self.batch_size = batch_size
        self.sample_percent = sample_percent
        if data_type == "train":
            cut_num = 20000
        else:
            cut_num = 5000


        # Loading Data
        loc = np.load(self.dataset_path + '/loc_' + data_type + self.suffix + '.npy', allow_pickle=True)[:cut_num]
        vel = np.load(self.dataset_path + '/vel_' + data_type + self.suffix + '.npy', allow_pickle=True)[:cut_num]
        edges = np.load(self.dataset_path + '/edges_' + data_type + self.suffix + '.npy', allow_pickle=True)[:cut_num]  # [500,5,5]
        times = np.load(self.dataset_path + '/times_' + data_type + self.suffix + '.npy', allow_pickle=True)[:cut_num]  # 【500,5]

        self.num_graph = loc.shape[0]
        self.num_atoms = loc.shape[1]
        self.feature = loc[0][0][0].shape[0] + vel[0][0][0].shape[0]
        print("number graph in   "+data_type+"   is %d" % self.num_graph)
        print("number atoms in   " + data_type + "   is %d" % self.num_atoms)

        if self.suffix == "_springs5" or self.suffix == "_charged5":
            # Normalize features to [-1, 1], across test and train dataset

            if self.max_loc == None:
                loc, max_loc, min_loc = self.normalize_features(loc,
                                                                self.num_atoms)  # [num_sims,num_atoms, (timestamps,2)]
                vel, max_vel, min_vel = self.normalize_features(vel, self.num_atoms)
                self.max_loc = max_loc
                self.min_loc = min_loc
                self.max_vel = max_vel
                self.min_vel = min_vel
            else:
                loc = (loc - self.min_loc) * 2 / (self.max_loc - self.min_loc) - 1
                vel = (vel - self.min_vel) * 2 / (self.max_vel - self.min_vel) - 1

        else:
            self.timelength = 49



        # split data w.r.t interp and extrap, also normalize times
        if self.mode=="interp":
            loc_en,vel_en,times_en = self.interp_extrap(loc,vel,times,self.mode,data_type)
            loc_de = loc_en
            vel_de = vel_en
            times_de = times_en
        elif self.mode == "extrap":
            loc_en,vel_en,times_en,loc_de,vel_de,times_de = self.interp_extrap(loc,vel,times,self.mode,data_type)

        #Encoder dataloader
        series_list_observed, loc_observed, vel_observed, times_observed = self.split_data(loc_en, vel_en, times_en)
        if self.mode == "interp":
            time_begin = 0
        else:
            time_begin = 1
        encoder_data_loader, graph_data_loader = self.transfer_data(loc_observed, vel_observed, edges,
                                                                    times_observed, time_begin=time_begin)


        # Graph Dataloader --USING NRI
        edges = np.reshape(edges, [-1, self.num_atoms ** 2])
        edges = np.array((edges + 1) / 2, dtype=np.int64)
        edges = torch.LongTensor(edges)
        # Exclude self edges
        off_diag_idx = np.ravel_multi_index(
            np.where(np.ones((self.num_atoms, self.num_atoms)) - np.eye(self.num_atoms)),
            [self.num_atoms, self.num_atoms])

        edges = edges[:, off_diag_idx]
        graph_data_loader = Loader(edges, batch_size=self.batch_size)


        # Decoder Dataloader
        if self.mode=="interp":
            series_list_de = series_list_observed
        elif self.mode == "extrap":
            series_list_de = self.decoder_data(loc_de,vel_de,times_de)
        decoder_data_loader = Loader(series_list_de, batch_size=self.batch_size * self.num_atoms, shuffle=False,
                                     collate_fn=lambda batch: self.variable_time_collate_fn_activity(
                                         batch))  # num_graph*num_ball [tt,vals,masks]


        num_batch = len(decoder_data_loader)
        encoder_data_loader = utils.inf_generator(encoder_data_loader)
        graph_data_loader = utils.inf_generator(graph_data_loader)
        decoder_data_loader = utils.inf_generator(decoder_data_loader)

        return encoder_data_loader, decoder_data_loader, graph_data_loader, num_batch
def parse_datasets(args, device, test_batch_size = 50):
    

    def basic_collate_fn(batch, time_steps, args = args, device = device, data_type = "train"):
        batch = torch.stack(batch)
        data_dict = {
            "data": batch, 
            "time_steps": time_steps}

        data_dict = utils.split_and_subsample_batch(data_dict, args, data_type = data_type)
        return data_dict


    dataset_name = args.dataset

    n_total_tp = args.timepoints + args.extrap
    max_t_extrap = args.max_t / args.timepoints * n_total_tp

    ##################################################################
    # MuJoCo dataset
    if dataset_name == "hopper":
        dataset_obj = HopperPhysics(root='data', download=True, generate=False, device = device)
        dataset = dataset_obj.get_dataset()[:args.n]
        dataset = dataset.to(device)


        n_tp_data = dataset[:].shape[1]

        # Time steps that are used later on for exrapolation
        time_steps = torch.arange(start=0, end = n_tp_data, step=1).float().to(device)
        time_steps = time_steps / len(time_steps)

        dataset = dataset.to(device)
        time_steps = time_steps.to(device)

        if not args.extrap:
            # Creating dataset for interpolation
            # sample time points from different parts of the timeline, 
            # so that the model learns from different parts of hopper trajectory
            n_traj = len(dataset)
            n_tp_data = dataset.shape[1]
            n_reduced_tp = args.timepoints

            # sample time points from different parts of the timeline, 
            # so that the model learns from different parts of hopper trajectory
            start_ind = np.random.randint(0, high=n_tp_data - n_reduced_tp +1, size=n_traj)
            end_ind = start_ind + n_reduced_tp
            sliced = []
            for i in range(n_traj):
                  sliced.append(dataset[i, start_ind[i] : end_ind[i], :])
            dataset = torch.stack(sliced).to(device)
            time_steps = time_steps[:n_reduced_tp]

        # Split into train and test by the time sequences
        train_y, test_y = utils.split_train_test(dataset, train_fraq = 0.8)

        n_samples = len(dataset)
        input_dim = dataset.size(-1)

        batch_size = min(args.batch_size, args.n)
        train_dataloader = DataLoader(train_y, batch_size = batch_size, shuffle=False,
            collate_fn= lambda batch: basic_collate_fn(batch, time_steps, data_type = "train"))
        test_dataloader = DataLoader(test_y, batch_size = n_samples, shuffle=False,
            collate_fn= lambda batch: basic_collate_fn(batch, time_steps, data_type = "test"))
        
        data_objects = {"dataset_obj": dataset_obj, 
                    "train_dataloader": utils.inf_generator(train_dataloader), 
                    "test_dataloader": utils.inf_generator(test_dataloader),
                    "input_dim": input_dim,
                    "n_train_batches": len(train_dataloader),
                    "n_test_batches": len(test_dataloader)}
        return data_objects

    ##################################################################
    # Physionet dataset

    if dataset_name == "physionet":
        train_dataset_obj = PhysioNet('data/physionet', train=True, 
                                        quantization = args.quantization,
                                        download=True, n_samples = min(10000, args.n), 
                                        device = device)
        # Use custom collate_fn to combine samples with arbitrary time observations.
        # Returns the dataset along with mask and time steps
        test_dataset_obj = PhysioNet('data/physionet', train=False, 
                                        quantization = args.quantization,
                                        download=True, n_samples = min(10000, args.n), 
                                        device = device)

        # Combine and shuffle samples from physionet Train and physionet Test
        total_dataset = train_dataset_obj[:len(train_dataset_obj)]

        if not args.classif:
            # Concatenate samples from original Train and Test sets
            # Only 'training' physionet samples are have labels. Therefore, if we do classifiction task, we don't need physionet 'test' samples.
            total_dataset = total_dataset + test_dataset_obj[:len(test_dataset_obj)]

        # Shuffle and split
        train_data, test_data = model_selection.train_test_split(total_dataset, train_size= 0.8, 
            random_state = 42, shuffle = True)

        record_id, tt, vals, mask, labels = train_data[0]

        n_samples = len(total_dataset)
        input_dim = vals.size(-1)

        batch_size = min(min(len(train_dataset_obj), args.batch_size), args.n)
        data_min, data_max = get_data_min_max(total_dataset)

        train_dataloader = DataLoader(train_data, batch_size= batch_size, shuffle=False, 
            collate_fn= lambda batch: physionet_collate_fn(batch, args, device, data_type = "train",
                data_min = data_min, data_max = data_max))
        test_dataloader = DataLoader(test_data, batch_size = 100, shuffle=False, 
            collate_fn= lambda batch: physionet_collate_fn(batch, args, device, data_type = "test",
                data_min = data_min, data_max = data_max))

        attr_names = train_dataset_obj.params
        data_objects = {"dataset_obj": train_dataset_obj, 
                    "train_dataloader": utils.inf_generator(train_dataloader), 
                    "test_dataloader": utils.inf_generator(test_dataloader),
                    "input_dim": input_dim,
                    "n_train_batches": len(train_dataloader),
                    "n_test_batches": len(test_dataloader),
                    "attr": attr_names, #optional
                    "classif_per_tp": False, #optional
                    "n_labels": 1} #optional
        return data_objects

    ##################################################################
    # Human activity dataset

    if dataset_name == "activity":
        n_samples =  min(10000, args.n)
        dataset_obj = PersonActivity('data/PersonActivity', 
                            download=True, n_samples =  n_samples, device = device)
        print(dataset_obj)
        # Use custom collate_fn to combine samples with arbitrary time observations.
        # Returns the dataset along with mask and time steps

        # Shuffle and split
        train_data, test_data = model_selection.train_test_split(dataset_obj, train_size= 0.8, 
            random_state = 42, shuffle = True)

        train_data = [train_data[i] for i in np.random.choice(len(train_data), len(train_data))]
        test_data = [test_data[i] for i in np.random.choice(len(test_data), len(test_data))]

        record_id, tt, vals, mask, labels = train_data[0]
        input_dim = vals.size(-1)

        batch_size = min(min(len(dataset_obj), args.batch_size), args.n)
        train_dataloader = DataLoader(train_data, batch_size= batch_size, shuffle=False, 
            collate_fn= lambda batch: activity_collate_fn(batch, args, device, data_type = "train"))
        test_dataloader = DataLoader(test_data, batch_size=n_samples, shuffle=False, 
            collate_fn= lambda batch: activity_collate_fn(batch, args, device, data_type = "test"))

        data_objects = {"dataset_obj": dataset_obj, 
                    "train_dataloader": utils.inf_generator(train_dataloader), 
                    "test_dataloader": utils.inf_generator(test_dataloader),
                    "input_dim": input_dim,
                    "n_train_batches": len(train_dataloader),
                    "n_test_batches": len(test_dataloader),
                    "classif_per_tp": True, #optional
                    "n_labels": labels.size(-1)}

        return data_objects
        
    ##################################################################
    # MELD dataset

    if dataset_name == "meld":
        n_samples =  min(10000, args.n)
        dataset_obj = MELD('data', 
                            download=True, n_samples =  n_samples, device = device)
        print(dataset_obj)
        # Use custom collate_fn to combine samples with arbitrary time observations.
        # Returns the dataset along with mask and time steps

        # Shuffle and split
        #train_data, test_data = model_selection.train_test_split(dataset_obj, train_size= 0.8, 
        #    random_state = 42, shuffle = True)
        train_data = dataset_obj.train_data
        #dev_data = dataset_obj.dev_data
        test_data = dataset_obj.test_data
        
        train_data = [train_data[i] for i in np.random.choice(len(train_data), len(train_data), replace=False)]
        #dev_data = [dev_data[i] for i in np.random.choice(len(dev_data), len(dev_data))]
        test_data = [test_data[i] for i in np.random.choice(len(test_data), len(test_data), replace=False)]

        vals, _, _, _  = train_data[0]
        input_dim = vals.size(-1)

        batch_size = min(min(len(dataset_obj), args.batch_size), args.n)
        train_dataloader = DataLoader(train_data, batch_size= batch_size, shuffle=False, 
            collate_fn= lambda batch: meld_collate_fn(batch, args, device, data_type = "train"))
        #dev_dataloader = DataLoader(dev_data, batch_size= n_samples, shuffle=False, 
        #    collate_fn= lambda batch: meld_collate_fn(batch, args, device, data_type = "test"))
        test_dataloader = DataLoader(test_data, batch_size = n_samples, shuffle=False, 
            collate_fn= lambda batch: meld_collate_fn(batch, args, device, data_type = "test"))

        data_objects = {"dataset_obj": dataset_obj, 
                    "train_dataloader": utils.inf_generator(train_dataloader), 
                    #"dev_dataloader": utils.inf_generator(dev_dataloader), 
                    "test_dataloader": utils.inf_generator(test_dataloader),
                    "input_dim": input_dim+1,
                    "n_train_batches": len(train_dataloader),
                    #"n_dev_batches": len(dev_dataloader),
                    "n_test_batches": len(test_dataloader),
                    "classif_per_tp": True,
                    "n_labels": 7}

        return data_objects

    ########### 1d datasets ###########

    # Sampling args.timepoints time points in the interval [0, args.max_t]
    # Sample points for both training sequence and explapolation (test)
    distribution = uniform.Uniform(torch.Tensor([0.0]),torch.Tensor([max_t_extrap]))
    time_steps_extrap =  distribution.sample(torch.Size([n_total_tp-1]))[:,0]
    time_steps_extrap = torch.cat((torch.Tensor([0.0]), time_steps_extrap))
    time_steps_extrap = torch.sort(time_steps_extrap)[0]

    dataset_obj = None
    ##################################################################
    # Sample a periodic function
    if dataset_name == "periodic":
        dataset_obj = Periodic_1d(
            init_freq = None, init_amplitude = 1.,
            final_amplitude = 1., final_freq = None, 
            z0 = 1.)

    ##################################################################

    if dataset_obj is None:
        raise Exception("Unknown dataset: {}".format(dataset_name))

    dataset = dataset_obj.sample_traj(time_steps_extrap, n_samples = args.n, 
        noise_weight = args.noise_weight)

    # Process small datasets
    dataset = dataset.to(device)
    time_steps_extrap = time_steps_extrap.to(device)

    train_y, test_y = utils.split_train_test(dataset, train_fraq = 0.8)

    n_samples = len(dataset)
    input_dim = dataset.size(-1)

    batch_size = min(args.batch_size, args.n)
    train_dataloader = DataLoader(train_y, batch_size = batch_size, shuffle=False,
        collate_fn= lambda batch: basic_collate_fn(batch, time_steps_extrap, data_type = "train"))
    test_dataloader = DataLoader(test_y, batch_size = args.n, shuffle=False,
        collate_fn= lambda batch: basic_collate_fn(batch, time_steps_extrap, data_type = "test"))
    
    data_objects = {#"dataset_obj": dataset_obj, 
                "train_dataloader": utils.inf_generator(train_dataloader), 
                "test_dataloader": utils.inf_generator(test_dataloader),
                "input_dim": input_dim,
                "n_train_batches": len(train_dataloader),
                "n_test_batches": len(test_dataloader)}

    return data_objects
Пример #3
0
def parse_datasets(args, device):
    def basic_collate_fn(batch,
                         time_steps,
                         args=args,
                         device=device,
                         data_type="train"):

        batch = torch.stack(batch)
        data_dict = {"data": batch, "time_steps": time_steps}

        data_dict = utils.split_and_subsample_batch(data_dict,
                                                    args,
                                                    data_type=data_type)
        return data_dict

    # def custom_collate_fn(batch, time_steps, args=args, device = device,  data)

    dataset_name = args.dataset

    n_total_tp = args.timepoints + args.extrap
    max_t_extrap = args.max_t / args.timepoints * n_total_tp

    ##################################################################
    # MuJoCo dataset
    if dataset_name == "hopper":
        dataset_obj = HopperPhysics(root='data',
                                    download=True,
                                    generate=False,
                                    device=device)
        dataset = dataset_obj.get_dataset()[:args.n]
        dataset = dataset.to(device)

        n_tp_data = dataset[:].shape[1]

        # Time steps that are used later on for exrapolation
        time_steps = torch.arange(start=0, end=n_tp_data,
                                  step=1).float().to(device)
        time_steps = time_steps / len(time_steps)

        dataset = dataset.to(device)
        time_steps = time_steps.to(device)

        if not args.extrap:
            # Creating dataset for interpolation
            # sample time points from different parts of the timeline,
            # so that the model learns from different parts of hopper trajectory
            n_traj = len(dataset)
            n_tp_data = dataset.shape[1]
            n_reduced_tp = args.timepoints

            # sample time points from different parts of the timeline,
            # so that the model learns from different parts of hopper trajectory
            start_ind = np.random.randint(0,
                                          high=n_tp_data - n_reduced_tp + 1,
                                          size=n_traj)
            end_ind = start_ind + n_reduced_tp
            sliced = []
            for i in range(n_traj):
                sliced.append(dataset[i, start_ind[i]:end_ind[i], :])
            dataset = torch.stack(sliced).to(device)
            time_steps = time_steps[:n_reduced_tp]

        # Split into train and test by the time sequences
        train_y, test_y = utils.split_train_test(dataset, train_fraq=0.8)

        n_samples = len(dataset)
        input_dim = dataset.size(-1)

        if args.mcar or args.mnar:
            collate_fn_train = lambda batch: data_ampute_batch_collate(
                batch, time_steps, args, device, data_type="train")
            collate_fn_test = lambda batch: data_ampute_batch_collate(
                batch, time_steps, args, device, data_type="test")
        else:
            collate_fn_train = lambda batch: basic_collate_fn(
                batch, time_steps, data_type="train")
            collate_fn_test = lambda batch: basic_collate_fn(
                batch, time_steps, data_type="test")

        batch_size = min(args.batch_size, args.n)
        train_dataloader = DataLoader(train_y,
                                      batch_size=batch_size,
                                      shuffle=False,
                                      collate_fn=collate_fn_train)
        test_dataloader = DataLoader(test_y,
                                     batch_size=n_samples,
                                     shuffle=False,
                                     collate_fn=collate_fn_test)

        data_objects = {
            "dataset_obj": dataset_obj,
            "train_dataloader": utils.inf_generator(train_dataloader),
            "test_dataloader": utils.inf_generator(test_dataloader),
            "input_dim": input_dim,
            "n_train_batches": len(train_dataloader),
            "n_test_batches": len(test_dataloader)
        }
        return data_objects

    ##################################################################
    # Physionet dataset

    if dataset_name == "physionet":
        train_dataset_obj = PhysioNet('data/physionet',
                                      train=True,
                                      quantization=args.quantization,
                                      download=True,
                                      n_samples=min(10000, args.n),
                                      device=device)
        # Use custom collate_fn to combine samples with arbitrary time observations.
        # Returns the dataset along with mask and time steps
        test_dataset_obj = PhysioNet('data/physionet',
                                     train=False,
                                     quantization=args.quantization,
                                     download=True,
                                     n_samples=min(10000, args.n),
                                     device=device)

        # Combine and shuffle samples from physionet Train and physionet Test
        total_dataset = train_dataset_obj[:len(train_dataset_obj)]

        if not args.classif:
            # Concatenate samples from original Train and Test sets
            # Only 'training' physionet samples are have labels. Therefore, if we do classifiction task, we don't need physionet 'test' samples.
            total_dataset = total_dataset + test_dataset_obj[:len(
                test_dataset_obj)]

        # Shuffle and split
        train_data, test_data = model_selection.train_test_split(
            total_dataset, train_size=0.8, random_state=42, shuffle=True)

        record_id, tt, vals, mask, labels = train_data[0]

        n_samples = len(total_dataset)
        input_dim = vals.size(-1)

        batch_size = min(min(len(train_dataset_obj), args.batch_size), args.n)
        data_min, data_max = get_data_min_max(total_dataset)

        train_dataloader = DataLoader(
            train_data,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=lambda batch: variable_time_collate_fn(
                batch,
                args,
                device,
                data_type="train",
                data_min=data_min,
                data_max=data_max))
        test_dataloader = DataLoader(
            test_data,
            batch_size=n_samples,
            shuffle=False,
            collate_fn=lambda batch: variable_time_collate_fn(
                batch,
                args,
                device,
                data_type="test",
                data_min=data_min,
                data_max=data_max))

        attr_names = train_dataset_obj.params
        data_objects = {
            "dataset_obj": train_dataset_obj,
            "train_dataloader": utils.inf_generator(train_dataloader),
            "test_dataloader": utils.inf_generator(test_dataloader),
            "input_dim": input_dim,
            "n_train_batches": len(train_dataloader),
            "n_test_batches": len(test_dataloader),
            "attr": attr_names,  #optional
            "classif_per_tp": False,  #optional
            "n_labels": 1
        }  #optional
        return data_objects

    ##################################################################
    # Human activity dataset

    if dataset_name == "activity":
        n_samples = min(10000, args.n)
        dataset_obj = PersonActivity('data/PersonActivity',
                                     download=True,
                                     n_samples=n_samples,
                                     device=device)
        print(dataset_obj)
        # Use custom collate_fn to combine samples with arbitrary time observations.
        # Returns the dataset along with mask and time steps

        # Shuffle and split
        train_data, test_data = model_selection.train_test_split(
            dataset_obj, train_size=0.8, random_state=42, shuffle=True)

        train_data = [
            train_data[i]
            for i in np.random.choice(len(train_data), len(train_data))
        ]
        test_data = [
            test_data[i]
            for i in np.random.choice(len(test_data), len(test_data))
        ]

        record_id, tt, vals, mask, labels = train_data[0]
        input_dim = vals.size(-1)

        batch_size = min(min(len(dataset_obj), args.batch_size), args.n)
        train_dataloader = DataLoader(
            train_data,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=lambda batch: variable_time_collate_fn_activity(
                batch, args, device, data_type="train"))
        test_dataloader = DataLoader(
            test_data,
            batch_size=n_samples,
            shuffle=False,
            collate_fn=lambda batch: variable_time_collate_fn_activity(
                batch, args, device, data_type="test"))

        data_objects = {
            "dataset_obj": dataset_obj,
            "train_dataloader": utils.inf_generator(train_dataloader),
            "test_dataloader": utils.inf_generator(test_dataloader),
            "input_dim": input_dim,
            "n_train_batches": len(train_dataloader),
            "n_test_batches": len(test_dataloader),
            "classif_per_tp": True,  #optional
            "n_labels": labels.size(-1)
        }

        return data_objects

    if dataset_name == "markov_chain":

        def generate_data(num_samples, length):
            from hmmlearn import hmm

            model = hmm.GaussianHMM(n_components=3, covariance_type="full")
            model.startprob_ = np.array([0.6, 0.3, 0.1])
            model.transmat_ = np.array([[0.7, 0.2, 0.1], [0.3, 0.5, 0.2],
                                        [0.3, 0.3, 0.4]])

            model.means_ = np.array([[0.0, 0.0], [3.0, -3.0], [5.0, -5.0]])
            model.covars_ = np.tile(np.identity(2), (3, 1, 1))

            dataset = []

            for i in range(num_samples):
                X, Z = model.sample(length)
                X = np.reshape(X[:, 0], (length, 1))
                dataset.append(torch.tensor(X))
            return torch.stack(dataset)

        dataset = generate_data(args.n, n_total_tp).float()
        time_steps_extrap = torch.arange(n_total_tp).float()

        dataset = dataset.to(device)
        time_steps_extrap = time_steps_extrap.to(device)

        train_y, test_y = utils.split_train_test(dataset, train_fraq=0.8)

        n_samples = len(dataset)
        input_dim = dataset.size(-1)

        if args.mcar or args.mnar:
            collate_fn_train = lambda batch: data_ampute_batch_collate(
                batch, time_steps_extrap, args, device, data_type="train")
            collate_fn_test = lambda batch: data_ampute_batch_collate(
                batch, time_steps_extrap, args, device, data_type="test")
        else:
            collate_fn_train = lambda batch: basic_collate_fn(
                batch, time_steps_extrap, data_type="train")
            collate_fn_test = lambda batch: basic_collate_fn(
                batch, time_steps_extrap, data_type="test")

        batch_size = min(args.batch_size, args.n)
        train_dataloader = DataLoader(train_y,
                                      batch_size=batch_size,
                                      shuffle=False,
                                      collate_fn=collate_fn_train)
        test_dataloader = DataLoader(test_y,
                                     batch_size=args.n,
                                     shuffle=False,
                                     collate_fn=collate_fn_test)

        data_objects = {  #"dataset_obj": dataset_obj, 
            "train_dataloader": utils.inf_generator(train_dataloader),
            "test_dataloader": utils.inf_generator(test_dataloader),
            "input_dim": input_dim,
            "n_train_batches": len(train_dataloader),
            "n_test_batches": len(test_dataloader)
        }

        return data_objects

    ########### 1d datasets ###########

    # Sampling args.timepoints time points in the interval [0, args.max_t]
    # Sample points for both training sequence and explapolation (test)
    distribution = uniform.Uniform(torch.Tensor([0.0]),
                                   torch.Tensor([max_t_extrap]))
    time_steps_extrap = distribution.sample(torch.Size([n_total_tp - 1]))[:, 0]
    time_steps_extrap = torch.cat((torch.Tensor([0.0]), time_steps_extrap))
    time_steps_extrap = torch.sort(time_steps_extrap)[0]

    dataset_obj = None
    ##################################################################
    # Sample a periodic function
    if dataset_name == "periodic":
        dataset_obj = Periodic_1d(init_freq=None,
                                  init_amplitude=1.,
                                  final_amplitude=1.,
                                  final_freq=None,
                                  z0=1.)

    ##################################################################

    if dataset_obj is None:
        raise Exception("Unknown dataset: {}".format(dataset_name))

    # def make_quick_plot(data,time_steps_extrap, index):
    # 	import matplotlib.pyplot as plt

    # 	plt.figure()
    # 	plt.scatter(time_steps_extrap, data[index, :, 0])
    # 	plt.title(f"Dataset at index {index}")
    # 	plt.savefig(f"dataset_at_index_{index}.png")

    dataset = dataset_obj.sample_traj(time_steps_extrap,
                                      n_samples=args.n,
                                      noise_weight=args.noise_weight)

    # for i in range(3):
    # 	make_quick_plot(dataset, time_steps_extrap, i)

    # Process small datasets
    dataset = dataset.to(device)
    time_steps_extrap = time_steps_extrap.to(device)

    train_y, test_y = utils.split_train_test(dataset, train_fraq=0.8)

    n_samples = len(dataset)
    input_dim = dataset.size(-1)

    if args.mcar or args.mnar:
        collate_fn_train = lambda batch: data_ampute_batch_collate(
            batch, time_steps_extrap, args, device, data_type="train")
        collate_fn_test = lambda batch: data_ampute_batch_collate(
            batch, time_steps_extrap, args, device, data_type="test")
    else:
        collate_fn_train = lambda batch: basic_collate_fn(
            batch, time_steps_extrap, data_type="train")
        collate_fn_test = lambda batch: basic_collate_fn(
            batch, time_steps_extrap, data_type="test")

    batch_size = min(args.batch_size, args.n)
    train_dataloader = DataLoader(train_y,
                                  batch_size=batch_size,
                                  shuffle=False,
                                  collate_fn=collate_fn_train)
    test_dataloader = DataLoader(test_y,
                                 batch_size=args.n,
                                 shuffle=False,
                                 collate_fn=collate_fn_test)

    data_objects = {  #"dataset_obj": dataset_obj, 
        "train_dataloader": utils.inf_generator(train_dataloader),
        "test_dataloader": utils.inf_generator(test_dataloader),
        "input_dim": input_dim,
        "n_train_batches": len(train_dataloader),
        "n_test_batches": len(test_dataloader)
    }

    return data_objects
Пример #4
0
def parse_datasets(args, device):
    def basic_collate_fn(batch, time_steps, args=args, data_type='train'):
        batch = torch.stack(batch)
        data_dict = {'data': batch, 'time_steps': time_steps}
        data_dict = utils.split_and_subsample_batch(data_dict,
                                                    args,
                                                    data_type=data_type)
        return data_dict

    dataset_name = args.dataset
    n_total_tp = args.timepoints + args.extrap
    max_t_extrap = args.max_t / args.timepoints * n_total_tp

    if dataset_name == 'hopper':
        dataset_obj = HopperPhysics(root='data',
                                    download=True,
                                    generate=False,
                                    device=device)
        dataset = dataset_obj.get_dataset()[:args.n]
        dataset = dataset.to(device)

        n_tp_data = dataset[:].shape[1]

        time_steps = torch.arange(start=0, end=n_tp_data,
                                  step=1).float().to(device)
        time_steps = time_steps / len(time_steps)
        dataset = dataset.to(device)
        time_steps = time_steps.to(device)

        if not args.extrap:

            n_traj = len(dataset)
            n_tp_data = dataset.shape[1]
            n_reduced_tp = args.timepoints
            start_ind = np.random.randint(0,
                                          high=n_tp_data - n_reduced_tp + 1,
                                          size=n_traj)
            end_ind = start_ind + n_reduced_tp
            sliced = []
            for i in range(n_traj):
                sliced.append(dataset[i, start_ind[i], end_ind[i], :])
            dataset = torch.stack(sliced).to(device)
            time_steps = time_steps[:n_reduced_tp]

        train_y, test_y = utils.split_train_test(dataset, train_fraq=0.8)

        n_samples = len(dataset)
        input_dim = dataset.size(-1)
        batch_size = min(args.batch_size, args.n)
        train_dataloader = DataLoader(
            train_y,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=lambda batch: basic_collate_fn(
                batch, time_steps, data_type="train"))
        test_dataloader = DataLoader(test_y,
                                     batch_size=n_samples,
                                     shuffle=False,
                                     collate_fn=lambda batch: basic_collate_fn(
                                         batch, time_steps, data_type="test"))

        data_objects = {
            "dataset_obj": dataset_obj,
            "train_dataloader": utils.inf_generator(train_dataloader),
            "test_dataloader": utils.inf_generator(test_dataloader),
            "input_dim": input_dim,
            "n_train_batches": len(train_dataloader),
            "n_test_batches": len(test_dataloader)
        }
        return data_objects

    ########### 1d datasets ###########

    # Sampling args.timepoints time points in the interval [0, args.max_t]
    # Sample points for both training sequence and explapolation (test)
    distribution = uniform.Uniform(torch.Tensor([0.0]),
                                   torch.Tensor([max_t_extrap]))
    time_steps_extrap = distribution.sample(torch.Size([n_total_tp - 1]))[:, 0]
    time_steps_extrap = torch.cat((torch.Tensor([0.0]), time_steps_extrap))
    time_steps_extrap = torch.sort(time_steps_extrap)[0]

    dataset_obj = None
    ##################################################################
    # Sample a periodic function
    if dataset_name == "periodic":
        dataset_obj = Periodic_1d(init_freq=None,
                                  init_amplitude=1.,
                                  final_amplitude=1.,
                                  final_freq=None,
                                  z0=1.)

        dataset = dataset_obj.sample_traj(time_steps_extrap,
                                          n_samples=args.n,
                                          noise_weight=args.noise_weight)

        # Process small datasets
        dataset = dataset.to(device)
        time_steps_extrap = time_steps_extrap.to(device)

        train_y, test_y = utils.split_train_test(dataset, train_fraq=0.8)

        n_samples = len(dataset)
        input_dim = dataset.size(-1)

        batch_size = min(args.batch_size, args.n)
        train_dataloader = DataLoader(
            train_y,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=lambda batch: basic_collate_fn(
                batch, time_steps_extrap, data_type="train"))
        test_dataloader = DataLoader(
            test_y,
            batch_size=args.n,
            shuffle=False,
            collate_fn=lambda batch: basic_collate_fn(
                batch, time_steps_extrap, data_type="test"))

        data_objects = {  #"dataset_obj": dataset_obj,
            "train_dataloader": utils.inf_generator(train_dataloader),
            "test_dataloader": utils.inf_generator(test_dataloader),
            "input_dim": input_dim,
            "n_train_batches": len(train_dataloader),
            "n_test_batches": len(test_dataloader)
        }

        return data_objects
    time_steps_extrap = time_steps_extrap.to(device)

    train_y, test_y = utils.split_train_test(dataset, train_fraq = 0.8)

    n_samples = len(dataset)
    input_dim = dataset.size(-1)

    batch_size = min(args.batch_size, args.n)
    train_dataloader = DataLoader(train_data, batch_size= batch_size, shuffle=False, 
			collate_fn= lambda batch: variable_time_collate_fn_activity(batch, args, device, data_type = "train"))
    test_dataloader = DataLoader(test_data, batch_size=n_samples, shuffle=False, 
			collate_fn= lambda batch: variable_time_collate_fn_activity(batch, args, device, data_type = "test"))

	
    data_obj = {"dataset_obj": dataset,
        "train_dataloader": utils.inf_generator(train_dataloader),
        "test_dataloader": utils.inf_generator(test_dataloader),
	"input_dim": input_dim,
	"n_train_batches": len(train_dataloader),
	"n_test_batches": len(test_dataloader)}

	# Create the model
	obsrv_std = 0.01
	if args.dataset == "hopper":
		obsrv_std = 1e-3 

	obsrv_std = torch.Tensor([obsrv_std]).to(device)

	z0_prior = Normal(torch.Tensor([0.0]).to(device), torch.Tensor([1.]).to(device))

	if args.rnn_vae: