with open('options_slice.toml', 'r') as optionsFile:
    # with open('options_lip.toml', 'r') as optionsFile:
    options = toml.loads(optionsFile.read())

if (options["general"]["usecudnnbenchmark"]
        and options["general"]["usecudnn"]):
    print("Running cudnn benchmark...")
    torch.backends.cudnn.benchmark = True

os.environ['CUDA_VISIBLE_DEVICES'] = options["general"]['gpuid']

torch.manual_seed(options["general"]['random_seed'])

# Create the model.
if options['general']['use_3d']:
    model = Dense3D(options)  ##TODO:1
elif options['general']['use_slice']:
    if options['general']['use_plus']:
        model = resnet152_plus(options['general']['class_num'],
                               asinput=options['general']['plus_as_input'],
                               USE_25D=options['general']['use25d'])
    else:
        model = resnet152(options['general']['class_num'],
                          USE_25D=options['general']
                          ['use25d'])  # vgg19_bn(2)#squeezenet1_1(2)
    if 'R' in options['general'].keys():
        model = resnet152_R(options['general']['class_num'])
else:
    model = densenet161(2)

if (options["general"]["loadpretrainedmodel"]):
예제 #2
0
print("Loading options...")
with open(sys.argv[1], 'r') as optionsFile:
    options = toml.loads(optionsFile.read())

if (options["general"]["usecudnnbenchmark"]
        and options["general"]["usecudnn"]):
    print("Running cudnn benchmark...")
    torch.backends.cudnn.benchmark = True

os.environ['CUDA_VISIBLE_DEVICES'] = options["general"]['gpuid']

torch.manual_seed(options["general"]['random_seed'])

#Create the model.
model = Dense3D(options)

if (options["general"]["loadpretrainedmodel"]):
    # remove paralle module
    pretrained_dict = torch.load(options["general"]["pretrainedmodelpath"])
    # load only exists weights
    model_dict = model.state_dict()
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items()
        if k in model_dict.keys() and v.size() == model_dict[k].size()
    }
    print('matched keys:', len(pretrained_dict))
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)