def train(rank, world_size, cfg): # Setup seeds torch.manual_seed(cfg.get("seed", 1337)) torch.cuda.manual_seed(cfg.get("seed", 1337)) np.random.seed(cfg.get("seed", 1337)) random.seed(cfg.get("seed", 1337)) # init distributed compute master_port = int(os.environ.get("MASTER_PORT", 8738)) master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") tcp_store = torch.distributed.TCPStore(master_addr, master_port, world_size, rank == 0) torch.distributed.init_process_group('nccl', store=tcp_store, rank=rank, world_size=world_size) # Setup device if torch.cuda.is_available(): device = torch.device("cuda", rank) torch.cuda.set_device(device) else: assert world_size == 1 device = torch.device("cpu") if rank == 0: writer = SummaryWriter(logdir=cfg["logdir"]) logger = get_logger(cfg["logdir"]) logger.info("Let SMNet training begin !!") # Setup Dataloader t_loader = SMNetLoader(cfg["data"], split=cfg['data']['train_split']) v_loader = SMNetLoader(cfg['data'], split=cfg["data"]["val_split"]) t_sampler = DistributedSampler(t_loader) v_sampler = DistributedSampler(v_loader, shuffle=False) if rank == 0: print('#Envs in train: %d' % (len(t_loader.files))) print('#Envs in val: %d' % (len(v_loader.files))) trainloader = data.DataLoader( t_loader, batch_size=cfg["training"]["batch_size"] // world_size, num_workers=cfg["training"]["n_workers"], drop_last=True, pin_memory=True, sampler=t_sampler, multiprocessing_context='fork', ) valloader = data.DataLoader( v_loader, batch_size=cfg["training"]["batch_size"] // world_size, num_workers=cfg["training"]["n_workers"], pin_memory=True, sampler=v_sampler, multiprocessing_context='fork', ) # Setup Model model = SMNet(cfg['model'], device) model.apply(model.weights_init) model = model.to(device) if device.type == 'cuda': model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) model_parameters = filter(lambda p: p.requires_grad, model.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) if rank == 0: print('# trainable parameters = ', params) # Setup optimizer, lr_scheduler and loss function optimizer_params = { k: v for k, v in cfg["training"]["optimizer"].items() if k != "name" } optimizer = torch.optim.SGD( filter(lambda p: p.requires_grad, model.parameters()), **optimizer_params) if rank == 0: logger.info("Using optimizer {}".format(optimizer)) lr_decay_lambda = lambda epoch: cfg['training']['scheduler'][ 'lr_decay_rate']**(epoch // cfg['training']['scheduler'][ 'lr_epoch_per_decay']) scheduler = LambdaLR(optimizer, lr_lambda=lr_decay_lambda) # Setup Metrics obj_running_metrics = IoU(cfg['model']['n_obj_classes']) obj_running_metrics_val = IoU(cfg['model']['n_obj_classes']) obj_running_metrics.reset() obj_running_metrics_val.reset() val_loss_meter = averageMeter() time_meter = averageMeter() # setup Loss loss_fn = SemmapLoss() loss_fn = loss_fn.to(device=device) if rank == 0: logger.info("Using loss {}".format(loss_fn)) # init training start_iter = 0 start_epoch = 0 best_iou = -100.0 if cfg["training"]["resume"] is not None: if os.path.isfile(cfg["training"]["resume"]): if rank == 0: logger.info( "Loading model and optimizer from checkpoint '{}'".format( cfg["training"]["resume"])) print( "Loading model and optimizer from checkpoint '{}'".format( cfg["training"]["resume"])) checkpoint = torch.load(cfg["training"]["resume"], map_location="cpu") model_state = checkpoint["model_state"] model.load_state_dict(model_state) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) start_epoch = checkpoint["epoch"] start_iter = checkpoint["iter"] best_iou = checkpoint['best_iou'] if rank == 0: logger.info("Loaded checkpoint '{}' (iter {})".format( cfg["training"]["resume"], checkpoint["epoch"])) else: if rank == 0: logger.info("No checkpoint found at '{}'".format( cfg["training"]["resume"])) print("No checkpoint found at '{}'".format( cfg["training"]["resume"])) elif cfg['training']['load_model'] is not None: checkpoint = torch.load(cfg["training"]["load_model"], map_location="cpu") model_state = checkpoint['model_state'] model.load_state_dict(model_state) if rank == 0: logger.info( "Loading model and optimizer from checkpoint '{}'".format( cfg["training"]["load_model"])) print("Loading model and optimizer from checkpoint '{}'".format( cfg["training"]["load_model"])) # start training iter = start_iter for epoch in range(start_epoch, cfg["training"]["train_epoch"], 1): t_sampler.set_epoch(epoch) for batch in trainloader: iter += 1 start_ts = time.time() features, masks_inliers, proj_indices, semmap_gt, _ = batch model.train() optimizer.zero_grad() semmap_pred, observed_masks = model(features, proj_indices, masks_inliers) if observed_masks.any(): loss = loss_fn(semmap_gt.to(device), semmap_pred, observed_masks) loss.backward() optimizer.step() semmap_pred = semmap_pred.permute(0, 2, 3, 1) masked_semmap_gt = semmap_gt[observed_masks] masked_semmap_pred = semmap_pred[observed_masks] obj_gt = masked_semmap_gt.detach() obj_pred = masked_semmap_pred.data.max(-1)[1].detach() obj_running_metrics.add(obj_pred, obj_gt) time_meter.update(time.time() - start_ts) if (iter % cfg["training"]["print_interval"] == 0): conf_metric = obj_running_metrics.conf_metric.conf conf_metric = torch.FloatTensor(conf_metric) conf_metric = conf_metric.to(device) distrib.all_reduce(conf_metric) distrib.all_reduce(loss) loss /= world_size if (rank == 0): conf_metric = conf_metric.cpu().numpy() conf_metric = conf_metric.astype(np.int32) tmp_metrics = IoU(cfg['model']['n_obj_classes']) tmp_metrics.reset() tmp_metrics.conf_metric.conf = conf_metric _, mIoU, acc, _, mRecall, _, mPrecision = tmp_metrics.value( ) writer.add_scalar("train_metrics/mIoU", mIoU, iter) writer.add_scalar("train_metrics/mRecall", mRecall, iter) writer.add_scalar("train_metrics/mPrecision", mPrecision, iter) writer.add_scalar("train_metrics/Overall_Acc", acc, iter) fmt_str = "Iter: {:d} == Epoch [{:d}/{:d}] == Loss: {:.4f} == mIoU: {:.4f} == mRecall:{:.4f} == mPrecision:{:.4f} == Overall_Acc:{:.4f} == Time/Image: {:.4f}" print_str = fmt_str.format( iter, epoch, cfg["training"]["train_epoch"], loss.item(), mIoU, mRecall, mPrecision, acc, time_meter.avg / cfg["training"]["batch_size"], ) print(print_str) writer.add_scalar("loss/train_loss", loss.item(), iter) time_meter.reset() model.eval() with torch.no_grad(): for batch_val in valloader: features, masks_inliers, proj_indices, semmap_gt, _ = batch_val semmap_pred, observed_masks = model(features, proj_indices, masks_inliers) if observed_masks.any(): loss_val = loss_fn(semmap_gt.to(device), semmap_pred, observed_masks) semmap_pred = semmap_pred.permute(0, 2, 3, 1) masked_semmap_gt = semmap_gt[observed_masks] masked_semmap_pred = semmap_pred[observed_masks] obj_gt_val = masked_semmap_gt obj_pred_val = masked_semmap_pred.data.max(-1)[1] obj_running_metrics_val.add(obj_pred_val, obj_gt_val) val_loss_meter.update(loss_val.item()) conf_metric = obj_running_metrics_val.conf_metric.conf conf_metric = torch.FloatTensor(conf_metric) conf_metric = conf_metric.to(device) distrib.all_reduce(conf_metric) val_loss_avg = val_loss_meter.avg val_loss_avg = torch.FloatTensor([val_loss_avg]) val_loss_avg = val_loss_avg.to(device) distrib.all_reduce(val_loss_avg) val_loss_avg /= world_size if rank == 0: val_loss_avg = val_loss_avg.cpu().numpy() val_loss_avg = val_loss_avg[0] writer.add_scalar("loss/val_loss", val_loss_avg, iter) logger.info("Iter %d Loss: %.4f" % (iter, val_loss_avg)) conf_metric = conf_metric.cpu().numpy() conf_metric = conf_metric.astype(np.int32) tmp_metrics = IoU(cfg['model']['n_obj_classes']) tmp_metrics.reset() tmp_metrics.conf_metric.conf = conf_metric _, mIoU, acc, _, mRecall, _, mPrecision = tmp_metrics.value() writer.add_scalar("val_metrics/mIoU", mIoU, iter) writer.add_scalar("val_metrics/mRecall", mRecall, iter) writer.add_scalar("val_metrics/mPrecision", mPrecision, iter) writer.add_scalar("val_metrics/Overall_Acc", acc, iter) logger.info("val -- mIoU: {}".format(mIoU)) logger.info("val -- mRecall: {}".format(mRecall)) logger.info("val -- mPrecision: {}".format(mPrecision)) logger.info("val -- Overall_Acc: {}".format(acc)) print("val -- mIoU: {}".format(mIoU)) print("val -- mRecall: {}".format(mRecall)) print("val -- mPrecision: {}".format(mPrecision)) print("val -- Overall_Acc: {}".format(acc)) if mIoU >= best_iou: best_iou = mIoU state = { "epoch": epoch, "iter": iter, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_iou": best_iou, } save_path = os.path.join( writer.file_writer.get_logdir(), "{}_mp3d_best_model.pkl".format(cfg["model"]["arch"]), ) torch.save(state, save_path) # -- save checkpoint after every epoch state = { "epoch": epoch, "iter": iter, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best_iou": best_iou, } save_path = os.path.join(cfg['checkpoint_dir'], "ckpt_model.pkl") torch.save(state, save_path) val_loss_meter.reset() obj_running_metrics_val.reset() obj_running_metrics.reset() scheduler.step(epoch)
else: pred_semmap = np.array(pred_h5_file['semmap']) pred_h5_file.close() h5file = h5py.File(os.path.join(obsmaps_dir, file), 'r') observed_map = np.array(h5file['observed_map']) observed_map = observed_map.astype(np.bool) h5file.close() obj_gt = gt_semmap[observed_map] obj_pred = pred_semmap[observed_map] f.create_dataset('{}_pred'.format(env), data=obj_pred, dtype=np.int16) f.create_dataset('{}_gt'.format(env), data=obj_gt, dtype=np.int16) metrics.add(obj_pred, obj_gt) print('total #envs= ', total, '\n') classes_iou, mIoU, acc, recalls, mRecall, precisions, mPrecision = metrics.value( ) print('Mean IoU: ', "%.2f" % round(mIoU * 100, 2)) print('Overall Acc: ', "%.2f" % round(acc * 100, 2)) print('Mean Recall: ', "%.2f" % round(mRecall * 100, 2)) print('Mean Precision: ', "%.2f" % round(mPrecision * 100, 2)) print('\n per class IoU:') for i in range(13): print(' ', "%.2f" % round(classes_iou[i] * 100, 2), object_whitelist[i])
def compute_accuracy(outputs, labels, num_classes): metric = IoU(num_classes, ignore_index=None) metric.reset() metric.add(outputs.detach(), labels.detach()) (iou, miou) = metric.value() return miou