def evaluate_one_epoch(test_loader, DATASET_CONFIG, CONFIG_DICT, AP_IOU_THRESHOLDS, model, criterion, config): stat_dict = {} if config.num_decoder_layers > 0: prefixes = ['last_', 'proposal_'] + [ f'{i}head_' for i in range(config.num_decoder_layers - 1) ] else: prefixes = ['proposal_'] # only proposal ap_calculator_list = [APCalculator(iou_thresh, DATASET_CONFIG.class2type) \ for iou_thresh in AP_IOU_THRESHOLDS] mAPs = [[iou_thresh, {k: 0 for k in prefixes}] for iou_thresh in AP_IOU_THRESHOLDS] model.eval() # set model to eval mode (for bn and dp) batch_pred_map_cls_dict = {k: [] for k in prefixes} batch_gt_map_cls_dict = {k: [] for k in prefixes} for batch_idx, batch_data_label in enumerate(test_loader): for key in batch_data_label: batch_data_label[key] = batch_data_label[key].cuda( non_blocking=True) # Forward pass inputs = {'point_clouds': batch_data_label['point_clouds']} with torch.no_grad(): end_points = model(inputs) # Compute loss for key in batch_data_label: assert (key not in end_points) end_points[key] = batch_data_label[key] loss, end_points = criterion( end_points, DATASET_CONFIG, num_decoder_layers=config.num_decoder_layers, query_points_generator_loss_coef=config. query_points_generator_loss_coef, obj_loss_coef=config.obj_loss_coef, box_loss_coef=config.box_loss_coef, sem_cls_loss_coef=config.sem_cls_loss_coef, query_points_obj_topk=config.query_points_obj_topk, center_loss_type=config.center_loss_type, center_delta=config.center_delta, size_loss_type=config.size_loss_type, size_delta=config.size_delta, heading_loss_type=config.heading_loss_type, heading_delta=config.heading_delta, size_cls_agnostic=config.size_cls_agnostic) # Accumulate statistics and print out for key in end_points: if 'loss' in key or 'acc' in key or 'ratio' in key: if key not in stat_dict: stat_dict[key] = 0 if isinstance(end_points[key], float): stat_dict[key] += end_points[key] else: stat_dict[key] += end_points[key].item() for prefix in prefixes: batch_pred_map_cls = parse_predictions( end_points, CONFIG_DICT, prefix, size_cls_agnostic=config.size_cls_agnostic) batch_gt_map_cls = parse_groundtruths( end_points, CONFIG_DICT, size_cls_agnostic=config.size_cls_agnostic) batch_pred_map_cls_dict[prefix].append(batch_pred_map_cls) batch_gt_map_cls_dict[prefix].append(batch_gt_map_cls) if (batch_idx + 1) % config.print_freq == 0: logger.info( f'Eval: [{batch_idx + 1}/{len(test_loader)}] ' + ''.join([ f'{key} {stat_dict[key] / (float(batch_idx + 1)):.4f} \t' for key in sorted(stat_dict.keys()) if 'loss' not in key ])) logger.info(''.join([ f'{key} {stat_dict[key] / (float(batch_idx + 1)):.4f} \t' for key in sorted(stat_dict.keys()) if 'loss' in key and 'proposal_' not in key and 'last_' not in key and 'head_' not in key ])) logger.info(''.join([ f'{key} {stat_dict[key] / (float(batch_idx + 1)):.4f} \t' for key in sorted(stat_dict.keys()) if 'last_' in key ])) logger.info(''.join([ f'{key} {stat_dict[key] / (float(batch_idx + 1)):.4f} \t' for key in sorted(stat_dict.keys()) if 'proposal_' in key ])) for ihead in range(config.num_decoder_layers - 2, -1, -1): logger.info(''.join([ f'{key} {stat_dict[key] / (float(batch_idx + 1)):.4f} \t' for key in sorted(stat_dict.keys()) if f'{ihead}head_' in key ])) mAP = 0.0 for prefix in prefixes: for (batch_pred_map_cls, batch_gt_map_cls) in zip(batch_pred_map_cls_dict[prefix], batch_gt_map_cls_dict[prefix]): for ap_calculator in ap_calculator_list: ap_calculator.step(batch_pred_map_cls, batch_gt_map_cls) # Evaluate average precision for i, ap_calculator in enumerate(ap_calculator_list): metrics_dict = ap_calculator.compute_metrics() logger.info( f'=====================>{prefix} IOU THRESH: {AP_IOU_THRESHOLDS[i]}<=====================' ) for key in metrics_dict: logger.info(f'{key} {metrics_dict[key]}') if prefix == 'last_' and ap_calculator.ap_iou_thresh > 0.3: mAP = metrics_dict['mAP'] mAPs[i][1][prefix] = metrics_dict['mAP'] ap_calculator.reset() for mAP in mAPs: logger.info(f'IoU[{mAP[0]}]:\t' + ''.join( [f'{key}: {mAP[1][key]:.4f} \t' for key in sorted(mAP[1].keys())])) return mAP, mAPs
def evaluate_one_time(test_loader, DATASET_CONFIG, CONFIG_DICT, AP_IOU_THRESHOLDS, model, criterion, args, time=0): stat_dict = {} if args.num_decoder_layers > 0: if args.dataset == 'sunrgbd': _prefixes = ['last_', 'proposal_'] _prefixes += [f'{i}head_' for i in range(args.num_decoder_layers - 1)] prefixes = _prefixes.copy() + ['all_layers_'] elif args.dataset == 'scannet': _prefixes = ['last_', 'proposal_'] _prefixes += [f'{i}head_' for i in range(args.num_decoder_layers - 1)] prefixes = _prefixes.copy() + ['last_three_'] + ['all_layers_'] else: prefixes = ['proposal_'] # only proposal _prefixes = prefixes if args.num_decoder_layers >= 3: last_three_prefixes = ['last_', f'{args.num_decoder_layers - 2}head_', f'{args.num_decoder_layers - 3}head_'] elif args.num_decoder_layers == 2: last_three_prefixes = ['last_', '0head_'] elif args.num_decoder_layers == 1: last_three_prefixes = ['last_'] else: last_three_prefixes = [] ap_calculator_list = [APCalculator(iou_thresh, DATASET_CONFIG.class2type) \ for iou_thresh in AP_IOU_THRESHOLDS] mAPs = [[iou_thresh, {k: 0 for k in prefixes}] for iou_thresh in AP_IOU_THRESHOLDS] model.eval() # set model to eval mode (for bn and dp) batch_pred_map_cls_dict = {k: [] for k in prefixes} batch_gt_map_cls_dict = {k: [] for k in prefixes} for batch_idx, batch_data_label in enumerate(test_loader): for key in batch_data_label: batch_data_label[key] = batch_data_label[key].cuda(non_blocking=True) # Forward pass inputs = {'point_clouds': batch_data_label['point_clouds']} with torch.no_grad(): end_points = model(inputs) # Compute loss for key in batch_data_label: assert (key not in end_points) end_points[key] = batch_data_label[key] loss, end_points = criterion(end_points, DATASET_CONFIG, num_decoder_layers=args.num_decoder_layers, query_points_generator_loss_coef=args.query_points_generator_loss_coef, obj_loss_coef=args.obj_loss_coef, box_loss_coef=args.box_loss_coef, sem_cls_loss_coef=args.sem_cls_loss_coef, query_points_obj_topk=args.query_points_obj_topk, center_loss_type=args.center_loss_type, center_delta=args.center_delta, size_loss_type=args.size_loss_type, size_delta=args.size_delta, heading_loss_type=args.heading_loss_type, heading_delta=args.heading_delta, size_cls_agnostic=args.size_cls_agnostic) # Accumulate statistics and print out for key in end_points: if 'loss' in key or 'acc' in key or 'ratio' in key: if key not in stat_dict: stat_dict[key] = 0 if isinstance(end_points[key], float): stat_dict[key] += end_points[key] else: stat_dict[key] += end_points[key].item() for prefix in prefixes: if prefix == 'last_three_': end_points[f'{prefix}center'] = torch.cat([end_points[f'{ppx}center'] for ppx in last_three_prefixes], 1) end_points[f'{prefix}heading_scores'] = torch.cat([end_points[f'{ppx}heading_scores'] for ppx in last_three_prefixes], 1) end_points[f'{prefix}heading_residuals'] = torch.cat([end_points[f'{ppx}heading_residuals'] for ppx in last_three_prefixes], 1) if args.size_cls_agnostic: end_points[f'{prefix}pred_size'] = torch.cat([end_points[f'{ppx}pred_size'] for ppx in last_three_prefixes], 1) else: end_points[f'{prefix}size_scores'] = torch.cat([end_points[f'{ppx}size_scores'] for ppx in last_three_prefixes], 1) end_points[f'{prefix}size_residuals'] = torch.cat([end_points[f'{ppx}size_residuals'] for ppx in last_three_prefixes], 1) end_points[f'{prefix}sem_cls_scores'] = torch.cat([end_points[f'{ppx}sem_cls_scores'] for ppx in last_three_prefixes], 1) end_points[f'{prefix}objectness_scores'] = torch.cat([end_points[f'{ppx}objectness_scores'] for ppx in last_three_prefixes], 1) elif prefix == 'all_layers_': end_points[f'{prefix}center'] = torch.cat([end_points[f'{ppx}center'] for ppx in _prefixes], 1) end_points[f'{prefix}heading_scores'] = torch.cat([end_points[f'{ppx}heading_scores'] for ppx in _prefixes], 1) end_points[f'{prefix}heading_residuals'] = torch.cat([end_points[f'{ppx}heading_residuals'] for ppx in _prefixes], 1) if args.size_cls_agnostic: end_points[f'{prefix}pred_size'] = torch.cat([end_points[f'{ppx}pred_size'] for ppx in _prefixes], 1) else: end_points[f'{prefix}size_scores'] = torch.cat([end_points[f'{ppx}size_scores'] for ppx in _prefixes], 1) end_points[f'{prefix}size_residuals'] = torch.cat([end_points[f'{ppx}size_residuals'] for ppx in _prefixes], 1) end_points[f'{prefix}sem_cls_scores'] = torch.cat([end_points[f'{ppx}sem_cls_scores'] for ppx in _prefixes], 1) end_points[f'{prefix}objectness_scores'] = torch.cat([end_points[f'{ppx}objectness_scores'] for ppx in _prefixes], 1) batch_pred_map_cls = parse_predictions(end_points, CONFIG_DICT, prefix, size_cls_agnostic=args.size_cls_agnostic) batch_gt_map_cls = parse_groundtruths(end_points, CONFIG_DICT, size_cls_agnostic=args.size_cls_agnostic) batch_pred_map_cls_dict[prefix].append(batch_pred_map_cls) batch_gt_map_cls_dict[prefix].append(batch_gt_map_cls) if (batch_idx + 1) % 10 == 0: logger.info(f'T[{time}] Eval: [{batch_idx + 1}/{len(test_loader)}] ' + ''.join( [f'{key} {stat_dict[key] / (float(batch_idx + 1)):.4f} \t' for key in sorted(stat_dict.keys()) if 'loss' not in key])) logger.info(''.join([f'{key} {stat_dict[key] / (float(batch_idx + 1)):.4f} \t' for key in sorted(stat_dict.keys()) if 'loss' in key and 'proposal_' not in key and 'last_' not in key and 'head_' not in key])) logger.info(''.join([f'{key} {stat_dict[key] / (float(batch_idx + 1)):.4f} \t' for key in sorted(stat_dict.keys()) if 'last_' in key])) logger.info(''.join([f'{key} {stat_dict[key] / (float(batch_idx + 1)):.4f} \t' for key in sorted(stat_dict.keys()) if 'proposal_' in key])) for ihead in range(args.num_decoder_layers - 2, -1, -1): logger.info(''.join([f'{key} {stat_dict[key] / (float(batch_idx + 1)):.4f} \t' for key in sorted(stat_dict.keys()) if f'{ihead}head_' in key])) for prefix in prefixes: for (batch_pred_map_cls, batch_gt_map_cls) in zip(batch_pred_map_cls_dict[prefix], batch_gt_map_cls_dict[prefix]): for ap_calculator in ap_calculator_list: ap_calculator.step(batch_pred_map_cls, batch_gt_map_cls) # Evaluate average precision for i, ap_calculator in enumerate(ap_calculator_list): metrics_dict = ap_calculator.compute_metrics() logger.info(f'===================>T{time} {prefix} IOU THRESH: {AP_IOU_THRESHOLDS[i]}<==================') for key in metrics_dict: logger.info(f'{key} {metrics_dict[key]}') mAPs[i][1][prefix] = metrics_dict['mAP'] ap_calculator.reset() for mAP in mAPs: logger.info(f'T[{time}] IoU[{mAP[0]}]: ' + ''.join([f'{key}: {mAP[1][key]:.4f} \t' for key in sorted(mAP[1].keys())])) return mAPs