示例#1
0
    paths = json.load(open('data/paths.json', 'r'))
    envs_splits = json.load(open('data/envs_splits.json', 'r'))
    envs = envs_splits['{}_envs'.format(split)]
    envs = [x for x in envs if x in paths]
    envs.sort()
elif dataset == 'replica':
    paths = json.load(open('../replica/paths.json', 'r'))
    envs = list(paths.keys())
    envs.sort()
    envs.remove('room_2')

if dataset == 'mp3d':
    metrics = IoU(13)
elif dataset == 'replica':
    metrics = IoU(13, ignore_index=5)
metrics.reset()

total = 0

filename = os.path.join(pred_dir, 'evaluation_metrics.h5')
with h5py.File(filename, 'w') as f:
    for env in tqdm(envs):

        file = env + '.h5'

        if not os.path.isfile(os.path.join(pred_dir, 'semmap', file)): continue

        total += 1

        gt_h5_file = h5py.File(os.path.join(GT_dir, file), 'r')
        gt_semmap = np.array(gt_h5_file['map_semantic'])
示例#2
0
def main():
    assert os.path.isdir(
        args.dataset_dir), "The directory \"{0}\" doesn't exist.".format(
            args.dataset_dir)

    # Fail fast if the saving directory doesn't exist
    assert os.path.isdir(
        args.save_dir), "The directory \"{0}\" doesn't exist.".format(
            args.save_dir)

    # Import the requested dataset
    if args.dataset.lower() == 'cityscapes':
        from data import Cityscapes as dataset
    else:
        # Should never happen...but just in case it does
        raise RuntimeError("\"{0}\" is not a supported dataset.".format(
            args.dataset))
    print("\nLoading dataset...\n")

    print("Selected dataset:", args.dataset)
    print("Dataset directory:", args.dataset_dir)
    print("Save directory:", args.save_dir)

    image_transform = transforms.Compose(
        [transforms.Resize((args.height, args.width)),
         transforms.ToTensor()])

    label_transform = transforms.Compose([
        transforms.Resize((args.height, args.width)),
        ext_transforms.PILToLongTensor()
    ])

    # Get selected dataset
    # Load the training set as tensors
    train_set = dataset(args.dataset_dir,
                        mode='train',
                        max_iters=args.max_iters,
                        transform=image_transform,
                        label_transform=label_transform)
    train_loader = data.DataLoader(train_set,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.workers)

    trainloader_iter = enumerate(train_loader)

    # Load the validation set as tensors
    val_set = dataset(args.dataset_dir,
                      mode='val',
                      max_iters=args.max_iters,
                      transform=image_transform,
                      label_transform=label_transform)
    val_loader = data.DataLoader(val_set,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.workers)

    # Load the test set as tensors
    test_set = dataset(args.dataset_dir,
                       mode='test',
                       max_iters=args.max_iters,
                       transform=image_transform,
                       label_transform=label_transform)
    test_loader = data.DataLoader(test_set,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=args.workers)

    # Get encoding between pixel valus in label images and RGB colors
    class_encoding = train_set.color_encoding
    # Get number of classes to predict
    num_classes = len(class_encoding)

    # Print information for debugging
    print("Number of classes to predict:", num_classes)
    print("Train dataset size:", len(train_set))
    print("Validation dataset size:", len(val_set))

    # Get the parameters for the validation set
    if args.mode.lower() == 'test':
        images, labels = iter(test_loader).next()
    else:
        images, labels = iter(train_loader).next()
    print("Image size:", images.size())
    print("Label size:", labels.size())
    print("Class-color encoding:", class_encoding)

    # Show a batch of samples and labels
    if args.imshow_batch:
        print("Close the figure window to continue...")
        label_to_rgb = transforms.Compose([
            ext_transforms.LongTensorToRGBPIL(class_encoding),
            transforms.ToTensor()
        ])
        color_labels = utils.batch_transform(labels, label_to_rgb)
        utils.imshow_batch(images, color_labels)

    # Get class weights from the selected weighing technique

    print("\nTraining...\n")

    num_classes = len(class_encoding)
    # Define the model with the encoder and decoder from the deeplabv2
    input_encoder = Encoder().to(device)
    decoder_t = Decoder(num_classes).to(device)

    # Define the entropy loss for the segmentation task
    criterion = CrossEntropy2d()

    # Set the optimizer function for model
    optimizer_g = optim.SGD(itertools.chain(input_encoder.parameters(),
                                            decoder_t.parameters()),
                            lr=args.learning_rate,
                            momentum=0.9,
                            weight_decay=1e-4)

    optimizer_g.zero_grad()

    # Evaluation metric
    if args.ignore_unlabeled:
        ignore_index = list(class_encoding).index('unlabeled')
    else:
        ignore_index = None
    metric = IoU(num_classes, ignore_index=ignore_index)

    # Optionally resume from a checkpoint
    if args.resume:

        input_encoder, decoder_t, optimizer_g, start_epoch, best_miou = utils.load_checkpoint(
            input_encoder, decoder_t, optimizer_g, args.save_dir, args.name)
        print("Resuming from model: Start epoch = {0} "
              "| Best mean IoU = {1:.4f}".format(start_epoch, best_miou))
    else:
        start_epoch = 0
        best_miou = 0

    # Start Training
    print()

    metric.reset()

    val = Test(input_encoder, decoder_t, val_loader, criterion, metric, device)

    for i_iter in range(args.max_iters):

        optimizer_g.zero_grad()
        adjust_learning_rate(optimizer_g, i_iter)

        _, batch_data = trainloader_iter.__next__()
        inputs = batch_data[0].to(device)
        labels = batch_data[1].to(device)

        f_i = input_encoder(inputs)

        outputs_i = decoder_t(f_i)
        loss_seg = criterion(outputs_i, labels)

        loss_g = loss_seg
        loss_g.backward()
        optimizer_g.step()

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}'.format(
                i_iter, args.max_iters, loss_g))
            print(">>>> [iter: {0:d}] Validation".format(i_iter))

            # Validate the trained model after the weights are saved
            loss, (iou, miou) = val.run_epoch(args.print_step)

            print(">>>> [iter: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
                  format(i_iter, loss, miou))

            if miou > best_miou:
                for key, class_iou in zip(class_encoding.keys(), iou):
                    print("{0}: {1:.4f}".format(key, class_iou))

            # Save the model if it's the best thus far
            if miou > best_miou:
                print("\nBest model thus far. Saving...\n")
                best_miou = miou
                utils.save_checkpoint(input_encoder, decoder_t, optimizer_g,
                                      i_iter + 1, best_miou, args)
def train(rank, world_size, cfg):

    # Setup seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # init distributed compute
    master_port = int(os.environ.get("MASTER_PORT", 8738))
    master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
    tcp_store = torch.distributed.TCPStore(master_addr, master_port,
                                           world_size, rank == 0)
    torch.distributed.init_process_group('nccl',
                                         store=tcp_store,
                                         rank=rank,
                                         world_size=world_size)

    # Setup device
    if torch.cuda.is_available():
        device = torch.device("cuda", rank)
        torch.cuda.set_device(device)
    else:
        assert world_size == 1
        device = torch.device("cpu")

    if rank == 0:
        writer = SummaryWriter(logdir=cfg["logdir"])
        logger = get_logger(cfg["logdir"])
        logger.info("Let SMNet training begin !!")

    # Setup Dataloader
    t_loader = SMNetLoader(cfg["data"], split=cfg['data']['train_split'])
    v_loader = SMNetLoader(cfg['data'], split=cfg["data"]["val_split"])
    t_sampler = DistributedSampler(t_loader)
    v_sampler = DistributedSampler(v_loader, shuffle=False)

    if rank == 0:
        print('#Envs in train: %d' % (len(t_loader.files)))
        print('#Envs in val: %d' % (len(v_loader.files)))

    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"] // world_size,
        num_workers=cfg["training"]["n_workers"],
        drop_last=True,
        pin_memory=True,
        sampler=t_sampler,
        multiprocessing_context='fork',
    )

    valloader = data.DataLoader(
        v_loader,
        batch_size=cfg["training"]["batch_size"] // world_size,
        num_workers=cfg["training"]["n_workers"],
        pin_memory=True,
        sampler=v_sampler,
        multiprocessing_context='fork',
    )

    # Setup Model
    model = SMNet(cfg['model'], device)
    model.apply(model.weights_init)
    model = model.to(device)

    if device.type == 'cuda':
        model = torch.nn.parallel.DistributedDataParallel(model,
                                                          device_ids=[rank])

    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    if rank == 0:
        print('# trainable parameters = ', params)

    # Setup optimizer, lr_scheduler and loss function
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }
    optimizer = torch.optim.SGD(
        filter(lambda p: p.requires_grad, model.parameters()),
        **optimizer_params)

    if rank == 0:
        logger.info("Using optimizer {}".format(optimizer))

    lr_decay_lambda = lambda epoch: cfg['training']['scheduler'][
        'lr_decay_rate']**(epoch // cfg['training']['scheduler'][
            'lr_epoch_per_decay'])
    scheduler = LambdaLR(optimizer, lr_lambda=lr_decay_lambda)

    # Setup Metrics
    obj_running_metrics = IoU(cfg['model']['n_obj_classes'])
    obj_running_metrics_val = IoU(cfg['model']['n_obj_classes'])
    obj_running_metrics.reset()
    obj_running_metrics_val.reset()
    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    # setup Loss
    loss_fn = SemmapLoss()
    loss_fn = loss_fn.to(device=device)

    if rank == 0:
        logger.info("Using loss {}".format(loss_fn))

    # init training
    start_iter = 0
    start_epoch = 0
    best_iou = -100.0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            if rank == 0:
                logger.info(
                    "Loading model and optimizer from checkpoint '{}'".format(
                        cfg["training"]["resume"]))
                print(
                    "Loading model and optimizer from checkpoint '{}'".format(
                        cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"],
                                    map_location="cpu")
            model_state = checkpoint["model_state"]
            model.load_state_dict(model_state)
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_epoch = checkpoint["epoch"]
            start_iter = checkpoint["iter"]
            best_iou = checkpoint['best_iou']
            if rank == 0:
                logger.info("Loaded checkpoint '{}' (iter {})".format(
                    cfg["training"]["resume"], checkpoint["epoch"]))
        else:
            if rank == 0:
                logger.info("No checkpoint found at '{}'".format(
                    cfg["training"]["resume"]))
                print("No checkpoint found at '{}'".format(
                    cfg["training"]["resume"]))

    elif cfg['training']['load_model'] is not None:
        checkpoint = torch.load(cfg["training"]["load_model"],
                                map_location="cpu")
        model_state = checkpoint['model_state']
        model.load_state_dict(model_state)
        if rank == 0:
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["load_model"]))
            print("Loading model and optimizer from checkpoint '{}'".format(
                cfg["training"]["load_model"]))

    # start training
    iter = start_iter
    for epoch in range(start_epoch, cfg["training"]["train_epoch"], 1):

        t_sampler.set_epoch(epoch)

        for batch in trainloader:

            iter += 1
            start_ts = time.time()

            features, masks_inliers, proj_indices, semmap_gt, _ = batch

            model.train()

            optimizer.zero_grad()
            semmap_pred, observed_masks = model(features, proj_indices,
                                                masks_inliers)

            if observed_masks.any():

                loss = loss_fn(semmap_gt.to(device), semmap_pred,
                               observed_masks)

                loss.backward()

                optimizer.step()

                semmap_pred = semmap_pred.permute(0, 2, 3, 1)

                masked_semmap_gt = semmap_gt[observed_masks]
                masked_semmap_pred = semmap_pred[observed_masks]

                obj_gt = masked_semmap_gt.detach()
                obj_pred = masked_semmap_pred.data.max(-1)[1].detach()
                obj_running_metrics.add(obj_pred, obj_gt)

            time_meter.update(time.time() - start_ts)

            if (iter % cfg["training"]["print_interval"] == 0):
                conf_metric = obj_running_metrics.conf_metric.conf
                conf_metric = torch.FloatTensor(conf_metric)
                conf_metric = conf_metric.to(device)
                distrib.all_reduce(conf_metric)
                distrib.all_reduce(loss)
                loss /= world_size

                if (rank == 0):
                    conf_metric = conf_metric.cpu().numpy()
                    conf_metric = conf_metric.astype(np.int32)
                    tmp_metrics = IoU(cfg['model']['n_obj_classes'])
                    tmp_metrics.reset()
                    tmp_metrics.conf_metric.conf = conf_metric
                    _, mIoU, acc, _, mRecall, _, mPrecision = tmp_metrics.value(
                    )
                    writer.add_scalar("train_metrics/mIoU", mIoU, iter)
                    writer.add_scalar("train_metrics/mRecall", mRecall, iter)
                    writer.add_scalar("train_metrics/mPrecision", mPrecision,
                                      iter)
                    writer.add_scalar("train_metrics/Overall_Acc", acc, iter)

                    fmt_str = "Iter: {:d} == Epoch [{:d}/{:d}] == Loss: {:.4f} == mIoU: {:.4f} == mRecall:{:.4f} == mPrecision:{:.4f} == Overall_Acc:{:.4f} == Time/Image: {:.4f}"

                    print_str = fmt_str.format(
                        iter,
                        epoch,
                        cfg["training"]["train_epoch"],
                        loss.item(),
                        mIoU,
                        mRecall,
                        mPrecision,
                        acc,
                        time_meter.avg / cfg["training"]["batch_size"],
                    )

                    print(print_str)
                    writer.add_scalar("loss/train_loss", loss.item(), iter)
                    time_meter.reset()

        model.eval()
        with torch.no_grad():
            for batch_val in valloader:

                features, masks_inliers, proj_indices, semmap_gt, _ = batch_val

                semmap_pred, observed_masks = model(features, proj_indices,
                                                    masks_inliers)

                if observed_masks.any():

                    loss_val = loss_fn(semmap_gt.to(device), semmap_pred,
                                       observed_masks)

                    semmap_pred = semmap_pred.permute(0, 2, 3, 1)

                    masked_semmap_gt = semmap_gt[observed_masks]
                    masked_semmap_pred = semmap_pred[observed_masks]

                    obj_gt_val = masked_semmap_gt
                    obj_pred_val = masked_semmap_pred.data.max(-1)[1]
                    obj_running_metrics_val.add(obj_pred_val, obj_gt_val)

                    val_loss_meter.update(loss_val.item())

        conf_metric = obj_running_metrics_val.conf_metric.conf
        conf_metric = torch.FloatTensor(conf_metric)
        conf_metric = conf_metric.to(device)
        distrib.all_reduce(conf_metric)

        val_loss_avg = val_loss_meter.avg
        val_loss_avg = torch.FloatTensor([val_loss_avg])
        val_loss_avg = val_loss_avg.to(device)
        distrib.all_reduce(val_loss_avg)
        val_loss_avg /= world_size

        if rank == 0:
            val_loss_avg = val_loss_avg.cpu().numpy()
            val_loss_avg = val_loss_avg[0]
            writer.add_scalar("loss/val_loss", val_loss_avg, iter)

            logger.info("Iter %d Loss: %.4f" % (iter, val_loss_avg))

            conf_metric = conf_metric.cpu().numpy()
            conf_metric = conf_metric.astype(np.int32)
            tmp_metrics = IoU(cfg['model']['n_obj_classes'])
            tmp_metrics.reset()
            tmp_metrics.conf_metric.conf = conf_metric
            _, mIoU, acc, _, mRecall, _, mPrecision = tmp_metrics.value()
            writer.add_scalar("val_metrics/mIoU", mIoU, iter)
            writer.add_scalar("val_metrics/mRecall", mRecall, iter)
            writer.add_scalar("val_metrics/mPrecision", mPrecision, iter)
            writer.add_scalar("val_metrics/Overall_Acc", acc, iter)

            logger.info("val -- mIoU: {}".format(mIoU))
            logger.info("val -- mRecall: {}".format(mRecall))
            logger.info("val -- mPrecision: {}".format(mPrecision))
            logger.info("val -- Overall_Acc: {}".format(acc))

            print("val -- mIoU: {}".format(mIoU))
            print("val -- mRecall: {}".format(mRecall))
            print("val -- mPrecision: {}".format(mPrecision))
            print("val -- Overall_Acc: {}".format(acc))

            if mIoU >= best_iou:
                best_iou = mIoU
                state = {
                    "epoch": epoch,
                    "iter": iter,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict(),
                    "best_iou": best_iou,
                }
                save_path = os.path.join(
                    writer.file_writer.get_logdir(),
                    "{}_mp3d_best_model.pkl".format(cfg["model"]["arch"]),
                )
                torch.save(state, save_path)

            # -- save checkpoint after every epoch
            state = {
                "epoch": epoch,
                "iter": iter,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict(),
                "best_iou": best_iou,
            }
            save_path = os.path.join(cfg['checkpoint_dir'], "ckpt_model.pkl")
            torch.save(state, save_path)

        val_loss_meter.reset()
        obj_running_metrics_val.reset()
        obj_running_metrics.reset()

        scheduler.step(epoch)
def compute_accuracy(outputs, labels, num_classes):
    metric = IoU(num_classes, ignore_index=None)
    metric.reset()
    metric.add(outputs.detach(), labels.detach())
    (iou, miou) = metric.value()
    return miou