示例#1
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("-d", "--deepsave", help="A path to save deepfeature", type=str,
                        # default='re/cap_vs_covid.npy')
                        default='deep_f')
    parser.add_argument("-e", "--exclude_list",
                        help="A path to a txt file for excluded data list. If no file need to be excluded, "
                             "it should be 'none'.", type=str,
                        default='none')
    parser.add_argument("-v", "--invert_exclude", help="Whether to invert exclude to include", type=bool,
                        default=False)
    parser.add_argument("-k", "--topk", help="gpuid", type=int,
                        default=5)
    parser.add_argument("-s", "--savenpy", help="gpuid", type=str,
                        default='top1.npy')
    args = parser.parse_args()
    os.makedirs(args.deepsave, exist_ok=True)

    print("Loading options...")
    with open('test.toml', 'r') as optionsFile:
        options = toml.loads(optionsFile.read())

    if (options["general"]["usecudnnbenchmark"] and options["general"]["usecudnn"]):
        print("Running cudnn benchmark...")
        torch.backends.cudnn.benchmark = True

    os.environ['CUDA_VISIBLE_DEVICES'] = options["general"]['gpuid']

    torch.manual_seed(options["general"]['random_seed'])

    # Create the model.
    if options['general']['use_plus']:
        model = resnet152_plus(options['general']['class_num'])
    else:
        model = resnet152(options['general']['class_num'])
    if 'R' in options['general'].keys():
        model = resnet152_R(options['general']['class_num'])
    pretrained_dict = torch.load(options['general']['pretrainedmodelpath'])
    # load only exists weights
    model_dict = model.state_dict()
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if
                       k in model_dict.keys() and v.size() == model_dict[k].size()}
    print('matched keys:', len(pretrained_dict))
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    tester = Validator(options, 'test',model,options['validation']['saves'],args)

    result, re_all = tester()
    print (tester.savenpy)
    print('-' * 21)
    print('All acc:' + str(re_all))
    print('{:<10}|{:>10}'.format('Cls #', 'Accuracy'))
    for i in range(result.shape[0]):
        print('{:<10}|{:>10}'.format(i, result[i]))
    print('-' * 21)
torch.manual_seed(options["general"]['random_seed'])

# Create the model.
if options['general']['use_3d']:
    model = Dense3D(options)  ##TODO:1
elif options['general']['use_slice']:
    if options['general']['use_plus']:
        model = resnet152_plus(options['general']['class_num'],
                               asinput=options['general']['plus_as_input'],
                               USE_25D=options['general']['use25d'])
    else:
        model = resnet152(options['general']['class_num'],
                          USE_25D=options['general']
                          ['use25d'])  # vgg19_bn(2)#squeezenet1_1(2)
    if 'R' in options['general'].keys():
        model = resnet152_R(options['general']['class_num'])
else:
    model = densenet161(2)

if (options["general"]["loadpretrainedmodel"]):
    # remove paralle module
    pretrained_dict = torch.load(options["general"]["pretrainedmodelpath"])
    # load only exists weights
    model_dict = model.state_dict()
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items()
        if k in model_dict.keys() and v.size() == model_dict[k].size()
    }
    print('matched keys:', len(pretrained_dict))
    model_dict.update(pretrained_dict)
def main():
    parser = argparse.ArgumentParser()
    mod='AB-in'
    parser.add_argument("-m", "--maskpath", help="A list of paths for lung segmentation data",  # type=list,
                        default=['/mnt/data6/CAP/seg_test',
                                 '/mnt/data7/ILD/resampled_seg',
                                 # '/mnt/data7/examples/seg',
                                 # '/mnt/data7/reader_ex/resampled_seg',
                                 # '/mnt/data7/LIDC/resampled_seg',
                                 '/mnt/data7/resampled_seg/test1', '/mnt/data7/resampled_seg/test2',
                                 '/mnt/data7/resampled_seg/test3'
                                 # '/mnt/data7/slice_test_seg/mask_re',
                                 # '/mnt/data7/resampled_seg/test3']
                                 ])
    parser.add_argument("-i", "--imgpath", help="A list of paths for image data",
                        default=['/mnt/data6/CAP/data_test',
                                 '/mnt/data7/ILD/resampled_data',
                                 # '/mnt/data7/examples/data',
                                 # '/mnt/data7/reader_ex/resampled_data',
                                 # '/mnt/data7/LIDC/resampled_data',
                                 '/mnt/data7/resampled_data/test1', '/mnt/data7/resampled_data/test2',
                                 '/mnt/data7/resampled_data/test3'
                                 # '/mnt/data7/slice_test_seg/data_re',
                                 # '/mnt/data7/resampled_data/resampled_test_3']
                                 ])
    parser.add_argument("-o", "--savenpy", help="A path to save record", type=str,
                        #default='re/reader_healthy_vs_ill.npy')
                        #default = 're/reader_cap_vs_covid.npy')
                        default = 're/reader_influenza_vs_covid.npy')
                        #default='re/test_2.npy')
    parser.add_argument("-d", "--deepsave", help="A path to save deepfeature", type=str,
                        # default='re/cap_vs_covid.npy')
                        default='deep_f')
    parser.add_argument("-e", "--exclude_list",
                        help="A path to a txt file for excluded data list. If no file need to be excluded, "
                             "it should be 'none'.", type=str,
                        default='none')
    parser.add_argument("-v", "--invert_exclude", help="Whether to invert exclude to include", type=bool,
                        default=False)
    parser.add_argument("-p", "--model_path", help="Whether to invert exclude to include", type=str,
                        default='weights/new_4cls_pure.pt')
    # default='weights/healthy_or_not.pt')
    parser.add_argument("-g", "--gpuid", help="gpuid", type=str,
                        default='1')
    args = parser.parse_args()
    os.makedirs(args.deepsave, exist_ok=True)

    print("Loading options...")
    with open('test.toml', 'r') as optionsFile:
        options = toml.loads(optionsFile.read())

    if (options["general"]["usecudnnbenchmark"] and options["general"]["usecudnn"]):
        print("Running cudnn benchmark...")
        torch.backends.cudnn.benchmark = True

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpuid
    if isinstance(args.imgpath,str):
        args.imgpath=eval(args.imgpath)
        args.maskpath=eval(args.maskpath)
    torch.manual_seed(options["general"]['random_seed'])

    # Create the model.
    if options['general']['use_plus']:
        model = resnet152_plus(options['general']['class_num'])
    else:
        model = resnet152(options['general']['class_num'])
    if 'R' in options['general'].keys():
        model = resnet152_R(options['general']['class_num'])
    pretrained_dict = torch.load(args.model_path)
    # load only exists weights
    model_dict = model.state_dict()
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if
                       k in model_dict.keys() and v.size() == model_dict[k].size()}
    print('matched keys:', len(pretrained_dict))
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    tester = Validator(options, 'test',model,mod,savenpy=args.savenpy)

    result, re_all = tester()
    print (tester.savenpy)
    print('-' * 21)
    print('All acc:' + str(re_all))
    print('{:<10}|{:>10}'.format('Cls #', 'Accuracy'))
    for i in range(result.shape[0]):
        print('{:<10}|{:>10}'.format(i, result[i]))
    print('-' * 21)