示例#1
0
def get_model(model):

    from models import resnet, base

    if model == 'R32_C10':
        rnet_checkpoint = 'cv/pretrained/R32_C10/pk_E_164_A_0.923.t7'
        layer_config = [5, 5, 5]
        rnet = resnet.FlatResNet32(base.BasicBlock,
                                   layer_config,
                                   num_classes=10)
        agent = resnet.Policy32([1, 1, 1], num_blocks=15)

    elif model == 'R110_C10':
        rnet_checkpoint = 'cv/pretrained/R110_C10/pk_E_130_A_0.932.t7'
        layer_config = [18, 18, 18]
        rnet = resnet.FlatResNet32(base.BasicBlock,
                                   layer_config,
                                   num_classes=10)
        agent = resnet.Policy32([1, 1, 1], num_blocks=54)

    elif model == 'R32_C100':
        rnet_checkpoint = 'cv/pretrained/R32_C100/pk_E_164_A_0.693.t7'
        layer_config = [5, 5, 5]
        rnet = resnet.FlatResNet32(base.BasicBlock,
                                   layer_config,
                                   num_classes=100)
        agent = resnet.Policy32([1, 1, 1], num_blocks=15)

    elif model == 'R110_C100':
        rnet_checkpoint = 'cv/pretrained/R110_C100/pk_E_160_A_0.723.t7'
        layer_config = [18, 18, 18]
        rnet = resnet.FlatResNet32(base.BasicBlock,
                                   layer_config,
                                   num_classes=100)
        agent = resnet.Policy32([1, 1, 1], num_blocks=54)

    elif model == 'R101_ImgNet':
        rnet_checkpoint = 'cv/pretrained/R101_ImgNet/ImageNet_R101_224_76.464'
        layer_config = [3, 4, 23, 3]
        rnet = resnet.FlatResNet224(base.Bottleneck,
                                    layer_config,
                                    num_classes=1000)
        agent = resnet.Policy224([1, 1, 1, 1], num_blocks=33)

    # elif model=='ResNext_C100':
    #     agent = resnet.Policy32([1,1,1], num_blocks=4*18)
    # elif model=='ResNext_C10':
    #     agent = resnet.Policy32([1,1,1], num_blocks=4*18)

    # load pretrained weights into flat ResNet
    # rnet_checkpoint = torch.load(rnet_checkpoint)
    # load_weights_to_flatresnet(rnet_checkpoint, rnet)

    return agent
示例#2
0
def get_model(model):

    from models import resnet, base

    if model == 'policy_satellite':
        # agent = resnet.Policy224([1,1,1,1], num_blocks=289, num_feat=20)
        # agent = resnet.Policy224GRU([1,1,1,1], num_blocks=289, num_feat=128)
        # agent = resnet.PolicySeq()
        agent = resnet.Policy2x2([1, 1, 1, 1], num_blocks=4)
        rnet_hr = None
        rnet_lr = None

    elif model == 'R32_C10':
        layer_config = [5, 5, 5]
        rnet_hr = resnet.FlatResNet32(base.BasicBlock,
                                      layer_config,
                                      num_classes=10)
        rnet_lr = resnet.FlatResNet32(base.BasicBlock,
                                      layer_config,
                                      num_classes=10)
        agent = resnet.Policy32([1, 1, 1], num_blocks=16)

    elif model == 'R32_C100':
        layer_config = [5, 5, 5]
        rnet_hr = resnet.FlatResNet32(base.BasicBlock,
                                      layer_config,
                                      num_classes=100)
        rnet_lr = resnet.FlatResNet32(base.BasicBlock,
                                      layer_config,
                                      num_classes=100)
        agent = resnet.Policy32([1, 1, 1], num_blocks=16)

    elif model == 'R50_ImgNet':
        agent = resnet.Policy224([1, 1, 1, 1], num_blocks=16)
        """ High Res. Classifier """
        rnet_hr = torchmodels.resnet50(pretrained=False)
        set_parameter_requires_grad(rnet_hr, False)
        num_ftrs = rnet_hr.fc.in_features
        rnet_hr.fc = torch.nn.Linear(num_ftrs, 1000)
        """ Low Res. Classifier """
        rnet_lr = torchmodels.resnet50(pretrained=False)
        set_parameter_requires_grad(rnet_lr, False)
        num_ftrs = rnet_lr.fc.in_features
        rnet_lr.fc = torch.nn.Linear(num_ftrs, 1000)

    elif model == 'R34_fMoW':
        agent = resnet.Policy224([1, 1, 1, 1], num_blocks=16)
        """ High Res. Classifier """
        rnet_hr = torchmodels.resnet34(pretrained=True)
        set_parameter_requires_grad(rnet_hr, False)
        num_ftrs = rnet_hr.fc.in_features
        rnet_hr.fc = torch.nn.Linear(num_ftrs, 62)
        """ Low Res. Classifier """
        rnet_lr = torchmodels.resnet34(pretrained=True)
        set_parameter_requires_grad(rnet_lr, False)
        num_ftrs = rnet_lr.fc.in_features
        rnet_lr.fc = torch.nn.Linear(num_ftrs, 62)

    elif model == 'R34_CARS':
        agent = resnet.Policy224([1, 1, 1, 1], num_blocks=16)
        """ High Res. Classifier """
        rnet_hr = torchmodels.resnet34(pretrained=True)
        set_parameter_requires_grad(rnet_hr, False)
        num_ftrs = rnet_hr.fc.in_features
        rnet_hr.fc = torch.nn.Linear(num_ftrs, 196)
        """ Low Res. Classifier """
        rnet_lr = torchmodels.resnet34(pretrained=True)
        set_parameter_requires_grad(rnet_lr, False)
        num_ftrs = rnet_lr.fc.in_features
        rnet_lr.fc = torch.nn.Linear(num_ftrs, 196)

    return rnet_hr, rnet_lr, agent
示例#3
0
def get_agent(blocks):
    from models import resnet, base
    agent = resnet.Policy32([1, 1, 1], num_blocks=blocks)

    return agent