def main(): # data files test_files, _ = data_util.get_train_files(args.input_data_path, args.test_file_list, '') if len(test_files) > args.max_to_vis: test_files = test_files[:args.max_to_vis] else: args.max_to_vis = len(test_files) random.seed(42) random.shuffle(test_files) print('#test files = ', len(test_files)) test_dataset = scene_dataloader.SceneDataset(test_files, args.input_dim, args.truncation, args.num_hierarchy_levels, args.max_input_height, 0, args.target_data_path) test_dataloader = torch.utils.data.DataLoader( test_dataset, batch_size=1, shuffle=False, num_workers=2, collate_fn=scene_dataloader.collate) if os.path.exists(args.output): raw_input( 'warning: output dir %s exists, press key to overwrite and continue' % args.output) if not os.path.exists(args.output): os.makedirs(args.output) # start testing print('starting testing...') loss_weights = np.ones(args.num_hierarchy_levels + 1, dtype=np.float32) test(loss_weights, test_dataloader, args.output, args.max_to_vis)
def val_dataloader(self): log.info('Validation data loader called.') data_path, train_files, val_files = self._get_train_files() input_dim = tuple(self.hparams.model.input_dim) num_hierarchy_levels = self.hparams.train.num_hierarchy_levels truncation = self.hparams.train.truncation batch_size = self.hparams.train.batch_size num_workers_valid = self.hparams.train.num_workers_valid num_overfit_val = self.hparams.train.num_overfit_val if len(val_files) > 0: val_dataset = scene_dataloader.SceneDataset( val_files, input_dim, truncation, num_hierarchy_levels, 0, num_overfit_val) print('val_dataset', len(val_dataset)) val_dataloader = torch.utils.data.DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers_valid, collate_fn=scene_dataloader.collate) return val_dataloader
def train_dataloader(self): log.info('Training data loader called.') data_path, train_files, val_files = self._get_train_files() input_dim = self.hparams.input_dim num_hierarchy_levels = self.hparams.num_hierarchy_levels truncation = self.hparams.truncation batch_size = self.hparams.batch_size num_workers_train = self.hparams.num_workers_train _OVERFIT = False if len(train_files) == 1: _OVERFIT = True # TODO: #args.use_loss_masking = False num_overfit_train = 0 if not _OVERFIT else 640 num_overfit_val = 0 if not _OVERFIT else 160 print('#train files = ', len(train_files)) print('#val files = ', len(val_files)) train_dataset = scene_dataloader.SceneDataset(train_files, input_dim, truncation, num_hierarchy_levels, 0, num_overfit_train) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers_train, collate_fn=scene_dataloader.collate) self._iter_counter = self.hparams.start_epoch * (len(train_dataset) // self.hparams.batch_size) return train_dataloader
def main(): # data files test_files, _, _ = data_util.get_train_files(args.input_data_path, args.test_file_list, '', 0) if len(test_files) > args.max_to_process: test_files = test_files[:args.max_to_process] else: args.max_to_process = len(test_files) random.seed(42) random.shuffle(test_files) print('#test files = ', len(test_files)) test_dataset = scene_dataloader.SceneDataset(test_files, args.input_dim, args.truncation, True, args.augment_rgb_scaling, (args.augment_scale_min, args.augment_scale_max), args.color_truncation, args.color_space, target_path=args.target_data_path, max_input_height=args.max_input_height) test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2, collate_fn=scene_dataloader.collate_voxels) if os.path.exists(args.output): if args.vis_only: print('warning: output dir %s exists, will overwrite any existing files') else: input('warning: output dir %s exists, press key to delete and continue' % args.output) shutil.rmtree(args.output) if not os.path.exists(args.output): os.makedirs(args.output) output_vis_path = os.path.join(args.output, 'vis') if not os.path.exists(output_vis_path): os.makedirs(output_vis_path) # start testing print('starting testing...') test(test_dataloader, output_vis_path, args.num_to_vis)
last_epoch=last_epoch) # data files train_files, val_files = data_util.get_train_files(args.data_path, args.train_file_list, args.val_file_list) _OVERFIT = False if len(train_files) == 1: _OVERFIT = True args.use_loss_masking = False num_overfit_train = 0 if not _OVERFIT else 640 num_overfit_val = 0 if not _OVERFIT else 160 print('#train files = ', len(train_files)) print('#val files = ', len(val_files)) train_dataset = scene_dataloader.SceneDataset(train_files, args.input_dim, args.truncation, args.num_hierarchy_levels, 0, num_overfit_train) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2, collate_fn=scene_dataloader.collate) if len(val_files) > 0: val_dataset = scene_dataloader.SceneDataset(val_files, args.input_dim, args.truncation, args.num_hierarchy_levels, 0, num_overfit_val) print('val_dataset', len(val_dataset)) val_dataloader = torch.utils.data.DataLoader( val_dataset,