def get_model(model): from models import resnet, base if model == 'R32_C10': rnet_checkpoint = 'cv/pretrained/R32_C10/pk_E_164_A_0.923.t7' layer_config = [5, 5, 5] rnet = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=10) agent = resnet.Policy32([1, 1, 1], num_blocks=15) elif model == 'R110_C10': rnet_checkpoint = 'cv/pretrained/R110_C10/pk_E_130_A_0.932.t7' layer_config = [18, 18, 18] rnet = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=10) agent = resnet.Policy32([1, 1, 1], num_blocks=54) elif model == 'R32_C100': rnet_checkpoint = 'cv/pretrained/R32_C100/pk_E_164_A_0.693.t7' layer_config = [5, 5, 5] rnet = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=100) agent = resnet.Policy32([1, 1, 1], num_blocks=15) elif model == 'R110_C100': rnet_checkpoint = 'cv/pretrained/R110_C100/pk_E_160_A_0.723.t7' layer_config = [18, 18, 18] rnet = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=100) agent = resnet.Policy32([1, 1, 1], num_blocks=54) elif model == 'R101_ImgNet': rnet_checkpoint = 'cv/pretrained/R101_ImgNet/ImageNet_R101_224_76.464' layer_config = [3, 4, 23, 3] rnet = resnet.FlatResNet224(base.Bottleneck, layer_config, num_classes=1000) agent = resnet.Policy224([1, 1, 1, 1], num_blocks=33) # elif model=='ResNext_C100': # agent = resnet.Policy32([1,1,1], num_blocks=4*18) # elif model=='ResNext_C10': # agent = resnet.Policy32([1,1,1], num_blocks=4*18) # load pretrained weights into flat ResNet # rnet_checkpoint = torch.load(rnet_checkpoint) # load_weights_to_flatresnet(rnet_checkpoint, rnet) return agent
def get_model(model): from models import resnet, base if model == 'policy_satellite': # agent = resnet.Policy224([1,1,1,1], num_blocks=289, num_feat=20) # agent = resnet.Policy224GRU([1,1,1,1], num_blocks=289, num_feat=128) # agent = resnet.PolicySeq() agent = resnet.Policy2x2([1, 1, 1, 1], num_blocks=4) rnet_hr = None rnet_lr = None elif model == 'R32_C10': layer_config = [5, 5, 5] rnet_hr = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=10) rnet_lr = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=10) agent = resnet.Policy32([1, 1, 1], num_blocks=16) elif model == 'R32_C100': layer_config = [5, 5, 5] rnet_hr = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=100) rnet_lr = resnet.FlatResNet32(base.BasicBlock, layer_config, num_classes=100) agent = resnet.Policy32([1, 1, 1], num_blocks=16) elif model == 'R50_ImgNet': agent = resnet.Policy224([1, 1, 1, 1], num_blocks=16) """ High Res. Classifier """ rnet_hr = torchmodels.resnet50(pretrained=False) set_parameter_requires_grad(rnet_hr, False) num_ftrs = rnet_hr.fc.in_features rnet_hr.fc = torch.nn.Linear(num_ftrs, 1000) """ Low Res. Classifier """ rnet_lr = torchmodels.resnet50(pretrained=False) set_parameter_requires_grad(rnet_lr, False) num_ftrs = rnet_lr.fc.in_features rnet_lr.fc = torch.nn.Linear(num_ftrs, 1000) elif model == 'R34_fMoW': agent = resnet.Policy224([1, 1, 1, 1], num_blocks=16) """ High Res. Classifier """ rnet_hr = torchmodels.resnet34(pretrained=True) set_parameter_requires_grad(rnet_hr, False) num_ftrs = rnet_hr.fc.in_features rnet_hr.fc = torch.nn.Linear(num_ftrs, 62) """ Low Res. Classifier """ rnet_lr = torchmodels.resnet34(pretrained=True) set_parameter_requires_grad(rnet_lr, False) num_ftrs = rnet_lr.fc.in_features rnet_lr.fc = torch.nn.Linear(num_ftrs, 62) elif model == 'R34_CARS': agent = resnet.Policy224([1, 1, 1, 1], num_blocks=16) """ High Res. Classifier """ rnet_hr = torchmodels.resnet34(pretrained=True) set_parameter_requires_grad(rnet_hr, False) num_ftrs = rnet_hr.fc.in_features rnet_hr.fc = torch.nn.Linear(num_ftrs, 196) """ Low Res. Classifier """ rnet_lr = torchmodels.resnet34(pretrained=True) set_parameter_requires_grad(rnet_lr, False) num_ftrs = rnet_lr.fc.in_features rnet_lr.fc = torch.nn.Linear(num_ftrs, 196) return rnet_hr, rnet_lr, agent
def get_agent(blocks): from models import resnet, base agent = resnet.Policy32([1, 1, 1], num_blocks=blocks) return agent