shuffle=True) #, num_workers=int(opt.workers)) print(len(dataset)) num_classes = len(dataset.classes) print('classes', num_classes) try: os.makedirs(opt.outf) except OSError: pass classifier = PointNetCls(k=num_classes, feature_transform=opt.feature_transform) if opt.model != '': classifier.load_state_dict(torch.load(opt.model)) optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999)) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) classifier.cuda() num_batch = len(dataset) / opt.batchSize best_val = 0 start_time = time.time() for epoch in range(opt.nepoch): scheduler.step() train_correct = 0 total_trainset = 0 for i, data in enumerate(dataloader, 0): points, target = data
def main(args): blue = lambda x: '\033[94m' + x + '\033[0m' seeding(args.seed) if args.hfta: B = consolidate_hyperparams_and_determine_B( args, ['lr', 'beta1', 'beta2', 'weight_decay', 'gamma', 'step_size'], ) else: B = 0 (args.lr, args.beta1, args.beta2, args.weight_decay, args.gamma, args.step_size) = (args.lr[0], args.beta1[0], args.beta2[0], args.weight_decay[0], args.gamma[0], args.step_size[0]) if args.device == 'cuda': assert torch.cuda.is_available() torch.backends.cudnn.benchmark = True print('Enable cuDNN heuristics!') device = (xm.xla_device() if args.device == 'xla' else torch.device(args.device)) dataset, test_dataset = build_dataset(args) dataloader, testdataloader = build_dataloader(args, dataset, test_dataset) print('len(dataset)={}'.format(len(dataset)), 'len(test_dataset)={}'.format(len(test_dataset))) num_classes = len(dataset.classes) print('classes', num_classes) if args.outf is not None: try: os.makedirs(args.outf) except OSError: pass classifier = PointNetCls( k=num_classes, feature_transform=args.feature_transform, B=B, track_running_stats=(args.device != 'xla'), ) if args.model != '': classifier.load_state_dict(torch.load(args.model)) optimizer = get_hfta_optim_for(optim.Adam, B=B)( classifier.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay, ) scheduler = get_hfta_lr_scheduler_for(optim.lr_scheduler.StepLR, B=B)( optimizer, step_size=args.step_size, gamma=args.gamma, ) scaler = amp.GradScaler(enabled=(args.device == 'cuda' and args.amp)) classifier.to(device) num_batch = len(dataloader) def loss_fn(output, label, batch_size, trans_feat): if B > 0: loss = B * F.nll_loss(output.view(B * batch_size, -1), label) else: loss = F.nll_loss(output, label) if args.feature_transform: loss += feature_transform_regularizer(trans_feat) * 0.001 return loss classifier = classifier.train() epoch_timer = EpochTimer() # Training loop for epoch in range(args.epochs): num_samples_per_epoch = 0 epoch_timer.epoch_start(epoch) for i, data in enumerate(dataloader, 0): if i > args.iters_per_epoch: break if args.warmup_data_loading: continue points, target = data target = target[:, 0] points, target = points.to(device), target.to(device) N = points.size(0) if B > 0: points = points.unsqueeze(0).expand(B, -1, -1, -1).contiguous() target = target.repeat(B) optimizer.zero_grad(set_to_none=True) if args.device == 'cuda': with amp.autocast(enabled=args.amp): pred, trans, trans_feat = classifier(points) loss = loss_fn(pred, target, N, trans_feat) scaler.scale(loss).backward() scaler.step(optimizer) else: pred, trans, trans_feat = classifier(points) loss = loss_fn(pred, target, N, trans_feat) loss.backward() if args.device == 'xla': xm.optimizer_step(optimizer, barrier=True) else: optimizer.step() print('[{}: {}/{}] train loss: {}'.format(epoch, i, num_batch, loss.item())) num_samples_per_epoch += N * max(B, 1) scaler.update() scheduler.step() epoch_timer.epoch_stop(num_samples_per_epoch) print('Epoch {} took {} s!'.format(epoch, epoch_timer.epoch_latency(epoch))) if args.device == 'xla' and not args.eval: print(met.metrics_report()) if args.outf is not None: epoch_timer.to_csv(args.outf) if args.eval: # Run validation loop. print("Running validation loop ...") classifier = classifier.eval() with torch.no_grad(): total_correct = torch.zeros(max(B, 1), device=device) total_testset = 0 for data in testdataloader: if args.warmup_data_loading: continue points, target = data target = target[:, 0] points, target = points.to(device), target.to(device) N = points.size(0) if B > 0: points = points.unsqueeze(0).expand(B, -1, -1, -1).contiguous() target = target.repeat(B) pred, _, _ = classifier(points) pred_choice = pred.argmax(-1) correct = pred_choice.eq( target.view(B, N) if B > 0 else target).sum(-1) total_correct.add_(correct) total_testset += N final_accuracy = total_correct / total_testset final_accuracy = final_accuracy.cpu().tolist() if args.outf is not None: pd.DataFrame({ 'acc': final_accuracy }).to_csv(os.path.join(args.outf, 'eval.csv')) # Return test_accuracy return final_accuracy