def vgg13_bn(num_classes=1000, pretrained='imagenet'):
    """VGG 13-layer model (configuration "B") with batch normalization
    """
    model = models.vgg13_bn(pretrained=False)
    if pretrained is not None:
        settings = pretrained_settings['vgg13_bn'][pretrained]
        model = load_pretrained(model, num_classes, settings)
    return model
criterion = nn.CrossEntropyLoss()
# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

model_tensorboard = "vgg13"
model_ft = train_model(model_ft,
                       criterion,
                       optimizer_ft,
                       exp_lr_scheduler,
                       num_epochs=25)

model_ft = models.vgg13_bn(pretrained=True)

model_ft.classifier[6] = nn.Linear(4096, len(class_names))

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()
# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

model_tensorboard = "vgg13_bn"
model_ft = train_model(model_ft,
                       criterion,
Exemplo n.º 3
0
    if not os.path.exists(os.path.join(data_dir, test_dir, 'unknown')):
        os.mkdir(os.path.join(data_dir, test_dir, 'unknown'))
        for root, dirs, files in os.walk(os.path.join(data_dir, test_dir),
                                         topdown=False):
            for file in files:
                shutil.move(os.path.join(root, file),
                            os.path.join(root, 'unknown', file))

    pretrained = True

    alexnet = models.alexnet(pretrained=pretrained)

    vgg11 = models.vgg11(pretrained=pretrained)
    vgg11_bn = models.vgg11_bn(pretrained=pretrained)
    vgg13 = models.vgg13(pretrained=pretrained)
    vgg13_bn = models.vgg13_bn(pretrained=pretrained)
    vgg16 = models.vgg16(pretrained=pretrained)
    vgg16_bn = models.vgg16_bn(pretrained=pretrained)
    vgg19 = models.vgg19(pretrained=pretrained)
    vgg19_bn = models.vgg19_bn(pretrained=pretrained)

    resnet18 = models.resnet18(pretrained=pretrained)
    resnet34 = models.resnet34(pretrained=pretrained)
    resnet50 = models.resnet50(pretrained=pretrained)
    resnet101 = models.resnet101(pretrained=pretrained)
    resnet152 = models.resnet152(pretrained=pretrained)

    squeezenet1_0 = models.squeezenet1_0(pretrained=pretrained)
    squeezenet1_1 = models.squeezenet1_1(pretrained=pretrained)

    densenet121 = models.densenet121(pretrained=pretrained)
Exemplo n.º 4
0
    def build(self):
        # Transform for input images.
        input_size = 299 if self.predictor_name == 'inception_v3' else 224
        self.transform = T.Compose([
            T.ToPILImage(),
            T.Resize((input_size, input_size)),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        if self.predictor_name == 'alexnet':
            model = M.alexnet(pretrained=True)
        elif self.predictor_name == 'vgg11':
            model = M.vgg11(pretrained=True)
        elif self.predictor_name == 'vgg13':
            model = M.vgg13(pretrained=True)
        elif self.predictor_name == 'vgg16':
            model = M.vgg16(pretrained=True)
        elif self.predictor_name == 'vgg19':
            model = M.vgg19(pretrained=True)
        elif self.predictor_name == 'vgg11_bn':
            model = M.vgg11_bn(pretrained=True)
        elif self.predictor_name == 'vgg13_bn':
            model = M.vgg13_bn(pretrained=True)
        elif self.predictor_name == 'vgg16_bn':
            model = M.vgg16_bn(pretrained=True)
        elif self.predictor_name == 'vgg19_bn':
            model = M.vgg19_bn(pretrained=True)
        elif self.predictor_name == 'googlenet':
            model = M.googlenet(pretrained=True, aux_logits=False)
        elif self.predictor_name == 'inception_v3':
            model = M.inception_v3(pretrained=True, aux_logits=False)
        elif self.predictor_name == 'resnet18':
            model = M.resnet18(pretrained=True)
        elif self.predictor_name == 'resnet34':
            model = M.resnet34(pretrained=True)
        elif self.predictor_name == 'resnet50':
            model = M.resnet50(pretrained=True)
        elif self.predictor_name == 'resnet101':
            model = M.resnet101(pretrained=True)
        elif self.predictor_name == 'resnet152':
            model = M.resnet152(pretrained=True)
        elif self.predictor_name == 'resnext50':
            model = M.resnext50_32x4d(pretrained=True)
        elif self.predictor_name == 'resnext101':
            model = M.resnext101_32x8d(pretrained=True)
        elif self.predictor_name == 'wideresnet50':
            model = M.wide_resnet50_2(pretrained=True)
        elif self.predictor_name == 'wideresnet101':
            model = M.wide_resnet101_2(pretrained=True)
        elif self.predictor_name == 'densenet121':
            model = M.densenet121(pretrained=True)
        elif self.predictor_name == 'densenet169':
            model = M.densenet169(pretrained=True)
        elif self.predictor_name == 'densenet201':
            model = M.densenet201(pretrained=True)
        elif self.predictor_name == 'densenet161':
            model = M.densenet161(pretrained=True)
        else:
            raise NotImplementedError(f'Unsupported architecture '
                                      f'`{self.predictor_name}`!')

        model.eval()

        if self.imagenet_logits:
            self.net = model
            self.feature_dim = (1000, )
            return

        if self.architecture_type == 'AlexNet':
            layers = list(model.features.children())
            if not self.spatial_feature:
                layers.append(nn.Flatten())
                self.feature_dim = (256 * 6 * 6, )
            else:
                self.feature_dim = (256, 6, 6)
        elif self.architecture_type == 'VGG':
            layers = list(model.features.children())
            if not self.spatial_feature:
                layers.append(nn.Flatten())
                self.feature_dim = (512 * 7 * 7, )
            else:
                self.feature_dim = (512, 7, 7)
        elif self.architecture_type == 'Inception':
            if self.predictor_name == 'googlenet':
                final_res = 7
                num_channels = 1024
                layers = list(model.children())[:-3]
            elif self.predictor_name == 'inception_v3':
                final_res = 8
                num_channels = 2048
                layers = list(model.children())[:-1]
                layers.insert(3, nn.MaxPool2d(kernel_size=3, stride=2))
                layers.insert(6, nn.MaxPool2d(kernel_size=3, stride=2))
            else:
                raise NotImplementedError(
                    f'Unsupported Inception architecture '
                    f'`{self.predictor_name}`!')
            if not self.spatial_feature:
                layers.append(nn.AdaptiveAvgPool2d((1, 1)))
                layers.append(nn.Flatten())
                self.feature_dim = (num_channels, )
            else:
                self.feature_dim = (num_channels, final_res, final_res)
        elif self.architecture_type == 'ResNet':
            if self.predictor_name in ['resnet18', 'resnet34']:
                num_channels = 512
            elif self.predictor_name in [
                    'resnet50', 'resnet101', 'resnet152', 'resnext50',
                    'resnext101', 'wideresnet50', 'wideresnet101'
            ]:
                num_channels = 2048
            else:
                raise NotImplementedError(f'Unsupported ResNet architecture '
                                          f'`{self.predictor_name}`!')
            if not self.spatial_feature:
                layers = list(model.children())[:-1]
                layers.append(nn.Flatten())
                self.feature_dim = (num_channels, )
            else:
                layers = list(model.children())[:-2]
                self.feature_dim = (num_channels, 7, 7)
        elif self.architecture_type == 'DenseNet':
            if self.predictor_name == 'densenet121':
                num_channels = 1024
            elif self.predictor_name == 'densenet169':
                num_channels = 1664
            elif self.predictor_name == 'densenet201':
                num_channels = 1920
            elif self.predictor_name == 'densenet161':
                num_channels = 2208
            else:
                raise NotImplementedError(f'Unsupported DenseNet architecture '
                                          f'`{self.predictor_name}`!')
            layers = list(model.features.children())
            if not self.spatial_feature:
                layers.append(nn.ReLU(inplace=True))
                layers.append(nn.AdaptiveAvgPool2d((1, 1)))
                layers.append(nn.Flatten())
                self.feature_dim = (num_channels, )
            else:
                self.feature_dim = (num_channels, 7, 7)
        else:
            raise NotImplementedError(f'Unsupported architecture type '
                                      f'`{self.architecture_type}`!')
        self.net = nn.Sequential(*layers)
Exemplo n.º 5
0
def main():
    cudnn.benchmark = True
    global args, best_prec1
    best_error = np.Inf
    args = parser.parse_args()

    today = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M")


    dir_name = args.name + '_' + args.dataset + '_' + 'topcls-' + str(args.topn_class) + \
                '_' + args.model + '_' + 'arg-' + str(args.augment) + \
                '_wtclass-' + str(args.weight_class) + '_imgnetnorm-' + str(args.imagenet) + \
                '_wtdecay-' + str(args.weight_decay) + \
                '_drop-' + str(args.droprate) + '_lr' + str(args.lr) + \
                '_decay-' + str(args.decay_every) + '-' + str(args.lr_decay) + \
                '_' + today

    args_dict = vars(args)
    args_dict['time'] = today
    args_dict['dir_name'] = dir_name

    if args.test:
        dir_name = args.test
        with open('../runs/%s/argparse.json' % (args.trained_model), 'r') as f:
            args_dict = json.load(f)
        args.model = args_dict['model']

    if args.tensorboard and not args.test:
        configure("../runs/%s" % (dir_name))
        with open('../runs/%s/argparse.json' % (dir_name), 'w') as f:
            json.dump(args_dict, f)

    root_dir = args.datapath
    if args.dataset == 'rodent_256_scale':
        data_dir = 'png_mip_256_fit_2d'
    else:
        raise ValueError('Unknown dataset.')

    classes = np.arange(args.topn_class)  # [0,1,...,5]
    metadata = pd.read_pickle(
        '../data/rodent_3d_dendrites_br-ct-filter-3_all_mainclasses_use_filter.pkl'
    )
    metadata = metadata[metadata['label1_id'].isin(classes)]
    neuron_ids = metadata['neuron_id'].values
    labels = metadata[
        'label1_id'].values  # contain the same set of values as classes
    unique, counts = np.unique(labels, return_counts=True)

    if args.imagenet == True:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
    else:
        normalize = transforms.Lambda(lambda x: x)

    if args.augment:
        transform_train = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            #transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        transform_train = transforms.Compose(
            [transforms.ToTensor(), normalize])
    transform_test = transforms.Compose([transforms.ToTensor(), normalize])


    train_ids, test_ids, train_y, test_y = \
        train_test_split(neuron_ids, labels, test_size=0.15, random_state=42, stratify=labels)

    train_ids, val_ids, train_y, val_y = \
        train_test_split(train_ids, train_y, test_size=0.15, random_state=42, stratify=train_y)

    kwargs = {'num_workers': 1, 'pin_memory': True}
    train_loader = torch.utils.data.DataLoader(NeuroMorpho(
        root_dir,
        data_dir,
        train_ids,
        train_y,
        img_size=256,
        transform=transform_train,
        rgb=args.rgb),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)

    val_loader = torch.utils.data.DataLoader(NeuroMorpho(
        root_dir,
        data_dir,
        val_ids,
        val_y,
        img_size=256,
        transform=transform_test,
        rgb=args.rgb),
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             **kwargs)

    test_loader = torch.utils.data.DataLoader(NeuroMorpho(
        root_dir,
        data_dir,
        test_ids,
        test_y,
        img_size=256,
        transform=transform_test,
        rgb=args.rgb),
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              **kwargs)

    def vgg_classifier_8():
        classifier = nn.Sequential(
            nn.Linear(512 * 8 * 8, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, args.topn_class),
        )
        return classifier

    if args.model == 'vggplus1':
        model = VGGplus1(num_classes=args.topn_class)
    elif args.model == 'resnet18':
        model = models.resnet18(num_classes=args.topn_class, pretrained=False)
    elif args.model == 'resnet34':
        model = models.resnet34(num_classes=args.topn_class, pretrained=False)
    elif args.model == 'vgg13bn':
        model = models.vgg13_bn(num_classes=args.topn_class, pretrained=False)
    elif args.model == 'vgg16bn':
        model = models.vgg16_bn(num_classes=args.topn_class, pretrained=False)
    elif args.model == 'resnet18_pretrained_tuneall':
        model = models.resnet18(pretrained=True)
        model.fc = nn.Linear(512 * 4,
                             args.topn_class)  # require_grad=True by default
    elif args.model == 'resnet34_pretrained_tuneall':
        model = models.resnet34(pretrained=True)
        model.fc = nn.Linear(512 * 4, args.topn_class)
    elif args.model == 'resnet18_pretrained_tunelast':
        model = models.resnet18(pretrained=True)
        for param in model.parameters():
            param.requires_grad = False
        model.fc = nn.Linear(512 * 4,
                             args.topn_class)  # require_grad=True by default
    elif args.model == 'resnet34_pretrained_tunelast':
        model = models.resnet34(pretrained=True)
        for param in model.parameters():
            param.requires_grad = False
        model.fc = nn.Linear(512 * 4, args.topn_class)
    elif args.model == 'vgg13bn_pretrained_tuneall':
        model = models.vgg13_bn(pretrained=True)
        model.classifier = None
        model.classifier = vgg_classifier_8()

    elif args.model == 'vgg13bn_pretrained_tunelast':
        # This actually do not work since the input size for the classifier is different.
        model = models.vgg13_bn(pretrained=True)
        for param in model.parameters():
            param.requires_grad = False
        mod = list(model.classifier.children())
        _ = mod.pop()
        mod.append(nn.Linear(4096, args.topn_class))
        model.classifier = torch.nn.Sequential(*mod)
    elif args.model == 'vgg13bn_pretrained_tuneclassifier':
        model = models.vgg13_bn(pretrained=True)
        for param in model.parameters():
            param.requires_grad = False
        model.classifier = vgg_classifier_8()
    #elif args.model == 'wide_resnet':
    #    model = WideResNet(14, num_classes)
    else:
        raise ValueError('Unknown model type.')

    if args.test:
        model.load_state_dict(
            torch.load("../runs/%s/model_best.pth.tar" %
                       (args.trained_model))['state_dict'])

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    if use_gpu:
        model = model.cuda()

    if "tunelast" in args.model:
        if 'vgg' in args.model:
            optimizer = torch.optim.Adam(list(
                model.classifier.children())[-1].parameters(),
                                         lr=args.lr,
                                         weight_decay=args.weight_decay)
        elif 'resnet' in args.model:
            optimizer = torch.optim.Adam(model.fc.parameters(),
                                         lr=args.lr,
                                         weight_decay=args.weight_decay)
        else:
            raise ValueError("Unknown model type for tuning the last layer.")
    elif "tuneclassifier" in args.model:
        if 'vgg' in args.model:
            optimizer = torch.optim.Adam(model.classifier.parameters(),
                                         lr=args.lr,
                                         weight_decay=args.weight_decay)
        else:
            raise ValueError("Unknown model type for tuning the classifier.")
    else:
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)

    if not args.weight_class is None:
        if args.weight_class == 'linear':
            weight_dict = dict(
                zip(unique,
                    counts.max() / counts.astype('float')))
        elif args.weight_class == 'log2':
            weight_dict = dict(
                zip(unique,
                    np.log2(counts.max() / counts.astype('float') + 1)))
        else:
            raise ValueError("Unknown class weight method.")
        print("Class Weight: " +
              " ".join(['%d: %.2f' % (k, v) for k, v in weight_dict.items()]))
        weight_tensor = torch.FloatTensor([weight_dict[i] for i in classes])
        criterion = nn.CrossEntropyLoss(weight=weight_tensor).cuda()
    else:
        criterion = nn.CrossEntropyLoss().cuda()

    lr = args.lr

    if args.test is True:
        print("Testing model: " + args.trained_model)
        val_loss, val_acc_all, val_acc_average, val_acc_each, target_list, pred_list = \
                validate(test_loader, model, criterion, epoch=0, classes=classes)
        return  # skip training

    for epoch in range(args.epochs):
        print('\nEpoch: {0} '.format(epoch) + '\t lr: ' +
              '{0:.3g}\n'.format(lr))
        t_before_epoch = time.time()
        lr = adjust_learning_rate(optimizer, lr, epoch, args.lr_decay,
                                  args.decay_every)

        train_loss, train_acc_all, train_acc_average, train_acc_each = \
                train(train_loader, model, criterion, optimizer, epoch, classes)

        # evaluate on validation set
        val_loss, val_acc_all, val_acc_average, val_acc_each, target_list, pred_list = \
                validate(val_loader, model, criterion, epoch, classes)

        # remember best prec@1 and save checkpoint
        is_best = val_acc_average > best_prec1
        best_prec1 = max(val_acc_average, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best, dir_name)
        if args.tensorboard:
            log_value('train_loss', train_loss, epoch)
            log_value('train_acc_all', train_acc_all, epoch)
            log_value('train_acc_average', train_acc_average, epoch)
            log_value('val_loss', val_loss, epoch)
            log_value('val_acc_all', val_acc_all, epoch)
            log_value('val_acc_average', val_acc_average, epoch)
            for k in classes:
                log_value('train_acc_%d' % k, train_acc_each[k], epoch)
                log_value('val_acc_%d' % k, val_acc_each[k], epoch)
    print('Best accuracy: ', best_prec1)
Exemplo n.º 6
0
def vgg13_bn(pretrained=False):
    return models.vgg13_bn(pretrained=pretrained)
Exemplo n.º 7
0
 def __init__(self, layer_num=14):  # 14, 20
     super().__init__()
     self.features = vgg13_bn(pretrained=True).features[0: layer_num]
     pass
Exemplo n.º 8
0
def instantiate_model(dataset='cifar10',
                      num_classes=10,
                      input_quant='FP',
                      arch='resnet',
                      dorefa=False,
                      abit=32,
                      wbit=32,
                      qin=False,
                      qout=False,
                      suffix='',
                      load=False,
                      torch_weights=False,
                      device=None,
                      normalize=None):
    """Initializes/load network with random weight/saved and return auto generated model name 'dataset_arch_suffix.ckpt'
    
    Args:
        dataset         : mnists/cifar10/cifar100/imagenet/tinyimagenet/simple dataset the netwoek is trained on. Used in model name 
        num_classes     : number of classes in dataset. 
        arch            : resnet/vgg/lenet5/basicnet/slpconv model architecture the network to be instantiated with 
        suffix          : str appended to the model name 
        load            : boolean variable to indicate load pretrained model from ./pretrained/dataset/
        torch_weights   : boolean variable to indicate load weight from torchvision for imagenet dataset
    Returns:
        model           : models with desired weight (pretrained / random )
        model_name      : str 'dataset_arch_suffix.ckpt' used to save/load model in ./pretrained/dataset
    """
    if normalize is None:
        un_normalize = True
    else:
        un_normalize = False

    if (device is None):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    else:
        device = device

    #Select the input transformation
    if input_quant == None:
        input_quant = ''
        Q = PreProcess()
    elif input_quant.lower() == 'q1':
        Q = Quantise2d(n_bits=1, un_normalized=un_normalize,
                       device=device).to(device)
    elif input_quant.lower() == 'q2':
        Q = Quantise2d(n_bits=2, un_normalized=un_normalize,
                       device=device).to(device)
    elif input_quant.lower() == 'q4':
        Q = Quantise2d(n_bits=4, un_normalized=un_normalize,
                       device=device).to(device)
    elif input_quant.lower() == 'q6':
        Q = Quantise2d(n_bits=6, un_normalized=un_normalize,
                       device=device).to(device)
    elif input_quant.lower() == 'q8':
        Q = Quantise2d(n_bits=8, un_normalized=un_normalize,
                       device=device).to(device)
    elif input_quant.lower() == 'fp':
        Q = Quantise2d(n_bits=1,
                       quantise=False,
                       un_normalized=un_normalize,
                       device=device).to(device)
    else:
        raise ValueError

    # Instantiate model1
    # RESNET IMAGENET
    if (arch == 'torch_resnet18'):
        if dorefa:
            raise ValueError("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.resnet18(pretrained=torch_weights)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    elif (arch == 'torch_resnet34'):
        if dorefa:
            raise ValueError("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.resnet34(pretrained=torch_weights)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    elif (arch == 'torch_resnet50'):
        if dorefa:
            raise ValueError("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.resnet50(pretrained=torch_weights)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    elif (arch == 'torch_resnet101'):
        if dorefa:
            raise ValueError("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.resnet101(pretrained=torch_weights)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    elif (arch == 'torch_resnet152'):
        if dorefa:
            raise ValueError("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.resnet152(pretrained=torch_weights)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    elif (arch == 'torch_resnet34'):
        if dorefa:
            raise ValueError("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.resnet34(pretrained=torch_weights)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    elif (arch == 'torch_resnext50_32x4d'):
        if dorefa:
            raise ValueError("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.resnext50_32x4d(pretrained=torch_weights)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    elif (arch == 'torch_resnext101_32x8d'):
        if dorefa:
            raise ValueError("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.resnext101_32x8d(pretrained=torch_weights)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    elif (arch == 'torch_wide_resnet50_2'):
        if dorefa:
            raise ValueError("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.wide_resnet50_2(pretrained=torch_weights)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    elif (arch == 'torch_wide_resnet101_2'):
        if dorefa:
            raise ValueError("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.wide_resnet101_2(pretrained=torch_weights)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    #VGG IMAGENET
    elif (arch == 'torch_vgg11'):
        if dorefa:
            raise ValueError("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.vgg11(pretrained=torch_weights)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    elif (arch == 'torch_vgg11bn'):
        if dorefa:
            raise ValueError("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.vgg11_bn(pretrained=torch_weights)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    elif (arch == 'torch_vgg13'):
        if dorefa:
            raise ValueError("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.vgg13(pretrained=torch_weights)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    elif (arch == 'torch_vgg13bn'):
        if dorefa:
            raise ValueError("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.vgg13_bn(pretrained=torch_weights)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    elif (arch == 'torch_vgg16'):
        if dorefa:
            raise ValueError("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.vgg16(pretrained=torch_weights)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    elif (arch == 'torch_vgg16bn'):
        if dorefa:
            raise ValueError("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.vgg16_bn(pretrained=torch_weights)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    elif (arch == 'torch_vgg19'):
        if dorefa:
            raise ValueError("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.vgg19(pretrained=torch_weights)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    elif (arch == 'torch_vgg19bn'):
        if dorefa:
            raise ValueError("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.vgg19_bn(pretrained=torch_weights)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    #MOBILENET IMAGENET
    elif (arch == 'torch_mobnet'):
        if dorefa:
            raise ValueError("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.mobilenet_v2(pretrained=torch_weights)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    #DENSENET IMAGENET
    elif (arch == 'torch_densenet121'):
        if dorefa:
            raise ValueError("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.densenet121(pretrained=torch_weights)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    elif (arch == 'torch_densenet169'):
        if dorefa:
            raise ValueError("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.densenet169(pretrained=torch_weights)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    elif (arch == 'torch_densenet201'):
        if dorefa:
            raise ValueError("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.densenet201(pretrained=torch_weights)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    elif (arch == 'torch_densenet161'):
        if dorefa:
            raise ValueError("Dorefa net unsupported for {}".format(arch))
        else:
            model = models.densenet161(pretrained=torch_weights)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    #RESNET CIFAR
    elif (arch[0:6] == 'resnet'):
        cfg = arch[6:]
        if dorefa:
            model = ResNet_Dorefa_(cfg=cfg,
                                   num_classes=num_classes,
                                   a_bit=abit,
                                   w_bit=wbit)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + "_a" + str(abit) + 'w' + str(
                wbit) + suffix

        else:
            model = ResNet_(cfg=cfg, num_classes=num_classes)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix

    #VGG CIFAR
    elif (arch[0:3] == 'vgg'):
        len_arch = len(arch)
        if arch[len_arch - 2:len_arch] == 'bn' and arch[len_arch - 4:len_arch -
                                                        2] == 'bn':
            batch_norm_conv = True
            batch_norm_linear = True
            cfg = arch[3:len_arch - 4]
        elif arch[len_arch - 2:len_arch] == 'bn':
            batch_norm_conv = True
            batch_norm_linear = False
            cfg = arch[3:len_arch - 2]
        else:
            batch_norm_conv = False
            batch_norm_linear = False
            cfg = arch[3:len_arch]
        if dorefa:
            model = vgg_Dorefa(cfg=cfg,
                               batch_norm_conv=batch_norm_conv,
                               batch_norm_linear=batch_norm_linear,
                               num_classes=num_classes,
                               a_bit=abit,
                               w_bit=wbit)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + "_a" + str(abit) + 'w' + str(
                wbit) + suffix

        else:
            model = vgg(cfg=cfg,
                        batch_norm_conv=batch_norm_conv,
                        batch_norm_linear=batch_norm_linear,
                        num_classes=num_classes)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    # LENET MNIST
    elif (arch == 'lenet5'):
        if dorefa:
            model = LeNet5_Dorefa(num_classes=num_classes,
                                  abit=abit,
                                  wbit=wbit)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + "_a" + str(abit) + 'w' + str(
                wbit) + suffix
        else:
            model = LeNet5(num_classes=num_classes)
            model_name = dataset.lower(
            ) + "_" + input_quant + "_" + arch + suffix
    else:
        # Right way to handle exception in python see https://stackoverflow.com/questions/2052390/manually-raising-throwing-an-exception-in-python
        # Explains all the traps of using exception, does a good job!! I mean the link :)
        raise ValueError("Unsupported neural net architecture")
    model = model.to(device)

    if load == True and torch_weights == False:
        print(" Using Model: " + arch)
        if model_name[-4:] == '_tfr':
            model_path = os.path.join('./pretrained/', dataset.lower(),
                                      model_name + '.tfr')
        else:
            model_path = os.path.join('./pretrained/', dataset.lower(),
                                      model_name + '.ckpt')
        model.load_state_dict(torch.load(model_path, map_location='cuda:0'))
        print(' Loaded trained model from :' + model_path)
        print(' {}'.format(Q))

    else:
        if model_name[-4:] == '_tfr':
            model_path = os.path.join('./pretrained/', dataset.lower(),
                                      model_name + '.tfr')
        else:
            model_path = os.path.join('./pretrained/', dataset.lower(),
                                      model_name + '.ckpt')
        print(' Training model save at:' + model_path)
    print('')
    return model, model_name, Q
Exemplo n.º 9
0
elif (modelName == 'resnet34'):
    model = models.resnet34(pretrained=True, progress=False)
elif (modelName == 'resnet50'):
    model = models.resnet50(pretrained=True, progress=False)
elif (modelName == 'resnet101'):
    model = models.resnet101(pretrained=True, progress=False)
elif (modelName == 'resnet152'):
    model = models.resnet152(pretrained=True, progress=False)
elif (modelName == 'vgg11'):
    model = models.vgg11(pretrained=True, progress=False)
elif (modelName == 'vgg11_bn'):
    model = models.vgg11_bn(pretrained=True, progress=False)
elif (modelName == 'vgg13'):
    model = models.vgg13(pretrained=True, progress=False)
elif (modelName == 'vgg13_bn'):
    model = models.vgg13_bn(pretrained=True, progress=False)
elif (modelName == 'squeezenet1_0'):
    model = models.squeezenet1_0(pretrained=True, progress=False)
elif (modelName == 'squeezenet1_1'):
    model = models.squeezenet1_1(pretrained=True, progress=False)
elif (modelName == 'densenet161'):
    model = models.densenet161(pretrained=True, progress=False)
elif (modelName == 'shufflenet_v2_x0_5'):
    model = models.shufflenet_v2_x0_5(pretrained=True, progress=False)
elif (modelName == 'mobilenet_v2'):
    model = models.mobilenet_v2(pretrained=True, progress=False)
elif (modelName == 'mnasnet1_0'):
    model = models.mnasnet1_0(pretrained=True, progress=False)
elif (modelName == 'googlenet'):
    model = models.googlenet(pretrained=True, progress=False)
Exemplo n.º 10
0
    def Model(self):
        if self.arch == 'alexnet':
            model = models.alexnet(num_classes=self.num_classes)
        if self.arch == 'vgg11':
            model = models.vgg11(num_classes=self.num_classes)
        if self.arch == 'vgg13':
            model = models.vgg13(num_classes=self.num_classes)
        if self.arch == 'vgg16':
            model = models.vgg16(num_classes=self.num_classes)
        if self.arch == 'vgg19':
            model = models.vgg19(num_classes=self.num_classes)
        if self.arch == 'vgg11_bn':
            model = models.vgg11_bn(num_classes=self.num_classes)
        if self.arch == 'vgg13_bn':
            model = models.vgg13_bn(num_classes=self.num_classes)
        if self.arch == 'vgg16_bn':
            model = models.vgg16_bn(num_classes=self.num_classes)
        if self.arch == 'vgg19_bn':
            model = models.vgg19_bn(num_classes=self.num_classes)
        if self.arch == 'resnet18':
            model = models.resnet18(num_classes=self.num_classes)
        if self.arch == 'resnet34':
            model = models.resnet34(num_classes=self.num_classes)
        if self.arch == 'resnet50':
            model = models.resnet50(num_classes=self.num_classes)
        if self.arch == 'resnet101':
            model = models.resnet101(num_classes=self.num_classes)
        if self.arch == 'resnet152':
            model = models.resnet152(num_classes=self.num_classes)
        if self.arch == 'squeezenet1_0':
            model = models.squeezenet1_0(num_classes=self.num_classes)
        if self.arch == 'squeezenet1_1':
            model = models.squeezenet1_1(num_classes=self.num_classes)
        if self.arch == 'densenet121':
            model = models.densenet121(num_classes=self.num_classes)
        if self.arch == 'densenet161':
            model = models.densenet161(num_classes=self.num_classes)
        if self.arch == 'densenet169':
            model = models.densenet169(num_classes=self.num_classes)
        if self.arch == 'densenet201':
            model = models.densenet201(num_classes=self.num_classes)
        if self.arch == 'inception_v1':
            # parameters 'aux_logits' maybe will make the model not work
            model = models.googlenet(num_classes=self.num_classes)
        if self.arch == 'inception_v3':
            # parameters 'aux_logits' maybe will make the model not work
            model = models.inception_v3(num_classes=self.num_classes)
        if self.arch == 'shufflenet_v2_x0_5':
            model = models.shufflenet_v2_x0_5(num_classes=self.num_classes)
        if self.arch == 'shufflenet_v2_x1_0':
            model = models.shufflenet_v2_x1_0(num_classes=self.num_classes)
        if self.arch == 'shufflenet_v2_x1_5':
            model = models.shufflenet_v2_x1_5(num_classes=self.num_classes)
        if self.arch == 'shufflenet_v2_x2_0':
            model = models.shufflenet_v2_x2_0(num_classes=self.num_classes)
        if self.arch == 'mobilenet_v2':
            model = models.mobilenet_v2(num_classes=self.num_classes)
        if self.arch == 'resnext50_32x4d':
            model = models.resnext50_32x4d(num_classes=self.num_classes)
        if self.arch == 'resnext101_32x4d':
            model = models.resnext101_32x4d(num_classes=self.num_classes)
        if self.arch == 'wide_resnet50_2':
            model = models.wide_resnet50_2(num_classes=self.num_classes)
        if self.arch == 'wide_resnet101_2':
            model = models.wide_resnet101_2(num_classes=self.num_classes)
        if self.arch == 'mnasnet1_0':
            model = models.mnasnet1_0(num_classes=self.num_classes)

        model = torch.nn.DataParallel(model, device_ids=self.gups).cuda()
        return model
Exemplo n.º 11
0
    def get_model(self):

        if self.args.architecture in ['resnet34', 'resnet50', 'resnet101']:
            if self.args.architecture == 'resnet34':
                model = torchmodels.resnet34(pretrained=True)
            if self.args.architecture == 'resnet50':
                model = torchmodels.resnet50(pretrained=True)
            if self.args.architecture == 'resnet101':
                model = torchmodels.resnet101(pretrained=True)

            num_classes = self.args.num_classes
            num_ftrs = model.fc.in_features
            model.fc = nn.Linear(num_ftrs, num_classes)

        if self.args.architecture in [
                'densenet121', 'densenet169', 'densenet201', 'densenet161'
        ]:

            if self.args.architecture == 'densenet121':
                model = torchmodels.densenet121(pretrained=True)
            if self.args.architecture == 'densenet169':
                model = torchmodels.densenet169(pretrained=True)
            if self.args.architecture == 'densenet201':
                model = torchmodels.densenet201(pretrained=True)
            if self.args.architecture == 'densenet161':
                model = torchmodels.densenet161(pretrained=True)

            num_classes = self.args.num_classes
            num_ftrs = model.classifier.in_features
            model.classifier = nn.Linear(num_ftrs, num_classes)

        if self.args.architecture in ['vgg11_ad', 'vgg11_bn_ad']:
            if self.args.architecture == 'vgg11_ad':
                model = torchmodels.vgg11(pretrained=True)

        if self.args.architecture in [
                'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
                'vgg19_bn', 'vgg19'
        ]:
            if self.args.architecture == 'vgg11':
                model = torchmodels.vgg11(pretrained=True)
            if self.args.architecture == 'vgg11_bn':
                model = torchmodels.vgg11_bn(pretrained=True)
            if self.args.architecture == 'vgg13':
                model = torchmodels.vgg13(pretrained=True)
            if self.args.architecture == 'vgg13_bn':
                model = torchmodels.vgg13_bn(pretrained=True)
            if self.args.architecture == 'vgg16':
                model = torchmodels.vgg16(pretrained=True)
            if self.args.architecture == 'vgg16_bn':
                model = torchmodels.vgg16_bn(pretrained=True)
            if self.args.architecture == 'vgg19_bn':
                model = torchmodels.vgg19_bn(pretrained=True)
            if self.args.architecture == 'vgg19':
                model = torchmodels.vgg19(pretrained=True)

            num_classes = self.args.num_classes
            in_features = model.classifier[6].in_features
            n_module = nn.Linear(in_features, num_classes)
            n_classifier = list(model.classifier.children())[:-1]
            n_classifier.append(n_module)
            model.classifier = nn.Sequential(*n_classifier)

        if self.args.cuda:
            model.cuda()

        return model
Exemplo n.º 12
0
    # this step is imperative
    if torch.cuda.is_available():
        model.cuda()
    return model


if not os.path.exists('model_params/vggb_cal256.pth'):
    # LOAD DATA for CALTECH256
    cal256 = MYCAL256(batch_size)
    train_loader = cal256.get_train_loader(parallel=True).get_generator()
    valid_loader = cal256.get_valid_loader(parallel=True).get_generator()
    test_loader = cal256.get_test_loader(parallel=True).get_generator()
    print('DATA LOADED')

    # TRAIN VGGB_S on CALTECH256
    model = vgg13_bn(pretrained=False, num_classes=256)
    model.classifier = nn.Sequential(
        nn.Linear(512 * 2 * 2, 4096),
        nn.ReLU(True),
        nn.Dropout(),
        nn.Linear(4096, 256),
    )
    model._initialize_weights()
    if torch.cuda.is_available():
        model.cuda()
    sum = 0
    for param in model.parameters():
        sum += np.prod(param.size())
    print(sum)
    print('START TRAINING ON CALTECH256')
    model, best_path = train_and_save_weights(model, train_loader,
Exemplo n.º 13
0
def main():
    early_stopping = EarlyStopping(5, 0.0)
    opts = get_train_args()
    print("load data ...")
    train_data = datasets.ImageFolder(
        root="data/train",
        transform=transforms.Compose([
            transforms.Resize((256, 256)),  # 한 축을 128로 조절하고
            #transforms.CenterCrop(256),  # square를 한 후,
            transforms.ToTensor(),  # Tensor로 바꾸고 (0~1로 자동으로 normalize)
            transforms.Normalize(
                (0.5, 0.5, 0.5),  # -1 ~ 1 사이로 normalize
                (0.5, 0.5, 0.5)),  # (c - m)/s 니까...
        ]))

    valid_data = datasets.ImageFolder(
        root="data/val",
        transform=transforms.Compose([
            transforms.Resize((256, 256)),  # 한 축을 128로 조절하고
            #transforms.CenterCrop(128),  # square를 한 후,
            transforms.ToTensor(),  # Tensor로 바꾸고 (0~1로 자동으로 normalize)
            transforms.Normalize(
                (0.5, 0.5, 0.5),  # -1 ~ 1 사이로 normalize
                (0.5, 0.5, 0.5)),  # (c - m)/s 니까...
        ]))
    train_loader = DataLoader(train_data,
                              batch_size=opts.batch_size,
                              shuffle=True,
                              num_workers=opts.num_processes)

    valid_loader = DataLoader(valid_data,
                              batch_size=opts.batch_size,
                              shuffle=True,
                              num_workers=opts.num_processes)

    classes = train_data.classes
    print(classes)

    print("load model ...")
    if opts.model == 'resnet':
        model = models.resnet50(progress=True)
    elif opts.model == 'vggnet':
        model = models.vgg13_bn(progress=True)
    elif opts.model == 'googlenet':
        model = models.googlenet(progress=True)
    elif opts.model == 'densenet':
        model = models.densenet121(progress=True)
    else:
        model = models.resnext50_32x4d(progress=True)
    print(opts.model)
    optimizer = optim.Adam(model.parameters(), lr=opts.lr)
    model.cuda()
    loss = torch.nn.CrossEntropyLoss()
    batch_nums = np.round(14400 / opts.batch_size)
    valid_nums = np.round(1600 / opts.batch_size)

    print("start training")
    for epoch in range(1, opts.epochs + 1):
        print("epoch : " + str(epoch))
        model.train()
        epoch_loss = 0
        tot = 0
        cnt = 0
        for i, (inputs, labels) in enumerate(train_loader):
            optimizer.zero_grad()
            inputs, labels = inputs.cuda(), labels.cuda()
            train_loss = loss(model(inputs), labels)
            train_loss.backward()
            optimizer.step()
            batch_loss = train_loss.item()
            epoch_loss += batch_loss
            cnt += 1
            print('\r{:>10} epoch {} progress {} loss: {}\n'.format(
                '', epoch, tot / 14400, train_loss))

        with open(str(opts.model) + ' log.txt', 'a') as f:
            f.write(
                str(epoch) + ' loss : ' + str(epoch_loss / batch_nums) + '\n')
        model.eval()
        valid_loss = 0

        total = 0
        correct = 0
        with torch.no_grad():
            for i, (inputs, labels) in enumerate(valid_loader):
                inputs, labels = inputs.cuda(), labels.cuda()
                outputs = model(inputs)
                batch_loss = loss(outputs, labels)
                batch_loss = batch_loss.item()
                valid_loss += batch_loss
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

            acc = 100 * correct / total

        with open(str(opts.model) + ' log.txt', 'a') as f:
            f.write(
                str(epoch) + ' loss : ' + str(valid_loss / valid_nums) +
                ' acc : ' + str(acc) + '\n')

        # check early stopping
        if early_stopping(valid_loss):
            print("[Training is early stopped in %d Epoch.]" % epoch)
            torch.save(model.state_dict(), str(opts.model) + '_model.pt')
            print("[Saved the trained model successfully.]")
            break

        if epoch % opts.save_step == 0:
            print("save model...")
            torch.save(model.state_dict(), str(opts.model) + '_model.pt')

    print("save model...")
    torch.save(model.state_dict(), str(opts.model) + '_model.pt')
Exemplo n.º 14
0
 def test_vgg13_bn(self):
     process_model(models.vgg13_bn(self.pretrained), self.image, _C_tests.forward_vgg13bn, 'VGG13BN')
Exemplo n.º 15
0
def set_model(model_name,
              num_class,
              neurons_reducer_block=0,
              comb_method=None,
              comb_config=None,
              pretrained=True,
              freeze_conv=False,
              p_dropout=0.5):

    if pretrained:
        pre_ptm = 'imagenet'
        pre_torch = True
    else:
        pre_torch = False
        pre_ptm = None

    if model_name not in _MODELS:
        raise Exception("The model {} is not available!".format(model_name))

    model = None
    if model_name == 'resnet-50':
        model = MyResnet(models.resnet50(pretrained=pre_torch),
                         num_class,
                         neurons_reducer_block,
                         freeze_conv,
                         comb_method=comb_method,
                         comb_config=comb_config)

    elif model_name == 'resnet-101':
        model = MyResnet(models.resnet101(pretrained=pre_torch),
                         num_class,
                         neurons_reducer_block,
                         freeze_conv,
                         comb_method=comb_method,
                         comb_config=comb_config)

    elif model_name == 'densenet-121':
        model = MyDensenet(models.densenet121(pretrained=pre_torch),
                           num_class,
                           neurons_reducer_block,
                           freeze_conv,
                           comb_method=comb_method,
                           comb_config=comb_config)

    elif model_name == 'vgg-13':
        model = MyVGGNet(models.vgg13_bn(pretrained=pre_torch),
                         num_class,
                         neurons_reducer_block,
                         freeze_conv,
                         comb_method=comb_method,
                         comb_config=comb_config)

    elif model_name == 'vgg-16':
        model = MyVGGNet(models.vgg16_bn(pretrained=pre_torch),
                         num_class,
                         neurons_reducer_block,
                         freeze_conv,
                         comb_method=comb_method,
                         comb_config=comb_config)

    elif model_name == 'vgg-19':
        model = MyVGGNet(models.vgg19_bn(pretrained=pre_torch),
                         num_class,
                         neurons_reducer_block,
                         freeze_conv,
                         comb_method=comb_method,
                         comb_config=comb_config)

    elif model_name == 'mobilenet':
        model = MyMobilenet(models.mobilenet_v2(pretrained=pre_torch),
                            num_class,
                            neurons_reducer_block,
                            freeze_conv,
                            comb_method=comb_method,
                            comb_config=comb_config)

    elif model_name == 'efficientnet-b4':
        if pretrained:
            model = MyEffnet(EfficientNet.from_pretrained(model_name),
                             num_class,
                             neurons_reducer_block,
                             freeze_conv,
                             comb_method=comb_method,
                             comb_config=comb_config)
        else:
            model = MyEffnet(EfficientNet.from_name(model_name),
                             num_class,
                             neurons_reducer_block,
                             freeze_conv,
                             comb_method=comb_method,
                             comb_config=comb_config)

    elif model_name == 'inceptionv4':
        model = MyInceptionV4(ptm.inceptionv4(num_classes=1000,
                                              pretrained=pre_ptm),
                              num_class,
                              neurons_reducer_block,
                              freeze_conv,
                              comb_method=comb_method,
                              comb_config=comb_config)

    elif model_name == 'senet':
        model = MySenet(ptm.senet154(num_classes=1000, pretrained=pre_ptm),
                        num_class,
                        neurons_reducer_block,
                        freeze_conv,
                        comb_method=comb_method,
                        comb_config=comb_config)

    return model
Exemplo n.º 16
0
 def vgg13(self):
     return models.vgg13_bn(pretrained=True)
Exemplo n.º 17
0
def initialize_model(model_name,
                     num_classes,
                     feature_extract,
                     use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0

    if model_name == "resnet18":
        """ Resnet18
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "resnet152":
        """ resnet152
        """
        model_ft = models.resnet152(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "resnet101":
        """ resnet101
        """
        model_ft = models.resnet101(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "resnet50":
        """ resnet50
        """
        model_ft = models.resnet50(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "resnet34":
        """ resnet34
        """
        model_ft = models.resnet34(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "alexnet":
        """ Alexnet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "vgg11_bn":
        """ VGG11_bn
        """
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "vgg13_bn":
        """ vgg13_bn
        """
        model_ft = models.vgg13_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "vgg16_bn":
        """ vgg16_bn
        """
        model_ft = models.vgg16_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "vgg19_bn":
        """ vgg19_bn
        """
        model_ft = models.vgg19_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "vgg11":
        """ VGG11
        """
        model_ft = models.vgg11(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "vgg13":
        """ vgg13
        """
        model_ft = models.vgg13(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "vgg16":
        """ vgg16
        """
        model_ft = models.vgg16(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "vgg19":
        """ vgg19
        """
        model_ft = models.vgg19(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "squeezenet1_0":
        """ squeezenet1_0
        """
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        model_ft.classifier[1] = nn.Conv2d(512,
                                           num_classes,
                                           kernel_size=(1, 1),
                                           stride=(1, 1))
        model_ft.num_classes = num_classes
        input_size = 224

    elif model_name == "squeezenet1_1":
        """ squeezenet1_1
        """
        model_ft = models.squeezenet1_1(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        model_ft.classifier[1] = nn.Conv2d(512,
                                           num_classes,
                                           kernel_size=(1, 1),
                                           stride=(1, 1))
        model_ft.num_classes = num_classes
        input_size = 224

    elif model_name == "densenet121":
        """ densenet121
        """
        model_ft = models.densenet121(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "densenet161":
        """ Densenet161
        """
        model_ft = models.densenet161(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "densenet169":
        """ Densenet169
        """
        model_ft = models.densenet169(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "densenet201":
        """ densenet201
        """
        model_ft = models.densenet201(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "inception":
        """ Inception v3
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model_ft = models.inception_v3(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        # Handle the auxilary net
        num_ftrs = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 299

    else:
        print("Invalid model name, exiting...")
        exit()

    return model_ft, input_size
Exemplo n.º 18
0
def initialize_model(model_name,
                     num_classes,
                     feature_extract,
                     use_pretrained=True,
                     only_bn=False):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None

    if 'resnet' in model_name:
        # Resnet family

        resnet_model_mapping = {
            'resnet18': models.resnet18(pretrained=use_pretrained),
            'resnet34': models.resnet34(pretrained=use_pretrained),
            'resnet50': models.resnet50(pretrained=use_pretrained),
            'resnet101': models.resnet101(pretrained=use_pretrained),
            'resnet152': models.resnet152(pretrained=use_pretrained)
        }

        model_ft = resnet_model_mapping[model_name]
        set_parameter_requires_grad(model_ft, feature_extract)

        if only_bn:
            model_ft.fc = Identity()
        else:
            # reshape the network
            num_ftrs = model_ft.fc.in_features
            model_ft.fc = nn.Linear(num_ftrs, num_classes)

    elif model_name == 'alexnet':
        # Alexnet
        model_ft = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        # reshape the network
        # (classifier): Sequential(
        # 	...
        # 	(6): Linear(in_features=4096, out_features=1000, bias=True)
        # )
        if only_bn:
            model_ft.classifier[6] = Identity()
        else:
            num_ftrs = model_ft.classifier[6].in_features
            model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)

    elif 'densenet' in model_name:
        # Densenet family

        densenet_model_mapping = {
            'densenet121': models.densenet121(pretrained=use_pretrained),
            'densenet161': models.densenet161(pretrained=use_pretrained),
            'densenet169': models.densenet169(pretrained=use_pretrained),
            'densenet201': models.densenet201(pretrained=use_pretrained)
        }

        model_ft = densenet_model_mapping[model_name]
        set_parameter_requires_grad(model_ft, feature_extract)
        # reshape the network
        # (classifier): Linear(in_features=1024, out_features=1000, bias=True)
        if only_bn:
            model_ft.classifier = Identity()
        else:
            num_ftrs = model_ft.classifier.in_features
            model_ft.classifier = nn.Linear(num_ftrs, num_classes)

    elif 'vgg' in model_name:
        # vgg family

        vgg_model_mapping = {
            'vgg11': models.vgg11(pretrained=use_pretrained),
            'vgg11_bn': models.vgg11_bn(pretrained=use_pretrained),
            'vgg13': models.vgg13(pretrained=use_pretrained),
            'vgg13_bn': models.vgg13_bn(pretrained=use_pretrained),
            'vgg16': models.vgg16(pretrained=use_pretrained),
            'vgg16_bn': models.vgg16_bn(pretrained=use_pretrained),
            'vgg19': models.vgg19(pretrained=use_pretrained),
            'vgg19_bn': models.vgg19_bn(pretrained=use_pretrained)
        }

        model_ft = vgg_model_mapping[model_name]
        set_parameter_requires_grad(model_ft, feature_extract)
        # reshape the network
        # (classifier): Sequential(
        # 	...
        # 	(6): Linear(in_features=4096, out_features=1000, bias=True)
        # )

        if only_bn:
            model_ft.classifier[6] = Identity()
        else:
            num_ftrs = model_ft.classifier[6].in_features
            model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)

    elif 'squeezenet' in model_name:
        # Squeezenet family

        squeezenet_model_mapping = {
            'squeezenet1_0': models.squeezenet1_0(pretrained=use_pretrained),
            'squeezenet1_1': models.squeezenet1_1(pretrained=use_pretrained)
        }

        model_ft = squeezenet_model_mapping[model_name]
        set_parameter_requires_grad(model_ft, feature_extract)
        # reshape the network
        # (classifier): Sequential(
        # 	(0): Dropout(p=0.5)
        # 	(1): Conv2d(512, 1000, kernel_size=(1, 1), stride=(1, 1))
        # 	(2): ReLU(inplace)
        # 	(3): AvgPool2d(kernel_size=13, stride=1, padding=0)
        # )
        if only_bn:
            model_ft.classifier[1] = Identity()
        else:
            model_ft.classifier[1] = nn.Conv2d(512,
                                               num_classes,
                                               kernel_size=(1, 1),
                                               stride=(1, 1))
            model_ft.num_classes = num_classes

    elif 'inception' in model_name:
        # Inception v3
        # Be careful, expects (299,299) sized images and has auxiliary output

        model_ft = models.inception_v3(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        # reshape the model
        # Handle the auxilary net
        if only_bn:
            model_ft.AuxLogits.fc = Identity()
            model_ft.fc = Identity()
        else:
            num_ftrs = model_ft.AuxLogits.fc.in_features
            model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
            # Handle the primary net
            num_ftrs = model_ft.fc.in_features
            model_ft.fc = nn.Linear(num_ftrs, num_classes)

    else:
        print("Invalid model name, exiting...")
        exit()

    return model_ft
Exemplo n.º 19
0
def get_model(model_name, in_ch, num_classes, pretrained):
    if in_ch == 1:
        gray = True
    elif in_ch == 3:
        gray = False
    ################ SqueezeNet ################
    if (model_name == 'SqueezeNet 1.0'):
        model = models.squeezenet1_0(pretrained=pretrained)
    if (model_name == 'SqueezeNet 1.1'):
        model = models.squeezenet1_1(pretrained=pretrained)

    ################ VGG ################
    if (model_name == 'VGG11'):
        model = models.vgg11(pretrained=pretrained)
    if (model_name == 'VGG11 with batch normalization'):
        model = models.vgg11_bn(pretrained=pretrained)
    if (model_name == 'VGG13'):
        model = models.vgg13(pretrained=pretrained)
    if (model_name == 'VGG13 with batch normalization'):
        model = models.vgg13_bn(pretrained=pretrained)
    if (model_name == 'VGG16'):
        model = models.vgg16(pretrained=pretrained)
    if (model_name == 'VGG16 with batch normalization'):
        model = models.vgg16_bn(pretrained=pretrained)
    if (model_name == 'VGG19'):
        model = models.vgg19(pretrained=pretrained)
    if (model_name == 'VGG19 with batch normalization'):
        model = models.vgg19_bn(pretrained=pretrained)

    ################ ResNet ################
    if (model_name == 'ResNet-18'):
        model = models.resnet18(pretrained=pretrained)
    if (model_name == 'ResNet-34'):
        model = models.resnet34(pretrained=pretrained)
    if (model_name == 'ResNet-50'):
        model = models.resnet50(pretrained=pretrained)
    if (model_name == 'ResNet-101'):
        model = models.resnet101(pretrained=pretrained)
    if (model_name == 'ResNet-152'):
        model = models.resnet152(pretrained=pretrained)

    ################ DenseNet ################
    if (model_name == 'DenseNet-121'):
        model = models.densenet121(pretrained=pretrained)
    if (model_name == 'DenseNet-161'):
        model = models.densenet161(pretrained=pretrained)
    if (model_name == 'DenseNet-169'):
        model = models.densenet169(pretrained=pretrained)
    if (model_name == 'DenseNet-201'):
        model = models.densenet201(pretrained=pretrained)

    ################ Other Networks ################
    if (model_name == 'AlexNet'):
        model = models.alexnet(pretrained=pretrained)
    if (model_name == 'Inception v3'):
        model = models.inception_v3(pretrained=pretrained)
    if (model_name == 'GoogLeNet'):
        model = models.googlenet(pretrained=pretrained)
    if (model_name == 'ShuffleNet v2'):
        model = models.shufflenet_v2_x1_0(pretrained=pretrained)
    if (model_name == 'MobileNet v2'):
        model = models.mobilenet_v2(pretrained=pretrained)
    if (model_name == 'MNASNet 1.0'):
        model = models.mnasnet1_0(pretrained=pretrained)
    if (model_name == 'ResNeXt-50-32x4d'):
        model = models.resnext50_32x4d(pretrained=pretrained)
    if (model_name == 'ResNeXt-101-32x8d'):
        model = models.resnext101_32x8d(pretrained=pretrained)
    if (model_name == 'Wide ResNet-50-2'):
        model = models.wide_resnet50_2(pretrained=pretrained)
    if (model_name == 'Wide ResNet-101-2'):
        model = models.wide_resnet101_2(pretrained=pretrained)

    if ('VGG' in model_name) or ('Dense' in model_name) or (
            'Squeeze' in model_name) or (model_name in [
                'AlexNet', 'MobileNet v2', 'MNASNet 1.0'
            ]):
        if pretrained and (gray
                           == False):  #layer freeze(unless gray scale iamge)
            for layer in model.features.parameters():
                layer.requires_grad = False
        # layer change
        if gray:
            out_channels = model.features[0].out_channels
            kernel_size = model.features[0].kernel_size[0]
            stride = model.features[0].stride[0]
            padding = model.features[0].padding[0]
            model.features[0] = nn.Conv2d(1, out_channels, kernel_size, stride,
                                          padding)

        if 'Squeeze' in model_name:
            in_channels = model.classifier[1].in_channels
            model.classifier[1] = nn.Conv2d(in_channels, num_classes)
        else:
            in_features = model.classifier[-1].in_features
            model.classifier[-1] = nn.Linear(in_features, num_classes)

    else:
        if pretrained and (gray
                           == False):  #layer freeze(unless gray scale iamge)
            for layer in model.parameters():
                layer.requires_grad = False
        # last layer change
        if gray:
            out_channels = model.conv1.out_channels
            kernel_size = model.conv1.kernel_size[0]
            stride = model.conv1.stride[0]
            padding = model.conv1.padding[0]
            model.conv1 = nn.Conv2d(1, out_channels, kernel_size, stride,
                                    padding)
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)

    print('\nLoad model :', model_name)
    print(model, '\n')
    return model