Ejemplo n.º 1
0
def run_inference(args):
    model = UNet(input_channels=3, num_classes=3)
    model.load_state_dict(torch.load(args.model_path, map_location='cpu'),
                          strict=False)
    print('Log: Loaded pretrained {}'.format(args.model_path))
    model.eval()
    # annus/Desktop/palsar/
    test_image_path = '/home/annus/Desktop/palsar/palsar_dataset_full/palsar_dataset/palsar_{}_region_{}.tif'.format(
        args.year, args.region)
    test_label_path = '/home/annus/Desktop/palsar/palsar_dataset_full/palsar_dataset/fnf_{}_region_{}.tif'.format(
        args.year, args.region)
    inference_loader = get_inference_loader(image_path=test_image_path,
                                            label_path=test_label_path,
                                            model_input_size=128,
                                            num_classes=4,
                                            one_hot=True,
                                            batch_size=args.bs,
                                            num_workers=4)
    # we need to fill our new generated test image
    generated_map = np.empty(shape=inference_loader.dataset.get_image_size())
    weights = torch.Tensor([1, 1, 2])
    focal_criterion = FocalLoss2d(weight=weights)
    un_confusion_meter = tnt.meter.ConfusionMeter(2, normalized=False)
    confusion_meter = tnt.meter.ConfusionMeter(2, normalized=True)
    total_correct, total_examples = 0, 0
    net_loss = []
    for idx, data in enumerate(inference_loader):
        coordinates, test_x, label = data['coordinates'].tolist(
        ), data['input'], data['label']
        out_x, softmaxed = model.forward(test_x)
        pred = torch.argmax(softmaxed, dim=1)
        not_one_hot_target = torch.argmax(label, dim=1)
        # convert to binary classes
        # 0-> noise, 1-> forest, 2-> non-forest, 3-> water
        pred[pred == 0] = 2
        pred[pred == 3] = 2
        not_one_hot_target[not_one_hot_target == 0] = 2
        not_one_hot_target[not_one_hot_target == 3] = 2
        # now convert 1, 2 to 0, 1
        pred -= 1
        not_one_hot_target -= 1
        pred_numpy = pred.numpy().transpose(1, 2, 0)
        for k in range(test_x.shape[0]):
            x, x_, y, y_ = coordinates[k]
            generated_map[x:x_, y:y_] = pred_numpy[:, :, k]
        loss = focal_criterion(
            softmaxed,
            not_one_hot_target)  # dice_criterion(softmaxed, label) #
        accurate = (pred == not_one_hot_target).sum().item()
        numerator = float(accurate)
        denominator = float(
            pred.view(-1).size(0))  # test_x.size(0) * dimension ** 2)
        total_correct += numerator
        total_examples += denominator
        net_loss.append(loss.item())
        un_confusion_meter.add(predicted=pred.view(-1),
                               target=not_one_hot_target.view(-1))
        confusion_meter.add(predicted=pred.view(-1),
                            target=not_one_hot_target.view(-1))
        # if idx % 5 == 0:
        accuracy = float(numerator) * 100 / denominator
        print(
            '{}, {} -> ({}/{}) output size = {}, loss = {}, accuracy = {}/{} = {:.2f}%'
            .format(args.year, args.region, idx, len(inference_loader),
                    out_x.size(), loss.item(), numerator, denominator,
                    accuracy))
        #################################
    mean_accuracy = total_correct * 100 / total_examples
    mean_loss = np.asarray(net_loss).mean()
    print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
    print('log: test:: total loss = {:.5f}, total accuracy = {:.5f}%'.format(
        mean_loss, mean_accuracy))
    print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
    print('---> Confusion Matrix:')
    print(confusion_meter.value())
    # class_names = ['background/clutter', 'buildings', 'trees', 'cars',
    #                'low_vegetation', 'impervious_surfaces', 'noise']
    with open('normalized.pkl', 'wb') as this:
        pkl.dump(confusion_meter.value(), this, protocol=pkl.HIGHEST_PROTOCOL)
    with open('un_normalized.pkl', 'wb') as this:
        pkl.dump(un_confusion_meter.value(),
                 this,
                 protocol=pkl.HIGHEST_PROTOCOL)

    # save_path = 'generated_maps/generated_{}_{}.npy'.format(args.year, args.region)
    save_path = '/home/annus/Desktop/palsar/generated_maps/using_separate_models/generated_{}_{}.npy'.format(
        args.year, args.region)
    np.save(save_path, generated_map)
    #########################################################################################3
    inference_loader.dataset.clear_mem()
    pass
            raise Exception("model type is invalid : " + m_name)

        class_weights = None
        if float(class_weight_addings[i]) > 0:
            class_weights = torch.tensor(get_class_weights(float(class_weight_addings[i]))).cuda()

        if loss_type == "cross_entropy":
            criterion = nn.CrossEntropyLoss(class_weights)
        elif loss_type == "bce":
            criterion = torch.nn.BCEWithLogitsLoss(class_weights)
        elif loss_type == "multi_soft_margin":
            criterion = nn.MultiLabelSoftMarginLoss(class_weights)
        elif loss_type == "multi_margin":
            criterion = nn.MultiLabelMarginLoss()
        elif loss_type == "focal_loss":
            criterion = FocalLoss2d(weight=class_weights)
        elif loss_type == "kldiv":
            criterion = torch.nn.KLDivLoss()
        else:
            raise Exception("loss type is invalid : " + args.loss_type)

        model = model.to(device)

        # DONOTCHANGE: They are reserved for nsml
        bind_model(model, args)
        if args.pause:
            nsml.paused(scope=locals())

        model.eval()
        transform = None
        batch_size = (256 if m_name == "Resnet18" else 32)
Ejemplo n.º 3
0
def train():
    train_transform = MyTransform(Config.f, Config.fish_size)
    train_transform.set_ext_params(Config.ext_param)
    train_transform.set_ext_param_range(Config.ext_range)
    if Config.rand_f:
        train_transform.rand_f(f_range=Config.f_range)
    if Config.rand_ext:
        train_transform.rand_ext_params()
    train_transform.set_bkg(bkg_label=20, bkg_color=[0, 0, 0])
    train_transform.set_crop(rand=Config.crop, rate=Config.crop_rate)

    # train_transform = RandOneTransform(Config.f, Config.fish_size)
    # train_transform.set_ext_params(Config.ext_param)
    # train_transform.set_ext_param_range(Config.ext_range)
    # train_transform.set_f_range(Config.f_range)
    # train_transform.set_bkg(bkg_label=20, bkg_color=[0, 0, 0])
    # train_transform.set_crop(rand=Config.crop, rate=Config.crop_rate)

    train_set = CityScape(Config.train_img_dir,
                          Config.train_annot_dir,
                          transform=train_transform)
    train_loader = DataLoader(
        train_set,
        batch_size=Config.batch_size,
        shuffle=True,
        num_workers=Config.dataloader_num_worker,
    )

    validation_set = CityScape(Config.valid_img_dir, Config.valid_annot_dir)
    validation_loader = DataLoader(validation_set,
                                   batch_size=Config.val_batch_size,
                                   shuffle=False)

    # model = ERFPSPNet(shapeHW=[640, 640], num_classes=21)
    resnet = resnet18(pretrained=True)
    model = SwiftNet(resnet, num_classes=21)
    model.to(MyDevice)

    class_weights = torch.tensor([
        8.6979065,
        8.497886,
        8.741297,
        5.983605,
        8.662319,
        8.681756,
        8.683093,
        8.763641,
        8.576978,
        2.7114885,
        6.237076,
        3.582358,
        8.439253,
        8.316548,
        8.129169,
        4.312109,
        8.170293,
        6.91469,
        8.135018,
        0.0,
        3.6,
    ]).cuda()

    # criterion = CrossEntropyLoss2d(weight=class_weights)
    criterion = FocalLoss2d(weight=class_weights)

    lr = Config.learning_rate

    # Pretrained SwiftNet optimizer
    optimizer = torch.optim.Adam(
        [
            {
                "params": model.random_init_params()
            },
            {
                "params": model.fine_tune_params(),
                "lr": 1e-4,
                "weight_decay": 2.5e-5
            },
        ],
        lr=4e-4,
        weight_decay=1e-4,
    )

    # ERFNetPSP optimizer
    # optimizer = torch.optim.Adam(model.parameters(),
    #                              lr=1e-3,
    #                              betas=(0.9, 0.999),
    #                              eps=1e-08,
    #                              weight_decay=2e-4)

    # scheduler = torch.optim.lr_scheduler.StepLR(
    #     optimizer, step_size=90, gamma=0.1)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, 200, 1e-6)

    start_epoch = 0
    step_per_epoch = math.ceil(2975 / Config.batch_size)
    writer = SummaryWriter(Config.logdir)
    # writer.add_graph(model)

    if Config.train_with_ckpt:
        checkpoint = torch.load(Config.ckpt_path)
        print("Load", Config.ckpt_path)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_epoch = checkpoint["epoch"] + 1
        val(model, validation_loader, is_print=True)
        # loss = checkpoint['loss']
        model.train()

    start_time = None
    for epoch in range(start_epoch, Config.max_epoch):

        for i, (image, annot) in enumerate(train_loader):
            if start_time is None:
                start_time = time.time()
            input = image.to(MyDevice)
            target = annot.to(MyDevice, dtype=torch.long)
            model.train()
            optimizer.zero_grad()
            score = model(input)
            # predict = torch.argmax(score, 1)
            loss = criterion(score, target)
            loss.backward()
            optimizer.step()

            global_step = step_per_epoch * epoch + i

            if i % 20 == 0:
                predict = torch.argmax(score, 1).to(MyCPU, dtype=torch.uint8)
                writer.add_image("Images/original_image",
                                 image[0],
                                 global_step=global_step)
                writer.add_image(
                    "Images/segmentation_output",
                    predict[0].view(1, 640, 640) * 10,
                    global_step=global_step,
                )
                writer.add_image(
                    "Images/segmentation_ground_truth",
                    annot[0].view(1, 640, 640) * 10,
                    global_step=global_step,
                )

            if i % 20 == 0 and global_step > 0:
                writer.add_scalar("Monitor/Loss",
                                  loss.item(),
                                  global_step=global_step)

            time_elapsed = time.time() - start_time
            start_time = time.time()
            print(
                f"{epoch}/{Config.max_epoch-1} epoch, {i}/{step_per_epoch} step, loss:{loss.item()}, "
                f"{time_elapsed} sec/step; global step={global_step}")

        scheduler.step()
        if epoch > 20:
            (
                mean_precision,
                mean_recall,
                mean_iou,
                m_precision_19,
                m_racall_19,
                m_iou_19,
            ) = val(model, validation_loader, is_print=True)

            writer.add_scalar("Monitor/precision20",
                              mean_precision,
                              global_step=epoch)
            writer.add_scalar("Monitor/recall20",
                              mean_recall,
                              global_step=epoch)
            writer.add_scalar("Monitor/mIOU20", mean_iou, global_step=epoch)
            writer.add_scalar("Monitor1/precision19",
                              m_precision_19,
                              global_step=epoch)
            writer.add_scalar("Monitor1/recall19",
                              m_racall_19,
                              global_step=epoch)
            writer.add_scalar("Monitor1/mIOU19", m_iou_19, global_step=epoch)

            print(epoch, "/", Config.max_epoch, " loss:", loss.item())
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "loss": loss.item(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                Config.ckpt_name + "_" + str(epoch) + ".pth",
            )
            print("model saved!")

    val(model, validation_loader, is_print=True)
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "loss": loss,
        },
        Config.model_path,
    )
    print("Save model to disk!")
    writer.close()
Ejemplo n.º 4
0
print('mean and std: ', datas['mean'], datas['std'])

# define loss function, respectively
weight = torch.from_numpy(datas['classWeights'])

# if(lossFunc == 'ohem'):
#     min_kept = int(batch_size // len(gpus) * h * w // 16)
#     criteria = ProbOhemCrossEntropy2d(use_weight=True, ignore_label=ignore_label, thresh=0.7, min_kept=min_kept)
# elif(lossFunc == 'label_smoothing'):
#     criteria = CrossEntropyLoss2dLabelSmooth(weight=weight, ignore_label=ignore_label)
if (lossFunc == 'CrossEntropy'):
    criteria = CrossEntropyLoss(weight=weight, ignore_index=ignore_label)
elif (lossFunc == 'LovaszSoftmax'):
    criteria = LovaszSoftmax(ignore_index=ignore_label)
elif (lossFunc == 'focal'):
    criteria = FocalLoss2d(weight=weight, ignore_index=ignore_label)
else:
    raise NotImplementedError(
        'We only support CrossEntropy, LovaszSoftmax and focal as loss function.'
    )

if use_cuda:
    criteria = criteria.to(device)
    if torch.cuda.device_count() > 1:
        print("torch.cuda.device_count()=", torch.cuda.device_count())
        gpu_nums = torch.cuda.device_count()
        model = nn.DataParallel(model).to(device)  # multi-card data parallel
    else:
        gpu_nums = 1
        print("single GPU for training")
        model = model.to(device)  # 1-card data parallel