def train(rank, world_size, cfg):

    # Setup seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # init distributed compute
    master_port = int(os.environ.get("MASTER_PORT", 8738))
    master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
    tcp_store = torch.distributed.TCPStore(master_addr, master_port,
                                           world_size, rank == 0)
    torch.distributed.init_process_group('nccl',
                                         store=tcp_store,
                                         rank=rank,
                                         world_size=world_size)

    # Setup device
    if torch.cuda.is_available():
        device = torch.device("cuda", rank)
        torch.cuda.set_device(device)
    else:
        assert world_size == 1
        device = torch.device("cpu")

    if rank == 0:
        writer = SummaryWriter(logdir=cfg["logdir"])
        logger = get_logger(cfg["logdir"])
        logger.info("Let SMNet training begin !!")

    # Setup Dataloader
    t_loader = SMNetLoader(cfg["data"], split=cfg['data']['train_split'])
    v_loader = SMNetLoader(cfg['data'], split=cfg["data"]["val_split"])
    t_sampler = DistributedSampler(t_loader)
    v_sampler = DistributedSampler(v_loader, shuffle=False)

    if rank == 0:
        print('#Envs in train: %d' % (len(t_loader.files)))
        print('#Envs in val: %d' % (len(v_loader.files)))

    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"] // world_size,
        num_workers=cfg["training"]["n_workers"],
        drop_last=True,
        pin_memory=True,
        sampler=t_sampler,
        multiprocessing_context='fork',
    )

    valloader = data.DataLoader(
        v_loader,
        batch_size=cfg["training"]["batch_size"] // world_size,
        num_workers=cfg["training"]["n_workers"],
        pin_memory=True,
        sampler=v_sampler,
        multiprocessing_context='fork',
    )

    # Setup Model
    model = SMNet(cfg['model'], device)
    model.apply(model.weights_init)
    model = model.to(device)

    if device.type == 'cuda':
        model = torch.nn.parallel.DistributedDataParallel(model,
                                                          device_ids=[rank])

    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    if rank == 0:
        print('# trainable parameters = ', params)

    # Setup optimizer, lr_scheduler and loss function
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }
    optimizer = torch.optim.SGD(
        filter(lambda p: p.requires_grad, model.parameters()),
        **optimizer_params)

    if rank == 0:
        logger.info("Using optimizer {}".format(optimizer))

    lr_decay_lambda = lambda epoch: cfg['training']['scheduler'][
        'lr_decay_rate']**(epoch // cfg['training']['scheduler'][
            'lr_epoch_per_decay'])
    scheduler = LambdaLR(optimizer, lr_lambda=lr_decay_lambda)

    # Setup Metrics
    obj_running_metrics = IoU(cfg['model']['n_obj_classes'])
    obj_running_metrics_val = IoU(cfg['model']['n_obj_classes'])
    obj_running_metrics.reset()
    obj_running_metrics_val.reset()
    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    # setup Loss
    loss_fn = SemmapLoss()
    loss_fn = loss_fn.to(device=device)

    if rank == 0:
        logger.info("Using loss {}".format(loss_fn))

    # init training
    start_iter = 0
    start_epoch = 0
    best_iou = -100.0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            if rank == 0:
                logger.info(
                    "Loading model and optimizer from checkpoint '{}'".format(
                        cfg["training"]["resume"]))
                print(
                    "Loading model and optimizer from checkpoint '{}'".format(
                        cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"],
                                    map_location="cpu")
            model_state = checkpoint["model_state"]
            model.load_state_dict(model_state)
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_epoch = checkpoint["epoch"]
            start_iter = checkpoint["iter"]
            best_iou = checkpoint['best_iou']
            if rank == 0:
                logger.info("Loaded checkpoint '{}' (iter {})".format(
                    cfg["training"]["resume"], checkpoint["epoch"]))
        else:
            if rank == 0:
                logger.info("No checkpoint found at '{}'".format(
                    cfg["training"]["resume"]))
                print("No checkpoint found at '{}'".format(
                    cfg["training"]["resume"]))

    elif cfg['training']['load_model'] is not None:
        checkpoint = torch.load(cfg["training"]["load_model"],
                                map_location="cpu")
        model_state = checkpoint['model_state']
        model.load_state_dict(model_state)
        if rank == 0:
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["load_model"]))
            print("Loading model and optimizer from checkpoint '{}'".format(
                cfg["training"]["load_model"]))

    # start training
    iter = start_iter
    for epoch in range(start_epoch, cfg["training"]["train_epoch"], 1):

        t_sampler.set_epoch(epoch)

        for batch in trainloader:

            iter += 1
            start_ts = time.time()

            features, masks_inliers, proj_indices, semmap_gt, _ = batch

            model.train()

            optimizer.zero_grad()
            semmap_pred, observed_masks = model(features, proj_indices,
                                                masks_inliers)

            if observed_masks.any():

                loss = loss_fn(semmap_gt.to(device), semmap_pred,
                               observed_masks)

                loss.backward()

                optimizer.step()

                semmap_pred = semmap_pred.permute(0, 2, 3, 1)

                masked_semmap_gt = semmap_gt[observed_masks]
                masked_semmap_pred = semmap_pred[observed_masks]

                obj_gt = masked_semmap_gt.detach()
                obj_pred = masked_semmap_pred.data.max(-1)[1].detach()
                obj_running_metrics.add(obj_pred, obj_gt)

            time_meter.update(time.time() - start_ts)

            if (iter % cfg["training"]["print_interval"] == 0):
                conf_metric = obj_running_metrics.conf_metric.conf
                conf_metric = torch.FloatTensor(conf_metric)
                conf_metric = conf_metric.to(device)
                distrib.all_reduce(conf_metric)
                distrib.all_reduce(loss)
                loss /= world_size

                if (rank == 0):
                    conf_metric = conf_metric.cpu().numpy()
                    conf_metric = conf_metric.astype(np.int32)
                    tmp_metrics = IoU(cfg['model']['n_obj_classes'])
                    tmp_metrics.reset()
                    tmp_metrics.conf_metric.conf = conf_metric
                    _, mIoU, acc, _, mRecall, _, mPrecision = tmp_metrics.value(
                    )
                    writer.add_scalar("train_metrics/mIoU", mIoU, iter)
                    writer.add_scalar("train_metrics/mRecall", mRecall, iter)
                    writer.add_scalar("train_metrics/mPrecision", mPrecision,
                                      iter)
                    writer.add_scalar("train_metrics/Overall_Acc", acc, iter)

                    fmt_str = "Iter: {:d} == Epoch [{:d}/{:d}] == Loss: {:.4f} == mIoU: {:.4f} == mRecall:{:.4f} == mPrecision:{:.4f} == Overall_Acc:{:.4f} == Time/Image: {:.4f}"

                    print_str = fmt_str.format(
                        iter,
                        epoch,
                        cfg["training"]["train_epoch"],
                        loss.item(),
                        mIoU,
                        mRecall,
                        mPrecision,
                        acc,
                        time_meter.avg / cfg["training"]["batch_size"],
                    )

                    print(print_str)
                    writer.add_scalar("loss/train_loss", loss.item(), iter)
                    time_meter.reset()

        model.eval()
        with torch.no_grad():
            for batch_val in valloader:

                features, masks_inliers, proj_indices, semmap_gt, _ = batch_val

                semmap_pred, observed_masks = model(features, proj_indices,
                                                    masks_inliers)

                if observed_masks.any():

                    loss_val = loss_fn(semmap_gt.to(device), semmap_pred,
                                       observed_masks)

                    semmap_pred = semmap_pred.permute(0, 2, 3, 1)

                    masked_semmap_gt = semmap_gt[observed_masks]
                    masked_semmap_pred = semmap_pred[observed_masks]

                    obj_gt_val = masked_semmap_gt
                    obj_pred_val = masked_semmap_pred.data.max(-1)[1]
                    obj_running_metrics_val.add(obj_pred_val, obj_gt_val)

                    val_loss_meter.update(loss_val.item())

        conf_metric = obj_running_metrics_val.conf_metric.conf
        conf_metric = torch.FloatTensor(conf_metric)
        conf_metric = conf_metric.to(device)
        distrib.all_reduce(conf_metric)

        val_loss_avg = val_loss_meter.avg
        val_loss_avg = torch.FloatTensor([val_loss_avg])
        val_loss_avg = val_loss_avg.to(device)
        distrib.all_reduce(val_loss_avg)
        val_loss_avg /= world_size

        if rank == 0:
            val_loss_avg = val_loss_avg.cpu().numpy()
            val_loss_avg = val_loss_avg[0]
            writer.add_scalar("loss/val_loss", val_loss_avg, iter)

            logger.info("Iter %d Loss: %.4f" % (iter, val_loss_avg))

            conf_metric = conf_metric.cpu().numpy()
            conf_metric = conf_metric.astype(np.int32)
            tmp_metrics = IoU(cfg['model']['n_obj_classes'])
            tmp_metrics.reset()
            tmp_metrics.conf_metric.conf = conf_metric
            _, mIoU, acc, _, mRecall, _, mPrecision = tmp_metrics.value()
            writer.add_scalar("val_metrics/mIoU", mIoU, iter)
            writer.add_scalar("val_metrics/mRecall", mRecall, iter)
            writer.add_scalar("val_metrics/mPrecision", mPrecision, iter)
            writer.add_scalar("val_metrics/Overall_Acc", acc, iter)

            logger.info("val -- mIoU: {}".format(mIoU))
            logger.info("val -- mRecall: {}".format(mRecall))
            logger.info("val -- mPrecision: {}".format(mPrecision))
            logger.info("val -- Overall_Acc: {}".format(acc))

            print("val -- mIoU: {}".format(mIoU))
            print("val -- mRecall: {}".format(mRecall))
            print("val -- mPrecision: {}".format(mPrecision))
            print("val -- Overall_Acc: {}".format(acc))

            if mIoU >= best_iou:
                best_iou = mIoU
                state = {
                    "epoch": epoch,
                    "iter": iter,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict(),
                    "best_iou": best_iou,
                }
                save_path = os.path.join(
                    writer.file_writer.get_logdir(),
                    "{}_mp3d_best_model.pkl".format(cfg["model"]["arch"]),
                )
                torch.save(state, save_path)

            # -- save checkpoint after every epoch
            state = {
                "epoch": epoch,
                "iter": iter,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict(),
                "best_iou": best_iou,
            }
            save_path = os.path.join(cfg['checkpoint_dir'], "ckpt_model.pkl")
            torch.save(state, save_path)

        val_loss_meter.reset()
        obj_running_metrics_val.reset()
        obj_running_metrics.reset()

        scheduler.step(epoch)
        else:
            pred_semmap = np.array(pred_h5_file['semmap'])
        pred_h5_file.close()

        h5file = h5py.File(os.path.join(obsmaps_dir, file), 'r')
        observed_map = np.array(h5file['observed_map'])
        observed_map = observed_map.astype(np.bool)
        h5file.close()

        obj_gt = gt_semmap[observed_map]
        obj_pred = pred_semmap[observed_map]

        f.create_dataset('{}_pred'.format(env), data=obj_pred, dtype=np.int16)
        f.create_dataset('{}_gt'.format(env), data=obj_gt, dtype=np.int16)

        metrics.add(obj_pred, obj_gt)

print('total #envs= ', total, '\n')

classes_iou, mIoU, acc, recalls, mRecall, precisions, mPrecision = metrics.value(
)

print('Mean IoU: ', "%.2f" % round(mIoU * 100, 2))
print('Overall Acc: ', "%.2f" % round(acc * 100, 2))
print('Mean Recall: ', "%.2f" % round(mRecall * 100, 2))
print('Mean Precision: ', "%.2f" % round(mPrecision * 100, 2))

print('\n per class IoU:')
for i in range(13):
    print('      ', "%.2f" % round(classes_iou[i] * 100, 2),
          object_whitelist[i])
def compute_accuracy(outputs, labels, num_classes):
    metric = IoU(num_classes, ignore_index=None)
    metric.reset()
    metric.add(outputs.detach(), labels.detach())
    (iou, miou) = metric.value()
    return miou