def main(): args = parser.parse_args() with open(args.config) as f: config = yaml.load(f) print("\n**************************") for k, v in config['common'].items(): setattr(args, k, v) print('\n[%s]:'%(k), v) print("\n**************************\n") try: os.makedirs(args.save_path) except OSError: pass train_transforms = transforms.Compose([ d_utils.PointcloudToTensor() ]) test_transforms = transforms.Compose([ d_utils.PointcloudToTensor() ]) train_dataset = ShapeNetPart(root = args.data_root, num_points = args.num_points, split = 'trainval', normalize = True, transforms = train_transforms) train_dataloader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=int(args.workers), pin_memory=True ) global test_dataset test_dataset = ShapeNetPart(root = args.data_root, num_points = args.num_points, split = 'test', normalize = True, transforms = test_transforms) test_dataloader = DataLoader( test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=int(args.workers), pin_memory=True ) model = RSCNN_MSN(num_classes = args.num_classes, input_channels = args.input_channels, relation_prior = args.relation_prior, use_xyz = True) model.cuda() optimizer = optim.Adam( model.parameters(), lr=args.base_lr, weight_decay=args.weight_decay) lr_lbmd = lambda e: max(args.lr_decay**(e // args.decay_step), args.lr_clip / args.base_lr) bnm_lmbd = lambda e: max(args.bn_momentum * args.bn_decay**(e // args.decay_step), args.bnm_clip) lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lbmd) bnm_scheduler = pt_utils.BNMomentumScheduler(model, bnm_lmbd) if args.checkpoint is not '': model.load_state_dict(torch.load(args.checkpoint)) print('Load model successfully: %s' % (args.checkpoint)) criterion = nn.CrossEntropyLoss() num_batch = len(train_dataset)/args.batch_size # training train(train_dataloader, test_dataloader, model, criterion, optimizer, lr_scheduler, bnm_scheduler, args, num_batch)
def main(): args = parser.parse_args() with open(args.config) as f: config = yaml.load(f) for k, v in config['common'].items(): setattr(args, k, v) test_transforms = transforms.Compose([ d_utils.PointcloudToTensor() ]) test_dataset = ShapeNetPart(root = args.data_root, num_points = args.num_points, split = 'test', normalize = True, transforms = test_transforms) test_dataloader = DataLoader( test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=int(args.workers), pin_memory=True ) model = RSCNN_MSN(num_classes = args.num_classes, input_channels = args.input_channels, relation_prior = args.relation_prior, use_xyz = True) model.cuda() if args.checkpoint is not '': model.load_state_dict(torch.load(args.checkpoint)) print('Load model successfully: %s' % (args.checkpoint)) # evaluate PointcloudScale = d_utils.PointcloudScale(scale_low=0.87, scale_high=1.15) # initialize random scaling model.eval() global_Class_mIoU, global_Inst_mIoU = 0, 0 seg_classes = test_dataset.seg_classes seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table} for cat in seg_classes.keys(): for label in seg_classes[cat]: seg_label_to_cat[label] = cat for i in range(NUM_REPEAT): shape_ious = {cat:[] for cat in seg_classes.keys()} for _, data in enumerate(test_dataloader, 0): points, target, cls = data points, target = Variable(points, volatile=True), Variable(target, volatile=True) points, target = points.cuda(), target.cuda() batch_one_hot_cls = np.zeros((len(cls), 16)) # 16 object classes for b in range(len(cls)): batch_one_hot_cls[b, int(cls[b])] = 1 batch_one_hot_cls = torch.from_numpy(batch_one_hot_cls) batch_one_hot_cls = Variable(batch_one_hot_cls.float().cuda()) pred = 0 new_points = Variable(torch.zeros(points.size()[0], points.size()[1], points.size()[2]).cuda(), volatile=True) for v in range(NUM_VOTE): if v > 0: new_points.data = PointcloudScale(points.data) pred += F.softmax(model(new_points, batch_one_hot_cls), dim = 2) pred /= NUM_VOTE pred = pred.data.cpu() target = target.data.cpu() pred_val = torch.zeros(len(cls), args.num_points).type(torch.LongTensor) # pred to the groundtruth classes (selected by seg_classes[cat]) for b in range(len(cls)): cat = seg_label_to_cat[target[b, 0]] logits = pred[b, :, :] # (num_points, num_classes) pred_val[b, :] = logits[:, seg_classes[cat]].max(1)[1] + seg_classes[cat][0] for b in range(len(cls)): segp = pred_val[b, :] segl = target[b, :] cat = seg_label_to_cat[segl[0]] part_ious = [0.0 for _ in range(len(seg_classes[cat]))] for l in seg_classes[cat]: if torch.sum((segl == l) | (segp == l)) == 0: # part is not present in this shape part_ious[l - seg_classes[cat][0]] = 1.0 else: part_ious[l - seg_classes[cat][0]] = torch.sum((segl == l) & (segp == l)) / float(torch.sum((segl == l) | (segp == l))) shape_ious[cat].append(np.mean(part_ious)) instance_ious = [] for cat in shape_ious.keys(): for iou in shape_ious[cat]: instance_ious.append(iou) shape_ious[cat] = np.mean(shape_ious[cat]) mean_class_ious = np.mean(list(shape_ious.values())) print('\n------ Repeat %3d ------' % (i + 1)) for cat in sorted(shape_ious.keys()): print('%s: %0.6f'%(cat, shape_ious[cat])) print('Class_mIoU: %0.6f' % (mean_class_ious)) print('Instance_mIoU: %0.6f' % (np.mean(instance_ious))) if mean_class_ious > global_Class_mIoU: global_Class_mIoU = mean_class_ious global_Inst_mIoU = np.mean(instance_ious) print('\nBest voting Class_mIoU = %0.6f, Instance_mIoU = %0.6f' % (global_Class_mIoU, global_Inst_mIoU))
def train(args, io): train_dataset = ShapeNetPart(partition='trainval', num_points=args.num_points, class_choice=args.class_choice) if (len(train_dataset) < 100): drop_last = False else: drop_last = True train_loader = DataLoader(train_dataset, num_workers=8, batch_size=args.batch_size, shuffle=True, drop_last=drop_last) test_loader = DataLoader(ShapeNetPart(partition='test', num_points=args.num_points, class_choice=args.class_choice), num_workers=8, batch_size=args.test_batch_size, shuffle=False, drop_last=False) device = torch.device("cuda" if args.cuda else "cpu") io.cprint("Let's use" + str(torch.cuda.device_count()) + "GPUs!") seg_num_all = train_loader.dataset.seg_num_all seg_start_index = train_loader.dataset.seg_start_index # create model model = CurveNet().to(device) model = nn.DataParallel(model) if args.use_sgd: print("Use SGD") opt = optim.SGD(model.parameters(), lr=args.lr * 100, momentum=args.momentum, weight_decay=1e-4) else: print("Use Adam") opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) if args.scheduler == 'cos': scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=1e-3) elif args.scheduler == 'step': scheduler = MultiStepLR(opt, [140, 180], gamma=0.1) criterion = cal_loss best_test_iou = 0 for epoch in range(args.epochs): #################### # Train #################### train_loss = 0.0 count = 0.0 model.train() train_true_cls = [] train_pred_cls = [] train_true_seg = [] train_pred_seg = [] train_label_seg = [] for data, label, seg in train_loader: seg = seg - seg_start_index label_one_hot = np.zeros((label.shape[0], 16)) for idx in range(label.shape[0]): label_one_hot[idx, label[idx]] = 1 label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32)) data, label_one_hot, seg = data.to(device), label_one_hot.to( device), seg.to(device) data = data.permute(0, 2, 1) batch_size = data.size()[0] opt.zero_grad() seg_pred = model(data, label_one_hot) seg_pred = seg_pred.permute(0, 2, 1).contiguous() loss = criterion(seg_pred.view(-1, seg_num_all), seg.view(-1, 1).squeeze()) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1) opt.step() pred = seg_pred.max(dim=2)[1] # (batch_size, num_points) count += batch_size train_loss += loss.item() * batch_size seg_np = seg.cpu().numpy() # (batch_size, num_points) pred_np = pred.detach().cpu().numpy() # (batch_size, num_points) train_true_cls.append( seg_np.reshape(-1)) # (batch_size * num_points) train_pred_cls.append( pred_np.reshape(-1)) # (batch_size * num_points) train_true_seg.append(seg_np) train_pred_seg.append(pred_np) train_label_seg.append(label.reshape(-1)) if args.scheduler == 'cos': scheduler.step() elif args.scheduler == 'step': if opt.param_groups[0]['lr'] > 1e-5: scheduler.step() if opt.param_groups[0]['lr'] < 1e-5: for param_group in opt.param_groups: param_group['lr'] = 1e-5 train_true_cls = np.concatenate(train_true_cls) train_pred_cls = np.concatenate(train_pred_cls) train_acc = metrics.accuracy_score(train_true_cls, train_pred_cls) avg_per_class_acc = metrics.balanced_accuracy_score( train_true_cls, train_pred_cls) train_true_seg = np.concatenate(train_true_seg, axis=0) train_pred_seg = np.concatenate(train_pred_seg, axis=0) train_label_seg = np.concatenate(train_label_seg) train_ious = calculate_shape_IoU(train_pred_seg, train_true_seg, train_label_seg, args.class_choice) outstr = 'Train %d, loss: %.6f, train acc: %.6f, train avg acc: %.6f, train iou: %.6f' % ( epoch, train_loss * 1.0 / count, train_acc, avg_per_class_acc, np.mean(train_ious)) io.cprint(outstr) #################### # Test #################### test_loss = 0.0 count = 0.0 model.eval() test_true_cls = [] test_pred_cls = [] test_true_seg = [] test_pred_seg = [] test_label_seg = [] for data, label, seg in test_loader: seg = seg - seg_start_index label_one_hot = np.zeros((label.shape[0], 16)) for idx in range(label.shape[0]): label_one_hot[idx, label[idx]] = 1 label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32)) data, label_one_hot, seg = data.to(device), label_one_hot.to( device), seg.to(device) data = data.permute(0, 2, 1) batch_size = data.size()[0] seg_pred = model(data, label_one_hot) seg_pred = seg_pred.permute(0, 2, 1).contiguous() loss = criterion(seg_pred.view(-1, seg_num_all), seg.view(-1, 1).squeeze()) pred = seg_pred.max(dim=2)[1] count += batch_size test_loss += loss.item() * batch_size seg_np = seg.cpu().numpy() pred_np = pred.detach().cpu().numpy() test_true_cls.append(seg_np.reshape(-1)) test_pred_cls.append(pred_np.reshape(-1)) test_true_seg.append(seg_np) test_pred_seg.append(pred_np) test_label_seg.append(label.reshape(-1)) test_true_cls = np.concatenate(test_true_cls) test_pred_cls = np.concatenate(test_pred_cls) test_acc = metrics.accuracy_score(test_true_cls, test_pred_cls) avg_per_class_acc = metrics.balanced_accuracy_score( test_true_cls, test_pred_cls) test_true_seg = np.concatenate(test_true_seg, axis=0) test_pred_seg = np.concatenate(test_pred_seg, axis=0) test_label_seg = np.concatenate(test_label_seg) test_ious = calculate_shape_IoU(test_pred_seg, test_true_seg, test_label_seg, args.class_choice) outstr = 'Test %d, loss: %.6f, test acc: %.6f, test avg acc: %.6f, test iou: %.6f, best iou %.6f' % ( epoch, test_loss * 1.0 / count, test_acc, avg_per_class_acc, np.mean(test_ious), best_test_iou) io.cprint(outstr) if np.mean(test_ious) >= best_test_iou: best_test_iou = np.mean(test_ious) torch.save(model.state_dict(), '../checkpoints/%s/models/model.t7' % args.exp_name)
def test(args, io): test_loader = DataLoader(ShapeNetPart(partition='test', num_points=args.num_points, class_choice=args.class_choice), batch_size=args.test_batch_size, shuffle=True, drop_last=False) device = torch.device("cuda" if args.cuda else "cpu") #Try to load models seg_start_index = test_loader.dataset.seg_start_index model = CurveNet().to(device) model = nn.DataParallel(model) model.load_state_dict(torch.load(args.model_path)) model = model.eval() test_acc = 0.0 test_true_cls = [] test_pred_cls = [] test_true_seg = [] test_pred_seg = [] test_label_seg = [] category = {} for data, label, seg in test_loader: seg = seg - seg_start_index label_one_hot = np.zeros((label.shape[0], 16)) for idx in range(label.shape[0]): label_one_hot[idx, label[idx]] = 1 label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32)) data, label_one_hot, seg = data.to(device), label_one_hot.to( device), seg.to(device) data = data.permute(0, 2, 1) seg_pred = model(data, label_one_hot) seg_pred = seg_pred.permute(0, 2, 1).contiguous() pred = seg_pred.max(dim=2)[1] seg_np = seg.cpu().numpy() pred_np = pred.detach().cpu().numpy() test_true_cls.append(seg_np.reshape(-1)) test_pred_cls.append(pred_np.reshape(-1)) test_true_seg.append(seg_np) test_pred_seg.append(pred_np) test_label_seg.append(label.reshape(-1)) test_true_cls = np.concatenate(test_true_cls) test_pred_cls = np.concatenate(test_pred_cls) test_acc = metrics.accuracy_score(test_true_cls, test_pred_cls) avg_per_class_acc = metrics.balanced_accuracy_score( test_true_cls, test_pred_cls) test_true_seg = np.concatenate(test_true_seg, axis=0) test_pred_seg = np.concatenate(test_pred_seg, axis=0) test_label_seg = np.concatenate(test_label_seg) test_ious, category = calculate_shape_IoU(test_pred_seg, test_true_seg, test_label_seg, args.class_choice, eva=True) outstr = 'Test :: test acc: %.6f, test avg acc: %.6f, test iou: %.6f' % ( test_acc, avg_per_class_acc, np.mean(test_ious)) io.cprint(outstr) results = [] for key in category.keys(): results.append((int(key), np.mean(category[key]), len(category[key]))) results.sort(key=lambda x: x[0]) for re in results: io.cprint('idx: %d mIoU: %.3f num: %d' % (re[0], re[1], re[2]))
def train(args, configpath): io = init(args, configpath) train_dataset = ShapeNetPart(partition='trainval', num_points=args.num_points) if (len(train_dataset) < 100): drop_last = False else: drop_last = True train_loader = DataLoader(train_dataset, num_workers=8, batch_size=args.batch_size, shuffle=True, drop_last=drop_last) test_loader = DataLoader(ShapeNetPart(partition='test', num_points=args.num_points), num_workers=8, batch_size=args.test_batch_size, shuffle=True, drop_last=False) seg_num_all = train_loader.dataset.seg_num_all seg_start_index = train_loader.dataset.seg_start_index device = torch.device("cuda" if args.cuda else "cpu") if args.model == 'consnet': model = ConsNet(args, seg_num_all).to(device) elif args.model == 'pretrain': model = ConsNet(args, seg_num_all).to(device) model.load_state_dict(torch.load(args.pretrain_path)) else: raise Exception("Not implemented") if args.parallel == True: model = nn.DataParallel(model) print(str(model)) if args.use_sgd: print("Use SGD") opt = optim.SGD(model.parameters(), lr=args.lr * 100, momentum=args.momentum, weight_decay=1e-4) cur_lr = args.lr * 100 else: print("Use Adam") opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) cur_lr = args.lr if args.scheduler == 'cos': scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=1e-3) print('Use COS') elif args.scheduler == 'step': scheduler = StepLR(opt, step_size=20, gamma=0.7) print('Use Step') if args.loss == 'l1loss': print('Use L1 Loss') elif args.loss == 'chamfer': print('Use Chamfer Distance') else: print('Not implemented') io.cprint('Experiment: %s' % args.exp_name) # Train min_loss = 100 io.cprint('Begin to train...') for epoch in range(args.epochs): io.cprint( '=====================================Epoch %d========================================' % epoch) io.cprint('*****Train*****') # Train model.train() train_loss = 0 for i, point in enumerate(train_loader): data, label, seg = point if epoch < 5: lr = 0.18 * cur_lr * epoch + 0.1 * cur_lr adjust_learning_rate(opt, lr) if args.task == '1obj_rotate': data1, data2, label1, label2 = obj_rotate_perm( data, label) # (B, N, 3) elif args.task == '2obj': data1, data2, label1, label2 = obj_2_perm(data, label) # (B, N, 3) elif args.task == 'alter': if epoch % 2 == 0: data1, data2, label1, label2 = obj_rotate_perm( data, label) # (B, N, 3) else: data1, data2, label1, label2 = obj_2_perm( data, label) # (B, N, 3) else: print('Task not implemented!') exit(0) if args.mixup == 'emd': mixup_data = emd_mixup(data1, data2) # (B, N, 3) elif args.mixup == 'add': mixup_data = add_mixup(data1, data2) # (B, N, 3) mixup_data = mixup_data.permute(0, 2, 1) # (B, 3, N) batch_size = mixup_data.size()[0] seg = seg - seg_start_index if args.use_one_hot: label_one_hot1 = np.zeros((batch_size, 16)) label_one_hot2 = np.zeros((batch_size, 16)) for idx in range(batch_size): label_one_hot1[idx, label1[idx]] = 1 label_one_hot2[idx, label2[idx]] = 1 label_one_hot1 = torch.from_numpy( label_one_hot1.astype(np.float32)) label_one_hot2 = torch.from_numpy( label_one_hot2.astype(np.float32)) else: label_one_hot1 = torch.rand(batch_size, 16) label_one_hot2 = torch.rand(batch_size, 16) data, label_one_hot1, label_one_hot2, seg = data.to( device), label_one_hot1.to(device), label_one_hot2.to( device), seg.to(device) # Project proj1 = rand_proj(data1) proj2 = rand_proj(data2) # Train opt.zero_grad() pred1 = model(mixup_data, proj1, label_one_hot1).permute(0, 2, 1) # (B, N, 3) pred2 = model(mixup_data, proj2, label_one_hot2).permute(0, 2, 1) # (B, N, 3) if args.loss == 'l1loss': loss = L1_loss(pred1, data1) + L1_loss(pred2, data2) elif args.loss == 'chamfer': loss1 = chamfer_distance(pred1, data1) + chamfer_distance( data1, pred1) loss2 = chamfer_distance(pred2, data2) + chamfer_distance( data2, pred2) loss = loss1 + loss2 elif args.loss == 'emd': loss = emd_loss(pred1, data1) + emd_loss(pred2, data2) elif args.loss == 'emd2': loss = emd_loss_2(pred1, data1) + emd_loss_2(pred2, data2) else: raise NotImplementedError if args.l2loss: l2_loss = nn.MSELoss()(pred1, data1) + nn.MSELoss()(pred2, data2) loss += args.l2_param * l2_loss loss.backward() train_loss = train_loss + loss.item() opt.step() if (i + 1) % 100 == 0: io.cprint('iters %d, tarin loss: %.6f' % (i, train_loss / i)) io.cprint('Learning rate: %.6f' % (opt.param_groups[0]['lr'])) if args.scheduler == 'cos': scheduler.step() elif args.scheduler == 'step': if opt.param_groups[0]['lr'] > 1e-5: scheduler.step() if opt.param_groups[0]['lr'] < 1e-5: for param_group in opt.param_groups: param_group['lr'] = 1e-5 # Test if args.valid: io.cprint('*****Test*****') test_loss = 0 model.eval() for data, label, seg in test_loader: with torch.no_grad(): if args.task == '1obj_rotate': data1, data2, label1, label2 = obj_rotate_perm( data, label) # (B, N, 3) elif args.task == '2obj': data1, data2, label1, label2 = obj_2_perm( data, label) # (B, N, 3) elif args.task == 'alter': if epoch % 2 == 0: data1, data2, label1, label2 = obj_rotate_perm( data, label) # (B, N, 3) else: data1, data2, label1, label2 = obj_2_perm( data, label) # (B, N, 3) else: print('Task not implemented!') exit(0) if args.mixup == 'emd': mixup_data = emd_mixup(data1, data2) # (B, N, 3) elif args.mixup == 'add': mixup_data = add_mixup(data1, data2) # (B, N, 3) mixup_data = mixup_data.permute(0, 2, 1) # (B, 3, N) batch_size = mixup_data.size()[0] seg = seg - seg_start_index label_one_hot1 = np.zeros((batch_size, 16)) label_one_hot2 = np.zeros((batch_size, 16)) for idx in range(batch_size): label_one_hot1[idx, label1[idx]] = 1 label_one_hot2[idx, label2[idx]] = 1 label_one_hot1 = torch.from_numpy( label_one_hot1.astype(np.float32)) label_one_hot2 = torch.from_numpy( label_one_hot2.astype(np.float32)) data, label_one_hot1, label_one_hot2, seg = data.to( device), label_one_hot1.to(device), label_one_hot2.to( device), seg.to(device) proj1 = rand_proj(data1) proj2 = rand_proj(data2) pred1 = model(mixup_data, proj1, label_one_hot1).permute(0, 2, 1) # (B, N, 3) pred2 = model(mixup_data, proj2, label_one_hot2).permute(0, 2, 1) # (B, N, 3) if args.loss == 'l1loss': loss = L1_loss(pred1, data1) + L1_loss(pred2, data2) elif args.loss == 'chamfer': loss1 = chamfer_distance( pred1, data1) + chamfer_distance(data1, pred1) loss2 = chamfer_distance( pred2, data2) + chamfer_distance(data2, pred2) loss = loss1 + loss2 elif args.loss == 'emd': loss = emd_loss(pred1, data1) + emd_loss(pred2, data2) elif args.loss == 'emd2': loss = emd_loss_2(pred1, data1) + emd_loss_2( pred2, data2) else: raise NotImplementedError test_loss = test_loss + loss.item() io.cprint( 'Train loss: %.6f, Test loss: %.6f' % (train_loss / len(train_loader), test_loss / len(test_loader))) cur_loss = test_loss / len(test_loader) if cur_loss < min_loss: min_loss = cur_loss torch.save( model.state_dict(), 'checkpoints/%s/best_%s.pkl' % (args.exp_name, args.exp_name)) if (epoch + 1) % 10 == 0: torch.save( model.state_dict(), 'checkpoints/%s/%s_epoch_%s.pkl' % (args.exp_name, args.exp_name, str(epoch))) torch.save(model.state_dict(), 'checkpoints/%s/%s.pkl' % (args.exp_name, args.exp_name))
def main(): args = parser.parse_args() with open(args.config) as f: config = yaml.load(f) print("\n**************************") for k, v in config['common'].items(): setattr(args, k, v) print('\n[%s]:' % (k), v) print("\n**************************\n") try: os.makedirs(args.save_path) except OSError: pass train_dataset = ShapeNetPart(root=args.data_root, num_points=args.num_points, split='trainval', normalize=True) train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=int(args.workers), pin_memory=True) global test_dataset_z test_dataset_z = ShapeNetPart(root=args.data_root, num_points=args.num_points, split='test', normalize=True) test_dataloader_z = DataLoader(test_dataset_z, batch_size=args.batch_size, shuffle=False, num_workers=int(args.workers), pin_memory=True) global test_dataset_so3 test_dataset_so3 = ShapeNetPart(root=args.data_root, num_points=args.num_points, split='test', normalize=True) test_dataloader_so3 = DataLoader(test_dataset_so3, batch_size=args.batch_size, shuffle=False, num_workers=int(args.workers), pin_memory=True) if args.model == "pointnet2_ssn": model = PointNet2_SSN(num_classes=args.num_classes) model.cuda() elif args.model == "rscnn_msn": model = RSCNN_MSN(num_classes=args.num_classes) model.cuda() model = torch.nn.DataParallel(model) else: print("Doesn't support this model") return 0 optimizer = optim.Adam(model.parameters(), lr=args.base_lr, weight_decay=args.weight_decay) lr_lbmd = lambda e: max(args.lr_decay**(e // args.decay_step), args.lr_clip / args.base_lr) bnm_lmbd = lambda e: max( args.bn_momentum * args.bn_decay** (e // args.decay_step), args.bnm_clip) lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lbmd) bnm_scheduler = pt_utils.BNMomentumScheduler(model, bnm_lmbd) if args.checkpoint is not '': model.load_state_dict(torch.load(args.checkpoint)) print('Load model successfully: %s' % (args.checkpoint)) criterion = nn.CrossEntropyLoss() num_batch = len(train_dataset) / args.batch_size # training # train(train_dataloader, test_dataloader_z, test_dataloader_so3, model, criterion, optimizer, lr_scheduler, bnm_scheduler, args, num_batch) validate(test_dataloader_so3, model, criterion, args, 1, 'so3')
def test(args, io): test_loader = DataLoader(ShapeNetPart(partition='test', num_points=args.num_points, class_choice=args.class_choice), batch_size=args.test_batch_size, shuffle=True, drop_last=False) device = torch.device("cuda" if args.cuda else "cpu") #Try to load models seg_num_all = test_loader.dataset.seg_num_all seg_start_index = test_loader.dataset.seg_start_index if args.model == 'dgcnn': model = DGCNN_partseg(args, seg_num_all).to(device) else: raise Exception("Not implemented") model = nn.DataParallel(model) model.load_state_dict(torch.load(args.model_path)) model = model.eval() test_acc = 0.0 count = 0.0 test_true_cls = [] test_pred_cls = [] test_true_seg = [] test_pred_seg = [] test_label_seg = [] for data, label, seg in test_loader: seg = seg - seg_start_index label_one_hot = np.zeros((label.shape[0], 16)) for idx in range(label.shape[0]): label_one_hot[idx, label[idx]] = 1 label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32)) data, label_one_hot, seg = data.to(device), label_one_hot.to( device), seg.to(device) data = data.permute(0, 2, 1) batch_size = data.size()[0] seg_pred = model(data, label_one_hot) seg_pred = seg_pred.permute(0, 2, 1).contiguous() pred = seg_pred.max(dim=2)[1] seg_np = seg.cpu().numpy() pred_np = pred.detach().cpu().numpy() test_true_cls.append(seg_np.reshape(-1)) test_pred_cls.append(pred_np.reshape(-1)) test_true_seg.append(seg_np) test_pred_seg.append(pred_np) test_label_seg.append(label.reshape(-1)) test_true_cls = np.concatenate(test_true_cls) test_pred_cls = np.concatenate(test_pred_cls) test_acc = metrics.accuracy_score(test_true_cls, test_pred_cls) avg_per_class_acc = metrics.balanced_accuracy_score( test_true_cls, test_pred_cls) test_true_seg = np.concatenate(test_true_seg, axis=0) test_pred_seg = np.concatenate(test_pred_seg, axis=0) test_label_seg = np.concatenate(test_label_seg) test_ious = calculate_shape_IoU(test_pred_seg, test_true_seg, test_label_seg, args.class_choice) outstr = 'Test :: test acc: %.6f, test avg acc: %.6f, test iou: %.6f' % ( test_acc, avg_per_class_acc, np.mean(test_ious)) io.cprint(outstr)