Ejemplo n.º 1
0
    def get_data_loaders(self):
        data_dir = self.args.data_dir

        self.train_dataset = PickleDataset(
            os.path.join(data_dir, f'{self.args.train_set}.pkl'),
            os.path.join(data_dir, self.args.train_index_file),
            segment_size=self.config.segment_size)

        self.val_dataset = PickleDataset(
            os.path.join(data_dir, f'{self.args.val_set}.pkl'),
            os.path.join(data_dir, self.args.val_index_file),
            segment_size=self.config.segment_size)

        self.train_loader = get_data_loader(self.train_dataset,
                                            batch_size=self.config.batch_size,
                                            shuffle=self.config.shuffle,
                                            num_workers=4,
                                            drop_last=False)

        self.val_loader = get_data_loader(self.val_dataset,
                                          batch_size=self.config.batch_size,
                                          shuffle=self.config.shuffle,
                                          num_workers=4,
                                          drop_last=False)

        self.train_iter = infinite_iter(self.train_loader)
        return
Ejemplo n.º 2
0
    def get_data_loaders(self):
        data_dir = self.args.data_dir
        self.gpu_num = torch.cuda.device_count() if torch.cuda.is_available(
        ) else 1
        self.train_dataset = PickleDataset(
            os.path.join(data_dir, f'{self.args.train_set}.pkl'),
            os.path.join(data_dir, self.args.train_index_file),
            segment_size=self.config['data_loader']['segment_size'])
        self.train_loader = get_data_loader(
            self.train_dataset,
            frame_size=self.config['data_loader']['frame_size'],
            batch_size=self.config['data_loader']['batch_size'] * self.gpu_num,
            num_workers=0,
            shuffle=self.config['data_loader']['shuffle'],
            drop_last=False)
        self.train_iter = infinite_iter(self.train_loader)

        if self.args.use_eval_set:
            self.eval_dataset = PickleDataset(
                os.path.join(data_dir, f'{self.args.eval_set}.pkl'),
                os.path.join(data_dir, self.args.eval_index_file),
                segment_size=self.config['data_loader']['segment_size'])

            self.eval_loader = get_data_loader(
                self.eval_dataset,
                frame_size=self.config['data_loader']['frame_size'],
                batch_size=self.config['data_loader']['batch_size'] *
                self.gpu_num,
                shuffle=self.config['data_loader']['shuffle'],
                num_workers=0,
                drop_last=False)
            self.eval_iter = infinite_iter(self.eval_loader)

        if self.args.use_test_set:
            self.test_dataset = PickleDataset(
                os.path.join(data_dir, f'{self.args.test_set}.pkl'),
                os.path.join(data_dir, self.args.test_index_file),
                segment_size=self.config['data_loader']['segment_size'])

            self.test_loader = get_data_loader(
                self.test_dataset,
                frame_size=self.config['data_loader']['frame_size'],
                batch_size=self.config['data_loader']['batch_size'],
                shuffle=False,
                num_workers=0,
                drop_last=False)
            self.test_iter = infinite_iter(self.test_loader)

        return
Ejemplo n.º 3
0
    def get_data_loaders(self):
        data_dir = self.args.data_dir

        self.test_dataset = PickleDataset(
            os.path.join(data_dir, f'{self.args.test_set}.pkl'),
            os.path.join(data_dir, self.args.test_index_file),
            segment_size=self.config['data_loader']['segment_size'])

        self.test_loader = get_data_loader(
            self.test_dataset,
            frame_size=self.config['data_loader']['frame_size'],
            batch_size=self.config['data_loader']['batch_size'],
            shuffle=False,
            drop_last=False)
Ejemplo n.º 4
0
 def get_data_loaders(self):
     data_dir = self.args.data_dir
     self.train_dataset = PickleDataset(
         os.path.join(data_dir, f'{self.args.train_set}.pkl'),
         os.path.join(data_dir, self.args.train_index_file),
         segment_size=self.config['data_loader']['segment_size'])
     self.train_loader = get_data_loader(
         self.train_dataset,
         frame_size=self.config['data_loader']['frame_size'],
         batch_size=self.config['data_loader']['batch_size'],
         shuffle=self.config['data_loader']['shuffle'],
         num_workers=4,
         drop_last=False)
     self.train_iter = infinite_iter(self.train_loader)
     return
Ejemplo n.º 5
0
 def get_data_loaders(self):
     data_dir = self.args.data_dir
     self.train_dataset = PickleDataset(
         os.path.join(data_dir, f"{self.args.train_set}.pkl"),
         os.path.join(data_dir, self.args.train_index_file),
         segment_size=self.config["data_loader"]["segment_size"],
     )
     self.train_loader = get_data_loader(
         self.train_dataset,
         frame_size=self.config["data_loader"]["frame_size"],
         batch_size=self.config["data_loader"]["batch_size"],
         shuffle=self.config["data_loader"]["shuffle"],
         num_workers=0,
         drop_last=False,
     )
     self.train_iter = infinite_iter(self.train_loader)
     return