def main(config): val_loader = get_loader(config) n_data = len(val_loader.dataset) logger.info(f"length of validation dataset: {n_data}") model, criterion = build_scene_segmentation(config) model.cuda() criterion.cuda() model = DistributedDataParallel(model, device_ids=[config.local_rank], broadcast_buffers=False) # optionally resume from a checkpoint if config.load_path: assert os.path.isfile(config.load_path) load_checkpoint(config, model) logger.info("==> checking loaded ckpt") validate('resume', val_loader, model, criterion, config, num_votes=20) validate('Last', val_loader, model, criterion, config, num_votes=20)
def main(config): train_loader, val_loader = get_loader(config) n_data = len(train_loader.dataset) logger.info(f"length of training dataset: {n_data}") n_data = len(val_loader.dataset) logger.info(f"length of validation dataset: {n_data}") model, criterion = build_scene_segmentation(config) model.cuda() criterion.cuda() if config.optimizer == 'sgd': optimizer = torch.optim.SGD(model.parameters(), lr=config.batch_size * dist.get_world_size() / 8 * config.base_learning_rate, momentum=config.momentum, weight_decay=config.weight_decay) elif config.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=config.base_learning_rate, weight_decay=config.weight_decay) elif config.optimizer == 'adamW': optimizer = torch.optim.AdamW(model.parameters(), lr=config.base_learning_rate, weight_decay=config.weight_decay) else: raise NotImplementedError( f"Optimizer {config.optimizer} not supported") scheduler = get_scheduler(optimizer, len(train_loader), config) model = DistributedDataParallel(model, device_ids=[config.local_rank], broadcast_buffers=False) runing_vote_logits = [ np.zeros((config.num_classes, l.shape[0]), dtype=np.float32) for l in val_loader.dataset.sub_clouds_points_labels ] # optionally resume from a checkpoint if config.load_path: assert os.path.isfile(config.load_path) load_checkpoint(config, model, optimizer, scheduler) logger.info("==> checking loaded ckpt") validate('resume', val_loader, model, criterion, runing_vote_logits, config, num_votes=2) # tensorboard if dist.get_rank() == 0: summary_writer = SummaryWriter(log_dir=config.log_dir) else: summary_writer = None # routine for epoch in range(config.start_epoch, config.epochs + 1): train_loader.sampler.set_epoch(epoch) val_loader.sampler.set_epoch(epoch) train_loader.dataset.epoch = epoch - 1 tic = time.time() loss = train(epoch, train_loader, model, criterion, optimizer, scheduler, config) logger.info('epoch {}, total time {:.2f}, lr {:.5f}'.format( epoch, (time.time() - tic), optimizer.param_groups[0]['lr'])) if epoch % config.val_freq == 0: validate(epoch, val_loader, model, criterion, runing_vote_logits, config, num_votes=2) if dist.get_rank() == 0: # save model save_checkpoint(config, epoch, model, optimizer, scheduler) if summary_writer is not None: # tensorboard logger summary_writer.add_scalar('ins_loss', loss, epoch) summary_writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch) validate('Last', val_loader, model, criterion, runing_vote_logits, config, num_votes=20)