def compute_topk(label_to_mat_id, model: RendNet3, image, seg_mask): if image.dtype != np.uint8: image = (image * 255).astype(np.uint8) seg_mask = skimage.transform.resize(seg_mask, (224, 224), order=0, anti_aliasing=False, mode='reflect') seg_mask = seg_mask[:, :, np.newaxis].astype(dtype=np.uint8) * 255 image_tensor = transforms.inference_image_transform(224)(image) mask_tensor = transforms.inference_mask_transform(224)(seg_mask) input_tensor = torch.cat((image_tensor, mask_tensor), dim=0).unsqueeze(0) input_vis = utils.visualize_input({'image': input_tensor}) vis.image(input_vis, win='input') output = model.forward(input_tensor.cuda()) topk_mat_scores, topk_mat_labels = torch.topk(F.softmax(output['material'], dim=1), k=10) topk_dict = { 'material': [{ 'score': score, 'id': label_to_mat_id[int(label)], } for score, label in zip(topk_mat_scores.squeeze().tolist(), topk_mat_labels.squeeze().tolist())] } if 'substance' in output: topk_subst_scores, topk_subst_labels = torch.topk( F.softmax(output['substance'].cpu(), dim=1), k=output['substance'].size(1)) topk_dict['substance'] = \ [ { 'score': score, 'id': label, 'name': SUBSTANCES[int(label)], } for score, label in zip(topk_subst_scores.squeeze().tolist(), topk_subst_labels.squeeze().tolist()) ] if 'roughness' in output: nrc = model.num_roughness_classes roughness_midpoints = np.linspace(1 / nrc / 2, 1 - 1 / nrc / 2, nrc) topk_roughness_scores, topk_roughness_labels = torch.topk(F.softmax( output['roughness'].cpu(), dim=1), k=5) topk_dict['roughness'] = \ [ { 'score': score, 'value': roughness_midpoints[int(label)], } for score, label in zip(topk_roughness_scores.squeeze().tolist(), topk_roughness_labels.squeeze().tolist()) ] return topk_dict
def main(checkpoint_path, batch_size, normalized, visdom_port): checkpoint_path = Path(checkpoint_path) snapshot_path = checkpoint_path.parent.parent.parent / 'snapshot.json' with snapshot_path.open('r') as f: snapshot_dict = json.load(f) mat_id_to_label = snapshot_dict['mat_id_to_label'] label_to_mat_id = {int(v): int(k) for k, v in mat_id_to_label.items()} num_classes = len(label_to_mat_id) + 1 print(f'Loading model checkpoint from {checkpoint_path!r}') checkpoint = torch.load(checkpoint_path) model = RendNet3(num_classes=num_classes, num_roughness_classes=20, num_substances=len(SUBSTANCES), base_model=resnet.resnet18(pretrained=False)) model.load_state_dict(checkpoint['state_dict']) model.train(False) model = model.cuda() validation_dataset = rendering_dataset.MaterialRendDataset( snapshot_dict, snapshot_dict['examples']['validation'], shape=(384, 384), image_transform=transforms.inference_image_transform(INPUT_SIZE), mask_transform=transforms.inference_mask_transform(INPUT_SIZE)) validation_loader = DataLoader( validation_dataset, batch_size=batch_size, num_workers=8, shuffle=False, pin_memory=True, collate_fn=rendering_dataset.collate_fn) pred_counts = collections.defaultdict(collections.Counter) # switch to evaluate mode model.eval() confusion_meter = tnt.meter.ConfusionMeter( k=num_classes, normalized=normalized) pbar = tqdm(validation_loader) for batch_idx, batch_dict in enumerate(pbar): input_tensor = batch_dict['image'].cuda() labels = batch_dict['material_label'].cuda() # compute output output = model.forward(input_tensor) pbar.set_description(f"{output['material'].size()}") # _, pred = output['material'].topk(k=1, dim=1, largest=True, sorted=True) confusion_meter.add(output['material'].cpu(), labels.cpu()) with session_scope() as sess: materials = sess.query(models.Material).filter_by(enabled=True).all() material_id_to_name = {m.id: m.name for m in materials} mat_by_id = {m.id: m for m in materials} class_names = ['background'] class_names.extend([ mat_by_id[label_to_mat_id[i]].name for i in range(1, num_classes) ]) print(len(class_names), ) confusion_matrix = confusion_meter.value() # sorted_confusion_matrix = confusion_matrix[:, inds] # sorted_confusion_matrix = sorted_confusion_matrix[inds, :] # sorted_class_names = [class_names[i] for i in inds] confusion_logger = VisdomLogger( 'heatmap', opts={ 'title': 'Confusion matrix', 'columnnames': class_names, 'rownames': class_names, 'xtickfont': {'size': 8}, 'ytickfont': {'size': 8}, }, env='brdf-classifier-confusion', port=visdom_port) confusion_logger.log(confusion_matrix)
def compute_topk(label_to_mat_id, model: RendNet3, image, seg_mask, *, color_binner, minc_substance, mat_by_id): if image.dtype != np.uint8: image = (image * 255).astype(np.uint8) seg_mask = skimage.transform.resize(seg_mask, (224, 224), order=0, anti_aliasing=False, mode='constant') seg_mask = seg_mask[:, :, np.newaxis].astype(dtype=np.uint8) * 255 image_tensor = transforms.inference_image_transform(input_size=224, output_size=224, pad=0, to_pil=True)(image) mask_tensor = transforms.inference_mask_transform(input_size=224, output_size=224, pad=0)(seg_mask) input_tensor = (torch.cat((image_tensor, mask_tensor), dim=0).unsqueeze(0).cuda()) input_vis = utils.visualize_input({'image': input_tensor}) vis.image(input_vis, win='input') output = model.forward(input_tensor) if 'color' in output: color_output = output['color'] color_hist_vis = color_binner.visualize( F.softmax(color_output[0].cpu().detach(), dim=0)) vis.heatmap(color_hist_vis, win='color-hist', opts=dict(title='Color Histogram')) topk_mat_scores, topk_mat_labels = torch.topk(F.softmax(output['material'], dim=1), k=output['material'].size(1)) topk_dict = {'material': list()} for score, label in zip(topk_mat_scores.squeeze().tolist(), topk_mat_labels.squeeze().tolist()): if int(label) == 0: continue mat_id = int(label_to_mat_id[int(label)]) material = mat_by_id[mat_id] topk_dict['material'].append({ 'score': score, 'id': mat_id, 'pred_substance': material.substance, 'minc_substance': minc_substance, }) if 'substance' in output: topk_subst_scores, topk_subst_labels = torch.topk( F.softmax(output['substance'].cpu(), dim=1), k=output['substance'].size(1)) topk_dict['substance'] = \ [ { 'score': score, 'id': label, 'name': SUBSTANCES[int(label)], } for score, label in zip(topk_subst_scores.squeeze().tolist(), topk_subst_labels.squeeze().tolist()) ] if 'roughness' in output: nrc = model.num_roughness_classes roughness_midpoints = np.linspace(1 / nrc / 2, 1 - 1 / nrc / 2, nrc) topk_roughness_scores, topk_roughness_labels = torch.topk(F.softmax( output['roughness'].cpu(), dim=1), k=5) topk_dict['roughness'] = \ [ { 'score': score, 'value': roughness_midpoints[int(label)], } for score, label in zip(topk_roughness_scores.squeeze().tolist(), topk_roughness_labels.squeeze().tolist()) ] return topk_dict
def main(): snapshot_dir = args.snapshot_dir checkpoint_dir = args.checkpoint_dir / args.model_name train_path = Path(snapshot_dir, 'train') validation_path = Path(snapshot_dir, 'validation') meta_path = Path(snapshot_dir, 'meta.json') print(f' * train_path = {train_path!s}\n' f' * validation_path = {validation_path!s}\n' f' * meta_path = {meta_path!s}\n' f' * checkpoint_dir = {checkpoint_dir!s}') with meta_path.open('r') as f: meta_dict = json.load(f) with session_scope() as sess: materials = sess.query(models.Material).all() mat_by_id = {m.id: m for m in materials} mat_id_to_label = { int(k): v for k, v in meta_dict['mat_id_to_label'].items() } subst_mat_labels = defaultdict(list) for mat_id, mat_label in mat_id_to_label.items(): material = mat_by_id[mat_id] subst_mat_labels[material.substance].append(mat_label) subst_mat_labels = { SUBSTANCES.index(k): torch.LongTensor(v).cuda() for k, v in subst_mat_labels.items() } mat_label_to_subst_label = { label: mat_by_id[mat_id].substance for mat_id, label in mat_id_to_label.items() } num_classes = max(mat_id_to_label.values()) + 1 if args.roughness_loss: num_roughness_classes = (args.num_roughness_classes if args.roughness_loss == 'cross_entropy' else 1) output_roughness = True else: num_roughness_classes = 0 output_roughness = False output_substance = (True if args.substance_loss == 'fc' else False) model_params = dict( num_classes=num_classes, num_substances=len(SUBSTANCES), num_roughness_classes=num_roughness_classes, output_roughness=output_roughness, output_substance=output_substance, output_color=color_binner is not None, num_color_bins=color_binner.size if color_binner else 0, ) base_model_fn = { 'resnet18': resnet.resnet18, 'resnet34': resnet.resnet34, 'resnet50': resnet.resnet50, }[args.base_model] if args.from_scratch: base_model = base_model_fn(pretrained=False) else: base_model = base_model_fn(pretrained=True) model = RendNet3( **model_params, base_model=base_model, ).train().cuda() train_cum_stats = defaultdict(list) val_cum_stats = defaultdict(list) if args.resume: print(f" * Loading weights from {args.resume!s}") checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict'], strict=False) stats_dict_path = checkpoint_dir / 'model_train_stats.json' if stats_dict_path.exists(): with stats_dict_path.open('r') as f: stats_dict = json.load(f) train_cum_stats = stats_dict['train'] val_cum_stats = stats_dict['validation'] print(' * Loading datasets') train_dataset = rendering_dataset.MaterialRendDataset( train_path, meta_dict, color_binner=color_binner, shape=SHAPE, lmdb_name=snapshot_dir.name, image_transform=transforms.train_image_transform(INPUT_SIZE, pad=0), mask_transform=transforms.train_mask_transform(INPUT_SIZE, pad=0), mask_noise_p=args.mask_noise_p) validation_dataset = rendering_dataset.MaterialRendDataset( validation_path, meta_dict, color_binner=color_binner, shape=SHAPE, lmdb_name=snapshot_dir.name, image_transform=transforms.inference_image_transform( INPUT_SIZE, INPUT_SIZE), mask_transform=transforms.inference_mask_transform( INPUT_SIZE, INPUT_SIZE)) train_loader = DataLoader( train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, pin_memory=True, ) validation_loader = DataLoader( validation_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=False, ) loss_weights = { 'material': args.material_loss_weight, 'roughness': args.roughness_loss_weight, 'substance': args.substance_loss_weight, 'color': args.color_loss_weight, } model_params = { 'init_lr': args.init_lr, 'lr_decay_epochs': args.lr_decay_epochs, 'lr_decay_frac': args.lr_decay_frac, 'momentum': args.momentum, 'weight_decay': args.weight_decay, 'batch_size': args.batch_size, 'from_scratch': args.from_scratch, 'base_model': args.base_model, 'resumed_from': str(args.resume), 'mask_noise_p': args.mask_noise_p, 'use_substance_loss': args.substance_loss is not None, 'substance_loss': args.substance_loss, 'roughness_loss': args.roughness_loss, 'material_variance_init': args.material_variance_init, 'substance_variance_init': args.substance_variance_init, 'color_variance_init': args.color_variance_init, 'color_loss': args.color_loss, 'num_classes': num_classes, 'num_roughness_classes': (args.num_roughness_classes if args.roughness_loss else None), 'use_variance': args.use_variance, 'loss_weights': loss_weights, 'model_params': model_params, } if color_binner: model_params = { **model_params, 'color_hist_name': color_binner.name, 'color_hist_shape': color_binner.shape, 'color_hist_space': color_binner.space, 'color_hist_sigma': color_binner.sigma, } vis.text(f'<pre>{json.dumps(model_params, indent=2)}</pre>', win='params') checkpoint_dir.mkdir(parents=True, exist_ok=True) model_params_path = (checkpoint_dir / 'model_params.json') if not args.dry_run and not model_params_path.exists(): with model_params_path.open('w') as f: print(f' * Saving model params to {model_params_path!s}') json.dump(model_params, f, indent=2) mat_criterion = nn.CrossEntropyLoss().cuda() if args.substance_loss == 'from_material': subst_criterion = nn.NLLLoss().cuda() else: subst_criterion = nn.CrossEntropyLoss().cuda() if args.color_loss == 'cross_entropy': color_criterion = nn.BCEWithLogitsLoss().cuda() elif args.color_loss == 'kl_divergence': color_criterion = nn.KLDivLoss().cuda() else: color_criterion = None loss_variances = { 'material': torch.tensor([args.material_variance_init], requires_grad=True), 'substance': torch.tensor([args.substance_variance_init], requires_grad=True), 'color': torch.tensor([args.color_variance_init], requires_grad=True), } optimizer = optim.SGD([*model.parameters(), *loss_variances.values()], args.init_lr, momentum=args.momentum, weight_decay=args.weight_decay) best_prec1 = 0 pbar = tqdm(total=args.epochs, dynamic_ncols=True) for epoch in range(args.start_epoch, args.epochs): pbar.set_description(f'Epoch {epoch}') decay_learning_rate(optimizer, epoch, args.init_lr, args.lr_decay_epochs, args.lr_decay_frac) train_stats = train_epoch( train_loader, model, epoch, mat_criterion, subst_criterion, color_criterion, optimizer, subst_mat_labels, loss_variances=loss_variances, loss_weights=loss_weights, ) # evaluate on validation set val_stats = validate_epoch( validation_loader, model, epoch, mat_criterion, subst_criterion, color_criterion, subst_mat_labels, loss_variances=loss_variances, loss_weights=loss_weights, ) # remember best prec@1 and save checkpoint is_best = val_stats['mat_prec1'] > best_prec1 best_prec1 = max(val_stats['mat_prec1'], best_prec1) for stat_name, stat_val in train_stats.items(): train_cum_stats[stat_name].append(stat_val) for stat_name, stat_val in val_stats.items(): val_cum_stats[stat_name].append(stat_val) stats_dict = { 'epoch': epoch + 1, 'train': train_cum_stats, 'validation': val_cum_stats, } if not args.dry_run: with (checkpoint_dir / 'model_train_stats.json').open('w') as f: json.dump(stats_dict, f) if is_best: with (checkpoint_dir / 'model_best_stats.json').open('w') as f: json.dump( { 'epoch': epoch + 1, 'train': { k: v for k, v in train_stats.items() if not math.isnan(v) }, 'validation': { k: v for k, v in val_stats.items() if not math.isnan(v) }, }, f) save_checkpoint( { 'epoch': epoch + 1, 'model_name': args.model_name, 'state_dict': model.state_dict(), 'loss_variances': loss_variances, 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), 'params': model_params, }, checkpoint_dir=checkpoint_dir, filename=f'model.{args.model_name}.epoch{epoch:03d}.pth.tar', is_best=is_best) # Plot graphs. for stat_name, train_stat_vals in train_cum_stats.items(): if stat_name not in val_cum_stats: continue val_stat_vals = val_cum_stats[stat_name] vis.line(Y=np.column_stack((train_stat_vals, val_stat_vals)), X=np.array(list(range(0, len(train_stat_vals)))), win=f'plot-{stat_name}', name=stat_name, opts={ 'legend': ['training', 'validation'], 'title': stat_name, })
def main(): checkpoint_dir = args.checkpoint_dir / args.model_name split_path = args.opensurfaces_dir / args.split_file with split_path.open('r') as f: split_dict = json.load(f) print(f' * opensurface_dir = {args.opensurfaces_dir!s}\n' f' * split_path = {split_path!s}\n' f' * checkpoint_dir = {checkpoint_dir!s}') print(' * Loading datasets') output_substance = (True if args.substance_loss == 'fc' else False) output_color = color_binner is not None train_dataset = dataset.OpenSurfacesDataset( base_dir=args.opensurfaces_dir, color_binner=color_binner, photo_ids=split_dict['train'], image_transform=transforms.train_image_transform( INPUT_SIZE, crop_scales=CROP_RANGE, max_rotation=MAX_ROTATION, max_brightness_jitter=0.2, max_contrast_jitter=0.2, max_saturation_jitter=0.2, max_hue_jitter=0.1, ), mask_transform=transforms.train_mask_transform( INPUT_SIZE, crop_scales=CROP_RANGE, max_rotation=MAX_ROTATION), cropped_image_transform=transforms.train_image_transform( INPUT_SIZE, pad=200, crop_scales=CROPPED_CROP_RANGE, max_rotation=CROPPED_MAX_ROTATION, max_brightness_jitter=0.2, max_contrast_jitter=0.2, max_saturation_jitter=0.2, max_hue_jitter=0.1, ), cropped_mask_transform=transforms.train_mask_transform( INPUT_SIZE, pad=200, crop_scales=CROPPED_CROP_RANGE, max_rotation=CROPPED_MAX_ROTATION), use_cropped=args.use_cropped, p_cropped=args.p_cropped, ) validation_dataset = dataset.OpenSurfacesDataset( base_dir=args.opensurfaces_dir, color_binner=color_binner, photo_ids=split_dict['validation'], image_transform=transforms.inference_image_transform( INPUT_SIZE, INPUT_SIZE), mask_transform=transforms.inference_mask_transform( INPUT_SIZE, INPUT_SIZE), cropped_image_transform=transforms.inference_image_transform( INPUT_SIZE, INPUT_SIZE), cropped_mask_transform=transforms.inference_mask_transform( INPUT_SIZE, INPUT_SIZE), use_cropped=args.use_cropped, p_cropped=args.p_cropped, ) model_params = dict( num_substances=len(SUBSTANCES), output_material=False, output_roughness=False, output_substance=output_substance, output_color=output_color, num_color_bins=color_binner.size if color_binner else 0, ) base_model_fn = { 'resnet18': resnet.resnet18, 'resnet34': resnet.resnet34, 'resnet50': resnet.resnet50, }[args.base_model] if args.from_scratch: base_model = base_model_fn(pretrained=False) else: base_model = base_model_fn(pretrained=True) model = RendNet3(**model_params, base_model=base_model).train().cuda() train_cum_stats = defaultdict(list) val_cum_stats = defaultdict(list) if args.resume: print(f" * Loading weights from {args.resume!s}") checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict'], strict=False) stats_dict_path = checkpoint_dir / 'model_train_stats.json' if stats_dict_path.exists(): with stats_dict_path.open('r') as f: stats_dict = json.load(f) train_cum_stats = stats_dict['train'] val_cum_stats = stats_dict['validation'] train_loader = DataLoader( train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, pin_memory=True, ) validation_loader = DataLoader( validation_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=False, ) loss_weights = { # 'material': args.material_loss_weight, # 'roughness': args.roughness_loss_weight, 'substance': args.substance_loss_weight, 'color': args.color_loss_weight, } model_params = { 'init_lr': args.init_lr, 'lr_decay_epochs': args.lr_decay_epochs, 'lr_decay_frac': args.lr_decay_frac, 'momentum': args.momentum, 'weight_decay': args.weight_decay, 'batch_size': args.batch_size, 'from_scratch': args.from_scratch, 'base_model': args.base_model, 'resumed_from': str(args.resume), 'use_substance_loss': args.substance_loss is not None, 'substance_loss': args.substance_loss, 'color_loss': args.color_loss, 'color_hist_name': args.color_hist_name, 'use_variance': args.use_variance, 'loss_weights': loss_weights, 'use_cropped': args.use_cropped, 'p_cropped': args.p_cropped, 'model_params': model_params, } if color_binner: model_params = { **model_params, 'color_hist_name': color_binner.name, 'color_hist_shape': color_binner.shape, 'color_hist_space': color_binner.space, 'color_hist_sigma': color_binner.sigma, } vis.text(f'<pre>{json.dumps(model_params, indent=2)}</pre>', win='params') checkpoint_dir.mkdir(parents=True, exist_ok=True) model_params_path = (checkpoint_dir / 'model_params.json') if not model_params_path.exists(): with model_params_path.open('w') as f: print(f' * Saving model params to {model_params_path!s}') json.dump(model_params, f, indent=2) subst_criterion = nn.CrossEntropyLoss().cuda() if args.color_loss == 'cross_entropy': color_criterion = nn.BCEWithLogitsLoss().cuda() elif args.color_loss == 'kl_divergence': color_criterion = nn.KLDivLoss().cuda() else: color_criterion = None loss_variances = { 'substance': torch.tensor([args.substance_variance_init], requires_grad=True), 'color': torch.tensor([args.color_variance_init], requires_grad=True), } optimizer = optim.SGD([*model.parameters(), *loss_variances.values()], args.init_lr, momentum=args.momentum, weight_decay=args.weight_decay) if 'subst_fc_prec1' in val_cum_stats: best_prec1 = max(val_cum_stats['subst_fc_prec1']) else: best_prec1 = 0 pbar = tqdm(total=args.epochs) for epoch in range(args.start_epoch, args.epochs): pbar.set_description(f'Epoch {epoch}') decay_learning_rate(optimizer, epoch, args.init_lr, args.lr_decay_epochs, args.lr_decay_frac) train_stats = train_epoch( train_loader, model, epoch, subst_criterion, color_criterion, optimizer, loss_variances=loss_variances, loss_weights=loss_weights, ) # evaluate on validation set val_stats = validate_epoch( validation_loader, model, epoch, subst_criterion, color_criterion, loss_variances=loss_variances, loss_weights=loss_weights, ) # remember best prec@1 and save checkpoint is_best = val_stats['subst_fc_prec1'] > best_prec1 best_prec1 = max(val_stats['subst_fc_prec1'], best_prec1) for stat_name, stat_val in train_stats.items(): train_cum_stats[stat_name].append(stat_val) for stat_name, stat_val in val_stats.items(): val_cum_stats[stat_name].append(stat_val) stats_dict = { 'epoch': epoch + 1, 'train': train_cum_stats, 'validation': val_cum_stats, } with (checkpoint_dir / 'model_train_stats.json').open('w') as f: json.dump(stats_dict, f) if is_best: with (checkpoint_dir / 'model_best_stats.json').open('w') as f: json.dump( { 'epoch': epoch + 1, 'train': { k: v for k, v in train_stats.items() if not math.isnan(v) }, 'validation': { k: v for k, v in val_stats.items() if not math.isnan(v) }, }, f) save_checkpoint( { 'epoch': epoch + 1, 'model_name': args.model_name, 'state_dict': model.state_dict(), 'loss_variances': loss_variances, 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), 'params': model_params, }, checkpoint_dir=checkpoint_dir, filename=f'model.{args.model_name}.epoch{epoch:03d}.pth.tar', is_best=is_best) # Plot graphs. for stat_name, train_stat_vals in train_cum_stats.items(): if stat_name not in val_cum_stats: continue tqdm.write(f"Plotting {stat_name}") val_stat_vals = val_cum_stats[stat_name] vis.line(Y=np.column_stack((train_stat_vals, val_stat_vals)), X=np.array(list(range(0, len(train_stat_vals)))), win=f'plot-{stat_name}', name=stat_name, opts={ 'legend': ['training', 'validation'], 'title': stat_name, })