def train_net(cfg): # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use # train_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TRAIN_DATASET](cfg) # test_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg) train_dataset_loader = dataloader_jt.DATASET_LOADER_MAPPING[ cfg.DATASET.TRAIN_DATASET](cfg) test_dataset_loader = dataloader_jt.DATASET_LOADER_MAPPING[ cfg.DATASET.TEST_DATASET](cfg) train_data_loader = train_dataset_loader.get_dataset( dataloader_jt.DatasetSubset.TRAIN, batch_size=cfg.TRAIN.BATCH_SIZE, num_workers=cfg.CONST.NUM_WORKERS, shuffle=True) val_data_loader = test_dataset_loader.get_dataset( dataloader_jt.DatasetSubset.VAL, batch_size=cfg.TRAIN.BATCH_SIZE, num_workers=cfg.CONST.NUM_WORKERS, shuffle=False) # Set up folders for logs and checkpoints output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s', datetime.now().isoformat()) cfg.DIR.CHECKPOINTS = output_dir % 'checkpoints' cfg.DIR.LOGS = output_dir % 'logs' if not os.path.exists(cfg.DIR.CHECKPOINTS): os.makedirs(cfg.DIR.CHECKPOINTS) # Create tensorboard writers train_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'train')) val_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'test')) model = Model(dataset=cfg.DATASET.TRAIN_DATASET) init_epoch = 0 best_metrics = float('inf') optimizer = nn.Adam(model.parameters(), lr=cfg.TRAIN.LEARNING_RATE, weight_decay=cfg.TRAIN.WEIGHT_DECAY, betas=cfg.TRAIN.BETAS) lr_scheduler = jittor.lr_scheduler.MultiStepLR( optimizer, milestones=cfg.TRAIN.LR_MILESTONES, gamma=cfg.TRAIN.GAMMA, last_epoch=init_epoch) # Training/Testing the network for epoch_idx in range(init_epoch + 1, cfg.TRAIN.N_EPOCHS + 1): epoch_start_time = time() model.train() loss_metric = AverageMeter() n_batches = len(train_data_loader) print('epoch: ', epoch_idx, 'optimizer: ', lr_scheduler.get_lr()) with tqdm(train_data_loader) as t: for batch_idx, (taxonomy_ids, model_ids, data) in enumerate(t): partial = jittor.array(data['partial_cloud']) gt = jittor.array(data['gtcloud']) pcds, deltas = model(partial) cd1 = chamfer(pcds[0], gt) cd2 = chamfer(pcds[1], gt) cd3 = chamfer(pcds[2], gt) loss_cd = cd1 + cd2 + cd3 delta_losses = [] for delta in deltas: delta_losses.append(jittor.sum(delta**2)) loss_pmd = jittor.sum(jittor.stack(delta_losses)) / 3 loss = loss_cd * cfg.TRAIN.LAMBDA_CD + loss_pmd * cfg.TRAIN.LAMBDA_PMD optimizer.step(loss) loss_item = loss.item() loss_metric.update(loss_item) jittor.sync_all() t.set_description( '[Epoch %d/%d][Batch %d/%d]' % (epoch_idx, cfg.TRAIN.N_EPOCHS, batch_idx + 1, n_batches)) t.set_postfix(loss='%s' % ['%.4f' % l for l in [loss_item]]) lr_scheduler.step() epoch_end_time = time() train_writer.add_scalar('Loss/Epoch/loss', loss_metric.avg(), epoch_idx) logging.info( '[Epoch %d/%d] EpochTime = %.3f (s) Losses = %s' % (epoch_idx, cfg.TRAIN.N_EPOCHS, epoch_end_time - epoch_start_time, ['%.4f' % l for l in [loss_metric.avg()]])) # Validate the current model cd_eval = test_net(cfg, epoch_idx, val_data_loader, val_writer, model) # Save checkpoints if epoch_idx % cfg.TRAIN.SAVE_FREQ == 0 or cd_eval < best_metrics: file_name = 'ckpt-best.pkl' if cd_eval < best_metrics else 'ckpt-epoch-%03d.pkl' % epoch_idx output_path = os.path.join(cfg.DIR.CHECKPOINTS, file_name) model.save(output_path) logging.info('Saved checkpoint to %s ...' % output_path) if cd_eval < best_metrics: best_metrics = cd_eval train_writer.close() val_writer.close()
def train_net(cfg): # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use torch.backends.cudnn.benchmark = True train_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[ cfg.DATASET.TRAIN_DATASET](cfg) test_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[ cfg.DATASET.TEST_DATASET](cfg) train_data_loader = torch.utils.data.DataLoader( dataset=train_dataset_loader.get_dataset( utils.data_loaders.DatasetSubset.TRAIN), batch_size=cfg.TRAIN.BATCH_SIZE, num_workers=cfg.CONST.NUM_WORKERS, collate_fn=utils.data_loaders.collate_fn, pin_memory=True, shuffle=True, drop_last=True) val_data_loader = torch.utils.data.DataLoader( dataset=test_dataset_loader.get_dataset( utils.data_loaders.DatasetSubset.TEST), batch_size=cfg.TRAIN.BATCH_SIZE, num_workers=cfg.CONST.NUM_WORKERS // 2, collate_fn=utils.data_loaders.collate_fn, pin_memory=True, shuffle=False) # Set up folders for logs and checkpoints output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s', datetime.now().isoformat()) cfg.DIR.CHECKPOINTS = output_dir % 'checkpoints' cfg.DIR.LOGS = output_dir % 'logs' if not os.path.exists(cfg.DIR.CHECKPOINTS): os.makedirs(cfg.DIR.CHECKPOINTS) # Create tensorboard writers train_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'train')) val_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'test')) model = Model(dataset=cfg.DATASET.TRAIN_DATASET) if torch.cuda.is_available(): model = torch.nn.DataParallel(model).cuda() # Create the optimizers optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.TRAIN.LEARNING_RATE, weight_decay=cfg.TRAIN.WEIGHT_DECAY, betas=cfg.TRAIN.BETAS) lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) init_epoch = 0 best_metrics = float('inf') if 'WEIGHTS' in cfg.CONST and cfg.CONST.WEIGHTS: logging.info('Recovering from %s ...' % (cfg.CONST.WEIGHTS)) checkpoint = torch.load(cfg.CONST.WEIGHTS) best_metrics = checkpoint['best_metrics'] model.load_state_dict(checkpoint['model']) logging.info( 'Recover complete. Current epoch = #%d; best metrics = %s.' % (init_epoch, best_metrics)) # Training/Testing the network for epoch_idx in range(init_epoch + 1, cfg.TRAIN.N_EPOCHS + 1): epoch_start_time = time() batch_time = AverageMeter() data_time = AverageMeter() model.train() total_cd1 = 0 total_cd2 = 0 total_cd3 = 0 total_pmd = 0 batch_end_time = time() n_batches = len(train_data_loader) with tqdm(train_data_loader) as t: for batch_idx, (taxonomy_ids, model_ids, data) in enumerate(t): data_time.update(time() - batch_end_time) for k, v in data.items(): data[k] = utils.helpers.var_or_cuda(v) partial = random_subsample(data['partial_cloud']) gt = random_subsample(data['gtcloud']) pcds, deltas = model(partial) cd1 = chamfer(pcds[0], gt) cd2 = chamfer(pcds[1], gt) cd3 = chamfer(pcds[2], gt) loss_cd = cd1 + cd2 + cd3 delta_losses = [] for delta in deltas: delta_losses.append(torch.sum(delta**2)) loss_pmd = torch.sum(torch.stack(delta_losses)) / 3 loss = loss_cd * cfg.TRAIN.LAMBDA_CD + loss_pmd * cfg.TRAIN.LAMBDA_PMD optimizer.zero_grad() loss.backward() optimizer.step() cd1_item = cd1.item() * 1e3 total_cd1 += cd1_item cd2_item = cd2.item() * 1e3 total_cd2 += cd2_item cd3_item = cd3.item() * 1e3 total_cd3 += cd3_item pmd_item = loss_pmd.item() total_pmd += pmd_item n_itr = (epoch_idx - 1) * n_batches + batch_idx train_writer.add_scalar('Loss/Batch/cd1', cd1_item, n_itr) train_writer.add_scalar('Loss/Batch/cd2', cd2_item, n_itr) train_writer.add_scalar('Loss/Batch/cd3', cd3_item, n_itr) train_writer.add_scalar('Loss/Batch/pmd', pmd_item, n_itr) batch_time.update(time() - batch_end_time) batch_end_time = time() t.set_description( '[Epoch %d/%d][Batch %d/%d]' % (epoch_idx, cfg.TRAIN.N_EPOCHS, batch_idx + 1, n_batches)) t.set_postfix(loss='%s' % [ '%.4f' % l for l in [cd1_item, cd2_item, cd3_item, pmd_item] ]) avg_cd1 = total_cd1 / n_batches avg_cd2 = total_cd2 / n_batches avg_cd3 = total_cd3 / n_batches avg_pmd = total_pmd / n_batches lr_scheduler.step() epoch_end_time = time() train_writer.add_scalar('Loss/Epoch/cd1', avg_cd1, epoch_idx) train_writer.add_scalar('Loss/Epoch/cd2', avg_cd2, epoch_idx) train_writer.add_scalar('Loss/Epoch/cd3', avg_cd3, epoch_idx) train_writer.add_scalar('Loss/Epoch/pmd', avg_pmd, epoch_idx) logging.info( '[Epoch %d/%d] EpochTime = %.3f (s) Losses = %s' % (epoch_idx, cfg.TRAIN.N_EPOCHS, epoch_end_time - epoch_start_time, ['%.4f' % l for l in [avg_cd1, avg_cd2, avg_cd3, avg_pmd]])) # Validate the current model cd_eval = test_net(cfg, epoch_idx, val_data_loader, val_writer, model) # Save checkpoints if epoch_idx % cfg.TRAIN.SAVE_FREQ == 0 or cd_eval < best_metrics: file_name = 'ckpt-best.pth' if cd_eval < best_metrics else 'ckpt-epoch-%03d.pth' % epoch_idx output_path = os.path.join(cfg.DIR.CHECKPOINTS, file_name) torch.save( { 'epoch_index': epoch_idx, 'best_metrics': best_metrics, 'model': model.state_dict() }, output_path) logging.info('Saved checkpoint to %s ...' % output_path) if cd_eval < best_metrics: best_metrics = cd_eval train_writer.close() val_writer.close()
def test_net(cfg, epoch_idx=-1, test_data_loader=None, test_writer=None, model=None): if test_data_loader is None: # Set up data loader dataset_loader = dataloader_jt.DATASET_LOADER_MAPPING[ cfg.DATASET.TEST_DATASET](cfg) test_data_loader = dataset_loader.get_dataset( dataloader_jt.DatasetSubset.VAL, batch_size=1, shuffle=False) # Setup networks and initialize networks if model is None: model = Model(dataset=cfg.DATASET.TEST_DATASET) assert 'WEIGHTS' in cfg.CONST and cfg.CONST.WEIGHTS print('loading: ', cfg.CONST.WEIGHTS) model.load(cfg.CONST.WEIGHTS) # Switch models to evaluation mode model.eval() n_samples = len(test_data_loader) test_losses = AverageMeter(['cd1', 'cd2', 'cd3', 'pmd']) test_metrics = AverageMeter(Metrics.names()) category_metrics = dict() # Testing loop with tqdm(test_data_loader) as t: # print('repeating') for model_idx, (taxonomy_id, model_id, data) in enumerate(t): taxonomy_id = taxonomy_id[0] if isinstance( taxonomy_id[0], str) else taxonomy_id[0].item() model_id = model_id[0] # for k, v in data.items(): # data[k] = utils.helpers.var_or_cuda(v) partial = jittor.array(data['partial_cloud']) gt = jittor.array(data['gtcloud']) b, n, _ = partial.shape pcds, deltas = model(partial) cd1 = chamfer(pcds[0], gt).item() * 1e3 cd2 = chamfer(pcds[1], gt).item() * 1e3 cd3 = chamfer(pcds[2], gt).item() * 1e3 # pmd loss pmd_losses = [] for delta in deltas: pmd_losses.append(jittor.sum(delta**2)) pmd = jittor.sum(jittor.stack(pmd_losses)) / 3 pmd_item = pmd.item() _metrics = [pmd_item, cd3] test_losses.update([cd1, cd2, cd3, pmd_item]) test_metrics.update(_metrics) if taxonomy_id not in category_metrics: category_metrics[taxonomy_id] = AverageMeter(Metrics.names()) category_metrics[taxonomy_id].update(_metrics) t.set_description( 'Test[%d/%d] Taxonomy = %s Sample = %s Losses = %s Metrics = %s' % (model_idx + 1, n_samples, taxonomy_id, model_id, ['%.4f' % l for l in test_losses.val()], ['%.4f' % m for m in _metrics])) # Print testing results print( '============================ TEST RESULTS ============================' ) print('Taxonomy', end='\t') print('#Sample', end='\t') for metric in test_metrics.items: print(metric, end='\t') print() for taxonomy_id in category_metrics: print(taxonomy_id, end='\t') print(category_metrics[taxonomy_id].count(0), end='\t') for value in category_metrics[taxonomy_id].avg(): print('%.4f' % value, end='\t') print() print('Overall', end='\t\t\t') for value in test_metrics.avg(): print('%.4f' % value, end='\t') print('\n') # Add testing results to TensorBoard if test_writer is not None: test_writer.add_scalar('Loss/Epoch/cd1', test_losses.avg(0), epoch_idx) test_writer.add_scalar('Loss/Epoch/cd2', test_losses.avg(1), epoch_idx) test_writer.add_scalar('Loss/Epoch/cd3', test_losses.avg(2), epoch_idx) test_writer.add_scalar('Loss/Epoch/delta', test_losses.avg(3), epoch_idx) for i, metric in enumerate(test_metrics.items): test_writer.add_scalar('Metric/%s' % metric, test_metrics.avg(i), epoch_idx) model.train() return test_losses.avg(2)