def main(): train_path = Path(args.snapshot_dir, 'train') validation_path = Path(args.snapshot_dir, 'validation') meta_path = Path(args.snapshot_dir, 'meta.json') with meta_path.open('r') as f: meta_dict = json.load(f) print(f"Opening LMDB datasets.") train_dataset = rendering_dataset.MaterialRendDataset( train_path, meta_dict, shape=(500, 500), image_transform=transforms.train_image_transform(INPUT_SIZE), mask_transform=transforms.train_mask_transform(INPUT_SIZE)) train_loader = DataLoader(train_dataset, batch_size=1, num_workers=1, shuffle=True, pin_memory=False, collate_fn=rendering_dataset.collate_fn) print(f"Woohoo") for batch in train_loader: vis.image(visualize_input(batch))
def main(): train_path = Path(args.snapshot_dir, 'train') validation_path = Path(args.snapshot_dir, 'validation') meta_path = Path(args.snapshot_dir, 'meta.json') with meta_path.open('r') as f: meta_dict = json.load(f) env = lmdb.open(train_path, readonly=False) print(f"Woohoo") for batch in train_loader: vis.image(visualize_input(batch))
def compute_topk(label_to_mat_id, model: RendNet3, image, seg_mask): if image.dtype != np.uint8: image = (image * 255).astype(np.uint8) seg_mask = skimage.transform.resize(seg_mask, (224, 224), order=0, anti_aliasing=False, mode='reflect') seg_mask = seg_mask[:, :, np.newaxis].astype(dtype=np.uint8) * 255 image_tensor = transforms.inference_image_transform(224)(image) mask_tensor = transforms.inference_mask_transform(224)(seg_mask) input_tensor = torch.cat((image_tensor, mask_tensor), dim=0).unsqueeze(0) input_vis = utils.visualize_input({'image': input_tensor}) vis.image(input_vis, win='input') output = model.forward(input_tensor.cuda()) topk_mat_scores, topk_mat_labels = torch.topk(F.softmax(output['material'], dim=1), k=10) topk_dict = { 'material': [{ 'score': score, 'id': label_to_mat_id[int(label)], } for score, label in zip(topk_mat_scores.squeeze().tolist(), topk_mat_labels.squeeze().tolist())] } if 'substance' in output: topk_subst_scores, topk_subst_labels = torch.topk( F.softmax(output['substance'].cpu(), dim=1), k=output['substance'].size(1)) topk_dict['substance'] = \ [ { 'score': score, 'id': label, 'name': SUBSTANCES[int(label)], } for score, label in zip(topk_subst_scores.squeeze().tolist(), topk_subst_labels.squeeze().tolist()) ] if 'roughness' in output: nrc = model.num_roughness_classes roughness_midpoints = np.linspace(1 / nrc / 2, 1 - 1 / nrc / 2, nrc) topk_roughness_scores, topk_roughness_labels = torch.topk(F.softmax( output['roughness'].cpu(), dim=1), k=5) topk_dict['roughness'] = \ [ { 'score': score, 'value': roughness_midpoints[int(label)], } for score, label in zip(topk_roughness_scores.squeeze().tolist(), topk_roughness_labels.squeeze().tolist()) ] return topk_dict
def validate_epoch(val_loader, model, epoch, mat_criterion, subst_criterion, color_criterion, subst_mat_labels, loss_variances, loss_weights): meter_dict = defaultdict(AverageValueMeter) # switch to evaluate mode model.eval() pbar = tqdm(total=len(val_loader), desc='Validating Epoch', dynamic_ncols=True) last_end_time = time.time() for batch_idx, batch_dict in enumerate(val_loader): input_var = batch_dict['image'].cuda() mat_labels = batch_dict['material_label'].cuda() subst_labels = batch_dict['substance_label'].cuda() output = model.forward(input_var) mat_output = output['material'] losses = {} if args.substance_loss is not None: if args.substance_loss == 'fc': subst_output = output['substance'] subst_fc_prec1 = compute_precision(subst_output, subst_labels) meter_dict['subst_fc_prec1'].add(subst_fc_prec1) elif args.substance_loss == 'from_material': subst_output = material_to_substance_output( mat_output, subst_mat_labels) else: raise ValueError('Invalid value for substance_loss') losses['substance'] = subst_criterion(subst_output, subst_labels) mat_subst_output = material_to_substance_output( mat_output, subst_mat_labels) subst_prec1 = compute_precision(mat_subst_output, subst_labels) meter_dict['subst_from_mat_prec1'].add(subst_prec1) if args.roughness_loss: roughness_output = output['roughness'] losses['roughness'], roughness_prec1 = compute_roughness_loss( batch_dict['roughness'], roughness_output, mat_criterion) meter_dict['roughness_prec1'].add(roughness_prec1) if args.color_loss is not None: color_output = output['color'] color_target = batch_dict['color_hist'].cuda() if isinstance(color_criterion, nn.KLDivLoss): color_output = F.log_softmax(color_output) losses['color'] = color_criterion(color_output, color_target) else: color_output = None color_target = None losses['material'] = mat_criterion(mat_output, mat_labels) loss = combine_losses(losses, loss_weights, loss_variances, args.use_variance) # Add losses to meters. meter_dict['loss'].add(loss.item()) for loss_name, loss_tensor in losses.items(): meter_dict[f'loss_{loss_name}'].add(loss_tensor.item()) # measure accuracy and record loss mat_prec1, mat_prec5 = compute_precision(mat_output, mat_labels, topk=(1, 5)) meter_dict['mat_prec1'].add(mat_prec1) meter_dict['mat_prec5'].add(mat_prec5) meter_dict['batch_time'].add(time.time() - last_end_time) last_end_time = time.time() pbar.update() if batch_idx % args.show_freq == 0: vis.image(visualize_input(batch_dict), win='validation-batch-example', opts=dict(title='Validation Batch Example')) if 'color' in losses: color_output_vis = color_binner.visualize( F.softmax(color_output[0].cpu().detach(), dim=0)) color_target_vis = color_binner.visualize(color_target[0]) color_hist_vis = np.vstack( (color_output_vis, color_target_vis)) vis.heatmap(color_hist_vis, win='validation-color-hist', opts=dict(title='Validation Color Histogram')) meters = [(k, v) for k, v in meter_dict.items()] meter_table = MeterTable( f"Epoch {epoch} Iter {batch_idx+1}/{len(val_loader)}", meters) vis.text(meter_table.render(), win='validation-status', opts=dict(title='Validation Status')) # measure elapsed time meter_dict['batch_time'].add(time.time() - last_end_time) last_end_time = time.time() mean_stat_dict = {k: v.mean for k, v in meter_dict.items()} return mean_stat_dict
def train_epoch(train_loader, model, epoch, mat_criterion, subst_criterion, color_criterion, optimizer, subst_mat_labels, loss_variances, loss_weights): meter_dict = defaultdict(AverageValueMeter) pbar = tqdm(total=len(train_loader), desc='Training Epoch', dynamic_ncols=True) last_end_time = time.time() for batch_idx, batch_dict in enumerate(train_loader): meter_dict['data_time'].add(time.time() - last_end_time) input_var = batch_dict['image'].cuda() mat_labels = batch_dict['material_label'].cuda() subst_labels = batch_dict['substance_label'].cuda() batch_start_time = time.time() output = model.forward(input_var) mat_output = output['material'] losses = {} if args.substance_loss is not None: if args.substance_loss == 'fc': subst_output = output['substance'] subst_fc_prec1 = compute_precision(subst_output, subst_labels) meter_dict['subst_fc_prec1'].add(subst_fc_prec1) elif args.substance_loss == 'from_material': subst_output = material_to_substance_output( mat_output, subst_mat_labels) else: raise ValueError('Invalid value for substance_loss') losses['substance'] = subst_criterion(subst_output, subst_labels) mat_subst_output = material_to_substance_output( mat_output, subst_mat_labels) mat_subst_prec1 = compute_precision(mat_subst_output, subst_labels) meter_dict['subst_from_mat_prec1'].add(mat_subst_prec1) if args.roughness_loss is not None: roughness_output = output['roughness'] losses['roughness'], roughness_prec1 = compute_roughness_loss( batch_dict['roughness'], roughness_output, mat_criterion) meter_dict['roughness_prec1'].add(roughness_prec1[0]) if args.color_loss is not None: color_output = output['color'] color_target = batch_dict['color_hist'].cuda() if isinstance(color_criterion, nn.KLDivLoss): color_output = F.log_softmax(color_output) losses['color'] = color_criterion(color_output, color_target) else: color_output = None color_target = None losses['material'] = mat_criterion(mat_output, mat_labels) loss = combine_losses(losses, loss_weights, loss_variances, args.use_variance) # Add losses to meters. meter_dict['loss'].add(loss.item()) for loss_name, loss_tensor in losses.items(): meter_dict[f'loss_{loss_name}'].add(loss_tensor.item()) mat_prec1, mat_prec5 = compute_precision(mat_output, mat_labels, topk=(1, 5)) meter_dict['mat_prec1'].add(mat_prec1) meter_dict['mat_prec5'].add(mat_prec5) for loss_name, var_tensor in loss_variances.items(): meter_dict[f'{loss_name}_variance(s_hat)'].add(var_tensor.item()) meter_dict[f'{loss_name}_variance(exp[-s_hat])'].add( torch.exp(-var_tensor).item()) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() meter_dict['batch_time'].add(time.time() - batch_start_time) last_end_time = time.time() pbar.update() if batch_idx % args.show_freq == 0: # Visualize batch input. vis.image(visualize_input(batch_dict), win='train-batch-example', opts=dict(title='Train Batch Example')) # Visualize color prediction. if 'color' in losses: color_output_vis = color_binner.visualize( F.softmax(color_output[0].cpu().detach(), dim=0)) color_target_vis = color_binner.visualize(color_target[0]) color_hist_vis = np.vstack( (color_output_vis, color_target_vis)) vis.heatmap(color_hist_vis, win='train-color-hist', opts=dict(title='Training Color Histogram')) # Show meters. meters = [(k, v) for k, v in meter_dict.items()] meter_table = MeterTable( f"Epoch {epoch} Iter {batch_idx+1}/{len(train_loader)}", meters) vis.text(meter_table.render(), win='train-status', opts=dict(title='Training Status')) mean_stat_dict = {k: v.mean for k, v in meter_dict.items()} return mean_stat_dict
def compute_topk(label_to_mat_id, model: RendNet3, image, seg_mask, *, color_binner, minc_substance, mat_by_id): if image.dtype != np.uint8: image = (image * 255).astype(np.uint8) seg_mask = skimage.transform.resize(seg_mask, (224, 224), order=0, anti_aliasing=False, mode='constant') seg_mask = seg_mask[:, :, np.newaxis].astype(dtype=np.uint8) * 255 image_tensor = transforms.inference_image_transform(input_size=224, output_size=224, pad=0, to_pil=True)(image) mask_tensor = transforms.inference_mask_transform(input_size=224, output_size=224, pad=0)(seg_mask) input_tensor = (torch.cat((image_tensor, mask_tensor), dim=0).unsqueeze(0).cuda()) input_vis = utils.visualize_input({'image': input_tensor}) vis.image(input_vis, win='input') output = model.forward(input_tensor) if 'color' in output: color_output = output['color'] color_hist_vis = color_binner.visualize( F.softmax(color_output[0].cpu().detach(), dim=0)) vis.heatmap(color_hist_vis, win='color-hist', opts=dict(title='Color Histogram')) topk_mat_scores, topk_mat_labels = torch.topk(F.softmax(output['material'], dim=1), k=output['material'].size(1)) topk_dict = {'material': list()} for score, label in zip(topk_mat_scores.squeeze().tolist(), topk_mat_labels.squeeze().tolist()): if int(label) == 0: continue mat_id = int(label_to_mat_id[int(label)]) material = mat_by_id[mat_id] topk_dict['material'].append({ 'score': score, 'id': mat_id, 'pred_substance': material.substance, 'minc_substance': minc_substance, }) if 'substance' in output: topk_subst_scores, topk_subst_labels = torch.topk( F.softmax(output['substance'].cpu(), dim=1), k=output['substance'].size(1)) topk_dict['substance'] = \ [ { 'score': score, 'id': label, 'name': SUBSTANCES[int(label)], } for score, label in zip(topk_subst_scores.squeeze().tolist(), topk_subst_labels.squeeze().tolist()) ] if 'roughness' in output: nrc = model.num_roughness_classes roughness_midpoints = np.linspace(1 / nrc / 2, 1 - 1 / nrc / 2, nrc) topk_roughness_scores, topk_roughness_labels = torch.topk(F.softmax( output['roughness'].cpu(), dim=1), k=5) topk_dict['roughness'] = \ [ { 'score': score, 'value': roughness_midpoints[int(label)], } for score, label in zip(topk_roughness_scores.squeeze().tolist(), topk_roughness_labels.squeeze().tolist()) ] return topk_dict
def validate_epoch(val_loader, model, epoch, subst_criterion, color_criterion, loss_variances, loss_weights): meter_dict = defaultdict(AverageValueMeter) # switch to evaluate mode model.eval() pbar = tqdm(total=len(val_loader), desc='Validating Epoch') last_end_time = time.time() for batch_idx, batch_dict in enumerate(val_loader): input_var = batch_dict['image'].cuda() subst_labels = batch_dict['substance_label'].cuda() output = model.forward(input_var) losses = {} subst_output = output['substance'] subst_fc_prec1 = compute_precision(subst_output, subst_labels) meter_dict['subst_fc_prec1'].add(subst_fc_prec1) losses['substance'] = subst_criterion(subst_output, subst_labels) if args.color_loss is not None: color_output = output['color'] color_target = batch_dict['color_hist'].cuda() if isinstance(color_criterion, nn.KLDivLoss): color_output = F.log_softmax(color_output) losses['color'] = color_criterion(color_output, color_target) else: color_output = None color_target = None loss = combine_losses(losses, loss_weights, loss_variances, args.use_variance) # Add losses to meters. meter_dict['loss'].add(loss.item()) for loss_name, loss_tensor in losses.items(): meter_dict[f'loss_{loss_name}'].add(loss_tensor.item()) meter_dict['batch_time'].add(time.time() - last_end_time) last_end_time = time.time() pbar.update() if batch_idx % args.show_freq == 0: vis.image(visualize_input(batch_dict), win='validation-batch-example', opts=dict(title='Validation Batch Example')) if 'color' in losses: visualize_color_hist_output( color_output[0], color_target[0], 'validation-color-hist', 'Validation Color Histogram', color_hist_shape=color_binner.shape, color_hist_space=color_binner.space, ) meters = [(k, v) for k, v in meter_dict.items()] meter_table = MeterTable( f"Epoch {epoch} Iter {batch_idx+1}/{len(val_loader)}", meters) vis.text(meter_table.render(), win='validation-status', opts=dict(title='Validation Status')) # measure elapsed time meter_dict['batch_time'].add(time.time() - last_end_time) last_end_time = time.time() mean_stat_dict = {k: v.mean for k, v in meter_dict.items()} return mean_stat_dict
def train_epoch(train_loader, model, epoch, subst_criterion, color_criterion, optimizer, loss_variances, loss_weights): meter_dict = defaultdict(AverageValueMeter) pbar = tqdm(total=len(train_loader), desc='Training Epoch') last_end_time = time.time() for batch_idx, batch_dict in enumerate(train_loader): meter_dict['data_time'].add(time.time() - last_end_time) input_var = batch_dict['image'].cuda() subst_labels = batch_dict['substance_label'].cuda() batch_start_time = time.time() output = model.forward(input_var) losses = {} subst_output = output['substance'] subst_fc_prec1 = compute_precision(subst_output, subst_labels) meter_dict['subst_fc_prec1'].add(subst_fc_prec1) losses['substance'] = subst_criterion(subst_output, subst_labels) if args.color_loss is not None: color_output = output['color'] color_target = batch_dict['color_hist'].cuda() if isinstance(color_criterion, nn.KLDivLoss): color_output = F.log_softmax(color_output) losses['color'] = color_criterion(color_output, color_target) else: color_output = None color_target = None loss = combine_losses(losses, loss_weights, loss_variances, args.use_variance) # Add losses to meters. meter_dict['loss'].add(loss.item()) for loss_name, loss_tensor in losses.items(): meter_dict[f'loss_{loss_name}'].add(loss_tensor.item()) for loss_name, var_tensor in loss_variances.items(): meter_dict[f'{loss_name}_variance(s_hat)'].add(var_tensor.item()) meter_dict[f'{loss_name}_variance(exp[-s_hat])'].add( torch.exp(-var_tensor).item()) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() meter_dict['batch_time'].add(time.time() - batch_start_time) last_end_time = time.time() pbar.update() if batch_idx % args.show_freq == 0: # Visualize batch input. vis.image(visualize_input(batch_dict), win='train-batch-example', opts=dict(title='Train Batch Example')) # Visualize color prediction. if 'color' in losses: visualize_color_hist_output( color_output[0], color_target[0], win='train-color-hist', title='Training Color Histogram', color_hist_shape=color_binner.shape, color_hist_space=color_binner.space, ) # Show meters. meters = [(k, v) for k, v in meter_dict.items()] meter_table = MeterTable( f"Epoch {epoch} Iter {batch_idx+1}/{len(train_loader)}", meters) vis.text(meter_table.render(), win='train-status', opts=dict(title='Training Status')) mean_stat_dict = {k: v.mean for k, v in meter_dict.items()} return mean_stat_dict