class Visualization_demo(): def __init__(self, cfg, output_dir): self.encoder = Encoder(cfg) self.decoder = Decoder(cfg) self.refiner = Refiner(cfg) self.merger = Merger(cfg) checkpoint = torch.load(cfg.CHECKPOINT) encoder_state_dict = clean_state_dict(checkpoint['encoder_state_dict']) self.encoder.load_state_dict(encoder_state_dict) decoder_state_dict = clean_state_dict(checkpoint['decoder_state_dict']) self.decoder.load_state_dict(decoder_state_dict) if cfg.NETWORK.USE_REFINER: refiner_state_dict = clean_state_dict( checkpoint['refiner_state_dict']) self.refiner.load_state_dict(refiner_state_dict) if cfg.NETWORK.USE_MERGER: merger_state_dict = clean_state_dict( checkpoint['merger_state_dict']) self.merger.load_state_dict(merger_state_dict) if not os.path.exists(output_dir): os.makedirs(output_dir) self.output_dir = output_dir def run_on_images(self, imgs, sid, mid, iid, sampled_idx): dir1 = os.path.join(output_dir, str(sid), str(mid)) if not os.path.exists(dir1): os.makedirs(dir1) deprocess = imagenet_deprocess(rescale_image=False) image_features = self.encoder(imgs) raw_features, generated_volume = self.decoder(image_features) generated_volume = self.merger(raw_features, generated_volume) generated_volume = self.refiner(generated_volume) mesh = cubify(generated_volume, 0.3) # mesh = voxel_to_world(meshes) save_mesh = os.path.join(dir1, "%s_%s.obj" % (iid, sampled_idx)) verts, faces = mesh.get_mesh_verts_faces(0) save_obj(save_mesh, verts, faces) generated_volume = generated_volume.squeeze() img = image_to_numpy(deprocess(imgs[0][0])) save_img = os.path.join(dir1, "%02d.png" % (iid)) # cv2.imwrite(save_img, img[:, :, ::-1]) cv2.imwrite(save_img, img) img1 = image_to_numpy(deprocess(imgs[0][1])) save_img1 = os.path.join(dir1, "%02d.png" % (sampled_idx)) cv2.imwrite(save_img1, img1) # cv2.imwrite(save_img1, img1[:, :, ::-1]) get_volume_views(generated_volume, dir1, iid, sampled_idx)
class Quantitative_analysis_demo(): def __init__(self, cfg, output_dir): self.encoder = Encoder(cfg) self.decoder = Decoder(cfg) self.refiner = Refiner(cfg) self.merger = Merger(cfg) # self.thresh = cfg.VOXEL_THRESH self.th = cfg.TEST.VOXEL_THRESH checkpoint = torch.load(cfg.CHECKPOINT) encoder_state_dict = clean_state_dict(checkpoint['encoder_state_dict']) self.encoder.load_state_dict(encoder_state_dict) decoder_state_dict = clean_state_dict(checkpoint['decoder_state_dict']) self.decoder.load_state_dict(decoder_state_dict) if cfg.NETWORK.USE_REFINER: refiner_state_dict = clean_state_dict( checkpoint['refiner_state_dict']) self.refiner.load_state_dict(refiner_state_dict) if cfg.NETWORK.USE_MERGER: merger_state_dict = clean_state_dict( checkpoint['merger_state_dict']) self.merger.load_state_dict(merger_state_dict) self.output_dir = output_dir def calculate_iou(self, imgs, GT_voxels, sid, mid, iid): dir1 = os.path.join(self.output_dir, str(sid), str(mid)) if not os.path.exists(dir1): os.makedirs(dir1) image_features = self.encoder(imgs) raw_features, generated_volume = self.decoder(image_features) generated_volume = self.merger(raw_features, generated_volume) generated_volume = self.refiner(generated_volume) generated_volume = generated_volume.squeeze() sample_iou = [] for th in self.th: _volume = torch.ge(generated_volume, th).float() intersection = torch.sum(_volume.mul(GT_voxels)).float() union = torch.sum(torch.ge(_volume.add(GT_voxels), 1)).float() sample_iou.append((intersection / union).item()) return sample_iou
def __init__(self, cfg, output_dir): self.encoder = Encoder(cfg) self.decoder = Decoder(cfg) self.refiner = Refiner(cfg) self.merger = Merger(cfg) # self.thresh = cfg.VOXEL_THRESH self.th = cfg.TEST.VOXEL_THRESH checkpoint = torch.load(cfg.CHECKPOINT) encoder_state_dict = clean_state_dict(checkpoint['encoder_state_dict']) self.encoder.load_state_dict(encoder_state_dict) decoder_state_dict = clean_state_dict(checkpoint['decoder_state_dict']) self.decoder.load_state_dict(decoder_state_dict) if cfg.NETWORK.USE_REFINER: refiner_state_dict = clean_state_dict(checkpoint['refiner_state_dict']) self.refiner.load_state_dict(refiner_state_dict) if cfg.NETWORK.USE_MERGER: merger_state_dict = clean_state_dict(checkpoint['merger_state_dict']) self.merger.load_state_dict(merger_state_dict) self.output_dir = output_dir
def __init__(self, cfg_network: DictConfig, cfg_tester: DictConfig): super().__init__() self.cfg_network = cfg_network self.cfg_tester = cfg_tester # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use torch.backends.cudnn.benchmark = True # Set up networks self.encoder = Encoder(cfg_network) self.decoder = Decoder(cfg_network) self.refiner = Refiner(cfg_network) self.merger = Merger(cfg_network) # Initialize weights of networks self.encoder.apply(utils.network_utils.init_weights) self.decoder.apply(utils.network_utils.init_weights) self.refiner.apply(utils.network_utils.init_weights) self.merger.apply(utils.network_utils.init_weights) self.bce_loss = nn.BCELoss()
def __init__(self, cfg, output_dir): self.encoder = Encoder(cfg) self.decoder = Decoder(cfg) self.refiner = Refiner(cfg) self.merger = Merger(cfg) checkpoint = torch.load(cfg.CHECKPOINT) encoder_state_dict = clean_state_dict(checkpoint['encoder_state_dict']) self.encoder.load_state_dict(encoder_state_dict) decoder_state_dict = clean_state_dict(checkpoint['decoder_state_dict']) self.decoder.load_state_dict(decoder_state_dict) if cfg.NETWORK.USE_REFINER: refiner_state_dict = clean_state_dict( checkpoint['refiner_state_dict']) self.refiner.load_state_dict(refiner_state_dict) if cfg.NETWORK.USE_MERGER: merger_state_dict = clean_state_dict( checkpoint['merger_state_dict']) self.merger.load_state_dict(merger_state_dict) if not os.path.exists(output_dir): os.makedirs(output_dir) self.output_dir = output_dir
def train_net(cfg): # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use torch.backends.cudnn.benchmark = True # Set up data augmentation IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W train_transforms = utils.data_transforms.Compose([ utils.data_transforms.RandomCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground( cfg.TRAIN.RANDOM_BG_COLOR_RANGE), utils.data_transforms.ColorJitter(cfg.TRAIN.BRIGHTNESS, cfg.TRAIN.CONTRAST, cfg.TRAIN.SATURATION), utils.data_transforms.RandomNoise(cfg.TRAIN.NOISE_STD), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.RandomFlip(), utils.data_transforms.RandomPermuteRGB(), utils.data_transforms.ToTensor(), ]) val_transforms = utils.data_transforms.Compose([ utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.ToTensor(), ]) # Set up data loader train_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[ cfg.DATASET.TRAIN_DATASET](cfg) val_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[ cfg.DATASET.TEST_DATASET](cfg) train_data_loader = torch.utils.data.DataLoader( dataset=train_dataset_loader.get_dataset( utils.data_loaders.DatasetType.TRAIN, cfg.CONST.N_VIEWS_RENDERING, train_transforms), batch_size=cfg.CONST.BATCH_SIZE, num_workers=cfg.TRAIN.NUM_WORKER, pin_memory=True, shuffle=True, drop_last=True) val_data_loader = torch.utils.data.DataLoader( dataset=val_dataset_loader.get_dataset( utils.data_loaders.DatasetType.VAL, cfg.CONST.N_VIEWS_RENDERING, val_transforms), batch_size=1, num_workers=1, pin_memory=True, shuffle=False) # Set up networks encoder = Encoder(cfg) decoder = Decoder(cfg) refiner = Refiner(cfg) merger = Merger(cfg) print('[DEBUG] %s Parameters in Encoder: %d.' % (dt.now(), utils.network_utils.count_parameters(encoder))) print('[DEBUG] %s Parameters in Decoder: %d.' % (dt.now(), utils.network_utils.count_parameters(decoder))) print('[DEBUG] %s Parameters in Refiner: %d.' % (dt.now(), utils.network_utils.count_parameters(refiner))) print('[DEBUG] %s Parameters in Merger: %d.' % (dt.now(), utils.network_utils.count_parameters(merger))) # Initialize weights of networks encoder.apply(utils.network_utils.init_weights) decoder.apply(utils.network_utils.init_weights) refiner.apply(utils.network_utils.init_weights) merger.apply(utils.network_utils.init_weights) # Set up solver if cfg.TRAIN.POLICY == 'adam': encoder_solver = torch.optim.Adam(filter(lambda p: p.requires_grad, encoder.parameters()), lr=cfg.TRAIN.ENCODER_LEARNING_RATE, betas=cfg.TRAIN.BETAS) decoder_solver = torch.optim.Adam(decoder.parameters(), lr=cfg.TRAIN.DECODER_LEARNING_RATE, betas=cfg.TRAIN.BETAS) refiner_solver = torch.optim.Adam(refiner.parameters(), lr=cfg.TRAIN.REFINER_LEARNING_RATE, betas=cfg.TRAIN.BETAS) merger_solver = torch.optim.Adam(merger.parameters(), lr=cfg.TRAIN.MERGER_LEARNING_RATE, betas=cfg.TRAIN.BETAS) elif cfg.TRAIN.POLICY == 'sgd': encoder_solver = torch.optim.SGD(filter(lambda p: p.requires_grad, encoder.parameters()), lr=cfg.TRAIN.ENCODER_LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM) decoder_solver = torch.optim.SGD(decoder.parameters(), lr=cfg.TRAIN.DECODER_LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM) refiner_solver = torch.optim.SGD(refiner.parameters(), lr=cfg.TRAIN.REFINER_LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM) merger_solver = torch.optim.SGD(merger.parameters(), lr=cfg.TRAIN.MERGER_LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM) else: raise Exception('[FATAL] %s Unknown optimizer %s.' % (dt.now(), cfg.TRAIN.POLICY)) # Set up learning rate scheduler to decay learning rates dynamically encoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( encoder_solver, milestones=cfg.TRAIN.ENCODER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA) decoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( decoder_solver, milestones=cfg.TRAIN.DECODER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA) refiner_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( refiner_solver, milestones=cfg.TRAIN.REFINER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA) merger_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( merger_solver, milestones=cfg.TRAIN.MERGER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA) if torch.cuda.is_available(): encoder = torch.nn.DataParallel(encoder).cuda() decoder = torch.nn.DataParallel(decoder).cuda() refiner = torch.nn.DataParallel(refiner).cuda() merger = torch.nn.DataParallel(merger).cuda() # Set up loss functions bce_loss = torch.nn.BCELoss() # Load pretrained model if exists init_epoch = 0 best_iou = -1 best_epoch = -1 if 'WEIGHTS' in cfg.CONST and cfg.TRAIN.RESUME_TRAIN: print('[INFO] %s Recovering from %s ...' % (dt.now(), cfg.CONST.WEIGHTS)) checkpoint = torch.load(cfg.CONST.WEIGHTS) init_epoch = checkpoint['epoch_idx'] best_iou = checkpoint['best_iou'] best_epoch = checkpoint['best_epoch'] encoder.load_state_dict(checkpoint['encoder_state_dict']) decoder.load_state_dict(checkpoint['decoder_state_dict']) if cfg.NETWORK.USE_REFINER: refiner.load_state_dict(checkpoint['refiner_state_dict']) if cfg.NETWORK.USE_MERGER: merger.load_state_dict(checkpoint['merger_state_dict']) print('[INFO] %s Recover complete. Current epoch #%d, Best IoU = %.4f at epoch #%d.' \ % (dt.now(), init_epoch, best_iou, best_epoch)) # Summary writer for TensorBoard output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s', dt.now().isoformat()) log_dir = output_dir % 'logs' ckpt_dir = output_dir % 'checkpoints' train_writer = SummaryWriter(os.path.join(log_dir, 'train')) val_writer = SummaryWriter(os.path.join(log_dir, 'test')) # Training loop for epoch_idx in range(init_epoch, cfg.TRAIN.NUM_EPOCHES): # Tick / tock epoch_start_time = time() # Batch average meterics batch_time = utils.network_utils.AverageMeter() data_time = utils.network_utils.AverageMeter() encoder_losses = utils.network_utils.AverageMeter() refiner_losses = utils.network_utils.AverageMeter() # Adjust learning rate encoder_lr_scheduler.step() decoder_lr_scheduler.step() refiner_lr_scheduler.step() merger_lr_scheduler.step() # switch models to training mode encoder.train() decoder.train() merger.train() refiner.train() batch_end_time = time() n_batches = len(train_data_loader) for batch_idx, (taxonomy_names, sample_names, rendering_images, ground_truth_volumes) in enumerate(train_data_loader): # Measure data time data_time.update(time() - batch_end_time) # Get data from data loader rendering_images = utils.network_utils.var_or_cuda( rendering_images) ground_truth_volumes = utils.network_utils.var_or_cuda( ground_truth_volumes) # Train the encoder, decoder, refiner, and merger image_features = encoder(rendering_images) raw_features, generated_volumes = decoder(image_features) if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER: generated_volumes = merger(raw_features, generated_volumes) else: generated_volumes = torch.mean(generated_volumes, dim=1) encoder_loss = bce_loss(generated_volumes, ground_truth_volumes) * 10 if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER: generated_volumes = refiner(generated_volumes) refiner_loss = bce_loss(generated_volumes, ground_truth_volumes) * 10 else: refiner_loss = encoder_loss # Gradient decent encoder.zero_grad() decoder.zero_grad() refiner.zero_grad() merger.zero_grad() if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER: encoder_loss.backward(retain_graph=True) refiner_loss.backward() else: encoder_loss.backward() encoder_solver.step() decoder_solver.step() refiner_solver.step() merger_solver.step() # Append loss to average metrics encoder_losses.update(encoder_loss.item()) refiner_losses.update(refiner_loss.item()) # Append loss to TensorBoard n_itr = epoch_idx * n_batches + batch_idx train_writer.add_scalar('EncoderDecoder/BatchLoss', encoder_loss.item(), n_itr) train_writer.add_scalar('Refiner/BatchLoss', refiner_loss.item(), n_itr) # Tick / tock batch_time.update(time() - batch_end_time) batch_end_time = time() print('[INFO] %s [Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) EDLoss = %.4f RLoss = %.4f' % \ (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, batch_idx + 1, n_batches, \ batch_time.val, data_time.val, encoder_loss.item(), refiner_loss.item())) # Append epoch loss to TensorBoard train_writer.add_scalar('EncoderDecoder/EpochLoss', encoder_losses.avg, epoch_idx + 1) train_writer.add_scalar('Refiner/EpochLoss', refiner_losses.avg, epoch_idx + 1) # Tick / tock epoch_end_time = time() print('[INFO] %s Epoch [%d/%d] EpochTime = %.3f (s) EDLoss = %.4f RLoss = %.4f' % (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, epoch_end_time - epoch_start_time, \ encoder_losses.avg, refiner_losses.avg)) # Update Rendering Views if cfg.TRAIN.UPDATE_N_VIEWS_RENDERING: n_views_rendering = random.randint(1, cfg.CONST.N_VIEWS_RENDERING) train_data_loader.dataset.set_n_views_rendering(n_views_rendering) print('[INFO] %s Epoch [%d/%d] Update #RenderingViews to %d' % \ (dt.now(), epoch_idx + 2, cfg.TRAIN.NUM_EPOCHES, n_views_rendering)) # Validate the training models iou = test_net(cfg, epoch_idx + 1, output_dir, val_data_loader, val_writer, encoder, decoder, refiner, merger) # Save weights to file if (epoch_idx + 1) % cfg.TRAIN.SAVE_FREQ == 0: if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) utils.network_utils.save_checkpoints(cfg, \ os.path.join(ckpt_dir, 'ckpt-epoch-%04d.pth' % (epoch_idx + 1)), \ epoch_idx + 1, encoder, encoder_solver, decoder, decoder_solver, \ refiner, refiner_solver, merger, merger_solver, best_iou, best_epoch) if iou > best_iou: if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) best_iou = iou best_epoch = epoch_idx + 1 utils.network_utils.save_checkpoints(cfg, \ os.path.join(ckpt_dir, 'best-ckpt.pth'), \ epoch_idx + 1, encoder, encoder_solver, decoder, decoder_solver, \ refiner, refiner_solver, merger, merger_solver, best_iou, best_epoch) # Close SummaryWriter for TensorBoard train_writer.close() val_writer.close()
def test_single_img(cfg): encoder = Encoder(cfg) decoder = Decoder(cfg) refiner = Refiner(cfg) merger = Merger(cfg) cfg.CONST.WEIGHTS = 'D:/Pix2Vox/Pix2Vox/pretrained/Pix2Vox-A-ShapeNet.pth' checkpoint = torch.load(cfg.CONST.WEIGHTS, map_location=torch.device('cpu')) fix_checkpoint = {} fix_checkpoint['encoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['encoder_state_dict'].items()) fix_checkpoint['decoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['decoder_state_dict'].items()) fix_checkpoint['refiner_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['refiner_state_dict'].items()) fix_checkpoint['merger_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['merger_state_dict'].items()) epoch_idx = checkpoint['epoch_idx'] encoder.load_state_dict(fix_checkpoint['encoder_state_dict']) decoder.load_state_dict(fix_checkpoint['decoder_state_dict']) if cfg.NETWORK.USE_REFINER: print('Use refiner') refiner.load_state_dict(fix_checkpoint['refiner_state_dict']) if cfg.NETWORK.USE_MERGER: print('Use merger') merger.load_state_dict(fix_checkpoint['merger_state_dict']) encoder.eval() decoder.eval() refiner.eval() merger.eval() img1_path = 'D:/Pix2Vox/Pix2Vox/rand/minecraft.png' img1_np = cv2.imread(img1_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. sample = np.array([img1_np]) IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W test_transforms = utils.data_transforms.Compose([ utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.ToTensor(), ]) rendering_images = test_transforms(rendering_images=sample) rendering_images = rendering_images.unsqueeze(0) with torch.no_grad(): image_features = encoder(rendering_images) raw_features, generated_volume = decoder(image_features) if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER: generated_volume = merger(raw_features, generated_volume) else: generated_volume = torch.mean(generated_volume, dim=1) if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER: generated_volume = refiner(generated_volume) generated_volume = generated_volume.squeeze(0) img_dir = 'D:/Pix2Vox/Pix2Vox/output' gv = generated_volume.cpu().numpy() gv_new = np.swapaxes(gv, 2, 1) print(gv_new) rendering_views = utils.binvox_visualization.get_volume_views(gv_new, os.path.join(img_dir), epoch_idx)
def test_net(cfg, epoch_idx=-1, output_dir=None, test_data_loader=None, \ test_writer=None, encoder=None, decoder=None, refiner=None, merger=None): # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use torch.backends.cudnn.benchmark = True # Load taxonomies of dataset taxonomies = [] with open(cfg.DATASETS[cfg.DATASET.TEST_DATASET.upper()].TAXONOMY_FILE_PATH, encoding='utf-8') as file: taxonomies = json.loads(file.read()) taxonomies = {t['taxonomy_id']: t for t in taxonomies} # Set up data loader if test_data_loader is None: # Set up data augmentation IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W test_transforms = utils.data_transforms.Compose([ utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.ToTensor(), ]) dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg) test_data_loader = torch.utils.data.DataLoader( dataset=dataset_loader.get_dataset(utils.data_loaders.DatasetType.TEST, cfg.CONST.N_VIEWS_RENDERING, test_transforms), batch_size=1, num_workers=1, pin_memory=True, shuffle=False) # Set up networks if decoder is None or encoder is None: encoder = Encoder(cfg) decoder = Decoder(cfg) refiner = Refiner(cfg) merger = Merger(cfg) if torch.cuda.is_available(): encoder = torch.nn.DataParallel(encoder).cuda() decoder = torch.nn.DataParallel(decoder).cuda() refiner = torch.nn.DataParallel(refiner).cuda() merger = torch.nn.DataParallel(merger).cuda() print('[INFO] %s Loading weights from %s ...' % (dt.now(), cfg.CONST.WEIGHTS)) checkpoint = torch.load(cfg.CONST.WEIGHTS) epoch_idx = checkpoint['epoch_idx'] encoder.load_state_dict(checkpoint['encoder_state_dict']) decoder.load_state_dict(checkpoint['decoder_state_dict']) if cfg.NETWORK.USE_REFINER: refiner.load_state_dict(checkpoint['refiner_state_dict']) if cfg.NETWORK.USE_MERGER: merger.load_state_dict(checkpoint['merger_state_dict']) # Set up loss functions bce_loss = torch.nn.BCELoss() # Testing loop n_samples = len(test_data_loader) test_iou = dict() encoder_losses = utils.network_utils.AverageMeter() refiner_losses = utils.network_utils.AverageMeter() # Switch models to evaluation mode encoder.eval() decoder.eval() refiner.eval() merger.eval() for sample_idx, (taxonomy_id, sample_name, rendering_images, ground_truth_volume) in enumerate(test_data_loader): taxonomy_id = taxonomy_id[0] if isinstance(taxonomy_id[0], str) else taxonomy_id[0].item() sample_name = sample_name[0] with torch.no_grad(): # Get data from data loader rendering_images = utils.network_utils.var_or_cuda(rendering_images) ground_truth_volume = utils.network_utils.var_or_cuda(ground_truth_volume) # Test the encoder, decoder, refiner and merger image_features = encoder(rendering_images) raw_features, generated_volume = decoder(image_features) if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER: generated_volume = merger(raw_features, generated_volume) else: generated_volume = torch.mean(generated_volume, dim=1) encoder_loss = bce_loss(generated_volume, ground_truth_volume) * 10 if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER: generated_volume = refiner(generated_volume) refiner_loss = bce_loss(generated_volume, ground_truth_volume) * 10 else: refiner_loss = encoder_loss print("vox shape {}".format(generated_volume.shape)) # Append loss and accuracy to average metrics encoder_losses.update(encoder_loss.item()) refiner_losses.update(refiner_loss.item()) # IoU per sample sample_iou = [] for th in cfg.TEST.VOXEL_THRESH: _volume = torch.ge(generated_volume, th).float() intersection = torch.sum(_volume.mul(ground_truth_volume)).float() union = torch.sum(torch.ge(_volume.add(ground_truth_volume), 1)).float() sample_iou.append((intersection / union).item()) # IoU per taxonomy if not taxonomy_id in test_iou: test_iou[taxonomy_id] = {'n_samples': 0, 'iou': []} test_iou[taxonomy_id]['n_samples'] += 1 test_iou[taxonomy_id]['iou'].append(sample_iou) # Append generated volumes to TensorBoard if output_dir and sample_idx < 3: img_dir = output_dir % 'images' # Volume Visualization gv = generated_volume.cpu().numpy() rendering_views = utils.binvox_visualization.get_volume_views(gv, os.path.join(img_dir, 'test'), epoch_idx) if not test_writer is None: test_writer.add_image('Test Sample#%02d/Volume Reconstructed' % sample_idx, rendering_views, epoch_idx) gtv = ground_truth_volume.cpu().numpy() rendering_views = utils.binvox_visualization.get_volume_views(gtv, os.path.join(img_dir, 'test'), epoch_idx) if not test_writer is None: test_writer.add_image('Test Sample#%02d/Volume GroundTruth' % sample_idx, rendering_views, epoch_idx) # Print sample loss and IoU print('[INFO] %s Test[%d/%d] Taxonomy = %s Sample = %s EDLoss = %.4f RLoss = %.4f IoU = %s' % \ (dt.now(), sample_idx + 1, n_samples, taxonomy_id, sample_name, encoder_loss.item(), \ refiner_loss.item(), ['%.4f' % si for si in sample_iou])) # Output testing results mean_iou = [] for taxonomy_id in test_iou: test_iou[taxonomy_id]['iou'] = np.mean(test_iou[taxonomy_id]['iou'], axis=0) mean_iou.append(test_iou[taxonomy_id]['iou'] * test_iou[taxonomy_id]['n_samples']) mean_iou = np.sum(mean_iou, axis=0) / n_samples # Print header print('============================ TEST RESULTS ============================') print('Taxonomy', end='\t') print('#Sample', end='\t') print('Baseline', end='\t') for th in cfg.TEST.VOXEL_THRESH: print('t=%.2f' % th, end='\t') print() # Print body for taxonomy_id in test_iou: print('%s' % taxonomies[taxonomy_id]['taxonomy_name'].ljust(8), end='\t') print('%d' % test_iou[taxonomy_id]['n_samples'], end='\t') if 'baseline' in taxonomies[taxonomy_id]: print('%.4f' % taxonomies[taxonomy_id]['baseline']['%d-view' % cfg.CONST.N_VIEWS_RENDERING], end='\t\t') else: print('N/a', end='\t\t') for ti in test_iou[taxonomy_id]['iou']: print('%.4f' % ti, end='\t') print() # Print mean IoU for each threshold print('Overall ', end='\t\t\t\t') for mi in mean_iou: print('%.4f' % mi, end='\t') print('\n') # Add testing results to TensorBoard max_iou = np.max(mean_iou) if not test_writer is None: test_writer.add_scalar('EncoderDecoder/EpochLoss', encoder_losses.avg, epoch_idx) test_writer.add_scalar('Refiner/EpochLoss', refiner_losses.avg, epoch_idx) test_writer.add_scalar('Refiner/IoU', max_iou, epoch_idx) return max_iou
def train_net(cfg): # Set up data augmentation IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W train_transforms = utils.data_transforms.Compose([ utils.data_transforms.RandomCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground( cfg.TRAIN.RANDOM_BG_COLOR_RANGE), utils.data_transforms.ColorJitter(cfg.TRAIN.BRIGHTNESS, cfg.TRAIN.CONTRAST, cfg.TRAIN.SATURATION), utils.data_transforms.RandomNoise(cfg.TRAIN.NOISE_STD), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.RandomFlip(), utils.data_transforms.RandomPermuteRGB(), utils.data_transforms.ToTensor(), ]) val_transforms = utils.data_transforms.Compose([ utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.ToTensor(), ]) # Set up data loader train_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[ cfg.DATASET.TRAIN_DATASET](cfg) val_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[ cfg.DATASET.TEST_DATASET](cfg) train_data_loader = paddle.io.DataLoader( dataset=train_dataset_loader.get_dataset( utils.data_loaders.DatasetType.TRAIN, cfg.CONST.N_VIEWS_RENDERING, train_transforms), batch_size=cfg.CONST.BATCH_SIZE, #num_workers=0 , # cfg.TRAIN.NUM_WORKER>0时报错,因为dev/shm/太小 https://blog.csdn.net/ctypyb2002/article/details/107914643 #pin_memory=True, use_shared_memory=False, shuffle=True, drop_last=True) val_data_loader = paddle.io.DataLoader( dataset=val_dataset_loader.get_dataset( utils.data_loaders.DatasetType.VAL, cfg.CONST.N_VIEWS_RENDERING, val_transforms), batch_size=1, #num_workers=1, #pin_memory=True, shuffle=False) # Set up networks # paddle.Model prepare fit save encoder = Encoder(cfg) decoder = Decoder(cfg) merger = Merger(cfg) refiner = Refiner(cfg) print('[DEBUG] %s Parameters in Encoder: %d.' % (dt.now(), utils.network_utils.count_parameters(encoder))) print('[DEBUG] %s Parameters in Decoder: %d.' % (dt.now(), utils.network_utils.count_parameters(decoder))) print('[DEBUG] %s Parameters in Merger: %d.' % (dt.now(), utils.network_utils.count_parameters(merger))) print('[DEBUG] %s Parameters in Refiner: %d.' % (dt.now(), utils.network_utils.count_parameters(refiner))) # # Initialize weights of networks # paddle的参数化不同,参见API # encoder.apply(utils.network_utils.init_weights) # decoder.apply(utils.network_utils.init_weights) # merger.apply(utils.network_utils.init_weights) # Set up learning rate scheduler to decay learning rates dynamically encoder_lr_scheduler = paddle.optimizer.lr.MultiStepDecay( learning_rate=cfg.TRAIN.ENCODER_LEARNING_RATE, milestones=cfg.TRAIN.ENCODER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA, verbose=True) decoder_lr_scheduler = paddle.optimizer.lr.MultiStepDecay( learning_rate=cfg.TRAIN.DECODER_LEARNING_RATE, milestones=cfg.TRAIN.DECODER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA, verbose=True) merger_lr_scheduler = paddle.optimizer.lr.MultiStepDecay( learning_rate=cfg.TRAIN.MERGER_LEARNING_RATE, milestones=cfg.TRAIN.MERGER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA, verbose=True) refiner_lr_scheduler = paddle.optimizer.lr.MultiStepDecay( learning_rate=cfg.TRAIN.REFINER_LEARNING_RATE, milestones=cfg.TRAIN.REFINER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA, verbose=True) # Set up solver # if cfg.TRAIN.POLICY == 'adam': encoder_solver = paddle.optimizer.Adam(learning_rate=encoder_lr_scheduler, parameters=encoder.parameters()) decoder_solver = paddle.optimizer.Adam(learning_rate=decoder_lr_scheduler, parameters=decoder.parameters()) merger_solver = paddle.optimizer.Adam(learning_rate=merger_lr_scheduler, parameters=merger.parameters()) refiner_solver = paddle.optimizer.Adam(learning_rate=refiner_lr_scheduler, parameters=refiner.parameters()) # if torch.cuda.is_available(): # encoder = torch.nn.DataParallel(encoder).cuda() # decoder = torch.nn.DataParallel(decoder).cuda() # merger = torch.nn.DataParallel(merger).cuda() # Set up loss functions bce_loss = paddle.nn.BCELoss() # Load pretrained model if exists init_epoch = 0 best_iou = -1 best_epoch = -1 if 'WEIGHTS' in cfg.CONST and cfg.TRAIN.RESUME_TRAIN: print('[INFO] %s Recovering from %s ...' % (dt.now(), cfg.CONST.WEIGHTS)) # load encoder_state_dict = paddle.load( os.path.join(cfg.CONST.WEIGHTS, "encoder.pdparams")) encoder_solver_state_dict = paddle.load( os.path.join(cfg.CONST.WEIGHTS, "encoder_solver.pdopt")) encoder.set_state_dict(encoder_state_dict) encoder_solver.set_state_dict(encoder_solver_state_dict) decoder_state_dict = paddle.load( os.path.join(cfg.CONST.WEIGHTS, "decoder.pdparams")) decoder_solver_state_dict = paddle.load( os.path.join(cfg.CONST.WEIGHTS, "decoder_solver.pdopt")) decoder.set_state_dict(decoder_state_dict) decoder_solver.set_state_dict(decoder_solver_state_dict) if cfg.NETWORK.USE_MERGER: merger_state_dict = paddle.load( os.path.join(cfg.CONST.WEIGHTS, "merger.pdparams")) merger_solver_state_dict = paddle.load( os.path.join(cfg.CONST.WEIGHTS, "merger_solver.pdopt")) merger.set_state_dict(merger_state_dict) merger_solver.set_state_dict(merger_solver_state_dict) if cfg.NETWORK.USE_REFINER: refiner_state_dict = paddle.load( os.path.join(cfg.CONST.WEIGHTS, "refiner.pdparams")) refiner_solver_state_dict = paddle.load( os.path.join(cfg.CONST.WEIGHTS, "refiner_solver.pdopt")) refiner.set_state_dict(refiner_state_dict) refiner_solver.set_state_dict(refiner_solver_state_dict) print( '[INFO] %s Recover complete. Current epoch #%d, Best IoU = %.4f at epoch #%d.' % (dt.now(), init_epoch, best_iou, best_epoch)) # Summary writer for TensorBoard output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s', dt.now().isoformat()) log_dir = output_dir % 'logs' ckpt_dir = output_dir % 'checkpoints' # train_writer = SummaryWriter() # val_writer = SummaryWriter(os.path.join(log_dir, 'test')) train_writer = LogWriter(os.path.join(log_dir, 'train')) val_writer = LogWriter(os.path.join(log_dir, 'val')) # Training loop for epoch_idx in range(init_epoch, cfg.TRAIN.NUM_EPOCHES): # Tick / tock epoch_start_time = time() # Batch average meterics batch_time = utils.network_utils.AverageMeter() data_time = utils.network_utils.AverageMeter() encoder_losses = utils.network_utils.AverageMeter() refiner_losses = utils.network_utils.AverageMeter() # # switch models to training mode encoder.train() decoder.train() merger.train() refiner.train() batch_end_time = time() n_batches = len(train_data_loader) # print("****debug: length of train data loder",n_batches) for batch_idx, (rendering_images, ground_truth_volumes) in enumerate( train_data_loader()): # # debug # if batch_idx>1: # break # Measure data time data_time.update(time() - batch_end_time) # print("****debug: batch_idx",batch_idx) # print(rendering_images.shape) # print(ground_truth_volumes.shape) # Get data from data loader rendering_images = utils.network_utils.var_or_cuda( rendering_images) ground_truth_volumes = utils.network_utils.var_or_cuda( ground_truth_volumes) # Train the encoder, decoder, and merger image_features = encoder(rendering_images) raw_features, generated_volumes = decoder(image_features) if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER: generated_volumes = merger(raw_features, generated_volumes) # else: # mergered_volumes = paddle.mean(generated_volumes, aixs=1) encoder_loss = bce_loss(generated_volumes, ground_truth_volumes) * 10 if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER: generated_volumes = refiner(generated_volumes) refiner_loss = bce_loss(generated_volumes, ground_truth_volumes) * 10 # else: # refiner_loss = encoder_loss # Gradient decent encoder_solver.clear_grad() decoder_solver.clear_grad() merger_solver.clear_grad() refiner_solver.clear_grad() if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER: encoder_loss.backward(retain_graph=True) refiner_loss.backward() # else: # encoder_loss.backward() encoder_solver.step() decoder_solver.step() merger_solver.step() refiner_solver.step() # Append loss to average metrics encoder_losses.update(encoder_loss.numpy()) refiner_losses.update(refiner_loss.numpy()) # Append loss to TensorBoard n_itr = epoch_idx * n_batches + batch_idx train_writer.add_scalar(tag='EncoderDecoder/BatchLoss', step=n_itr, value=encoder_loss.numpy()) train_writer.add_scalar('Refiner/BatchLoss', value=refiner_loss.numpy(), step=n_itr) # Tick / tock batch_time.update(time() - batch_end_time) batch_end_time = time() if (batch_idx % int(cfg.CONST.INFO_BATCH)) == 0: print( '[INFO] %s [Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) EDLoss = %.4f RLoss = %.4f' % (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, batch_idx + 1, n_batches, batch_time.val, data_time.val, encoder_loss.numpy(), refiner_loss.numpy())) # Append epoch loss to TensorBoard train_writer.add_scalar(tag='EncoderDecoder/EpochLoss', step=epoch_idx + 1, value=encoder_losses.avg) train_writer.add_scalar('Refiner/EpochLoss', value=refiner_losses.avg, step=epoch_idx + 1) # update scheduler each step encoder_lr_scheduler.step() decoder_lr_scheduler.step() merger_lr_scheduler.step() refiner_lr_scheduler.step() # Tick / tock epoch_end_time = time() print( '[INFO] %s Epoch [%d/%d] EpochTime = %.3f (s) EDLoss = %.4f RLoss = %.4f' % (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, epoch_end_time - epoch_start_time, encoder_losses.avg, refiner_losses.avg)) # Update Rendering Views if cfg.TRAIN.UPDATE_N_VIEWS_RENDERING: n_views_rendering = random.randint(1, cfg.CONST.N_VIEWS_RENDERING) train_data_loader.dataset.set_n_views_rendering(n_views_rendering) print('[INFO] %s Epoch [%d/%d] Update #RenderingViews to %d' % (dt.now(), epoch_idx + 2, cfg.TRAIN.NUM_EPOCHES, n_views_rendering)) # Validate the training models iou = test_net(cfg, epoch_idx + 1, output_dir, val_data_loader, val_writer, encoder, decoder, merger, refiner) # Save weights to file if (epoch_idx + 1) % cfg.TRAIN.SAVE_FREQ == 0: if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) utils.network_utils.save_checkpoints( cfg, os.path.join(ckpt_dir, 'ckpt-epoch-%04d' % (epoch_idx + 1)), epoch_idx + 1, encoder, encoder_solver, decoder, decoder_solver, merger, merger_solver, refiner, refiner_solver, best_iou, best_epoch) if iou > best_iou: if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) best_iou = iou best_epoch = epoch_idx + 1 utils.network_utils.save_checkpoints( cfg, os.path.join(ckpt_dir, 'best-ckpt'), epoch_idx + 1, encoder, encoder_solver, decoder, decoder_solver, merger, merger_solver, refiner, refiner_solver, best_iou, best_epoch)
def test_single_img_net(cfg): encoder = Encoder(cfg) decoder = Decoder(cfg) refiner = Refiner(cfg) merger = Merger(cfg) print('[INFO] %s Loading weights from %s ...' % (dt.now(), cfg.CONST.WEIGHTS)) checkpoint = torch.load(cfg.CONST.WEIGHTS, map_location=torch.device('cpu')) fix_checkpoint = {} fix_checkpoint['encoder_state_dict'] = OrderedDict( (k.split('module.')[1:][0], v) for k, v in checkpoint['encoder_state_dict'].items()) fix_checkpoint['decoder_state_dict'] = OrderedDict( (k.split('module.')[1:][0], v) for k, v in checkpoint['decoder_state_dict'].items()) fix_checkpoint['refiner_state_dict'] = OrderedDict( (k.split('module.')[1:][0], v) for k, v in checkpoint['refiner_state_dict'].items()) fix_checkpoint['merger_state_dict'] = OrderedDict( (k.split('module.')[1:][0], v) for k, v in checkpoint['merger_state_dict'].items()) epoch_idx = checkpoint['epoch_idx'] encoder.load_state_dict(fix_checkpoint['encoder_state_dict']) decoder.load_state_dict(fix_checkpoint['decoder_state_dict']) if cfg.NETWORK.USE_REFINER: print('Use refiner') refiner.load_state_dict(fix_checkpoint['refiner_state_dict']) if cfg.NETWORK.USE_MERGER: print('Use merger') merger.load_state_dict(fix_checkpoint['merger_state_dict']) encoder.eval() decoder.eval() refiner.eval() merger.eval() img1_path = '/media/caig/FECA2C89CA2C406F/dataset/ShapeNetRendering_copy/03001627/1a74a83fa6d24b3cacd67ce2c72c02e/rendering/00.png' img1_np = cv2.imread(img1_path, cv2.IMREAD_UNCHANGED).astype( np.float32) / 255. sample = np.array([img1_np]) IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W test_transforms = utils.data_transforms.Compose([ utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.ToTensor(), ]) rendering_images = test_transforms(rendering_images=sample) rendering_images = rendering_images.unsqueeze(0) with torch.no_grad(): image_features = encoder(rendering_images) raw_features, generated_volume = decoder(image_features) if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER: generated_volume = merger(raw_features, generated_volume) else: generated_volume = torch.mean(generated_volume, dim=1) if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER: generated_volume = refiner(generated_volume) generated_volume = generated_volume.squeeze(0) img_dir = '/media/caig/FECA2C89CA2C406F/sketch3D/sketch3D/test_output' gv = generated_volume.cpu().numpy() gv_new = np.swapaxes(gv, 2, 1) rendering_views = utils.binvox_visualization.get_volume_views( gv_new, os.path.join(img_dir), epoch_idx)
def test_img(cfg): encoder = Encoder(cfg) decoder = Decoder(cfg) refiner = Refiner(cfg) merger = Merger(cfg) cfg.CONST.WEIGHTS = '/Users/pranavpomalapally/Downloads/new-Pix2Vox-A-ShapeNet.pth' checkpoint = torch.load(cfg.CONST.WEIGHTS, map_location=torch.device('cpu')) print() # fix_checkpoint = {} # fix_checkpoint['encoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['encoder_state_dict'].items()) # fix_checkpoint['decoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['decoder_state_dict'].items()) # fix_checkpoint['refiner_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['refiner_state_dict'].items()) # fix_checkpoint['merger_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['merger_state_dict'].items()) # fix_checkpoint['encoder_state_dict'] = OrderedDict((k.split('module.')[0], v) for k, v in checkpoint['encoder_state_dict'].items()) # fix_checkpoint['decoder_state_dict'] = OrderedDict((k.split('module.')[0], v) for k, v in checkpoint['decoder_state_dict'].items()) # fix_checkpoint['refiner_state_dict'] = OrderedDict((k.split('module.')[0], v) for k, v in checkpoint['refiner_state_dict'].items()) # fix_checkpoint['merger_state_dict'] = OrderedDict((k.split('module.')[0], v) for k, v in checkpoint['merger_state_dict'].items()) epoch_idx = checkpoint['epoch_idx'] # encoder.load_state_dict(fix_checkpoint['encoder_state_dict']) # decoder.load_state_dict(fix_checkpoint['decoder_state_dict']) encoder.load_state_dict(checkpoint['encoder_state_dict']) decoder.load_state_dict(checkpoint['decoder_state_dict']) # if cfg.NETWORK.USE_REFINER: # print('Use refiner') # refiner.load_state_dict(fix_checkpoint['refiner_state_dict']) print('Use refiner') refiner.load_state_dict(checkpoint['refiner_state_dict']) if cfg.NETWORK.USE_MERGER: print('Use merger') # merger.load_state_dict(fix_checkpoint['merger_state_dict']) merger.load_state_dict(checkpoint['merger_state_dict']) encoder.eval() decoder.eval() refiner.eval() merger.eval() #img1_path = '/Users/pranavpomalapally/Downloads/ShapeNetRendering/02691156/1a04e3eab45ca15dd86060f189eb133/rendering/00.png' img1_path = '/Users/pranavpomalapally/Downloads/09 copy.png' img1_np = cv2.imread(img1_path, cv2.IMREAD_UNCHANGED).astype( np.float32) / 255. sample = np.array([img1_np]) IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W test_transforms = utils.data_transforms.Compose([ utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.ToTensor(), ]) rendering_images = test_transforms(rendering_images=sample) rendering_images = rendering_images.unsqueeze(0) with torch.no_grad(): image_features = encoder(rendering_images) raw_features, generated_volume = decoder(image_features) if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER: generated_volume = merger(raw_features, generated_volume) else: generated_volume = torch.mean(generated_volume, dim=1) # if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER: # generated_volume = refiner(generated_volume) generated_volume = refiner(generated_volume) generated_volume = generated_volume.squeeze(0) img_dir = '/Users/pranavpomalapally/Downloads/outputs' # gv = generated_volume.cpu().numpy() gv = generated_volume.cpu().detach().numpy() gv_new = np.swapaxes(gv, 2, 1) os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' rendering_views = utils.binvox_visualization.get_volume_views( gv_new, img_dir, epoch_idx)
def test_net(cfg, epoch_idx=-1, output_dir=None, test_data_loader=None, test_writer=None, encoder=None, decoder=None, merger=None, refiner=None): # Load taxonomies of dataset taxonomies = [] with open( cfg.DATASETS[cfg.DATASET.TEST_DATASET.upper()].TAXONOMY_FILE_PATH, encoding='utf-8') as file: taxonomies = json.loads(file.read()) taxonomies = {t['taxonomy_id']: t for t in taxonomies} # # Set up data loader if test_data_loader is None: # Set up data augmentation IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W test_transforms = utils.data_transforms.Compose([ utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground( cfg.TEST.RANDOM_BG_COLOR_RANGE), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.ToTensor(), ]) dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[ cfg.DATASET.TEST_DATASET](cfg) test_data_loader = paddle.io.DataLoader( dataset=dataset_loader.get_dataset( utils.data_loaders.DatasetType.TEST, cfg.CONST.N_VIEWS_RENDERING, test_transforms), batch_size=1, # num_workers=1, shuffle=False) mode = 'test' else: mode = 'val' # paddle.io.Dataset not support 'str' input dataset_taxonomy = None rendering_image_path_template = cfg.DATASETS.SHAPENET.RENDERING_PATH volume_path_template = cfg.DATASETS.SHAPENET.VOXEL_PATH # Load all taxonomies of the dataset with open('./datasets/ShapeNet.json', encoding='utf-8') as file: dataset_taxonomy = json.loads(file.read()) # print("[INFO]TEST-- open TAXONOMY_FILE_PATH succeess") all_test_taxonomy_id_and_sample_name = [] # Load data for each category for taxonomy in dataset_taxonomy: taxonomy_folder_name = taxonomy['taxonomy_id'] # print('[INFO] %set -- Collecting files of Taxonomy[ID=%s, Name=%s]' % # (mode, taxonomy['taxonomy_id'], taxonomy['taxonomy_name'])) samples = taxonomy[mode] for sample in samples: all_test_taxonomy_id_and_sample_name.append( [taxonomy_folder_name, sample]) # print(len(all_test_taxonomy_id_and_sample_name)) # print(all_test_taxonomy_id_and_sample_name) print('[INFO] Collected files of %set' % (mode)) # Set up networks if decoder is None or encoder is None: encoder = Encoder(cfg) decoder = Decoder(cfg) merger = Merger(cfg) refiner = Refiner(cfg) # if torch.cuda.is_available(): # encoder = paddle.DataParallel(encoder) # decoder = paddle.DataParallel(decoder) # merger = paddle.DataParallel(merger) print('[INFO] %s Loading weights from %s ...' % (dt.now(), cfg.CONST.WEIGHTS)) encoder_state_dict = paddle.load( os.path.join(cfg.CONST.WEIGHTS, "encoder.pdparams")) # encoder_solver_state_dict = paddle.load(os.path.join(cfg.CONST.WEIGHTS, "encoder_solver.pdopt")) encoder.set_state_dict(encoder_state_dict) # encoder_solver.set_state_dict(encoder_solver_state_dict) decoder_state_dict = paddle.load( os.path.join(cfg.CONST.WEIGHTS, "decoder.pdparams")) # decoder_solver_state_dict = paddle.load(os.path.join(cfg.CONST.WEIGHTS, "decoder_solver.pdopt")) decoder.set_state_dict(decoder_state_dict) # decoder_solver.set_state_dict(decoder_solver_state_dict) refiner_state_dict = paddle.load( os.path.join(cfg.CONST.WEIGHTS, "refiner.pdparams")) refiner.set_state_dict(refiner_state_dict) if cfg.NETWORK.USE_MERGER: merger_state_dict = paddle.load( os.path.join(cfg.CONST.WEIGHTS, "merger.pdparams")) # merger_solver_state_dict = paddle.load(os.path.join(cfg.CONST.WEIGHTS, "merger_solver.pdopt")) merger.set_state_dict(merger_state_dict) # merger_solver.set_state_dict(merger_solver_state_dict) # Set up loss functions bce_loss = paddle.nn.BCELoss() # Testing loop n_samples = len(test_data_loader) test_iou = dict() encoder_losses = utils.network_utils.AverageMeter() refiner_losses = utils.network_utils.AverageMeter() # Switch models to evaluation mode encoder.eval() decoder.eval() merger.eval() refiner.eval() for sample_idx, (rendering_images, ground_truth_volume) in enumerate(test_data_loader): taxonomy_id = all_test_taxonomy_id_and_sample_name[sample_idx][0] sample_name = all_test_taxonomy_id_and_sample_name[sample_idx][1] # print("all_test_taxonomy_id_and_sample_name") # print(taxonomy_id) # print(sample_name) with paddle.no_grad(): # Get data from data loader rendering_images = utils.network_utils.var_or_cuda( rendering_images) ground_truth_volume = utils.network_utils.var_or_cuda( ground_truth_volume) # Test the encoder, decoder and merger image_features = encoder(rendering_images) raw_features, generated_volume = decoder(image_features) if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER: generated_volume = merger(raw_features, generated_volume) else: generated_volume = paddle.mean(generated_volume, axis=1) encoder_loss = bce_loss(generated_volume, ground_truth_volume) * 10 if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER: generated_volume = refiner(generated_volume) refiner_loss = bce_loss(generated_volume, ground_truth_volume) * 10 else: refiner_loss = encoder_loss # Append loss and accuracy to average metrics encoder_losses.update(encoder_loss.numpy()) refiner_losses.update(refiner_loss.numpy()) # IoU per sample sample_iou = [] for th in cfg.TEST.VOXEL_THRESH: # _volume = torch.ge(generated_volume, th).float() # intersection = torch.sum(_volume.mul(ground_truth_volume)).float() # union = torch.sum(torch.ge(_volume.add(ground_truth_volume), 1)).float() # print("#################") _volume = paddle.greater_equal( generated_volume, paddle.to_tensor(th)).astype("float32") # print(_volume) # print("@@@@@@@") # print(ground_truth_volume) intersection = paddle.sum( paddle.multiply(_volume, ground_truth_volume)) # print(paddle.greater_equal(paddle.add(_volume, ground_truth_volume).astype("float32"), paddle.to_tensor(1., dtype='float32')).astype("float32")) union = paddle.sum( paddle.greater_equal( paddle.add(_volume, ground_truth_volume).astype("float32"), paddle.to_tensor(1., dtype='float32')).astype("float32")) # print(union) sample_iou.append((intersection / union)) # IoU per taxonomy if taxonomy_id not in test_iou: test_iou[taxonomy_id] = {'n_samples': 0, 'iou': []} test_iou[taxonomy_id]['n_samples'] += 1 test_iou[taxonomy_id]['iou'].append(sample_iou) # Append generated volumes to TensorBoard if output_dir and sample_idx < 1: img_dir = output_dir % 'images' # Volume Visualization gv = generated_volume.cpu().numpy() rendering_views = utils.binvox_visualization.get_volume_views( gv, os.path.join(img_dir, 'Reconstructed'), epoch_idx) test_writer.add_image(tag='Reconstructed', img=rendering_views, step=epoch_idx) gtv = ground_truth_volume.cpu().numpy() rendering_views = utils.binvox_visualization.get_volume_views( gtv, os.path.join(img_dir, 'GroundTruth'), epoch_idx) test_writer.add_image(tag='GroundTruth', img=rendering_views, step=epoch_idx) # # Print sample loss and IoU # print('[INFO] %s Test[%d/%d] Taxonomy = %s Sample = %s EDLoss = %.4f RLoss = %.4f IoU = %s' % # (dt.now(), sample_idx + 1, n_samples, taxonomy_id, sample_name, encoder_loss, refiner_loss, # ['%.4f' % si for si in sample_iou])) # Output testing results mean_iou = [] for taxonomy_id in test_iou: test_iou[taxonomy_id]['iou'] = np.mean(test_iou[taxonomy_id]['iou'], axis=0) mean_iou.append(test_iou[taxonomy_id]['iou'] * test_iou[taxonomy_id]['n_samples']) mean_iou = np.sum(mean_iou, axis=0) / n_samples # Print header print( '============================ TEST RESULTS ============================' ) print('Taxonomy', end='\t') print('#Sample', end='\t') print('Baseline', end='\t') for th in cfg.TEST.VOXEL_THRESH: print('t=%.2f' % th, end='\t') print() # Print body for taxonomy_id in test_iou: print('%s' % taxonomies[taxonomy_id]['taxonomy_name'].ljust(8), end='\t') print('%d' % test_iou[taxonomy_id]['n_samples'], end='\t') if 'baseline' in taxonomies[taxonomy_id]: print('%.4f' % taxonomies[taxonomy_id]['baseline'][ '%d-view' % cfg.CONST.N_VIEWS_RENDERING], end='\t\t') else: print('N/a', end='\t\t') for ti in test_iou[taxonomy_id]['iou']: print('%.4f' % ti, end='\t') print() # Print mean IoU for each threshold print('Overall ', end='\t\t\t\t') for mi in mean_iou: print('%.4f' % mi, end='\t') print('\n') # Add testing results to TensorBoard max_iou = np.max(mean_iou) if test_writer is not None: test_writer.add_scalar(tag='EncoderDecoder/EpochLoss', value=encoder_losses.avg, step=epoch_idx) test_writer.add_scalar(tag='Refiner/EpochLoss', value=refiner_losses.avg, step=epoch_idx) test_writer.add_scalar(tag='Refiner/IoU', value=max_iou, step=epoch_idx) return max_iou
def test_net(cfg, epoch_idx=-1, output_dir=None, test_data_loader=None, \ test_writer=None, encoder=None, decoder=None, refiner=None, merger=None): # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use torch.backends.cudnn.benchmark = True # Load taxonomies of dataset taxonomies = [] with open( cfg.DATASETS[cfg.DATASET.TEST_DATASET.upper()].TAXONOMY_FILE_PATH, encoding='utf-8') as file: taxonomies = json.loads(file.read()) taxonomies = {t['taxonomy_id']: t for t in taxonomies} # Set up data loader if test_data_loader is None: # Set up data augmentation IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W test_transforms = utils.data_transforms.Compose([ utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground( cfg.TEST.RANDOM_BG_COLOR_RANGE), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.ToTensor(), ]) dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[ cfg.DATASET.TEST_DATASET](cfg) test_data_loader = torch.utils.data.DataLoader( dataset=dataset_loader.get_dataset( utils.data_loaders.DatasetType.TEST, cfg.CONST.N_VIEWS_RENDERING, test_transforms), batch_size=1, num_workers=1, pin_memory=True, shuffle=False) # Set up networks if decoder is None or encoder is None: encoder = Encoder(cfg) decoder = Decoder(cfg) refiner = Refiner(cfg) merger = Merger(cfg) if torch.cuda.is_available(): encoder = torch.nn.DataParallel(encoder).cuda() decoder = torch.nn.DataParallel(decoder).cuda() refiner = torch.nn.DataParallel(refiner).cuda() merger = torch.nn.DataParallel(merger).cuda() print('[INFO] %s Loading weights from %s ...' % (dt.now(), cfg.CONST.WEIGHTS)) if torch.cuda.is_available(): checkpoint = torch.load(cfg.CONST.WEIGHTS) else: map_location = torch.device('cpu') checkpoint = torch.load(cfg.CONST.WEIGHTS, map_location=map_location) epoch_idx = checkpoint['epoch_idx'] print('Epoch ID of the current model is {}'.format(epoch_idx)) encoder.load_state_dict(checkpoint['encoder_state_dict']) decoder.load_state_dict(checkpoint['decoder_state_dict']) if cfg.NETWORK.USE_REFINER: refiner.load_state_dict(checkpoint['refiner_state_dict']) if cfg.NETWORK.USE_MERGER: merger.load_state_dict(checkpoint['merger_state_dict']) # Set up loss functions bce_loss = torch.nn.BCELoss() # Testing loop n_samples = len(test_data_loader) test_iou = dict() encoder_losses = utils.network_utils.AverageMeter() refiner_losses = utils.network_utils.AverageMeter() # Switch models to evaluation mode encoder.eval() decoder.eval() refiner.eval() merger.eval() print("test data loader type is {}".format(type(test_data_loader))) for sample_idx, (taxonomy_id, sample_name, rendering_images) in enumerate(test_data_loader): taxonomy_id = taxonomy_id[0] if isinstance( taxonomy_id[0], str) else taxonomy_id[0].item() sample_name = sample_name[0] print("sample IDx {}".format(sample_idx)) print("taxonomy id {}".format(taxonomy_id)) with torch.no_grad(): # Get data from data loader rendering_images = utils.network_utils.var_or_cuda( rendering_images) print("Shape of the loaded images {}".format( rendering_images.shape)) # Test the encoder, decoder, refiner and merger image_features = encoder(rendering_images) raw_features, generated_volume = decoder(image_features) if cfg.NETWORK.USE_MERGER: generated_volume = merger(raw_features, generated_volume) else: generated_volume = torch.mean(generated_volume, dim=1) if cfg.NETWORK.USE_REFINER: generated_volume = refiner(generated_volume) print("vox shape {}".format(generated_volume.shape)) gv = generated_volume.cpu().numpy() rendering_views = utils.binvox_visualization.get_volume_views( gv, os.path.join('./LargeDatasets/inference_images/', 'inference'), sample_idx) print("gv shape is {}".format(gv.shape)) return gv, rendering_images
def test_net(cfg, model_type, dataset_type, results_file_name, epoch_idx=-1, test_data_loader=None, test_writer=None, encoder=None, decoder=None, refiner=None, merger=None, save_results_to_file=False, show_voxels=False, path_to_times_csv=None): if model_type == Pix2VoxTypes.Pix2Vox_A or model_type == Pix2VoxTypes.Pix2Vox_Plus_Plus_A: use_refiner = True else: use_refiner = False # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use torch.backends.cudnn.benchmark = True # Set up data loader if test_data_loader is None: # Set up data augmentation IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W test_transforms = utils.data_transforms.Compose([ utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground( cfg.TEST.RANDOM_BG_COLOR_RANGE), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.ToTensor(), ]) dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[ cfg.DATASET.TEST_DATASET](cfg) test_data_loader = torch.utils.data.DataLoader( dataset=dataset_loader.get_dataset(dataset_type, cfg.CONST.N_VIEWS_RENDERING, test_transforms), batch_size=1, num_workers=cfg.CONST.NUM_WORKER, pin_memory=True, shuffle=False) # Set up networks if decoder is None or encoder is None: encoder = Encoder(cfg, model_type) decoder = Decoder(cfg, model_type) if use_refiner: refiner = Refiner(cfg) merger = Merger(cfg, model_type) if torch.cuda.is_available(): encoder = torch.nn.DataParallel(encoder).cuda() decoder = torch.nn.DataParallel(decoder).cuda() if use_refiner: refiner = torch.nn.DataParallel(refiner).cuda() merger = torch.nn.DataParallel(merger).cuda() logging.info('Loading weights from %s ...' % (cfg.CONST.WEIGHTS)) checkpoint = torch.load(cfg.CONST.WEIGHTS) epoch_idx = checkpoint['epoch_idx'] encoder.load_state_dict(checkpoint['encoder_state_dict']) decoder.load_state_dict(checkpoint['decoder_state_dict']) if use_refiner: refiner.load_state_dict(checkpoint['refiner_state_dict']) if cfg.NETWORK.USE_MERGER: merger.load_state_dict(checkpoint['merger_state_dict']) # Set up loss functions bce_loss = torch.nn.BCELoss() # Testing loop n_samples = len(test_data_loader) test_iou = dict() encoder_losses = AverageMeter() if use_refiner: refiner_losses = AverageMeter() # Switch models to evaluation mode encoder.eval() decoder.eval() if use_refiner: refiner.eval() merger.eval() samples_names = [] edlosses = [] rlosses = [] ious_dict = {} for iou_threshold in cfg.TEST.VOXEL_THRESH: ious_dict[iou_threshold] = [] if path_to_times_csv is not None: n_view_list = [] times_list = [] for sample_idx, (taxonomy_id, sample_name, rendering_images, ground_truth_volume) in enumerate(test_data_loader): taxonomy_id = taxonomy_id[0] if isinstance( taxonomy_id[0], str) else taxonomy_id[0].item() sample_name = sample_name[0] with torch.no_grad(): # Get data from data loader rendering_images = utils.helpers.var_or_cuda(rendering_images) ground_truth_volume = utils.helpers.var_or_cuda( ground_truth_volume) if path_to_times_csv is not None: start_time = time.time() # Test the encoder, decoder, refiner and merger image_features = encoder(rendering_images) raw_features, generated_volume = decoder(image_features) if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER: generated_volume = merger(raw_features, generated_volume) else: generated_volume = torch.mean(generated_volume, dim=1) encoder_loss = bce_loss(generated_volume, ground_truth_volume) * 10 if use_refiner and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER: generated_volume = refiner(generated_volume) refiner_loss = bce_loss(generated_volume, ground_truth_volume) * 10 else: refiner_loss = encoder_loss if path_to_times_csv is not None: end_time = time.time() n_view_list.append(rendering_images.size()[1]) times_list.append(end_time - start_time) # Append loss and accuracy to average metrics encoder_losses.update(encoder_loss.item()) if use_refiner: refiner_losses.update(refiner_loss.item()) # IoU per sample sample_iou = [] for th in cfg.TEST.VOXEL_THRESH: _volume = torch.ge(generated_volume, th).float() intersection = torch.sum( _volume.mul(ground_truth_volume)).float() union = torch.sum(torch.ge(_volume.add(ground_truth_volume), 1)).float() sample_iou.append((intersection / union).item()) ious_dict[th].append((intersection / union).item()) # IoU per taxonomy if taxonomy_id not in test_iou: test_iou[taxonomy_id] = {'n_samples': 0, 'iou': []} test_iou[taxonomy_id]['n_samples'] += 1 test_iou[taxonomy_id]['iou'].append(sample_iou) # Append generated volumes to TensorBoard if show_voxels: with open("model.binvox", "wb") as f: v = br.Voxels( torch.ge(generated_volume, 0.2).float().cpu().numpy()[0], (32, 32, 32), (0, 0, 0), 1, "xyz") v.write(f) subprocess.run([VIEWVOX_EXE, "model.binvox"]) with open("model.binvox", "wb") as f: v = br.Voxels(ground_truth_volume.cpu().numpy()[0], (32, 32, 32), (0, 0, 0), 1, "xyz") v.write(f) subprocess.run([VIEWVOX_EXE, "model.binvox"]) # Print sample loss and IoU logging.info( 'Test[%d/%d] Taxonomy = %s Sample = %s EDLoss = %.4f RLoss = %.4f IoU = %s' % (sample_idx + 1, n_samples, taxonomy_id, sample_name, encoder_loss.item(), refiner_loss.item(), ['%.4f' % si for si in sample_iou])) samples_names.append(sample_name) edlosses.append(encoder_loss.item()) if use_refiner: rlosses.append(refiner_loss.item()) if save_results_to_file: save_test_results_to_csv(samples_names, edlosses, rlosses, ious_dict, path_to_csv=results_file_name) if path_to_times_csv is not None: save_times_to_csv(times_list, n_view_list, path_to_csv=path_to_times_csv) # Output testing results mean_iou = [] for taxonomy_id in test_iou: test_iou[taxonomy_id]['iou'] = np.mean(test_iou[taxonomy_id]['iou'], axis=0) mean_iou.append(test_iou[taxonomy_id]['iou'] * test_iou[taxonomy_id]['n_samples']) mean_iou = np.sum(mean_iou, axis=0) / n_samples # Print header print( '============================ TEST RESULTS ============================' ) print('Taxonomy', end='\t') print('#Sample', end='\t') print('Baseline', end='\t') for th in cfg.TEST.VOXEL_THRESH: print('t=%.2f' % th, end='\t') print() # Print mean IoU for each threshold print('Overall ', end='\t\t\t\t') for mi in mean_iou: print('%.4f' % mi, end='\t') print('\n') # Add testing results to TensorBoard max_iou = np.max(mean_iou) if test_writer is not None: test_writer.add_scalar('EncoderDecoder/EpochLoss', encoder_losses.avg, epoch_idx) if use_refiner: test_writer.add_scalar('Refiner/EpochLoss', refiner_losses.avg, epoch_idx) test_writer.add_scalar('Refiner/IoU', max_iou, epoch_idx) return max_iou
class Model(pl.LightningModule): def __init__(self, cfg_network: DictConfig, cfg_tester: DictConfig): super().__init__() self.cfg_network = cfg_network self.cfg_tester = cfg_tester # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use torch.backends.cudnn.benchmark = True # Set up networks self.encoder = Encoder(cfg_network) self.decoder = Decoder(cfg_network) self.refiner = Refiner(cfg_network) self.merger = Merger(cfg_network) # Initialize weights of networks self.encoder.apply(utils.network_utils.init_weights) self.decoder.apply(utils.network_utils.init_weights) self.refiner.apply(utils.network_utils.init_weights) self.merger.apply(utils.network_utils.init_weights) self.bce_loss = nn.BCELoss() def configure_optimizers(self): params = self.cfg_network.optimization # Set up solver if params.policy == 'adam': encoder_solver = optim.Adam(filter(lambda p: p.requires_grad, self.encoder.parameters()), lr=params.encoder_lr, betas=params.betas) decoder_solver = optim.Adam(self.decoder.parameters(), lr=params.decoder_lr, betas=params.betas) refiner_solver = optim.Adam(self.refiner.parameters(), lr=params.refiner_lr, betas=params.betas) merger_solver = optim.Adam(self.merger.parameters(), lr=params.merger_lr, betas=params.betas) elif params.policy == 'sgd': encoder_solver = optim.SGD(filter(lambda p: p.requires_grad, self.encoder.parameters()), lr=params.encoder_lr, momentum=params.momentum) decoder_solver = optim.SGD(self.decoder.parameters(), lr=params.decoder_lr, momentum=params.momentum) refiner_solver = optim.SGD(self.refiner.parameters(), lr=params.refiner_lr, momentum=params.momentum) merger_solver = optim.SGD(self.merger.parameters(), lr=params.merger_lr, momentum=params.momentum) else: raise Exception('[FATAL] %s Unknown optimizer %s.' % (dt.now(), params.policy)) # Set up learning rate scheduler to decay learning rates dynamically encoder_lr_scheduler = optim.lr_scheduler.MultiStepLR(encoder_solver, milestones=params.encoder_lr_milestones, gamma=params.gamma) decoder_lr_scheduler = optim.lr_scheduler.MultiStepLR(decoder_solver, milestones=params.decoder_lr_milestones, gamma=params.gamma) refiner_lr_scheduler = optim.lr_scheduler.MultiStepLR(refiner_solver, milestones=params.refiner_lr_milestones, gamma=params.gamma) merger_lr_scheduler = optim.lr_scheduler.MultiStepLR(merger_solver, milestones=params.merger_lr_milestones, gamma=params.gamma) return [encoder_solver, decoder_solver, refiner_solver, merger_solver], \ [encoder_lr_scheduler, decoder_lr_scheduler, refiner_lr_scheduler, merger_lr_scheduler] def _fwd(self, batch): taxonomy_names, sample_names, rendering_images, ground_truth_volumes = batch image_features = self.encoder(rendering_images) raw_features, generated_volumes = self.decoder(image_features) if self.cfg_network.use_merger and self.current_epoch >= self.cfg_network.optimization.epoch_start_use_merger: generated_volumes = self.merger(raw_features, generated_volumes) else: generated_volumes = torch.mean(generated_volumes, dim=1) encoder_loss = self.bce_loss(generated_volumes, ground_truth_volumes) * 10 if self.cfg_network.use_refiner and self.current_epoch >= self.cfg_network.optimization.epoch_start_use_refiner: generated_volumes = self.refiner(generated_volumes) refiner_loss = self.bce_loss(generated_volumes, ground_truth_volumes) * 10 else: refiner_loss = encoder_loss return generated_volumes, encoder_loss, refiner_loss def training_step(self, batch, batch_idx, optimizer_idx): (opt_enc, opt_dec, opt_ref, opt_merg) = self.optimizers() generated_volumes, encoder_loss, refiner_loss = self._fwd(batch) self.log('loss/EncoderDecoder', encoder_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log('loss/Refiner', refiner_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) if self.cfg_network.use_refiner and self.current_epoch >= self.cfg_network.optimization.epoch_start_use_refiner: self.manual_backward(encoder_loss, opt_enc, retain_graph=True) self.manual_backward(refiner_loss, opt_ref) else: self.manual_backward(encoder_loss, opt_enc) for opt in self.optimizers(): opt.step() opt.zero_grad() def training_epoch_end(self, outputs) -> None: # Update Rendering Views if self.cfg_network.update_n_views_rendering: n_views_rendering = self.trainer.datamodule.update_n_views_rendering() print('[INFO] %s Epoch [%d/%d] Update #RenderingViews to %d' % (dt.now(), self.current_epoch + 2, self.trainer.max_epochs, n_views_rendering)) def _eval_step(self, batch, batch_idx): # SUPPORTS ONLY BATCH_SIZE=1 taxonomy_names, sample_names, rendering_images, ground_truth_volumes = batch taxonomy_id = taxonomy_names[0] sample_name = sample_names[0] generated_volumes, encoder_loss, refiner_loss = self._fwd(batch) self.log('val_loss/EncoderDecoder', encoder_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log('val_loss/Refiner', refiner_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) # IoU per sample sample_iou = [] for th in self.cfg_tester.voxel_thresh: _volume = torch.ge(generated_volumes, th).float() intersection = torch.sum(_volume.mul(ground_truth_volumes)).float() union = torch.sum( torch.ge(_volume.add(ground_truth_volumes), 1)).float() sample_iou.append((intersection / union).item()) # Print sample loss and IoU n_samples = -1 print('\n[INFO] %s Test[%d/%d] Taxonomy = %s Sample = %s EDLoss = %.4f RLoss = %.4f IoU = %s' % (dt.now(), batch_idx + 1, n_samples, taxonomy_id, sample_name, encoder_loss.item(), refiner_loss.item(), ['%.4f' % si for si in sample_iou])) return { 'taxonomy_id': taxonomy_id, 'sample_name': sample_name, 'sample_iou': sample_iou } def _eval_epoch_end(self, outputs): # Load taxonomies of dataset taxonomies = [] taxonomy_path = self.trainer.datamodule.get_test_taxonomy_file_path() with open(taxonomy_path, encoding='utf-8') as file: taxonomies = json.loads(file.read()) taxonomies = {t['taxonomy_id']: t for t in taxonomies} test_iou = {} for output in outputs: taxonomy_id, sample_name, sample_iou = output[ 'taxonomy_id'], output['sample_name'], output['sample_iou'] if taxonomy_id not in test_iou: test_iou[taxonomy_id] = {'n_samples': 0, 'iou': []} test_iou[taxonomy_id]['n_samples'] += 1 test_iou[taxonomy_id]['iou'].append(sample_iou) mean_iou = [] for taxonomy_id in test_iou: test_iou[taxonomy_id]['iou'] = torch.mean( torch.tensor(test_iou[taxonomy_id]['iou']), dim=0) mean_iou.append(test_iou[taxonomy_id]['iou'] * test_iou[taxonomy_id]['n_samples']) n_samples = len(outputs) mean_iou = torch.stack(mean_iou) mean_iou = torch.sum(mean_iou, dim=0) / n_samples # Print header print('============================ TEST RESULTS ============================') print('Taxonomy', end='\t') print('#Sample', end='\t') print(' Baseline', end='\t') for th in self.cfg_tester.voxel_thresh: print('t=%.2f' % th, end='\t') print() # Print body for taxonomy_id in test_iou: print('%s' % taxonomies[taxonomy_id] ['taxonomy_name'].ljust(8), end='\t') print('%d' % test_iou[taxonomy_id]['n_samples'], end='\t') if 'baseline' in taxonomies[taxonomy_id]: n_views_rendering = self.trainer.datamodule.get_n_views_rendering() print('%.4f' % taxonomies[taxonomy_id]['baseline'] ['%d-view' % n_views_rendering], end='\t\t') else: print('N/a', end='\t\t') for ti in test_iou[taxonomy_id]['iou']: print('%.4f' % ti, end='\t') print() # Print mean IoU for each threshold print('Overall ', end='\t\t\t\t') for mi in mean_iou: print('%.4f' % mi, end='\t') print('\n') max_iou = torch.max(mean_iou) self.log('Refiner/IoU', max_iou, prog_bar=True, on_epoch=True) def validation_step(self, batch, batch_idx): return self._eval_step(batch, batch_idx) def validation_epoch_end(self, outputs): self._eval_epoch_end(outputs) def test_step(self, batch, batch_idx): return self._eval_step(batch, batch_idx) def test_epoch_end(self, outputs): self._eval_epoch_end(outputs) def get_progress_bar_dict(self): # don't show the loss as it's None items = super().get_progress_bar_dict() items.pop("loss", None) return items