Exemplo n.º 1
0
    def loadDatasets(self, CONFIG, bs):
        # Sampler
        sampler = MyDistributedSampler(
            self.seenset,
            self.novelset,
            num_replicas=torch.distributed.get_world_size(),
            rank=torch.distributed.get_rank())

        self.dataset = get_dataset(CONFIG.DATASET)(
            train=self.train,
            test=None,
            root=CONFIG.ROOT,
            transform=None,
            split=CONFIG.SPLIT.TRAIN,
            base_size=513,
            crop_size=CONFIG.IMAGE.SIZE.TRAIN,
            mean=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G,
                  CONFIG.IMAGE.MEAN.R),
            warp=CONFIG.WARP_IMAGE,
            scale=(0.5, 1.5),
            flip=True,
            visibility_mask=self.visibility_mask,
        )
        random.seed(42)
        # DataLoader
        self.loader = torch.utils.data.DataLoader(
            dataset=self.dataset,
            batch_size=bs,
            num_workers=CONFIG.NUM_WORKERS,
            # num_workers = 1,
            sampler=sampler,
            pin_memory=True)
        return self.dataset, self.loader
Exemplo n.º 2
0
def evaluate(model, writer, iteration, CONFIG):
    """
    Evaluation on validation set
    """

    device = 0
    torch.set_grad_enabled(False)
    model.eval()
    model.to(device)
    # Dataset
    if CONFIG.DATASET.NAME == "h16":
        CONFIG.DATASET.NAME = "vocaug"
    dataset = get_dataset(CONFIG.DATASET.NAME)(
        root=CONFIG.DATASET.ROOT,
        split=CONFIG.DATASET.SPLIT.VAL,
        ignore_label=CONFIG.DATASET.IGNORE_LABEL,
        mean_bgr=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G,
                  CONFIG.IMAGE.MEAN.R),
        augment=False,
    )

    # DataLoader
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=CONFIG.SOLVER.BATCH_SIZE.TEST,
        num_workers=CONFIG.DATALOADER.NUM_WORKERS,
        shuffle=False,
    )

    preds, gts = [], []
    for image_ids, images, gt_labels in loader:
        # Image
        images = images.to(device)

        # Forward propagation
        logits = model(images)

        # Pixel-wise labeling
        _, H, W = gt_labels.shape
        logits = F.interpolate(logits,
                               size=(H, W),
                               mode="bilinear",
                               align_corners=False)
        probs = F.softmax(logits, dim=1)
        labels = torch.argmax(probs, dim=1)

        preds += list(labels.cpu().numpy())
        gts += list(gt_labels.numpy())

    # Pixel Accuracy, Mean Accuracy, Class IoU, Mean IoU, Freq Weighted IoU
    score = scores(gts, preds, n_class=CONFIG.DATASET.N_CLASSES)
    print("MeanIoU: {:2.2f}".format(score["Mean IoU"] * 100))
    writer.add_scalar("meanIoU",
                      score["Mean IoU"] * 100,
                      global_step=iteration)
Exemplo n.º 3
0
def crf(config_path, n_jobs):
    """
    CRF post-processing on pre-computed logits
    """

    # Configuration
    CONFIG = Dict(yaml.load(config_path))
    torch.set_grad_enabled(False)
    print("# jobs:", n_jobs)

    # Dataset
    dataset = get_dataset(CONFIG.DATASET.NAME)(
        root=CONFIG.DATASET.ROOT,
        split=CONFIG.DATASET.SPLIT.VAL,
        ignore_label=CONFIG.DATASET.IGNORE_LABEL,
        mean_bgr=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G,
                  CONFIG.IMAGE.MEAN.R),
        augment=False,
    )
    print(dataset)

    # CRF post-processor
    postprocessor = DenseCRF(
        iter_max=CONFIG.CRF.ITER_MAX,
        pos_xy_std=CONFIG.CRF.POS_XY_STD,
        pos_w=CONFIG.CRF.POS_W,
        bi_xy_std=CONFIG.CRF.BI_XY_STD,
        bi_rgb_std=CONFIG.CRF.BI_RGB_STD,
        bi_w=CONFIG.CRF.BI_W,
    )

    # Path to logit files
    logit_dir = os.path.join(
        CONFIG.EXP.OUTPUT_DIR,
        "features",
        CONFIG.EXP.ID,
        CONFIG.MODEL.NAME.lower(),
        CONFIG.DATASET.SPLIT.VAL,
        "logit",
    )
    print("Logit src:", logit_dir)
    if not os.path.isdir(logit_dir):
        print("Logit not found, run first: python main.py test [OPTIONS]")
        quit()

    # Path to save scores
    save_dir = os.path.join(
        CONFIG.EXP.OUTPUT_DIR,
        "scores",
        CONFIG.EXP.ID,
        CONFIG.MODEL.NAME.lower(),
        CONFIG.DATASET.SPLIT.VAL,
    )
    makedirs(save_dir)
    save_path = os.path.join(save_dir, "scores_crf.json")
    print("Score dst:", save_path)

    # Process per sample
    def process(i):
        image_id, image, gt_label = dataset.__getitem__(i)

        filename = os.path.join(logit_dir, image_id + ".npy")
        logit = np.load(filename)

        _, H, W = image.shape
        logit = torch.FloatTensor(logit)[None, ...]
        logit = F.interpolate(logit,
                              size=(H, W),
                              mode="bilinear",
                              align_corners=False)
        prob = F.softmax(logit, dim=1)[0].numpy()

        image = image.astype(np.uint8).transpose(1, 2, 0)
        prob = postprocessor(image, prob)
        label = np.argmax(prob, axis=0)

        return label, gt_label

    # CRF in multi-process
    results = joblib.Parallel(n_jobs=n_jobs, verbose=10, pre_dispatch="all")(
        [joblib.delayed(process)(i) for i in range(len(dataset))])

    preds, gts = zip(*results)

    # Pixel Accuracy, Mean Accuracy, Class IoU, Mean IoU, Freq Weighted IoU
    score = scores(gts, preds, n_class=CONFIG.DATASET.N_CLASSES)

    with open(save_path, "w") as f:
        json.dump(score, f, indent=4, sort_keys=True)
Exemplo n.º 4
0
def test(config_path, model_path, cuda):
    """
    Evaluation on validation set
    """

    # Configuration
    CONFIG = Dict(yaml.load(config_path))
    device = get_device(cuda)
    torch.set_grad_enabled(False)

    # Dataset
    dataset = get_dataset(CONFIG.DATASET.NAME)(
        root=CONFIG.DATASET.ROOT,
        split=CONFIG.DATASET.SPLIT.VAL,
        ignore_label=CONFIG.DATASET.IGNORE_LABEL,
        mean_bgr=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G,
                  CONFIG.IMAGE.MEAN.R),
        augment=False,
    )
    print(dataset)

    # DataLoader
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=CONFIG.SOLVER.BATCH_SIZE.TEST,
        num_workers=CONFIG.DATALOADER.NUM_WORKERS,
        shuffle=False,
    )

    # Model
    model = eval(CONFIG.MODEL.NAME)(n_classes=CONFIG.DATASET.N_CLASSES)
    state_dict = torch.load(model_path,
                            map_location=lambda storage, loc: storage)
    model.load_state_dict(state_dict)
    model = nn.DataParallel(model)
    model.eval()
    model.to(device)

    # Path to save logits
    logit_dir = os.path.join(
        CONFIG.EXP.OUTPUT_DIR,
        "features",
        CONFIG.EXP.ID,
        CONFIG.MODEL.NAME.lower(),
        CONFIG.DATASET.SPLIT.VAL,
        "logit",
    )
    makedirs(logit_dir)
    print("Logit dst:", logit_dir)

    # Path to save scores
    save_dir = os.path.join(
        CONFIG.EXP.OUTPUT_DIR,
        "scores",
        CONFIG.EXP.ID,
        CONFIG.MODEL.NAME.lower(),
        CONFIG.DATASET.SPLIT.VAL,
    )
    makedirs(save_dir)
    save_path = os.path.join(save_dir, "scores.json")
    print("Score dst:", save_path)

    preds, gts = [], []
    for image_ids, images, gt_labels in tqdm(loader,
                                             total=len(loader),
                                             dynamic_ncols=True):
        # Image
        images = images.to(device)

        # Forward propagation
        logits = model(images)

        # Save on disk for CRF post-processing
        for image_id, logit in zip(image_ids, logits):
            filename = os.path.join(logit_dir, image_id + ".npy")
            np.save(filename, logit.cpu().numpy())

        # Pixel-wise labeling
        _, H, W = gt_labels.shape
        logits = F.interpolate(logits,
                               size=(H, W),
                               mode="bilinear",
                               align_corners=False)
        probs = F.softmax(logits, dim=1)
        labels = torch.argmax(probs, dim=1)

        preds += list(labels.cpu().numpy())
        gts += list(gt_labels.numpy())

    # Pixel Accuracy, Mean Accuracy, Class IoU, Mean IoU, Freq Weighted IoU
    score = scores(gts, preds, n_class=CONFIG.DATASET.N_CLASSES)

    with open(save_path, "w") as f:
        json.dump(score, f, indent=4, sort_keys=True)
Exemplo n.º 5
0
def main(config, cuda):
    cuda = cuda and torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")

    if cuda:
        current_device = torch.cuda.current_device()
        print("Running on", torch.cuda.get_device_name(current_device))
    else:
        print("Running on CPU")

    # Configuration
    CONFIG = Dict(yaml.load(open(config)))

    # Dataset 10k or 164k
    dataset = get_dataset(CONFIG.DATASET)(
        root=CONFIG.ROOT,
        split=CONFIG.SPLIT.TRAIN,
        base_size=513,
        crop_size=CONFIG.IMAGE.SIZE.TRAIN,
        mean=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R),
        warp=CONFIG.WARP_IMAGE,
        scale=(0.5, 1.5),
        flip=True,
    )

    # DataLoader
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=CONFIG.BATCH_SIZE.TRAIN,
        num_workers=CONFIG.NUM_WORKERS,
        shuffle=True,
    )
    loader_iter = iter(loader)

    # Model
    model = DeepLabV2_ResNet101_MSC(n_classes=CONFIG.N_CLASSES)
    state_dict = torch.load(CONFIG.INIT_MODEL)
    model.load_state_dict(state_dict, strict=False)  # Skip "aspp" layer
    model = nn.DataParallel(model)
    model.to(device)

    # Optimizer
    optimizer = {
        "sgd": torch.optim.SGD(
            # cf lr_mult and decay_mult in train.prototxt
            params=[
                {
                    "params": get_params(model.module, key="1x"),
                    "lr": CONFIG.LR,
                    "weight_decay": CONFIG.WEIGHT_DECAY,
                },
                {
                    "params": get_params(model.module, key="10x"),
                    "lr": 10 * CONFIG.LR,
                    "weight_decay": CONFIG.WEIGHT_DECAY,
                },
                {
                    "params": get_params(model.module, key="20x"),
                    "lr": 20 * CONFIG.LR,
                    "weight_decay": 0.0,
                },
            ],
            momentum=CONFIG.MOMENTUM,
        )
        # Add any other optimizer
    }.get(CONFIG.OPTIMIZER)

    # Loss definition
    criterion = CrossEntropyLoss2d(ignore_index=CONFIG.IGNORE_LABEL)
    criterion.to(device)

    # TensorBoard Logger
    writer = SummaryWriter(CONFIG.LOG_DIR)
    loss_meter = MovingAverageValueMeter(20)

    model.train()
    model.module.scale.freeze_bn()

    for iteration in tqdm(
        range(1, CONFIG.ITER_MAX + 1),
        total=CONFIG.ITER_MAX,
        leave=False,
        dynamic_ncols=True,
    ):

        # Set a learning rate
        poly_lr_scheduler(
            optimizer=optimizer,
            init_lr=CONFIG.LR,
            iter=iteration - 1,
            lr_decay_iter=CONFIG.LR_DECAY,
            max_iter=CONFIG.ITER_MAX,
            power=CONFIG.POLY_POWER,
        )

        # Clear gradients (ready to accumulate)
        optimizer.zero_grad()

        iter_loss = 0
        for i in range(1, CONFIG.ITER_SIZE + 1):
            try:
                data, target = next(loader_iter)
            except:
                loader_iter = iter(loader)
                data, target = next(loader_iter)

            # Image
            data = data.to(device)

            # Propagate forward
            outputs = model(data)

            # Loss
            loss = 0
            for output in outputs:
                # Resize target for {100%, 75%, 50%, Max} outputs
                target_ = resize_target(target, output.size(2))
                target_ = target_.to(device)
                # Compute crossentropy loss
                loss += criterion(output, target_)

            # Backpropagate (just compute gradients wrt the loss)
            loss /= float(CONFIG.ITER_SIZE)
            loss.backward()

            iter_loss += float(loss)

        loss_meter.add(iter_loss)

        # Update weights with accumulated gradients
        optimizer.step()

        # TensorBoard
        if iteration % CONFIG.ITER_TB == 0:
            writer.add_scalar("train_loss", loss_meter.value()[0], iteration)
            for i, o in enumerate(optimizer.param_groups):
                writer.add_scalar("train_lr_group{}".format(i), o["lr"], iteration)
            if False:  # This produces a large log file
                for name, param in model.named_parameters():
                    name = name.replace(".", "/")
                    writer.add_histogram(name, param, iteration, bins="auto")
                    if param.requires_grad:
                        writer.add_histogram(
                            name + "/grad", param.grad, iteration, bins="auto"
                        )

        # Save a model
        if iteration % CONFIG.ITER_SAVE == 0:
            torch.save(
                model.module.state_dict(),
                osp.join(CONFIG.SAVE_DIR, "checkpoint_{}.pth".format(iteration)),
            )

        # Save a model (short term)
        if iteration % 100 == 0:
            torch.save(
                model.module.state_dict(),
                osp.join(CONFIG.SAVE_DIR, "checkpoint_current.pth"),
            )

    torch.save(
        model.module.state_dict(), osp.join(CONFIG.SAVE_DIR, "checkpoint_final.pth")
    )
Exemplo n.º 6
0
def train(config, cuda):
    # Auto-tune cuDNN
    torch.backends.cudnn.benchmark = True

    # Configuration
    device = get_device(cuda)
    CONFIG = Dict(yaml.load(open(config)))

    # Dataset 10k or 164k
    dataset = get_dataset(CONFIG.DATASET.NAME)(
        root=CONFIG.DATASET.ROOT,
        split=CONFIG.DATASET.SPLIT.TRAIN,
        base_size=CONFIG.IMAGE.SIZE.TRAIN.BASE,
        crop_size=CONFIG.IMAGE.SIZE.TRAIN.CROP,
        mean=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R),
        warp=CONFIG.DATASET.WARP_IMAGE,
        scale=CONFIG.DATASET.SCALES,
        flip=True,
    )

    # DataLoader
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=CONFIG.SOLVER.BATCH_SIZE.TRAIN,
        num_workers=CONFIG.DATALOADER.NUM_WORKERS,
        shuffle=True,
    )
    loader_iter = iter(loader)

    # Model
    model = setup_model(CONFIG.MODEL.INIT_MODEL,
                        CONFIG.DATASET.N_CLASSES,
                        train=True)
    model.to(device)

    # Optimizer
    optimizer = torch.optim.SGD(
        # cf lr_mult and decay_mult in train.prototxt
        params=[
            {
                "params": get_params(model.module, key="1x"),
                "lr": CONFIG.SOLVER.LR,
                "weight_decay": CONFIG.SOLVER.WEIGHT_DECAY,
            },
            {
                "params": get_params(model.module, key="10x"),
                "lr": 10 * CONFIG.SOLVER.LR,
                "weight_decay": CONFIG.SOLVER.WEIGHT_DECAY,
            },
            {
                "params": get_params(model.module, key="20x"),
                "lr": 20 * CONFIG.SOLVER.LR,
                "weight_decay": 0.0,
            },
        ],
        momentum=CONFIG.SOLVER.MOMENTUM,
    )

    # Learning rate scheduler
    scheduler = PolynomialLR(
        optimizer=optimizer,
        step_size=CONFIG.SOLVER.LR_DECAY,
        iter_max=CONFIG.SOLVER.ITER_MAX,
        power=CONFIG.SOLVER.POLY_POWER,
    )

    # Loss definition
    criterion = nn.CrossEntropyLoss(ignore_index=CONFIG.DATASET.IGNORE_LABEL)
    criterion.to(device)

    # TensorBoard logger
    writer = SummaryWriter(CONFIG.SOLVER.LOG_DIR)
    average_loss = MovingAverageValueMeter(CONFIG.SOLVER.AVERAGE_LOSS)

    # Freeze the batch norm pre-trained on COCO
    model.train()
    model.module.base.freeze_bn()

    for iteration in tqdm(
            range(1, CONFIG.SOLVER.ITER_MAX + 1),
            total=CONFIG.SOLVER.ITER_MAX,
            leave=False,
            dynamic_ncols=True,
    ):

        # Clear gradients (ready to accumulate)
        optimizer.zero_grad()

        loss = 0
        for _ in range(CONFIG.SOLVER.ITER_SIZE):
            try:
                images, labels = next(loader_iter)
            except:
                loader_iter = iter(loader)
                images, labels = next(loader_iter)

            images = images.to(device)
            labels = labels.to(device)

            # Propagate forward
            logits = model(images)

            # Loss
            iter_loss = 0
            for logit in logits:
                # Resize labels for {100%, 75%, 50%, Max} logits
                _, _, H, W = logit.shape
                labels_ = resize_labels(labels, shape=(H, W))
                iter_loss += criterion(logit, labels_)

            # Backpropagate (just compute gradients wrt the loss)
            iter_loss /= CONFIG.SOLVER.ITER_SIZE
            iter_loss.backward()

            loss += float(iter_loss)

        average_loss.add(loss)

        # Update weights with accumulated gradients
        optimizer.step()

        # Update learning rate
        scheduler.step(epoch=iteration)

        # TensorBoard
        if iteration % CONFIG.SOLVER.ITER_TB == 0:
            writer.add_scalar("loss/train", average_loss.value()[0], iteration)
            for i, o in enumerate(optimizer.param_groups):
                writer.add_scalar("lr/group{}".format(i), o["lr"], iteration)
            if False:  # This produces a large log file
                for name, param in model.named_parameters():
                    name = name.replace(".", "/")
                    # Weight/gradient distribution
                    writer.add_histogram(name, param, iteration, bins="auto")
                    if param.requires_grad:
                        writer.add_histogram(name + "/grad",
                                             param.grad,
                                             iteration,
                                             bins="auto")

        # Save a model
        if iteration % CONFIG.SOLVER.ITER_SAVE == 0:
            torch.save(
                model.module.state_dict(),
                osp.join(CONFIG.MODEL.SAVE_DIR,
                         "checkpoint_{}.pth".format(iteration)),
            )

        # To verify progress separately
        torch.save(
            model.module.state_dict(),
            osp.join(CONFIG.MODEL.SAVE_DIR, "checkpoint_current.pth"),
        )

    torch.save(
        model.module.state_dict(),
        osp.join(CONFIG.MODEL.SAVE_DIR, "checkpoint_final.pth"),
    )
Exemplo n.º 7
0
def test(config, model_path, cuda, crf):
    # Disable autograd globally
    torch.set_grad_enabled(False)

    # Setup
    device = get_device(cuda)
    CONFIG = Dict(yaml.load(open(config)))

    # If the image size never change,
    if CONFIG.DATASET.WARP_IMAGE:
        # Auto-tune cuDNN
        torch.backends.cudnn.benchmark = True

    # Dataset 10k or 164k
    dataset = get_dataset(CONFIG.DATASET.NAME)(
        root=CONFIG.DATASET.ROOT,
        split=CONFIG.DATASET.SPLIT.VAL,
        base_size=CONFIG.IMAGE.SIZE.TEST,
        crop_size=None,
        mean=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R),
        warp=CONFIG.DATASET.WARP_IMAGE,
        scale=None,
        flip=False,
    )

    # DataLoader
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=CONFIG.SOLVER.BATCH_SIZE.TEST,
        num_workers=CONFIG.DATALOADER.NUM_WORKERS,
        shuffle=False,
    )

    # Model
    model = setup_model(model_path, CONFIG.DATASET.N_CLASSES, train=False)
    model.to(device)

    # CRF post-processor
    postprocessor = DenseCRF(
        iter_max=CONFIG.CRF.ITER_MAX,
        pos_xy_std=CONFIG.CRF.POS_XY_STD,
        pos_w=CONFIG.CRF.POS_W,
        bi_xy_std=CONFIG.CRF.BI_XY_STD,
        bi_rgb_std=CONFIG.CRF.BI_RGB_STD,
        bi_w=CONFIG.CRF.BI_W,
    )

    preds, gts = [], []
    for images, labels in tqdm(loader,
                               total=len(loader),
                               leave=False,
                               dynamic_ncols=True):
        # Image
        images = images.to(device)
        _, H, W = labels.shape

        # Forward propagation
        logits = model(images)
        logits = F.interpolate(logits,
                               size=(H, W),
                               mode="bilinear",
                               align_corners=True)
        probs = F.softmax(logits, dim=1)
        probs = probs.data.cpu().numpy()

        # Postprocessing
        if crf:
            # images: (B,C,H,W) -> (B,H,W,C)
            images = images.data.cpu().numpy().astype(np.uint8).transpose(
                0, 2, 3, 1)
            probs = joblib.Parallel(n_jobs=-1)([
                joblib.delayed(postprocessor)(*pair)
                for pair in zip(images, probs)
            ])

        labelmaps = np.argmax(probs, axis=1)

        preds += list(labelmaps)
        gts += list(labels.numpy())

    # Pixel Accuracy, Mean Accuracy, Class IoU, Mean IoU, Freq Weighted IoU
    score = scores(gts, preds, n_class=CONFIG.DATASET.N_CLASSES)

    with open(model_path.replace(".pth", ".json"), "w") as f:
        json.dump(score, f, indent=4, sort_keys=True)
Exemplo n.º 8
0
def main():
    opts = get_argparser().parse_args()
    print(opts)

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)
    """
    Training DeepLab by v2 protocol
    """
    # Configuration

    with open(opts.config_path) as f:
        CONFIG = Dict(yaml.load(f))

    device = get_device(opts.cuda)
    torch.backends.cudnn.benchmark = True

    # Dataset
    train_dataset = get_dataset(CONFIG.DATASET.NAME)(
        root=CONFIG.DATASET.ROOT,
        split=CONFIG.DATASET.SPLIT.TRAIN,
        ignore_label=CONFIG.DATASET.IGNORE_LABEL,
        mean_bgr=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G,
                  CONFIG.IMAGE.MEAN.R),
        augment=True,
        base_size=CONFIG.IMAGE.SIZE.BASE,
        crop_size=CONFIG.IMAGE.SIZE.TRAIN,
        scales=CONFIG.DATASET.SCALES,
        flip=True,
        gt_path=opts.gt_path,
    )
    print(train_dataset)
    print()

    valid_dataset = get_dataset(CONFIG.DATASET.NAME)(
        root=CONFIG.DATASET.ROOT,
        split=CONFIG.DATASET.SPLIT.VAL,
        ignore_label=CONFIG.DATASET.IGNORE_LABEL,
        mean_bgr=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G,
                  CONFIG.IMAGE.MEAN.R),
        augment=False,
        gt_path="SegmentationClassAug",
    )
    print(valid_dataset)

    # DataLoader
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=CONFIG.SOLVER.BATCH_SIZE.TRAIN,
        num_workers=CONFIG.DATALOADER.NUM_WORKERS,
        shuffle=True,
        pin_memory=True,
        drop_last=True,
    )
    valid_loader = torch.utils.data.DataLoader(
        dataset=valid_dataset,
        batch_size=CONFIG.SOLVER.BATCH_SIZE.TEST,
        num_workers=CONFIG.DATALOADER.NUM_WORKERS,
        shuffle=False,
        pin_memory=True,
    )

    # Model check
    print("Model:", CONFIG.MODEL.NAME)
    assert (CONFIG.MODEL.NAME == "DeepLabV2_ResNet101_MSC"
            ), 'Currently support only "DeepLabV2_ResNet101_MSC"'

    # Model setup
    model = eval(CONFIG.MODEL.NAME)(n_classes=CONFIG.DATASET.N_CLASSES)
    print("    Init:", CONFIG.MODEL.INIT_MODEL)
    state_dict = torch.load(CONFIG.MODEL.INIT_MODEL, map_location='cpu')

    for m in model.base.state_dict().keys():
        if m not in state_dict.keys():
            print("    Skip init:", m)

    model.base.load_state_dict(state_dict, strict=False)  # to skip ASPP
    model = nn.DataParallel(model)
    model.to(device)

    # Loss definition
    criterion = nn.CrossEntropyLoss(ignore_index=CONFIG.DATASET.IGNORE_LABEL)
    criterion.to(device)

    # Optimizer
    optimizer = torch.optim.SGD(
        # cf lr_mult and decay_mult in train.prototxt
        params=[
            {
                "params": get_params(model.module, key="1x"),
                "lr": CONFIG.SOLVER.LR,
                "weight_decay": CONFIG.SOLVER.WEIGHT_DECAY,
            },
            {
                "params": get_params(model.module, key="10x"),
                "lr": 10 * CONFIG.SOLVER.LR,
                "weight_decay": CONFIG.SOLVER.WEIGHT_DECAY,
            },
            {
                "params": get_params(model.module, key="20x"),
                "lr": 20 * CONFIG.SOLVER.LR,
                "weight_decay": 0.0,
            },
        ],
        momentum=CONFIG.SOLVER.MOMENTUM,
    )

    # Learning rate scheduler
    scheduler = PolynomialLR(
        optimizer=optimizer,
        step_size=CONFIG.SOLVER.LR_DECAY,
        iter_max=CONFIG.SOLVER.ITER_MAX,
        power=CONFIG.SOLVER.POLY_POWER,
    )

    # Path to save models
    checkpoint_dir = os.path.join(
        CONFIG.EXP.OUTPUT_DIR,
        "models",
        opts.log_dir,
        CONFIG.MODEL.NAME.lower(),
        CONFIG.DATASET.SPLIT.TRAIN,
    )
    makedirs(checkpoint_dir)
    print("Checkpoint dst:", checkpoint_dir)

    def set_train(model):
        model.train()
        model.module.base.freeze_bn()

    metrics = StreamSegMetrics(CONFIG.DATASET.N_CLASSES)

    scaler = torch.cuda.amp.GradScaler(enabled=opts.amp)
    avg_loss = AverageMeter()
    avg_time = AverageMeter()

    set_train(model)
    best_score = 0
    end_time = time.time()

    for iteration in range(1, CONFIG.SOLVER.ITER_MAX + 1):
        # Clear gradients (ready to accumulate)
        optimizer.zero_grad()

        loss = 0
        for _ in range(CONFIG.SOLVER.ITER_SIZE):
            try:
                _, images, labels, cls_labels = next(train_loader_iter)
            except:
                train_loader_iter = iter(train_loader)
                _, images, labels, cls_labels = next(train_loader_iter)
                avg_loss.reset()
                avg_time.reset()

            with torch.cuda.amp.autocast(enabled=opts.amp):
                # Propagate forward
                logits = model(images.to(device, non_blocking=True))

                # Loss
                iter_loss = 0
                for logit in logits:
                    # Resize labels for {100%, 75%, 50%, Max} logits
                    _, _, H, W = logit.shape
                    labels_ = resize_labels(labels, size=(H, W))

                    pseudo_labels = logit.detach(
                    ) * cls_labels[:, :, None, None].to(device)
                    pseudo_labels = pseudo_labels.argmax(dim=1)

                    _loss = criterion(logit, labels_.to(device, )) + criterion(
                        logit, pseudo_labels)

                    iter_loss += _loss

                # Propagate backward (just compute gradients wrt the loss)
                iter_loss /= CONFIG.SOLVER.ITER_SIZE

            scaler.scale(iter_loss).backward()
            loss += iter_loss.item()

        # Update weights with accumulated gradients
        scaler.step(optimizer)
        scaler.update()

        # Update learning rate
        scheduler.step(epoch=iteration)

        avg_loss.update(loss)
        avg_time.update(time.time() - end_time)
        end_time = time.time()

        # TensorBoard
        if iteration % 100 == 0:
            print(" Itrs %d/%d, Loss=%6f, Time=%.2f , LR=%.8f" %
                  (iteration, CONFIG.SOLVER.ITER_MAX, avg_loss.avg,
                   avg_time.avg * 1000, optimizer.param_groups[0]['lr']))

        # validation
        if iteration % opts.val_interval == 0:
            print("... validation")
            model.eval()
            metrics.reset()
            with torch.no_grad():
                for _, images, labels, _ in valid_loader:
                    images = images.to(device, non_blocking=True)

                    # Forward propagation
                    logits = model(images)

                    # Pixel-wise labeling
                    _, H, W = labels.shape
                    logits = F.interpolate(logits,
                                           size=(H, W),
                                           mode="bilinear",
                                           align_corners=False)
                    preds = torch.argmax(logits, dim=1).cpu().numpy()
                    targets = labels.cpu().numpy()
                    metrics.update(targets, preds)

            set_train(model)
            score = metrics.get_results()
            print(metrics.to_str(score))

            if score['Mean IoU'] > best_score:  # save best model
                best_score = score['Mean IoU']
                torch.save(model.module.state_dict(),
                           os.path.join(checkpoint_dir, "checkpoint_best.pth"))
            end_time = time.time()
Exemplo n.º 9
0
def main(config, cuda):
    # Configuration
    with open(config) as f:
        CONFIG = yaml.load(f)

    cuda = cuda and torch.cuda.is_available()

    # Dataset
    dataset = get_dataset(CONFIG['DATASET'])(
        root=CONFIG['ROOT'],
        split='train',
        image_size=(CONFIG['IMAGE']['SIZE']['TRAIN'],
                    CONFIG['IMAGE']['SIZE']['TRAIN']),
        scale=True,
        flip=True,
        # preload=True
    )

    # DataLoader
    loader = torch.utils.data.DataLoader(dataset=dataset,
                                         batch_size=CONFIG['BATCH_SIZE'],
                                         num_workers=CONFIG['NUM_WORKERS'],
                                         shuffle=True)
    loader_iter = iter(loader)

    # Model
    model = DeepLabV2_ResNet101_MSC(n_classes=CONFIG['N_CLASSES'])
    state_dict = torch.load(CONFIG['INIT_MODEL'])
    model.load_state_dict(state_dict, strict=False)  # Skip "aspp" layer
    if cuda:
        model.cuda()

    # Optimizer
    optimizer = {
        'sgd':
        torch.optim.SGD(
            params=[
                {
                    'params': get_1x_lr_params(model),
                    'lr': float(CONFIG['LR'])
                },
                {
                    'params': get_10x_lr_params(model),
                    'lr': 10 * float(CONFIG['LR'])
                }  # NOQA
            ],
            lr=float(CONFIG['LR']),
            momentum=float(CONFIG['MOMENTUM']),
            weight_decay=float(CONFIG['WEIGHT_DECAY'])),
    }.get(CONFIG['OPTIMIZER'])

    # Loss definition
    criterion = CrossEntropyLoss2d(ignore_index=CONFIG['IGNORE_LABEL'])
    if cuda:
        criterion.cuda()

    # TensorBoard Logger
    writer = SummaryWriter(CONFIG['LOG_DIR'])
    loss_meter = MovingAverageValueMeter(20)

    model.train()
    for iteration in tqdm(range(1, CONFIG['ITER_MAX'] + 1),
                          total=CONFIG['ITER_MAX'],
                          leave=False,
                          dynamic_ncols=True):

        # Polynomial lr decay
        poly_lr_scheduler(optimizer=optimizer,
                          init_lr=float(CONFIG['LR']),
                          iter=iteration - 1,
                          lr_decay_iter=CONFIG['LR_DECAY'],
                          max_iter=CONFIG['ITER_MAX'],
                          power=CONFIG['POLY_POWER'])

        optimizer.zero_grad()

        iter_loss = 0
        for i in range(1, CONFIG['ITER_SIZE'] + 1):
            data, target = next(loader_iter)

            # Image
            data = data.cuda() if cuda else data
            data = Variable(data)

            # Forward propagation
            outputs = model(data)

            # Label
            target = resize_target(target, outputs[0].size(2))
            target = target.cuda() if cuda else target
            target = Variable(target)

            # Aggregate losses for [100%, 75%, 50%, Max]
            loss = 0
            for output in outputs:
                loss += criterion(output, target)

            loss /= CONFIG['ITER_SIZE']
            iter_loss += loss.data[0]
            loss.backward()

            # Reload dataloader
            if ((iteration - 1) * CONFIG['ITER_SIZE'] + i) % len(loader) == 0:
                loader_iter = iter(loader)

        loss_meter.add(iter_loss)

        # Back propagation
        optimizer.step()

        # TensorBoard
        if iteration % CONFIG['ITER_TF'] == 0:
            writer.add_scalar('train_loss', loss_meter.value()[0], iteration)

        # Save a model
        if iteration % CONFIG['ITER_SNAP'] == 0:
            torch.save(
                model.state_dict(),
                osp.join(CONFIG['SAVE_DIR'],
                         'checkpoint_{}.pth.tar'.format(iteration)))  # NOQA
            writer.add_text('log', 'Saved a model', iteration)

    torch.save(model.state_dict(),
               osp.join(CONFIG['SAVE_DIR'], 'checkpoint_final.pth.tar'))
Exemplo n.º 10
0
def test(config_path, model_path, cuda):
    """
    Evaluation on validation set
    """

    # Configuration
    CONFIG = OmegaConf.load(config_path)
    device = get_device(cuda)
    torch.set_grad_enabled(False)

    # Dataset
    dataset = get_dataset(CONFIG.DATASET.NAME)(
        root=CONFIG.DATASET.ROOT,
        split=CONFIG.DATASET.SPLIT.VAL,
        ignore_label=CONFIG.DATASET.IGNORE_LABEL,
        mean_bgr=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R),
        augment=False,
    )
    print(dataset)

    # DataLoader
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=CONFIG.SOLVER.BATCH_SIZE.TEST,
        num_workers=CONFIG.DATALOADER.NUM_WORKERS,
        shuffle=False,
    )

    # Model
    model = eval(CONFIG.MODEL.NAME)(n_classes=CONFIG.DATASET.N_CLASSES)
    state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
    model.load_state_dict(state_dict)
    model = nn.DataParallel(model)
    model.eval()
    model.to(device)

    # Path to save logits
    logit_dir = os.path.join(
        CONFIG.EXP.OUTPUT_DIR,
        "features",
        CONFIG.EXP.ID,
        CONFIG.MODEL.NAME.lower(),
        CONFIG.DATASET.SPLIT.VAL,
        "logit",
    )
    makedirs(logit_dir)
    print("Logit dst:", logit_dir)

    # Path to save scores
    save_dir = os.path.join(
        CONFIG.EXP.OUTPUT_DIR,
        "scores",
        CONFIG.EXP.ID,
        CONFIG.MODEL.NAME.lower(),
        CONFIG.DATASET.SPLIT.VAL,
    )
    makedirs(save_dir)
    save_path = os.path.join(save_dir, "scores.json")
    print("Score dst:", save_path)

    preds, gts = [], []
    for image_ids, images, gt_labels in tqdm(
        loader, total=len(loader), dynamic_ncols=True
    ):
        # Image
        images = images.to(device)

        # Forward propagation
        logits = model(images)

        # Save on disk for CRF post-processing
        for image_id, logit in zip(image_ids, logits):
            filename = os.path.join(logit_dir, image_id + ".npy")
            np.save(filename, logit.cpu().numpy())
Exemplo n.º 11
0
def main(config, excludeval, embedding, model_path, run, cuda, crf, redo,
         imagedataset, threshold):
    pth_extn = '.pth.tar'
    if osp.isfile(model_path.replace(
            pth_extn, "_" + run + ".json")) and not threshold and not redo:
        print("Already Done!")
        with open(model_path.replace(pth_extn,
                                     "_" + run + ".json")) as json_file:
            data = json.load(json_file)
            for key, value in data.items():
                if not key == "Class IoU":
                    print(key, value)
        sys.exit()

    cuda = cuda and torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")

    if cuda:
        current_device = torch.cuda.current_device()
        print("Running on", torch.cuda.get_device_name(current_device))
    else:
        print("Running on CPU")

    # Configuration
    CONFIG = Dict(yaml.load(open(config)))

    datadir = os.path.join('data/datasets', imagedataset)
    print("Split dir: ", datadir)
    savedir = osp.dirname(model_path)
    epoch = re.findall("checkpoint_(.*)\." + pth_extn[1:],
                       osp.basename(model_path))[-1]
    val = None
    visible_classes = None

    if run == 'zlss' or run == 'flss':
        val = np.load(datadir + '/split/test_list.npy')
        visible_classes = np.load(datadir + '/split/novel_cls.npy')
    elif run == 'gzlss' or run == 'gflss':
        val = np.load(datadir + '/split/test_list.npy')
        if excludeval:
            vals_cls = np.asarray(np.load(datadir + '/split/seen_cls.npy'),
                                  dtype=int)
        else:
            vals_cls = np.asarray(np.concatenate([
                np.load(datadir + '/split/seen_cls.npy'),
                np.load(datadir + '/split/val_cls.npy')
            ]),
                                  dtype=int)
        valu_cls = np.load(datadir + '/split/novel_cls.npy')
        visible_classes = np.concatenate([vals_cls, valu_cls])
    else:
        print("invalid run ", run)
        sys.exit()

    if threshold is not None and run != 'gzlss':
        print("invalid run for threshold", run)
        sys.exit()

    cls_map = np.array([255] * 256)
    for i, n in enumerate(visible_classes):
        cls_map[n] = i

    if threshold is not None:
        savedir = osp.join(savedir, str(threshold))

    if crf is not None:
        savedir = savedir + '-crf'

    if run == 'gzlss' or run == 'gflss':

        novel_cls_map = np.array([255] * 256)
        for i, n in enumerate(list(valu_cls)):
            novel_cls_map[cls_map[n]] = i

        seen_cls_map = np.array([255] * 256)
        for i, n in enumerate(list(vals_cls)):
            seen_cls_map[cls_map[n]] = i

        if threshold is not None:

            thresholdv = np.asarray(np.zeros((visible_classes.shape[0], 1)),
                                    dtype=np.float)
            thresholdv[np.in1d(visible_classes, vals_cls), 0] = threshold
            thresholdv = torch.tensor(thresholdv).float().cuda()

    visible_classesp = np.concatenate([visible_classes, [255]])

    all_labels = np.genfromtxt(datadir + '/labels_2.txt',
                               delimiter='\t',
                               usecols=1,
                               dtype='str')

    print("Visible Classes: ", visible_classes)

    # Dataset
    dataset = get_dataset(CONFIG.DATASET)(
        train=None,
        test=val,
        root=CONFIG.ROOT,
        split=CONFIG.SPLIT.TEST,
        base_size=CONFIG.IMAGE.SIZE.TEST,
        mean=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R),
        warp=CONFIG.WARP_IMAGE,
        scale=None,
        flip=False,
    )

    if embedding == 'word2vec':
        class_emb = pickle.load(
            open(datadir + '/word_vectors/word2vec.pkl', "rb"))
    elif embedding == 'fasttext':
        class_emb = pickle.load(
            open(datadir + '/word_vectors/fasttext.pkl', "rb"))
    elif embedding == 'fastnvec':
        class_emb = np.concatenate([
            pickle.load(open(datadir + '/word_vectors/fasttext.pkl', "rb")),
            pickle.load(open(datadir + '/word_vectors/word2vec.pkl', "rb"))
        ],
                                   axis=1)
    else:
        print("invalid emb ", embedding)
        sys.exit()

    class_emb = class_emb[visible_classes]
    class_emb = F.normalize(torch.tensor(class_emb), p=2, dim=1).cuda()

    print("Embedding dim: ", class_emb.shape[1])
    print("# Visible Classes: ", class_emb.shape[0])

    # DataLoader
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=CONFIG.BATCH_SIZE.TEST,
        num_workers=CONFIG.NUM_WORKERS,
        shuffle=False,
    )

    torch.set_grad_enabled(False)

    # Model
    model = DeepLabV2_ResNet101_MSC(class_emb.shape[1], class_emb)

    sdir = osp.join(savedir, model_path.replace(pth_extn, ""), str(epoch), run)

    state_dict = torch.load(model_path,
                            map_location=lambda storage, loc: storage)
    model = nn.DataParallel(model)
    model.load_state_dict(state_dict['state_dict'])
    model.eval()
    model.to(device)
    imgfeat = []
    targets, outputs = [], []
    for data, target, img_id in tqdm(loader,
                                     total=len(loader),
                                     leave=False,
                                     dynamic_ncols=True):
        # Image
        data = data.to(device)
        # Forward propagation
        output = model(data)
        output = F.interpolate(output,
                               size=data.shape[2:],
                               mode="bilinear",
                               align_corners=False)

        output = F.softmax(output, dim=1)
        if threshold is not None:
            output = output - thresholdv.view(1, -1, 1, 1)

        target = cls_map[target.numpy()]

        # Postprocessing
        if crf:
            output = output.data.cpu().numpy()
            crf_output = np.zeros(output.shape)
            images = data.data.cpu().numpy().astype(np.uint8)
            for i, (image, prob_map) in enumerate(zip(images, output)):
                image = image.transpose(1, 2, 0)
                crf_output[i] = dense_crf(image, prob_map)
            output = crf_output
            output = np.argmax(output, axis=1)
        else:
            output = torch.argmax(output, dim=1).cpu().numpy()

        for o, t in zip(output, target):
            outputs.append(o)
            targets.append(t)

    if run == 'gzlss' or run == 'gflss':
        score, class_iou = scores_gzsl(targets,
                                       outputs,
                                       n_class=len(visible_classes),
                                       seen_cls=cls_map[vals_cls],
                                       unseen_cls=cls_map[valu_cls])
    else:
        score, class_iou = scores(targets,
                                  outputs,
                                  n_class=len(visible_classes))

    for k, v in score.items():
        print(k, v)

    score["Class IoU"] = {}
    for i in range(len(visible_classes)):
        score["Class IoU"][all_labels[visible_classes[i]]] = class_iou[i]

    if threshold is not None:
        with open(
                model_path.replace(pth_extn, "_" + run + '_T' +
                                   str(threshold) + ".json"), "w") as f:
            json.dump(score, f, indent=4, sort_keys=True)
    else:
        with open(model_path.replace(pth_extn, "_" + run + ".json"), "w") as f:
            json.dump(score, f, indent=4, sort_keys=True)

    print(score["Class IoU"])
Exemplo n.º 12
0
def main(config, model_path, cuda):
    # Configuration
    with open(config) as f:
        CONFIG = yaml.load(f)

    cuda = cuda and torch.cuda.is_available()

    image_size = (CONFIG['IMAGE']['SIZE']['TEST'],
                  CONFIG['IMAGE']['SIZE']['TEST'])
    n_classes = CONFIG['N_CLASSES']

    # Dataset
    dataset = get_dataset(CONFIG['DATASET'])(root=CONFIG['ROOT'],
                                             split='test',
                                             image_size=image_size,
                                             scale=False,
                                             flip=False,
                                             preload=False)

    # DataLoader
    loader = torch.utils.data.DataLoader(dataset=dataset,
                                         batch_size=CONFIG['BATCH_SIZE'],
                                         num_workers=CONFIG['NUM_WORKERS'],
                                         shuffle=False)

    state_dict = torch.load(model_path,
                            map_location=lambda storage, loc: storage)

    # Model
    model = DeepLabV2_ResNet101_MSC(n_classes=n_classes)
    model.load_state_dict(state_dict)
    model.eval()
    if cuda:
        model.cuda()

    targets, outputs = [], []
    for data, target in tqdm(loader,
                             total=len(loader),
                             leave=False,
                             dynamic_ncols=True):
        # Image
        data = data.cuda() if cuda else data
        data = Variable(data, volatile=True)

        # Forward propagation
        output = model(data)
        output = F.upsample(output, size=image_size, mode='bilinear')
        output = F.softmax(output, dim=1)
        output = output.data.cpu().numpy()

        crf_output = np.zeros(output.shape)
        images = data.data.cpu().numpy().astype(np.uint8)
        for i, (image, prob_map) in enumerate(zip(images, output)):
            image = image.transpose(1, 2, 0)
            crf_output[i] = dense_crf(image, prob_map)
        output = crf_output

        output = np.argmax(output, axis=1)
        target = target.numpy()

        for o, t in zip(output, target):
            outputs.append(o)
            targets.append(t)

    score, class_iou = scores(targets, outputs, n_class=n_classes)

    for k, v in score.items():
        print k, v

    score['Class IoU'] = {}
    for i in range(n_classes):
        score['Class IoU'][i] = class_iou[i]

    with open('results.json', 'w') as f:
        json.dump(score, f)
Exemplo n.º 13
0
def main(config, model_path, cuda, crf):
    cuda = cuda and torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")

    if cuda:
        current_device = torch.cuda.current_device()
        print("Running on", torch.cuda.get_device_name(current_device))
    else:
        print("Running on CPU")

    # Configuration
    CONFIG = Dict(yaml.load(open(config)))

    # Dataset 10k or 164k
    dataset = get_dataset(CONFIG.DATASET)(
        root=CONFIG.ROOT,
        split=CONFIG.SPLIT.VAL,
        base_size=CONFIG.IMAGE.SIZE.TEST,
        mean=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R),
        warp=CONFIG.WARP_IMAGE,
        scale=None,
        flip=False,
    )

    # DataLoader
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=CONFIG.BATCH_SIZE.TEST,
        num_workers=CONFIG.NUM_WORKERS,
        shuffle=False,
    )

    torch.set_grad_enabled(False)

    # Model
    model = DeepLabV2_ResNet101_MSC(n_classes=CONFIG.N_CLASSES)
    state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
    model.load_state_dict(state_dict)
    model = nn.DataParallel(model)
    model.eval()
    model.to(device)

    preds, gts = [], []
    for images, labels in tqdm(
        loader, total=len(loader), leave=False, dynamic_ncols=True
    ):
        # Image
        images = images.to(device)

        # Forward propagation
        logits = model(images)
        logits = F.interpolate(
            logits, size=images.shape[2:], mode="bilinear", align_corners=True
        )
        probs = F.softmax(logits, dim=1)
        probs = probs.data.cpu().numpy()

        # Postprocessing
        if crf:
            pool = mp.Pool(mp.cpu_count())
            images = images.data.cpu().numpy().astype(np.uint8).transpose(0, 2, 3, 1)
            probs = pool.map(dense_crf_wrapper, zip(images, probs))
            pool.close()

        preds += list(np.argmax(probs, axis=1))
        gts += list(labels.numpy())

    score = scores(gts, preds, n_class=CONFIG.N_CLASSES)

    with open(model_path.replace(".pth", ".json"), "w") as f:
        json.dump(score, f, indent=4, sort_keys=True)
Exemplo n.º 14
0
def main(config, cuda, excludeval, embedding, continue_from, nolog, inputmix,
         imagedataset, experimentid, nshot, ishot):
    frame = inspect.currentframe()
    args, _, _, values = inspect.getargvalues(frame)
    #print(values)

    #in case you want to save to the location of script you're running
    datadir = os.path.join(
        '/home/SharedData/omkar/zscoseg/yash_manas/data/datasets',
        imagedataset)
    if not nolog:
        #name the savedir, might add logs/ before the datetime for clarity
        if experimentid is None:
            savedir = time.strftime('%Y%m%d%H%M%S')
        else:
            savedir = experimentid
        #the full savepath is then:
        savepath = os.path.join('logs', imagedataset, savedir)
        #in case the folder has not been created yet / except already exists error:
        try:
            os.makedirs(savepath)
            print("Log dir:", savepath)
        except:
            pass
        if continue_from is None:
            #now join the path in save_screenshot:
            shutil.copytree('./libs/', savepath + '/libs')
            shutil.copy2(osp.abspath(inspect.stack()[0][1]), savepath)
            shutil.copy2(config, savepath)
            args_dict = {}
            for a in args:
                args_dict[a] = values[a]
            with open(savepath + '/args.json', 'w') as fp:
                json.dump(args_dict, fp)

    cuda = cuda and torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")

    if cuda:
        current_device = torch.cuda.current_device()
        print("Running on", torch.cuda.get_device_name(current_device))
    else:
        print("Running on CPU")

    # Configuration
    CONFIG = Dict(yaml.load(open(config), Loader=yaml.FullLoader))
    visibility_mask = {}
    if excludeval:
        seen_classes = np.load(datadir + '/split/seen_cls.npy')
    else:
        seen_classes = np.asarray(np.concatenate([
            np.load(datadir + '/split/seen_cls.npy'),
            np.load(datadir + '/split/val_cls.npy')
        ]),
                                  dtype=int)

    novel_classes = np.load(datadir + '/split/novel_cls.npy')
    seen_novel_classes = np.concatenate([seen_classes, novel_classes])

    seen_map = np.array([-1] * 256)
    for i, n in enumerate(list(seen_classes)):
        seen_map[n] = i

    visibility_mask[0] = seen_map.copy()
    for i, n in enumerate(list(novel_classes)):
        visibility_mask[i + 1] = seen_map.copy()
        visibility_mask[i + 1][n] = seen_classes.shape[0] + i
    if excludeval:
        train = np.load(datadir + '/split/train_list.npy')[:-CONFIG.VAL_SIZE]
    else:
        train = np.load(datadir + '/split/train_list.npy')

    novelset = []
    seenset = []

    if inputmix == 'novel' or inputmix == 'both':
        inverse_dict = pickle.load(
            open(datadir + '/split/inverse_dict_train.pkl', 'rb'))
        for icls, key in enumerate(novel_classes):
            if (inverse_dict[key].size > 0):
                for v in inverse_dict[key][ishot * 20:ishot * 20 + nshot]:
                    novelset.append((v, icls))
                    #print((v, icls))

    if inputmix == 'both':
        seenset = []
        inverse_dict = pickle.load(
            open(datadir + '/split/inverse_dict_train.pkl', 'rb'))
        for icls, key in enumerate(seen_classes):
            if (inverse_dict[key].size > 0):
                for v in inverse_dict[key][ishot * 20:ishot * 20 + nshot]:
                    seenset.append(v)

    if inputmix == 'seen':
        seenset = range(train.shape[0])

    sampler = RandomImageSampler(seenset, novelset)

    if inputmix == 'novel':
        visible_classes = seen_novel_classes
        if nshot is not None:
            nshot = str(nshot) + 'n'
    elif inputmix == 'seen':
        visible_classes = seen_classes
        if nshot is not None:
            nshot = str(nshot) + 's'
    elif inputmix == 'both':
        visible_classes = seen_novel_classes
        if nshot is not None:
            nshot = str(nshot) + 'b'

    print("Visible classes:", visible_classes.size, " \nClasses are: ",
          visible_classes, "\nTrain Images:", train.shape[0])

    #a Dataset 10k or 164k
    dataset = get_dataset(CONFIG.DATASET)(train=train,
                                          test=None,
                                          root=CONFIG.ROOT,
                                          split=CONFIG.SPLIT.TRAIN,
                                          base_size=513,
                                          crop_size=CONFIG.IMAGE.SIZE.TRAIN,
                                          mean=(CONFIG.IMAGE.MEAN.B,
                                                CONFIG.IMAGE.MEAN.G,
                                                CONFIG.IMAGE.MEAN.R),
                                          warp=CONFIG.WARP_IMAGE,
                                          scale=(0.5, 1.5),
                                          flip=True,
                                          visibility_mask=visibility_mask)

    # DataLoader
    loader = torch.utils.data.DataLoader(dataset=dataset,
                                         batch_size=CONFIG.BATCH_SIZE.TRAIN,
                                         num_workers=CONFIG.NUM_WORKERS,
                                         sampler=sampler)

    if embedding == 'word2vec':
        class_emb = pickle.load(
            open(datadir + '/word_vectors/word2vec.pkl', "rb"))
    elif embedding == 'fasttext':
        class_emb = pickle.load(
            open(datadir + '/word_vectors/fasttext.pkl', "rb"))
    elif embedding == 'fastnvec':
        class_emb = np.concatenate([
            pickle.load(open(datadir + '/word_vectors/fasttext.pkl', "rb")),
            pickle.load(open(datadir + '/word_vectors/word2vec.pkl', "rb"))
        ],
                                   axis=1)
    else:
        print("invalid emb ", embedding)
        sys.exit()

    print((class_emb.shape))
    class_emb = F.normalize(torch.tensor(class_emb), p=2, dim=1).cuda()

    loader_iter = iter(loader)
    DeepLab = DeepLabV2_ResNet101_MSC
    #import ipdb; ipdb.set_trace()
    state_dict = torch.load(CONFIG.INIT_MODEL)

    # Model load
    model = DeepLab(class_emb.shape[1], class_emb[visible_classes])
    if continue_from is not None and continue_from > 0:
        print("Loading checkpoint: {}".format(continue_from))
        #import ipdb; ipdb.set_trace()
        model = nn.DataParallel(model)
        state_file = osp.join(savepath,
                              "checkpoint_{}.pth".format(continue_from))
        if osp.isfile(state_file + '.tar'):
            state_dict = torch.load(state_file + '.tar')
            model.load_state_dict(state_dict['state_dict'], strict=True)
        elif osp.isfile(state_file):
            state_dict = torch.load(state_file)
            model.load_state_dict(state_dict, strict=True)
        else:
            print("Checkpoint {} not found".format(continue_from))
            sys.exit()

    else:
        model.load_state_dict(
            state_dict, strict=False
        )  # make strict=True to debug if checkpoint is loaded correctly or not if performance is low
        model = nn.DataParallel(model)
    model.to(device)
    # Optimizer

    optimizer = {
        "sgd":
        torch.optim.SGD(
            # cf lr_mult and decay_mult in train.prototxt
            params=[{
                "params": get_params(model.module, key="1x"),
                "lr": CONFIG.LR,
                "weight_decay": CONFIG.WEIGHT_DECAY,
            }, {
                "params": get_params(model.module, key="10x"),
                "lr": 10 * CONFIG.LR,
                "weight_decay": CONFIG.WEIGHT_DECAY,
            }, {
                "params": get_params(model.module, key="20x"),
                "lr": 20 * CONFIG.LR,
                "weight_decay": 0.0,
            }],
            momentum=CONFIG.MOMENTUM,
        ),
        "adam":
        torch.optim.Adam(
            # cf lr_mult and decay_mult in train.prototxt
            params=[{
                "params": get_params(model.module, key="1x"),
                "lr": CONFIG.LR,
                "weight_decay": CONFIG.WEIGHT_DECAY,
            }, {
                "params": get_params(model.module, key="10x"),
                "lr": 10 * CONFIG.LR,
                "weight_decay": CONFIG.WEIGHT_DECAY,
            }, {
                "params": get_params(model.module, key="20x"),
                "lr": 20 * CONFIG.LR,
                "weight_decay": 0.0,
            }])
        # Add any other optimizer
    }.get(CONFIG.OPTIMIZER)

    if 'optimizer' in state_dict:
        optimizer.load_state_dict(state_dict['optimizer'])
    print("Learning rate:", CONFIG.LR)
    # Loss definition
    criterion = nn.CrossEntropyLoss(ignore_index=-1)
    criterion.to(device)

    if not nolog:
        # TensorBoard Logger
        if continue_from is not None:
            writer = SummaryWriter(
                savepath +
                '/runs/fs_{}_{}_{}'.format(continue_from, nshot, ishot))
        else:
            writer = SummaryWriter(savepath + '/runs')
        loss_meter = MovingAverageValueMeter(20)

    model.train()
    model.module.scale.freeze_bn()

    pbar = tqdm(
        range(1, CONFIG.ITER_MAX + 1),
        total=CONFIG.ITER_MAX,
        leave=False,
        dynamic_ncols=True,
    )
    for iteration in pbar:

        # Set a learning rate
        poly_lr_scheduler(
            optimizer=optimizer,
            init_lr=CONFIG.LR,
            iter=iteration - 1,
            lr_decay_iter=CONFIG.LR_DECAY,
            max_iter=CONFIG.ITER_MAX,
            power=CONFIG.POLY_POWER,
        )

        # Clear gradients (ready to accumulate)
        optimizer.zero_grad()

        iter_loss = 0
        for i in range(1, CONFIG.ITER_SIZE + 1):
            try:
                data, target = next(loader_iter)
            except:
                loader_iter = iter(loader)
                data, target = next(loader_iter)

            # Image
            data = data.to(device)

            # Propagate forward
            outputs = model(data)
            # Loss
            loss = 0
            for output in outputs:
                # Resize target for {100%, 75%, 50%, Max} outputs
                target_ = resize_target(target, output.size(2))
                target_ = torch.tensor(target_).to(device)
                loss += criterion.forward(output, target_)

            # Backpropagate (just compute gradients wrt the loss)
            #print(loss)
            loss /= float(CONFIG.ITER_SIZE)
            loss.backward()

            iter_loss += float(loss)
            del data, target, outputs

        #print(iter_loss)
        pbar.set_postfix(loss="%.3f" % iter_loss)

        # Update weights with accumulated gradients
        optimizer.step()
        if not nolog:
            loss_meter.add(iter_loss)
            # TensorBoard
            if iteration % CONFIG.ITER_TB == 0:
                writer.add_scalar("train_loss",
                                  loss_meter.value()[0], iteration)
                for i, o in enumerate(optimizer.param_groups):
                    writer.add_scalar("train_lr_group{}".format(i), o["lr"],
                                      iteration)
                if False:  # This produces a large log file
                    for name, param in model.named_parameters():
                        name = name.replace(".", "/")
                        writer.add_histogram(name,
                                             param,
                                             iteration,
                                             bins="auto")
                        if param.requires_grad:
                            writer.add_histogram(name + "/grad",
                                                 param.grad,
                                                 iteration,
                                                 bins="auto")

            # Save a model
            if continue_from is not None:
                if iteration in CONFIG.ITER_SAVE:
                    torch.save(
                        {
                            'iteration': iteration,
                            'state_dict': model.state_dict(),
                        },
                        osp.join(
                            savepath, "checkpoint_{}_{}_{}_{}.pth.tar".format(
                                continue_from, nshot, ishot, iteration)),
                    )

                # Save a model (short term) [unnecessary for fewshot]
                if False and iteration % 100 == 0:
                    torch.save(
                        {
                            'iteration': iteration,
                            'state_dict': model.state_dict(),
                        },
                        osp.join(
                            savepath,
                            "checkpoint_{}_{}_{}_current.pth.tar".format(
                                continue_from, nshot, ishot)),
                    )
                    print(
                        osp.join(
                            savepath,
                            "checkpoint_{}_{}_{}_current.pth.tar".format(
                                continue_from, nshot, ishot)))
            else:
                if iteration % CONFIG.ITER_SAVE == 0:
                    torch.save(
                        {
                            'iteration': iteration,
                            'state_dict': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                        },
                        osp.join(savepath,
                                 "checkpoint_{}.pth.tar".format(iteration)),
                    )

                # Save a model (short term)
                if iteration % 100 == 0:
                    torch.save(
                        {
                            'iteration': iteration,
                            'state_dict': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                        },
                        osp.join(savepath, "checkpoint_current.pth.tar"),
                    )

        torch.cuda.empty_cache()

    if not nolog:
        if continue_from is not None:
            torch.save(
                {
                    'iteration': iteration,
                    'state_dict': model.state_dict(),
                },
                osp.join(
                    savepath, "checkpoint_{}_{}_{}_{}.pth.tar".format(
                        continue_from, nshot, ishot, CONFIG.ITER_MAX)))
        else:
            torch.save(
                {
                    'iteration': iteration,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                },
                osp.join(savepath,
                         "checkpoint_{}.pth.tar".format(CONFIG.ITER_MAX)))
Exemplo n.º 15
0
def main():
    """
    Acquire args and config
    """
    args = parse_args()
    assert (os.path.exists(args.config))
    assert (args.schedule in ['step1', 'mixed', 'st', 'st_mixed'])
    assert ((args.multigpus == False and args.ngpu >= 0)
            or (args.multigpus == True and args.ngpu > 1))
    assert (not (args.val and args.resume_from > 0))
    config = get_config(args.config)
    assert (not (args.val and config['init_model'] == 'none'
                 and args.init_model == 'none'))
    if args.init_model != 'none':
        assert (os.path.exists(args.init_model))
        config['init_model'] = args.init_model
    """
    Path to save results.
    """
    dataset_path = os.path.join(config['save_path'], config['dataset'])
    if not os.path.exists(dataset_path):
        os.makedirs(dataset_path)

    save_path = os.path.join(dataset_path, args.experimentid)
    if not os.path.exists(save_path) and not args.val:
        os.makedirs(save_path)

    if args.schedule == 'step1':
        model_path = os.path.join(save_path, 'models')
    elif args.schedule == 'mixed':
        model_path = os.path.join(save_path, 'models_transfer')
    elif args.schedule == 'st':
        model_path = os.path.join(save_path, 'models_st')
    else:
        model_path = os.path.join(save_path, 'models_st_transfer')
    if args.resume_from > 0:
        assert (os.path.exists(model_path))
    if not os.path.exists(model_path) and not args.val:
        os.makedirs(model_path)

    if args.schedule == 'step1':
        log_file = os.path.join(save_path, 'logs.txt')
    elif args.schedule == 'mixed':
        log_file = os.path.join(save_path, 'logs_transfer.txt')
    elif args.schedule == 'st':
        log_file = os.path.join(save_path, 'logs_st.txt')
    else:
        log_file = os.path.join(save_path, 'logs_st_transfer.txt')
    if args.val:
        log_file = os.path.join(dataset_path, 'logs_test.txt')
    logger = logWritter(log_file)

    if args.schedule == 'step1':
        config_path = os.path.join(save_path, 'configs.yaml')
    elif args.schedule == 'mixed':
        config_path = os.path.join(save_path, 'configs_transfer.yaml')
    elif args.schedule == 'st':
        config_path = os.path.join(save_path, 'configs_st.yaml')
    else:
        config_path = os.path.join(save_path, 'configs_st_transfer.yaml')
    """
    Start
    """
    if args.val:
        print("\n***Testing of model {0}***\n".format(config['init_model']))
        logger.write("\n***Testing of model {0}***\n".format(
            config['init_model']))
    else:
        print("\n***Training of model {0}***\n".format(args.experimentid))
        logger.write("\n***Training of model {0}***\n".format(
            args.experimentid))
    """
    Continue train or train from scratch
    """
    if args.resume_from >= 1:
        assert (args.val == False)
        if not os.path.exists(config_path):
            assert 0, "Old config not found."
        config_old = get_config(config_path)
        if config['save_path'] != config_old['save_path'] or config[
                'dataset'] != config_old['dataset']:
            assert 0, "New config does not coordinate with old config."
        config = config_old
        start_iter = args.resume_from
        print(
            "Continue training from Iter - [{0:0>6d}] ...".format(start_iter +
                                                                  1))
        logger.write(
            "Continue training from Iter - [{0:0>6d}] ...".format(start_iter +
                                                                  1))
    else:
        start_iter = 0
        if not args.val:
            shutil.copy(args.config, config_path)
            print("Train from scratch ...")
            logger.write("Train from scratch ...")
    """
    Modify config
    """
    if args.schedule == 'step1':
        config['back_scheduler']['init_lr'] = config['back_opt']['lr']
    elif args.schedule == 'mixed':
        config['back_scheduler']['init_lr_transfer'] = config['back_opt'][
            'lr_transfer']
    elif args.schedule == 'st':
        config['back_scheduler']['init_lr_st'] = config['back_opt']['lr_st']
    else:
        config['back_scheduler']['init_lr_st_transfer'] = config['back_opt'][
            'lr_st_transfer']

    if args.schedule == 'step1':
        config['back_scheduler']['max_iter'] = config['ITER_MAX']
    elif args.schedule == 'mixed':
        config['back_scheduler']['max_iter_transfer'] = config[
            'ITER_MAX_TRANSFER']
    elif args.schedule == 'st':
        config['back_scheduler']['max_iter_st'] = config['ITER_MAX_ST']
    else:
        config['back_scheduler']['max_iter_st_transfer'] = config[
            'ITER_MAX_ST_TRANSFER']
    """
    Schedule method
    """
    s = "Schedule method: {0}".format(args.schedule)
    if args.schedule == 'mixed' or args.schedule == 'st_mixed':
        s += ", interval_step1={0}, interval_step2={1}".format(
            config['interval_step1'], config['interval_step2'])
    s += '\n'
    print(s)
    logger.write(s)
    """
    Use GPU
    """
    device = torch.device("cuda")
    if not args.multigpus:
        torch.cuda.set_device(args.ngpu)
    torch.backends.cudnn.benchmark = True
    """
    Get dataLoader
    """
    vals_cls, valu_cls, all_labels, visible_classes, visible_classes_test, train, val, sampler, visibility_mask, cls_map, cls_map_test = get_split(
        config)
    assert (visible_classes_test.shape[0] == config['dis']['out_dim_cls'] - 1)

    dataset = get_dataset(config['DATAMODE'])(
        train=train,
        test=None,
        root=config['ROOT'],
        split=config['SPLIT']['TRAIN'],
        base_size=513,
        crop_size=config['IMAGE']['SIZE']['TRAIN'],
        mean=(config['IMAGE']['MEAN']['B'], config['IMAGE']['MEAN']['G'],
              config['IMAGE']['MEAN']['R']),
        warp=config['WARP_IMAGE'],
        scale=(0.5, 1.5),
        flip=True,
        visibility_mask=visibility_mask)

    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=config['BATCH_SIZE']['TRAIN'],
        num_workers=config['NUM_WORKERS'],
        sampler=sampler)

    dataset_test = get_dataset(config['DATAMODE'])(
        train=None,
        test=val,
        root=config['ROOT'],
        split=config['SPLIT']['TEST'],
        base_size=513,
        crop_size=config['IMAGE']['SIZE']['TEST'],
        mean=(config['IMAGE']['MEAN']['B'], config['IMAGE']['MEAN']['G'],
              config['IMAGE']['MEAN']['R']),
        warp=config['WARP_IMAGE'],
        scale=None,
        flip=False)

    loader_test = torch.utils.data.DataLoader(
        dataset=dataset_test,
        batch_size=config['BATCH_SIZE']['TEST'],
        num_workers=config['NUM_WORKERS'],
        shuffle=False)
    """
    Load Class embedding
    """
    class_emb = get_embedding(config)
    class_emb_vis = class_emb[visible_classes]
    class_emb_vis_ = torch.zeros(
        (config['ignore_index'] + 1 - class_emb_vis.shape[0],
         class_emb_vis.shape[1]),
        dtype=torch.float32)
    class_emb_vis_aug = torch.cat((class_emb_vis, class_emb_vis_), dim=0)
    class_emb_all = class_emb[visible_classes_test]
    """
    Get trainer
    """
    trainer = Trainer(
        cfg=config,
        class_emb_vis=class_emb_vis_aug,
        class_emb_all=class_emb_all,
        schedule=args.schedule,
        checkpoint_dir=model_path,  # for model loading in continued train
        resume_from=start_iter  # for model loading in continued train
    ).to(device)
    if args.multigpus:
        trainer.model = torch.nn.DataParallel(trainer.model,
                                              device_ids=range(args.ngpu))
    """
    Train/Val
    """
    if args.val:
        """
        Only do validation
        """
        loader_iter_test = iter(loader_test)
        targets, outputs = [], []

        while True:
            try:
                data_test, gt_test, image_id = next(
                    loader_iter_test
                )  # gt_test: torch.LongTensor with shape (N,H,W). elements: 0-19,255 in voc12
            except:
                break  # finish test

            data_test = torch.Tensor(data_test).to(device)

            with torch.no_grad():
                try:
                    test_res = trainer.test(data_test,
                                            gt_test,
                                            multigpus=args.multigpus)
                except MeaninglessError:
                    continue  # skip meaningless batch

                pred_cls_test = test_res['pred_cls_real'].cpu(
                )  # torch.LongTensor with shape (N,H',W'). elements: 0-20 in voc12
                resized_gt_test = test_res['resized_gt'].cpu(
                )  # torch.LongTensor with shape (N,H',W'). elements: 0-19,255 in voc12

                ##### gt mapping to target #####
                resized_target = cls_map_test[resized_gt_test]

            for o, t in zip(pred_cls_test.numpy(), resized_target):
                outputs.append(o)
                targets.append(t)

        score, class_iou = scores_gzsl(targets,
                                       outputs,
                                       n_class=len(visible_classes_test),
                                       seen_cls=cls_map_test[vals_cls],
                                       unseen_cls=cls_map_test[valu_cls])

        print("Test results:")
        logger.write("Test results:")

        for k, v in score.items():
            print(k + ': ' + json.dumps(v))
            logger.write(k + ': ' + json.dumps(v))

        score["Class IoU"] = {}
        for i in range(len(visible_classes_test)):
            score["Class IoU"][all_labels[
                visible_classes_test[i]]] = class_iou[i]
        print("Class IoU: " + json.dumps(score["Class IoU"]))
        logger.write("Class IoU: " + json.dumps(score["Class IoU"]))

        print("Test finished.\n\n")
        logger.write("Test finished.\n\n")

    else:
        """
        Training loop
        """
        if args.schedule == 'step1':
            ITER_MAX = config['ITER_MAX']
        elif args.schedule == 'mixed':
            ITER_MAX = config['ITER_MAX_TRANSFER']
        elif args.schedule == 'st':
            ITER_MAX = config['ITER_MAX_ST']
        else:
            ITER_MAX = config['ITER_MAX_ST_TRANSFER']
        assert (start_iter < ITER_MAX)

        # dealing with 'st_mixed' is the same as dealing with 'mixed'
        if args.schedule == 'st_mixed':
            args.schedule = 'mixed'
        assert (args.schedule in ['step1', 'mixed', 'st'])

        if args.schedule == 'step1':
            step_scheduler = Const_Scheduler(step_n='step1')
        elif args.schedule == 'mixed':
            step_scheduler = Step_Scheduler(config['interval_step1'],
                                            config['interval_step2'],
                                            config['first'])
        else:
            step_scheduler = Const_Scheduler(step_n='self_training')

        iteration = start_iter
        loader_iter = iter(loader)
        while True:
            if iteration == start_iter or iteration % 1000 == 0:
                now_lr = trainer.get_lr()
                print("Now lr of dis: {0:.10f}".format(now_lr['dis_lr']))
                print("Now lr of gen: {0:.10f}".format(now_lr['gen_lr']))
                print("Now lr of back: {0:.10f}".format(now_lr['back_lr']))
                logger.write("Now lr of dis: {0:.10f}".format(
                    now_lr['dis_lr']))
                logger.write("Now lr of gen: {0:.10f}".format(
                    now_lr['gen_lr']))
                logger.write("Now lr of back: {0:.10f}".format(
                    now_lr['back_lr']))

                sum_loss_train = np.zeros(config['loss_count'],
                                          dtype=np.float64)
                sum_acc_real_train, sum_acc_fake_train = 0, 0
                temp_iter = 0

                sum_loss_train_transfer = 0
                sum_acc_fake_train_transfer = 0
                temp_iter_transfer = 0

            # mode should be constant 'step1' in non-zero-shot-learning
            # mode should be switched between 'step1' and 'step2' in zero-shot-learning
            mode = step_scheduler.now()
            assert (mode in ['step1', 'step2', 'self_training'])

            if mode == 'step1' or mode == 'self_training':
                try:
                    data, gt = next(loader_iter)
                except:
                    loader_iter = iter(loader)
                    data, gt = next(loader_iter)

                data = torch.Tensor(data).to(device)

            if mode == 'step1' or mode == 'step2':
                try:
                    loss = trainer.train(data,
                                         gt,
                                         mode=mode,
                                         multigpus=args.multigpus)
                except MeaninglessError:
                    print("Skipping meaningless batch...")
                    continue
            else:  # self training mode
                try:
                    with torch.no_grad():
                        test_res = trainer.test(data,
                                                gt,
                                                multigpus=args.multigpus)
                        resized_gt_for_st = test_res['resized_gt'].cpu(
                        )  # torch.LongTensor with shape (N,H',W'). elements: 0-14,255 in voc12
                        sorted_indices = test_res['sorted_indices'].cpu(
                        )  # torch.LongTensor with shape (N,H',W',C)
                        gt_new = construct_gt_st(resized_gt_for_st,
                                                 sorted_indices, config)
                    loss = trainer.train(data,
                                         gt_new,
                                         mode='step1',
                                         multigpus=args.multigpus)
                except MeaninglessError:
                    print("Skipping meaningless batch...")
                    continue

            if mode == 'step1' or mode == 'self_training':
                loss_G_GAN = loss['loss_G_GAN']
                loss_G_Content = loss['loss_G_Content']
                loss_B_KLD = loss['loss_B_KLD']
                loss_D_real = loss['loss_D_real']
                loss_D_fake = loss['loss_D_fake']
                loss_D_gp = loss['loss_D_gp']
                loss_cls_real = loss['loss_cls_real']
                loss_cls_fake = loss['loss_cls_fake']
                acc_cls_real = loss['acc_cls_real']
                acc_cls_fake = loss['acc_cls_fake']

                sum_loss_train += np.array([
                    loss_G_GAN, loss_G_Content, loss_B_KLD, loss_D_real,
                    loss_D_fake, loss_D_gp, loss_cls_real, loss_cls_fake
                ]).astype(np.float64)
                sum_acc_real_train += acc_cls_real
                sum_acc_fake_train += acc_cls_fake
                temp_iter += 1

                tal = sum_loss_train / temp_iter
                tsar = sum_acc_real_train / temp_iter
                tsaf = sum_acc_fake_train / temp_iter

                # display accumulated average loss and accuracy in step1
                if (iteration + 1) % config['display_interval'] == 0:
                    print("Iter - [{0:0>6d}] AAL: G_G-[{1:.4f}] G_C-[{2:.4f}] B_K-[{3:.4f}] D_r-[{4:.4f}] D_f-[{5:.4f}] D_gp-[{6:.4f}] cls_r-[{7:.4f}] cls_f-[{8:.4f}] Acc: cls_r-[{9:.4f}] cls_f-[{10:.4f}]".format(\
                            iteration + 1, tal[0], tal[1], tal[2], tal[3], tal[4], tal[5], tal[6], tal[7], tsar, tsaf))
                if (iteration + 1) % config['log_interval'] == 0:
                    logger.write("Iter - [{0:0>6d}] AAL: G_G-[{1:.4f}] G_C-[{2:.4f}] B_K-[{3:.4f}] D_r-[{4:.4f}] D_f-[{5:.4f}] D_gp-[{6:.4f}] cls_r-[{7:.4f}] cls_f-[{8:.4f}] Acc: cls_r-[{9:.4f}] cls_f-[{10:.4f}]".format(\
                                iteration + 1, tal[0], tal[1], tal[2], tal[3], tal[4], tal[5], tal[6], tal[7], tsar, tsaf))

            elif mode == 'step2':
                loss_cls_fake_transfer = loss['loss_cls_fake']
                acc_cls_fake_transfer = loss['acc_cls_fake']

                sum_loss_train_transfer += loss_cls_fake_transfer
                sum_acc_fake_train_transfer += acc_cls_fake_transfer
                temp_iter_transfer += 1

                talt = sum_loss_train_transfer / temp_iter_transfer
                tsaft = sum_acc_fake_train_transfer / temp_iter_transfer

                # display accumulated average loss and accuracy in step2 (transfer learning)
                if (iteration + 1) % config['display_interval'] == 0:
                    print("Iter - [{0:0>6d}] Transfer Learning: aal_cls_f-[{1:.4f}] acc_cls_f-[{2:.4f}]".format(\
                            iteration + 1, talt, tsaft))
                if (iteration + 1) % config['log_interval'] == 0:
                    logger.write("Iter - [{0:0>6d}] Transfer Learning: aal_cls_f-[{1:.4f}] acc_cls_f-[{2:.4f}]".format(\
                            iteration + 1, talt, tsaft))

            else:
                raise NotImplementedError('Mode {} not implemented' % mode)

            # Save the temporary model
            if (iteration + 1) % config['snapshot'] == 0:
                trainer.save(model_path, iteration, args.multigpus)
                print(
                    "Temporary model of Iter - [{0:0>6d}] successfully stored.\n"
                    .format(iteration + 1))
                logger.write(
                    "Temporary model of Iter - [{0:0>6d}] successfully stored.\n"
                    .format(iteration + 1))

            # Test the saved model
            if (iteration + 1) % config['snapshot'] == 0:
                print(
                    "Testing model of Iter - [{0:0>6d}] ...".format(iteration +
                                                                    1))
                logger.write(
                    "Testing model of Iter - [{0:0>6d}] ...".format(iteration +
                                                                    1))

                loader_iter_test = iter(loader_test)
                targets, outputs = [], []

                while True:
                    try:
                        data_test, gt_test, image_id = next(
                            loader_iter_test
                        )  # gt_test: torch.LongTensor with shape (N,H,W). elements: 0-19,255 in voc12
                    except:
                        break  # finish test

                    data_test = torch.Tensor(data_test).to(device)

                    with torch.no_grad():
                        try:
                            test_res = trainer.test(data_test,
                                                    gt_test,
                                                    multigpus=args.multigpus)
                        except MeaninglessError:
                            continue  # skip meaningless batch

                        pred_cls_test = test_res['pred_cls_real'].cpu(
                        )  # torch.LongTensor with shape (N,H',W'). elements: 0-20 in voc12
                        resized_gt_test = test_res['resized_gt'].cpu(
                        )  # torch.LongTensor with shape (N,H',W'). elements: 0-19,255 in voc12

                        ##### gt mapping to target #####
                        resized_target = cls_map_test[resized_gt_test]

                    for o, t in zip(pred_cls_test.numpy(), resized_target):
                        outputs.append(o)
                        targets.append(t)

                score, class_iou = scores_gzsl(
                    targets,
                    outputs,
                    n_class=len(visible_classes_test),
                    seen_cls=cls_map_test[vals_cls],
                    unseen_cls=cls_map_test[valu_cls])

                print("Test results:")
                logger.write("Test results:")

                for k, v in score.items():
                    print(k + ': ' + json.dumps(v))
                    logger.write(k + ': ' + json.dumps(v))

                score["Class IoU"] = {}
                for i in range(len(visible_classes_test)):
                    score["Class IoU"][all_labels[
                        visible_classes_test[i]]] = class_iou[i]
                print("Class IoU: " + json.dumps(score["Class IoU"]))
                logger.write("Class IoU: " + json.dumps(score["Class IoU"]))

                print("Test finished.\n")
                logger.write("Test finished.\n")

            step_scheduler.step()

            iteration += 1
            if iteration == ITER_MAX:
                break

        print("Train finished.\n\n")
        logger.write("Train finished.\n\n")
Exemplo n.º 16
0
def train(config_path, cuda):
    """
    Training DeepLab by v2 protocol
    """

    # Configuration
    CONFIG = Dict(yaml.load(config_path))
    device = get_device(cuda)
    torch.backends.cudnn.benchmark = True

    # Dataset
    dataset = get_dataset(CONFIG.DATASET.NAME)(
        root=CONFIG.DATASET.ROOT,
        split=CONFIG.DATASET.SPLIT.TRAIN,
        ignore_label=CONFIG.DATASET.IGNORE_LABEL,
        mean_bgr=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R),
        augment=True,
        base_size=CONFIG.IMAGE.SIZE.BASE,
        crop_size=CONFIG.IMAGE.SIZE.TRAIN,
        scales=CONFIG.DATASET.SCALES,
        flip=True,
    )
    print(dataset)

    # DataLoader
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=CONFIG.SOLVER.BATCH_SIZE.TRAIN,
        num_workers=CONFIG.DATALOADER.NUM_WORKERS,
        shuffle=True,
    )
    loader_iter = iter(loader)

    # Model check
    print("Model:", CONFIG.MODEL.NAME)
    assert (
        CONFIG.MODEL.NAME == "DeepLabV2_ResNet101_MSC"
    ), 'Currently support only "DeepLabV2_ResNet101_MSC"'

    # Model setup
    model = eval(CONFIG.MODEL.NAME)(n_classes=CONFIG.DATASET.N_CLASSES)
    state_dict = torch.load(CONFIG.MODEL.INIT_MODEL)
    print("    Init:", CONFIG.MODEL.INIT_MODEL)
    for m in model.base.state_dict().keys():
        if m not in state_dict.keys():
            print("    Skip init:", m)
    model.base.load_state_dict(state_dict, strict=False)  # to skip ASPP
    model = nn.DataParallel(model)
    model.to(device)

    # Loss definition
    criterion = nn.CrossEntropyLoss(ignore_index=CONFIG.DATASET.IGNORE_LABEL)
    criterion.to(device)

    # Optimizer
    optimizer = torch.optim.SGD(
        # cf lr_mult and decay_mult in train.prototxt
        params=[
            {
                "params": get_params(model.module, key="1x"),
                "lr": CONFIG.SOLVER.LR,
                "weight_decay": CONFIG.SOLVER.WEIGHT_DECAY,
            },
            {
                "params": get_params(model.module, key="10x"),
                "lr": 10 * CONFIG.SOLVER.LR,
                "weight_decay": CONFIG.SOLVER.WEIGHT_DECAY,
            },
            {
                "params": get_params(model.module, key="20x"),
                "lr": 20 * CONFIG.SOLVER.LR,
                "weight_decay": 0.0,
            },
        ],
        momentum=CONFIG.SOLVER.MOMENTUM,
    )

    # Learning rate scheduler
    scheduler = PolynomialLR(
        optimizer=optimizer,
        step_size=CONFIG.SOLVER.LR_DECAY,
        iter_max=CONFIG.SOLVER.ITER_MAX,
        power=CONFIG.SOLVER.POLY_POWER,
    )

    # Setup loss logger
    writer = SummaryWriter(os.path.join(CONFIG.EXP.OUTPUT_DIR, "logs", CONFIG.EXP.ID))
    average_loss = MovingAverageValueMeter(CONFIG.SOLVER.AVERAGE_LOSS)

    # Path to save models
    checkpoint_dir = os.path.join(
        CONFIG.EXP.OUTPUT_DIR,
        "models",
        CONFIG.EXP.ID,
        CONFIG.MODEL.NAME.lower(),
        CONFIG.DATASET.SPLIT.TRAIN,
    )
    makedirs(checkpoint_dir)
    print("Checkpoint dst:", checkpoint_dir)

    # Freeze the batch norm pre-trained on COCO
    model.train()
    model.module.base.freeze_bn()

    for iteration in tqdm(
        range(1, CONFIG.SOLVER.ITER_MAX + 1),
        total=CONFIG.SOLVER.ITER_MAX,
        dynamic_ncols=True,
    ):

        # Clear gradients (ready to accumulate)
        optimizer.zero_grad()

        loss = 0
        for _ in range(CONFIG.SOLVER.ITER_SIZE):
            try:
                _, images, labels = next(loader_iter)
            except:
                loader_iter = iter(loader)
                _, images, labels = next(loader_iter)

            # Propagate forward
            logits = model(images.to(device))

            # Loss
            iter_loss = 0
            for logit in logits:
                # Resize labels for {100%, 75%, 50%, Max} logits
                _, _, H, W = logit.shape
                labels_ = resize_labels(labels, size=(H, W))
                iter_loss += criterion(logit, labels_.to(device))

            # Propagate backward (just compute gradients wrt the loss)
            iter_loss /= CONFIG.SOLVER.ITER_SIZE
            iter_loss.backward()

            loss += float(iter_loss)

        #print(loss)
        average_loss.add(loss)

        # Update weights with accumulated gradients
        optimizer.step()

        # Update learning rate
        scheduler.step(epoch=iteration)

        # TensorBoard
        if iteration % CONFIG.SOLVER.ITER_TB == 0:
            writer.add_scalar("loss/train", average_loss.value()[0], iteration)
            for i, o in enumerate(optimizer.param_groups):
                writer.add_scalar("lr/group_{}".format(i), o["lr"], iteration)
            for i in range(torch.cuda.device_count()):
                writer.add_scalar(
                    "gpu/device_{}/memory_cached".format(i),
                    torch.cuda.memory_cached(i) / 1024 ** 3,
                    iteration,
                )

            if False:
                for name, param in model.module.base.named_parameters():
                    name = name.replace(".", "/")
                    # Weight/gradient distribution
                    writer.add_histogram(name, param, iteration, bins="auto")
                    if param.requires_grad:
                        writer.add_histogram(
                            name + "/grad", param.grad, iteration, bins="auto"
                        )

        # Save a model
        if iteration % CONFIG.SOLVER.ITER_SAVE == 0:
            torch.save(
                model.module.state_dict(),
                os.path.join(checkpoint_dir, "checkpoint_{}.pth".format(iteration)),
            )

    torch.save(
        model.module.state_dict(), os.path.join(checkpoint_dir, "checkpoint_final.pth")
    )
Exemplo n.º 17
0
def test(config_path, model_path, cuda):
    """
    Evaluation on validation set
    """

    # Configuration
    CONFIG = Dict(yaml.load(config_path))
    device = get_device(cuda)
    torch.set_grad_enabled(False)

    # Dataset
    dataset = get_dataset(CONFIG.DATASET.NAME)(
        root=CONFIG.DATASET.ROOT,
        split=CONFIG.DATASET.SPLIT.VAL,
        ignore_label=CONFIG.DATASET.IGNORE_LABEL,
        mean_bgr=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R),
        augment=False,
    )
    print(dataset)

    # DataLoader
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=CONFIG.SOLVER.BATCH_SIZE.TEST,
        num_workers=CONFIG.DATALOADER.NUM_WORKERS,
        shuffle=False,
    )

    # Model
    model = eval(CONFIG.MODEL.NAME)(n_classes=CONFIG.DATASET.N_CLASSES)
    state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
    model.load_state_dict(state_dict)
    model = nn.DataParallel(model)
    model.eval()
    model.to(device)

    # Path to save logits
    logit_dir = os.path.join(
        CONFIG.EXP.OUTPUT_DIR,
        "features",
        CONFIG.EXP.ID,
        CONFIG.MODEL.NAME.lower(),
        CONFIG.DATASET.SPLIT.VAL,
        "logit",
    )
    makedirs(logit_dir)
    print("Logit dst:", logit_dir)

    # Path to save scores
    save_dir = os.path.join(
        CONFIG.EXP.OUTPUT_DIR,
        "scores",
        CONFIG.EXP.ID,
        CONFIG.MODEL.NAME.lower(),
        CONFIG.DATASET.SPLIT.VAL,
    )
    makedirs(save_dir)
    save_path = os.path.join(save_dir, "scores.json")
    print("Score dst:", save_path)

    preds, gts = [], []
    for image_ids, images, gt_labels in tqdm(
        loader, total=len(loader), dynamic_ncols=True
    ):
        # Image
        images = images.to(device)

        # Forward propagation
        logits = model(images)

        """
        # Save on disk for CRF post-processing
        for image_id, logit in zip(image_ids, logits):
            filename = os.path.join(logit_dir, image_id + ".npy")
            np.save(filename, logit.cpu().numpy())

        # Pixel-wise labeling
        _, H, W = gt_labels.shape
        logits = F.interpolate(
            logits, size=(H, W), mode="bilinear", align_corners=False
        )
        probs = F.softmax(logits, dim=1)
        labels = torch.argmax(probs, dim=1)

        preds += list(labels.cpu().numpy())
        gts += list(gt_labels.numpy())

    # Pixel Accuracy, Mean Accuracy, Class IoU, Mean IoU, Freq Weighted IoU
    score = scores(gts, preds, n_class=CONFIG.DATASET.N_CLASSES)

    with open(save_path, "w") as f:
        json.dump(score, f, indent=4, sort_keys=True)


@main.command()
@click.option(
    "-c",
    "--config-path",
    type=click.File(),
    required=True,
    help="Dataset configuration file in YAML",
)
@click.option(
    "-j",
    "--n-jobs",
    type=int,
    default=multiprocessing.cpu_count(),
    show_default=True,
    help="Number of parallel jobs",
)
def crf(config_path, n_jobs):
    """
    CRF post-processing on pre-computed logits
    """

    # Configuration
    CONFIG = Dict(yaml.load(config_path))
    torch.set_grad_enabled(False)
    print("# jobs:", n_jobs)

    # Dataset
    dataset = get_dataset(CONFIG.DATASET.NAME)(
        root=CONFIG.DATASET.ROOT,
        split=CONFIG.DATASET.SPLIT.VAL,
        ignore_label=CONFIG.DATASET.IGNORE_LABEL,
        mean_bgr=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R),
        augment=False,
    )
    print(dataset)

    # CRF post-processor
    postprocessor = DenseCRF(
        iter_max=CONFIG.CRF.ITER_MAX,
        pos_xy_std=CONFIG.CRF.POS_XY_STD,
        pos_w=CONFIG.CRF.POS_W,
        bi_xy_std=CONFIG.CRF.BI_XY_STD,
        bi_rgb_std=CONFIG.CRF.BI_RGB_STD,
        bi_w=CONFIG.CRF.BI_W,
    )

    # Path to logit files
    logit_dir = os.path.join(
        CONFIG.EXP.OUTPUT_DIR,
        "features",
        CONFIG.EXP.ID,
        CONFIG.MODEL.NAME.lower(),
        CONFIG.DATASET.SPLIT.VAL,
        "logit",
    )
    print("Logit src:", logit_dir)
    if not os.path.isdir(logit_dir):
        print("Logit not found, run first: python main.py test [OPTIONS]")
        quit()

    # Path to save scores
    save_dir = os.path.join(
        CONFIG.EXP.OUTPUT_DIR,
        "scores",
        CONFIG.EXP.ID,
        CONFIG.MODEL.NAME.lower(),
        CONFIG.DATASET.SPLIT.VAL,
    )
    makedirs(save_dir)
    save_path = os.path.join(save_dir, "scores_crf.json")
    print("Score dst:", save_path)

    # Process per sample
    def process(i):
        image_id, image, gt_label = dataset.__getitem__(i)

        filename = os.path.join(logit_dir, image_id + ".npy")
        logit = np.load(filename)

        _, H, W = image.shape
        logit = torch.FloatTensor(logit)[None, ...]
        logit = F.interpolate(logit, size=(H, W), mode="bilinear", align_corners=False)
        prob = F.softmax(logit, dim=1)[0].numpy()

        image = image.astype(np.uint8).transpose(1, 2, 0)
        prob = postprocessor(image, prob)
        label = np.argmax(prob, axis=0)

        return label, gt_label

    # CRF in multi-process
    results = joblib.Parallel(n_jobs=n_jobs, verbose=10, pre_dispatch="all")(
        [joblib.delayed(process)(i) for i in range(len(dataset))]
    )

    preds, gts = zip(*results)

    # Pixel Accuracy, Mean Accuracy, Class IoU, Mean IoU, Freq Weighted IoU
    score = scores(gts, preds, n_class=CONFIG.DATASET.N_CLASSES)

    with open(save_path, "w") as f:
        json.dump(score, f, indent=4, sort_keys=True)
Exemplo n.º 18
0
def main(config, cuda):
    cuda = cuda and torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")

    if cuda:
        current_device = torch.cuda.current_device()
        print("Running on", torch.cuda.get_device_name(current_device))
    else:
        print("Running on CPU")

    # Configuration
    CONFIG = Dict(yaml.load(open(config)))

    dataset = get_dataset(CONFIG.DATASET)(
        data_path=CONFIG.ROOT,
        crop_size=256,
        scale=(0.6, 0.8, 1., 1.2, 1.4),
        rotation=15,
        flip=True,
        mean=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R),
    )
    """
    # Dataset 10k or 164k
    dataset = get_dataset(CONFIG.DATASET)(
        root=CONFIG.ROOT,
        split=CONFIG.SPLIT.TRAIN,
        base_size=513,
        crop_size=CONFIG.IMAGE.SIZE.TRAIN,
        mean=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R),
        warp=CONFIG.WARP_IMAGE,
        scale=(0.5, 0.75, 1.0, 1.25, 1.5),
        flip=True,
    )
    """

    # DataLoader
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=CONFIG.BATCH_SIZE.TRAIN,
        num_workers=CONFIG.NUM_WORKERS,
        shuffle=True,
    )
    loader_iter = iter(loader)

    # Model
    model = DeepLabV3Plus_ResNet101_MSC(n_classes=CONFIG.N_CLASSES)
    state_dict = torch.load(CONFIG.INIT_MODEL)
    model.load_state_dict(state_dict, strict=False)  # Skip "aspp" layer
    model = nn.DataParallel(model)
    model.to(device)

    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name)

    # Optimizer
    optimizer = torch.optim.Adam(
        params=get_params(model.module),
        lr=CONFIG.LR,
        weight_decay=CONFIG.WEIGHT_DECAY,
    )
    """
    # Optimizer
    optimizer = torch.optim.SGD(
        # cf lr_mult and decay_mult in train.prototxt
        params=[
            {
                "params": get_params(model.module, key="1x"),
                "lr": CONFIG.LR,
                "weight_decay": CONFIG.WEIGHT_DECAY,
            },
            {
                "params": get_params(model.module, key="10x"),
                "lr": 10 * CONFIG.LR,
                "weight_decay": CONFIG.WEIGHT_DECAY,
            },
            {
                "params": get_params(model.module, key="20x"),
                "lr": 20 * CONFIG.LR,
                "weight_decay": 0.0,
            },
        ],
        momentum=CONFIG.MOMENTUM,
    )
    """
    # Loss definition
    criterion = CrossEntropyLoss2d(ignore_index=CONFIG.IGNORE_LABEL)
    criterion.to(device)
    max_pooling_loss = MaxPoolingLoss(ratio=0.3, p=1.7, reduce=True)

    # TensorBoard Logger
    writer = SummaryWriter(CONFIG.LOG_DIR)
    loss_meter = MovingAverageValueMeter(20)

    model.train()
    model.module.scale.freeze_bn()

    for iteration in tqdm(
            range(1, CONFIG.ITER_MAX + 1),
            total=CONFIG.ITER_MAX,
            leave=False,
            dynamic_ncols=True,
    ):
        """
        # Set a learning rate
        poly_lr_scheduler(
            optimizer=optimizer,
            init_lr=CONFIG.LR,
            iter=iteration - 1,
            lr_decay_iter=CONFIG.LR_DECAY,
            max_iter=CONFIG.ITER_MAX,
            power=CONFIG.POLY_POWER,
        )
        """

        # Clear gradients (ready to accumulate)
        optimizer.zero_grad()

        iter_loss = 0
        for i in range(1, CONFIG.ITER_SIZE + 1):
            try:
                images, labels = next(loader_iter)
            except:
                loader_iter = iter(loader)
                images, labels = next(loader_iter)

            images = images.to(device)
            labels = labels.to(device).unsqueeze(1).float()

            # Propagate forward
            logits = model(images)

            # Loss
            loss = 0
            for logit in logits:
                # Resize labels for {100%, 75%, 50%, Max} logits
                labels_ = F.interpolate(labels,
                                        logit.shape[2:],
                                        mode="nearest")
                labels_ = labels_.squeeze(1).long()
                # Compute NLL and MPL
                nll_loss = criterion(logit, labels_)
                # loss += nll_loss
                loss += max_pooling_loss(nll_loss)

            # Backpropagate (just compute gradients wrt the loss)
            loss /= float(CONFIG.ITER_SIZE)
            loss.backward()

            iter_loss += float(loss)

        loss_meter.add(iter_loss)

        # Update weights with accumulated gradients
        optimizer.step()

        if iteration % CONFIG.ITER_TB == 0:
            writer.add_scalar("train_loss", loss_meter.value()[0], iteration)
            for i, o in enumerate(optimizer.param_groups):
                writer.add_scalar("train_lr_group{}".format(i), o["lr"],
                                  iteration)

            gt_viz, images_viz, predicts_viz = make_vizs(
                images, labels_, logits,
                (CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G,
                 CONFIG.IMAGE.MEAN.R))
            writer.add_image("gt/images", torch.from_numpy(images_viz[0]),
                             iteration)
            writer.add_image("gt/labels", torch.from_numpy(gt_viz[0]),
                             iteration)
            for i, predict_viz in enumerate(predicts_viz):
                writer.add_image("predict/" + str(i),
                                 torch.from_numpy(predict_viz[0]), iteration)

            if False:  # This produces a large log file
                for name, param in model.named_parameters():
                    name = name.replace(".", "/")
                    writer.add_histogram(name, param, iteration, bins="auto")
                    if param.requires_grad:
                        writer.add_histogram(name + "/grad",
                                             param.grad,
                                             iteration,
                                             bins="auto")

        # Save a model
        if iteration % CONFIG.ITER_SAVE == 0:
            torch.save(
                model.module.state_dict(),
                osp.join(CONFIG.SAVE_DIR,
                         "checkpoint_{}.pth".format(iteration)),
            )

        # Save a model (short term)
        if iteration % 100 == 0:
            torch.save(
                model.module.state_dict(),
                osp.join(CONFIG.SAVE_DIR, "checkpoint_current.pth"),
            )

    torch.save(model.module.state_dict(),
               osp.join(CONFIG.SAVE_DIR, "checkpoint_final.pth"))
Exemplo n.º 19
0
def main(config, model_path, cuda, crf):
    cuda = cuda and torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")

    if cuda:
        current_device = torch.cuda.current_device()
        print("Running on", torch.cuda.get_device_name(current_device))
    else:
        print("Running on CPU")

    # Configuration
    CONFIG = Dict(yaml.load(open(config)))

    # Dataset 10k or 164k
    dataset = get_dataset(CONFIG.DATASET)(
        root=CONFIG.ROOT,
        split=CONFIG.SPLIT.VAL,
        base_size=CONFIG.IMAGE.SIZE.TEST,
        mean=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R),
        warp=CONFIG.WARP_IMAGE,
        scale=None,
        flip=False,
    )

    # DataLoader
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=CONFIG.BATCH_SIZE.TEST,
        num_workers=CONFIG.NUM_WORKERS,
        shuffle=False,
    )

    torch.set_grad_enabled(False)

    # Model
    model = DeepLabV2_ResNet101_MSC(n_classes=CONFIG.N_CLASSES)
    state_dict = torch.load(model_path,
                            map_location=lambda storage, loc: storage)
    model.load_state_dict(state_dict)
    model = nn.DataParallel(model)
    model.eval()
    model.to(device)

    targets, outputs = [], []
    for data, target in tqdm(loader,
                             total=len(loader),
                             leave=False,
                             dynamic_ncols=True):
        # Image
        data = data.to(device)

        # Forward propagation
        output = model(data)
        output = F.interpolate(output, size=data.shape[2:], mode="bilinear")
        output = F.softmax(output, dim=1)
        output = output.data.cpu().numpy()

        # Postprocessing
        if crf:
            crf_output = np.zeros(output.shape)
            images = data.data.cpu().numpy().astype(np.uint8)
            for i, (image, prob_map) in enumerate(zip(images, output)):
                image = image.transpose(1, 2, 0)
                crf_output[i] = dense_crf(image, prob_map)
            output = crf_output

        output = np.argmax(output, axis=1)
        target = target.numpy()

        for o, t in zip(output, target):
            outputs.append(o)
            targets.append(t)

    score, class_iou = scores(targets, outputs, n_class=CONFIG.N_CLASSES)

    for k, v in score.items():
        print(k, v)

    score["Class IoU"] = {}
    for i in range(CONFIG.N_CLASSES):
        score["Class IoU"][i] = class_iou[i]

    with open(model_path.replace(".pth", ".json"), "w") as f:
        json.dump(score, f, indent=4, sort_keys=True)
Exemplo n.º 20
0
def main(config, embedding, model_path, run, imagedataset, local_rank, resnet,
         bkg):

    rank, world_size, device_id, device = setup(local_rank)
    print("Local rank: {} Rank: {} World Size: {} Device_id: {} Device: {}".
          format(local_rank, rank, world_size, device_id, device))
    pth_extn = '.pth.tar'

    # Configuration
    CONFIG = Dict(yaml.load(open(config)))

    datadir = os.path.join('data/datasets', imagedataset)
    print("Split dir: ", datadir)
    savedir = osp.dirname(model_path)
    epoch = re.findall("checkpoint_(.*)\." + pth_extn[1:],
                       osp.basename(model_path))[-1]

    if run == 'zlss' or run == 'flss':
        val = np.load(datadir + '/split/test_list.npy')
        visible_classes = np.load(datadir + '/split/novel_cls.npy')
        if bkg:
            visible_classes = np.asarray(np.concatenate(
                [np.array([0]), visible_classes]),
                                         dtype=int)
    elif run == 'gzlss' or run == 'gflss':
        val = np.load(datadir + '/split/test_list.npy')

        vals_cls = np.asarray(np.concatenate([
            np.load(datadir + '/split/seen_cls.npy'),
            np.load(datadir + '/split/val_cls.npy')
        ]),
                              dtype=int)

        if bkg:
            vals_cls = np.asarray(np.concatenate([np.array([0]), vals_cls]),
                                  dtype=int)
        valu_cls = np.load(datadir + '/split/novel_cls.npy')
        visible_classes = np.concatenate([vals_cls, valu_cls])
    else:
        print("invalid run ", run)
        sys.exit()

    cls_map = np.array([255] * 256)
    for i, n in enumerate(visible_classes):
        cls_map[n] = i

    if run == 'gzlss' or run == 'gflss':

        novel_cls_map = np.array([255] * 256)
        for i, n in enumerate(list(valu_cls)):
            novel_cls_map[cls_map[n]] = i

        seen_cls_map = np.array([255] * 256)
        for i, n in enumerate(list(vals_cls)):
            seen_cls_map[cls_map[n]] = i

    all_labels = np.genfromtxt(datadir + '/labels_2.txt',
                               delimiter='\t',
                               usecols=1,
                               dtype='str')

    print("Visible Classes: ", visible_classes)

    # Dataset
    dataset = get_dataset(CONFIG.DATASET)(
        train=None,
        test=val,
        root=CONFIG.ROOT,
        split=CONFIG.SPLIT.TEST,
        base_size=CONFIG.IMAGE.SIZE.TEST,
        mean=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R),
        warp=CONFIG.WARP_IMAGE,
        scale=None,
        flip=False,
    )

    random.seed(42)

    if embedding == 'word2vec':
        class_emb = pickle.load(
            open(datadir + '/word_vectors/word2vec.pkl', "rb"))
    elif embedding == 'fasttext':
        class_emb = pickle.load(
            open(datadir + '/word_vectors/fasttext.pkl', "rb"))
    elif embedding == 'fastnvec':
        class_emb = np.concatenate([
            pickle.load(open(datadir + '/word_vectors/fasttext.pkl', "rb")),
            pickle.load(open(datadir + '/word_vectors/word2vec.pkl', "rb"))
        ],
                                   axis=1)
    else:
        print("invalid emb ", embedding)
        sys.exit()

    class_emb = class_emb[visible_classes]
    class_emb = F.normalize(torch.tensor(class_emb), p=2, dim=1).cuda()

    print("Embedding dim: ", class_emb.shape[1])
    print("# Visible Classes: ", class_emb.shape[0])

    # DataLoader
    loader = torch.utils.data.DataLoader(dataset=dataset,
                                         batch_size=CONFIG.BATCH_SIZE.TEST,
                                         num_workers=CONFIG.NUM_WORKERS,
                                         shuffle=False,
                                         sampler=DistributedSampler(
                                             dataset,
                                             num_replicas=world_size,
                                             rank=rank,
                                             shuffle=False),
                                         pin_memory=True,
                                         drop_last=True)

    torch.set_grad_enabled(False)

    # Model
    model = DeepLabV2_ResNet101_MSC(class_emb.shape[1],
                                    class_emb,
                                    resnet=resnet)

    state_dict = torch.load(model_path, map_location='cpu')
    model = DistributedDataParallel(model.to(device), device_ids=[rank])
    new_state_dict = OrderedDict()
    if resnet == 'spnet':
        for k, v in state_dict['state_dict'].items():
            name = k.replace("scale", "base")  # 'scale'->base
            name = name.replace("stages.", "")
            new_state_dict[name] = v
    else:
        new_state_dict = state_dict['state_dict']
    model.load_state_dict(new_state_dict)
    del state_dict

    model.eval()
    targets, outputs = [], []

    loader_iter = iter(loader)
    iterations = len(loader_iter)
    print("Iterations: {}".format(iterations))

    pbar = tqdm(loader,
                total=iterations,
                leave=False,
                dynamic_ncols=True,
                position=rank)
    for iteration in pbar:

        data, target, img_id = next(loader_iter)
        # Image
        data = data.to(device)
        # Forward propagation
        output = model(data)
        output = F.interpolate(output,
                               size=data.shape[2:],
                               mode="bilinear",
                               align_corners=False)

        output = F.softmax(output, dim=1)
        target = cls_map[target.numpy()]

        remote_target = torch.tensor(target).to(device)
        if rank == 0:
            remote_target = torch.zeros_like(remote_target).to(device)

        output = torch.argmax(output, dim=1).cpu().numpy()

        remote_output = torch.tensor(output).to(device)
        if rank == 0:
            remote_output = torch.zeros_like(remote_output).to(device)

        for o, t in zip(output, target):
            outputs.append(o)
            targets.append(t)

        torch.distributed.reduce(remote_output, dst=0)
        torch.distributed.reduce(remote_target, dst=0)

        torch.distributed.barrier()

        if rank == 0:
            remote_output = remote_output.cpu().numpy()
            remote_target = remote_target.cpu().numpy()
            for o, t in zip(remote_output, remote_target):
                outputs.append(o)
                targets.append(t)

    if rank == 0:

        if run == 'gzlss' or run == 'gflss':
            score, class_iou = scores_gzsl(targets,
                                           outputs,
                                           n_class=len(visible_classes),
                                           seen_cls=cls_map[vals_cls],
                                           unseen_cls=cls_map[valu_cls])
        else:
            score, class_iou = scores(targets,
                                      outputs,
                                      n_class=len(visible_classes))

        for k, v in score.items():
            print(k, v)

        score["Class IoU"] = {}
        for i in range(len(visible_classes)):
            score["Class IoU"][all_labels[visible_classes[i]]] = class_iou[i]

        name = ""
        name = model_path.replace(pth_extn, "_" + run + ".json")

        if bkg == True:
            with open(name.replace('.json', '_bkg.json'), "w") as f:
                json.dump(score, f, indent=4, sort_keys=True)
        else:
            with open(name, "w") as f:
                json.dump(score, f, indent=4, sort_keys=True)

        print(score["Class IoU"])

    return