def get_network(args): """ return given network """ if args.net == 'vgg16': from models.vgg import vgg16_bn net = vgg16_bn() elif args.net == 'vgg13': from models.vgg import vgg13_bn net = vgg13_bn() elif args.net == 'vgg11': from models.vgg import vgg11_bn net = vgg11_bn() elif args.net == 'vgg19': from models.vgg import vgg19_bn net = vgg19_bn() elif args.net == 'densenet121': from models.densenet import densenet121 net = densenet121() elif args.net == 'densenet161': from models.densenet import densenet161 net = densenet161() elif args.net == 'densenet169': from models.densenet import densenet169 net = densenet169() elif args.net == 'densenet201': from models.densenet import densenet201 net = densenet201() elif args.net == 'googlenet': from models.googlenet import googlenet net = googlenet() elif args.net == 'inceptionv3': from models.inceptionv3 import inceptionv3 net = inceptionv3() elif args.net == 'inceptionv4': from models.inceptionv4 import inceptionv4 net = inceptionv4() elif args.net == 'inceptionresnetv2': from models.inceptionv4 import inception_resnet_v2 net = inception_resnet_v2() elif args.net == 'xception': from models.xception import xception net = xception() elif args.net == 'resnet18': from models.resnet import resnet18 net = resnet18() elif args.net == 'resnet34': from models.resnet import resnet34 net = resnet34() elif args.net == 'resnet50': from models.resnet import resnet50 net = resnet50() elif args.net == 'resnet101': from models.resnet import resnet101 net = resnet101() elif args.net == 'resnet152': from models.resnet import resnet152 net = resnet152() elif args.net == 'preactresnet18': from models.preactresnet import preactresnet18 net = preactresnet18() elif args.net == 'preactresnet34': from models.preactresnet import preactresnet34 net = preactresnet34() elif args.net == 'preactresnet50': from models.preactresnet import preactresnet50 net = preactresnet50() elif args.net == 'preactresnet101': from models.preactresnet import preactresnet101 net = preactresnet101() elif args.net == 'preactresnet152': from models.preactresnet import preactresnet152 net = preactresnet152() elif args.net == 'resnext50': from models.resnext import resnext50 net = resnext50() elif args.net == 'resnext101': from models.resnext import resnext101 net = resnext101() elif args.net == 'resnext152': from models.resnext import resnext152 net = resnext152() elif args.net == 'shufflenet': from models.shufflenet import shufflenet net = shufflenet() elif args.net == 'shufflenetv2': from models.shufflenetv2 import shufflenetv2 net = shufflenetv2() elif args.net == 'squeezenet': from models.squeezenet import squeezenet net = squeezenet() elif args.net == 'mobilenet': from models.mobilenet import mobilenet net = mobilenet() elif args.net == 'mobilenetv2': from models.mobilenetv2 import mobilenetv2 net = mobilenetv2() elif args.net == 'nasnet': from models.nasnet import nasnet net = nasnet() elif args.net == 'attention56': from models.attention import attention56 net = attention56() elif args.net == 'attention92': from models.attention import attention92 net = attention92() elif args.net == 'seresnet18': from models.senet import seresnet18 net = seresnet18() elif args.net == 'seresnet34': from models.senet import seresnet34 net = seresnet34() elif args.net == 'seresnet50': from models.senet import seresnet50 net = seresnet50() elif args.net == 'seresnet101': from models.senet import seresnet101 net = seresnet101() elif args.net == 'seresnet152': from models.senet import seresnet152 net = seresnet152() else: print('the network name you have entered is not supported yet') sys.exit() if args.gpu: #use_gpu net = net.cuda() return net
def generate_model(opt): assert opt.model in ['c3d', 'squeezenet', 'mobilenet', 'resnext', 'resnet', 'resnetl', 'shufflenet', 'mobilenetv2', 'shufflenetv2'] if opt.model == 'resnetl': assert opt.model_depth in [10] # 깊이는 10만 된다! from models.resnetl import get_fine_tuning_parameters # 전이학습을 위함. if opt.model_depth == 10: model = resnetl.resnetl10( num_classes=opt.n_classes, # 클래스 개수. shortcut_type=opt.resnet_shortcut, # 디폴트 값 : 'B' sample_size=opt.sample_size, # 디폴트 값 : 112 sample_duration=opt.sample_duration) # 디폴트 값 : 16 , 입력 프레임 elif opt.model == 'resnext': assert opt.model_depth in [50, 101, 152] from models.resnext import get_fine_tuning_parameters if opt.model_depth == 50: model = resnext.resnext50( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 101: model = resnext.resnext101( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 152: model = resnext.resnext152( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality, sample_size=opt.sample_size, sample_duration=opt.sample_duration) if not opt.no_cuda: if opt.gpus == '0': model = model.cuda() else: opt.gpus = opt.local_rank torch.cuda.set_device(opt.gpus) model = model.cuda() #model = nn.DataParallel(model, device_ids=None) # 병렬처리를 위함인데, 안쓸 것 같음. pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) # grad를 하는 파라미터들을 모두 더한다. print("Total number of trainable parameters: ", pytorch_total_params) # 파라미터 값 출력 if opt.pretrain_path: # 전이학습. print('loading pretrained model {}'.format(opt.pretrain_path)) pretrain = torch.load(opt.pretrain_path, map_location=torch.device('cpu')) # print(opt.arch) # print(pretrain['arch']) # assert opt.arch == pretrain['arch'] model = modify_kernels(opt, model, opt.pretrain_modality) model.load_state_dict(pretrain['state_dict']) if opt.model in ['mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2']: model.module.classifier = nn.Sequential( nn.Dropout(0.5), nn.Linear(model.module.classifier[1].in_features, opt.n_finetune_classes)) model.module.classifier = model.module.classifier.cuda() elif opt.model == 'squeezenet': model.module.classifier = nn.Sequential( nn.Dropout(p=0.5), nn.Conv3d(model.module.classifier[1].in_channels, opt.n_finetune_classes, kernel_size=1), nn.ReLU(inplace=True), nn.AvgPool3d((1,4,4), stride=1)) model.module.classifier = model.module.classifier.cuda() else: model.module.fc = nn.Linear(model.module.fc.in_features, opt.n_finetune_classes) model.module.fc = model.module.fc.cuda() model = modify_kernels(opt, model, opt.modality) else: # 전이학습이 아닐때 pass model = modify_kernels(opt, model, opt.modality) parameters = get_fine_tuning_parameters(model, opt.ft_portion) # 전이학습할때만 적용 지금은 그냥 파라미터 그대로 반환됨. return model, parameters
def get_network(args): """ return given network """ if args.task == 'cifar10': nclass = 10 elif args.task == 'cifar100': nclass = 100 #Yang added none bn vggs if args.net == 'vgg16': from models.vgg import vgg16 net = vgg16(num_classes=nclass) elif args.net == 'vgg13': from models.vgg import vgg13 net = vgg13(num_classes=nclass) elif args.net == 'vgg11': from models.vgg import vgg11 net = vgg11(num_classes=nclass) elif args.net == 'vgg19': from models.vgg import vgg19 net = vgg19(num_classes=nclass) elif args.net == 'vgg16bn': from models.vgg import vgg16_bn net = vgg16_bn(num_classes=nclass) elif args.net == 'vgg13bn': from models.vgg import vgg13_bn net = vgg13_bn(num_classes=nclass) elif args.net == 'vgg11bn': from models.vgg import vgg11_bn net = vgg11_bn(num_classes=nclass) elif args.net == 'vgg19bn': from models.vgg import vgg19_bn net = vgg19_bn(num_classes=nclass) elif args.net == 'densenet121': from models.densenet import densenet121 net = densenet121() elif args.net == 'densenet161': from models.densenet import densenet161 net = densenet161() elif args.net == 'densenet169': from models.densenet import densenet169 net = densenet169() elif args.net == 'densenet201': from models.densenet import densenet201 net = densenet201() elif args.net == 'googlenet': from models.googlenet import googlenet net = googlenet(num_classes=nclass) elif args.net == 'inceptionv3': from models.inceptionv3 import inceptionv3 net = inceptionv3() elif args.net == 'inceptionv4': from models.inceptionv4 import inceptionv4 net = inceptionv4() elif args.net == 'inceptionresnetv2': from models.inceptionv4 import inception_resnet_v2 net = inception_resnet_v2() elif args.net == 'xception': from models.xception import xception net = xception(num_classes=nclass) elif args.net == 'scnet': from models.sphereconvnet import sphereconvnet net = sphereconvnet(num_classes=nclass) elif args.net == 'sphereresnet18': from models.sphereconvnet import resnet18 net = resnet18(num_classes=nclass) elif args.net == 'sphereresnet32': from models.sphereconvnet import sphereresnet32 net = sphereresnet32(num_classes=nclass) elif args.net == 'plainresnet32': from models.sphereconvnet import plainresnet32 net = plainresnet32(num_classes=nclass) elif args.net == 'ynet18': from models.ynet import resnet18 net = resnet18(num_classes=nclass) elif args.net == 'ynet34': from models.ynet import resnet34 net = resnet34(num_classes=nclass) elif args.net == 'ynet50': from models.ynet import resnet50 net = resnet50(num_classes=nclass) elif args.net == 'ynet101': from models.ynet import resnet101 net = resnet101(num_classes=nclass) elif args.net == 'ynet152': from models.ynet import resnet152 net = resnet152(num_classes=nclass) elif args.net == 'resnet18': from models.resnet import resnet18 net = resnet18(num_classes=nclass) elif args.net == 'resnet34': from models.resnet import resnet34 net = resnet34(num_classes=nclass) elif args.net == 'resnet50': from models.resnet import resnet50 net = resnet50(num_classes=nclass) elif args.net == 'resnet101': from models.resnet import resnet101 net = resnet101(num_classes=nclass) elif args.net == 'resnet152': from models.resnet import resnet152 net = resnet152(num_classes=nclass) elif args.net == 'preactresnet18': from models.preactresnet import preactresnet18 net = preactresnet18(num_classes=nclass) elif args.net == 'preactresnet34': from models.preactresnet import preactresnet34 net = preactresnet34(num_classes=nclass) elif args.net == 'preactresnet50': from models.preactresnet import preactresnet50 net = preactresnet50(num_classes=nclass) elif args.net == 'preactresnet101': from models.preactresnet import preactresnet101 net = preactresnet101(num_classes=nclass) elif args.net == 'preactresnet152': from models.preactresnet import preactresnet152 net = preactresnet152(num_classes=nclass) elif args.net == 'resnext50': from models.resnext import resnext50 net = resnext50(num_classes=nclass) elif args.net == 'resnext101': from models.resnext import resnext101 net = resnext101(num_classes=nclass) elif args.net == 'resnext152': from models.resnext import resnext152 net = resnext152(num_classes=nclass) elif args.net == 'shufflenet': from models.shufflenet import shufflenet net = shufflenet() elif args.net == 'shufflenetv2': from models.shufflenetv2 import shufflenetv2 net = shufflenetv2() elif args.net == 'squeezenet': from models.squeezenet import squeezenet net = squeezenet() elif args.net == 'mobilenet': from models.mobilenet import mobilenet net = mobilenet(num_classes=nclass) elif args.net == 'mobilenetv2': from models.mobilenetv2 import mobilenetv2 net = mobilenetv2(num_classes=nclass) elif args.net == 'nasnet': from models.nasnet import nasnet net = nasnet(num_classes=nclass) elif args.net == 'attention56': from models.attention import attention56 net = attention56() elif args.net == 'attention92': from models.attention import attention92 net = attention92() elif args.net == 'seresnet18': from models.senet import seresnet18 net = seresnet18(num_classes=nclass) elif args.net == 'seresnet34': from models.senet import seresnet34 net = seresnet34(num_classes=nclass) elif args.net == 'seresnet50': from models.senet import seresnet50 net = seresnet50(num_classes=nclass) elif args.net == 'seresnet101': from models.senet import seresnet101 net = seresnet101(num_classes=nclass) elif args.net == 'seresnet152': from models.senet import seresnet152 net = seresnet152(num_classes=nclass) else: print('the network name you have entered is not supported yet') sys.exit() if args.gpu: #use_gpu net = net.cuda() return net
def generate_model(opt): assert opt.model in ['xcresnet', 'resnet', 'resnext', 'i6f_resnet'] if opt.model == 'xcresnet': assert opt.model_depth in [10, 18, 34, 50, 101, 152] from models.x_channel_resnet import get_fine_tuning_parameters if opt.model_depth == 10: model = x_channel_resnet.xcresnet10(num_classes=opt.n_classes, image_nums=opt.sample_duration) elif opt.model_depth == 18: model = x_channel_resnet.xcresnet18(num_classes=opt.n_classes, image_nums=opt.sample_duration) elif opt.model_depth == 34: model = x_channel_resnet.xcresnet34(num_classes=opt.n_classes, image_nums=opt.sample_duration) elif opt.model_depth == 50: model = x_channel_resnet.xcresnet50(num_classes=opt.n_classes, image_nums=opt.sample_duration) elif opt.model_depth == 101: model = x_channel_resnet.xcresnet101( num_classes=opt.n_classes, image_nums=opt.sample_duration) elif opt.model_depth == 152: model = x_channel_resnet.xcresnet152( num_classes=opt.n_classes, image_nums=opt.sample_duration) elif opt.model == 'resnet': assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200] from models.resnet import get_fine_tuning_parameters if opt.model_depth == 10: model = resnet.resnet10(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 18: model = resnet.resnet18(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 34: model = resnet.resnet34(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 50: model = resnet.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 101: model = resnet.resnet101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 152: model = resnet.resnet152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 200: model = resnet.resnet200(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model == 'resnext': assert opt.model_depth in [50, 101, 152] from models.resnext import get_fine_tuning_parameters if opt.model_depth == 50: model = resnext.resnext50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 101: model = resnext.resnext101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 152: model = resnext.resnext152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model == 'i6f_resnet': assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200] #from models.resnet import get_fine_tuning_parameters if opt.model_depth == 10: model = i6f_resnet.i6f_resnet10( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 18: model = i6f_resnet.i6f_resnet18( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 34: model = i6f_resnet.i6f_resnet34( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 50: model = i6f_resnet.i6f_resnet50( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 101: model = i6f_resnet.i6f_resnet101( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 152: model = i6f_resnet.i6f_resnet152( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) elif opt.model_depth == 200: model = i6f_resnet.i6f_resnet200( num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, sample_size=opt.sample_size, sample_duration=opt.sample_duration) if not opt.no_cuda: model = model.cuda() #model = nn.DataParallel(model,device_ids=None) if opt.pretrain_path: print('loading pretrained model {}'.format(opt.pretrain_path)) pretrain = torch.load(opt.pretrain_path) assert opt.arch == pretrain['arch'] model.load_state_dict(pretrain['state_dict']) model.fc = nn.Linear(model.fc.in_features, opt.n_finetune_classes) model.fc = model.fc.cuda() parameters = get_fine_tuning_parameters(model, opt.ft_begin_index) model = nn.DataParallel(model, device_ids=None) return model, parameters else: if opt.pretrain_path: print('loading pretrained model {}'.format(opt.pretrain_path)) pretrain = torch.load(opt.pretrain_path) assert opt.arch == pretrain['arch'] model.load_state_dict(pretrain['state_dict']) modele.fc = nn.Linear(model.module.fc.in_features, opt.n_finetune_classes) parameters = get_fine_tuning_parameters(model, opt.ft_begin_index) return model, parameters model = nn.DataParallel(model, device_ids=None) return model, model.parameters()