Exemple #1
0
    # ---- build models ----
    torch.cuda.set_device(opt.gpu_device)
    # - please asign your prefer backbone in opt.
    if opt.backbone == 'Res2Net50':
        print('Backbone loading: Res2Net50')
        from Code.model_lung_infection.InfNet_Res2Net import Inf_Net
    elif opt.backbone == 'ResNet50':
        print('Backbone loading: ResNet50')
        from Code.model_lung_infection.InfNet_ResNet import Inf_Net
    elif opt.backbone == 'VGGNet16':
        print('Backbone loading: VGGNet16')
        from Code.model_lung_infection.InfNet_VGGNet import Inf_Net
    else:
        raise ValueError('Invalid backbone parameters: {}'.format(
            opt.backbone))
    model = Inf_Net(channel=opt.net_channel, n_class=opt.n_classes).cuda()

    # ---- load pre-trained weights (mode=Semi-Inf-Net) ----
    # - See Sec.2.3 of `README.md` to learn how to generate your own img/pseudo-label from scratch.
    if opt.is_semi and opt.backbone == 'Res2Net50':
        print('Loading weights from weights file trained on pseudo label')
        model.load_state_dict(
            torch.load(
                './snapshots/save_weights/Inf-Net_Pseudo/Inf-Net-100.pth'))
    else:
        print('Not loading weights from weights file')

    # weights file save path
    if opt.is_pseudo and (not opt.is_semi):
        print("Inf-Net_Pseudo")
        train_save = 'Inf-Net_Pseudo'
Exemple #2
0
def inference():
    parser = argparse.ArgumentParser()
    parser.add_argument('--testsize',
                        type=int,
                        default=352,
                        help='testing size')
    parser.add_argument('--data_path',
                        type=str,
                        default='./Dataset/TestingSet/LungInfection-Test/',
                        help='Path to test data')
    parser.add_argument(
        '--pth_path',
        type=str,
        default='./Snapshots/save_weights/Semi-Inf-Net/Semi-Inf-Net-100.pth',
        help=
        'Path to weights file. If `semi-sup`, edit it to `Semi-Inf-Net/Semi-Inf-Net-100.pth`'
    )
    parser.add_argument(
        '--save_path',
        type=str,
        default='./Results/Lung infection segmentation/Semi-Inf-Net/',
        help=
        'Path to save the predictions. if `semi-sup`, edit it to `Semi-Inf-Net`'
    )
    opt = parser.parse_args()

    print(
        "#" * 20,
        "\nStart Testing (Inf-Net)\n{}\nThis code is written for 'Inf-Net: Automatic COVID-19 Lung "
        "Infection Segmentation from CT Scans', 2020, TMI.\n"
        "----\nPlease cite the paper if you use this code and dataset. "
        "And any questions feel free to contact me "
        "via E-mail ([email protected])\n----\n".format(opt), "#" * 20)

    model = Network()
    # model = torch.nn.DataParallel(model, device_ids=[0, 1]) # uncomment it if you have multiply GPUs.
    model.load_state_dict(
        torch.load(opt.pth_path, map_location={'cuda:1': 'cuda:0'}))
    model.cuda()
    model.eval()

    image_root = '{}/Imgs/'.format(opt.data_path)
    # gt_root = '{}/GT/'.format(opt.data_path)
    test_loader = test_dataset(image_root, opt.testsize)
    os.makedirs(opt.save_path, exist_ok=True)

    for i in range(test_loader.size):
        image, name = test_loader.load_data()

        image = image.cuda()

        lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2, lateral_edge = model(
            image)

        res = lateral_map_2
        # res = F.upsample(res, size=(ori_size[1],ori_size[0]), mode='bilinear', align_corners=False)
        res = res.sigmoid().data.cpu().numpy().squeeze()
        res = (res - res.min()) / (res.max() - res.min() + 1e-8)
        misc.imsave(opt.save_path + name, res)

    print('Test Done!')