Esempio n. 1
0
def convert_to_mobile(model_name, path, dest):
    if model_name == 'concat':
        model = ConcatModel(2)
    elif model_name == 'vgg':
        model = Vgg16(2)
    else:
        model = ResNet50(2)
    model = convert(model, path)
    script_model = torch.jit.script(model)
    mobile_model = mobile_optimizer.optimize_for_mobile(script_model)
    torch.jit.save(mobile_model, dest)
Esempio n. 2
0
 def __init__(self, num_classes):
     super().__init__()
     self.resnet = ResNet50(num_classes, feature=True)
     self.vgg = Vgg16(num_classes, feature=True)
     self.fc0 = nn.Linear(8192 + 2048, 2048)
     self.act0 = nn.ReLU()
     self.fc1 = nn.Linear(2048, 1024)
     self.act1 = nn.ReLU()
     self.fc2 = nn.Linear(1024, 512)
     self.act2 = nn.ReLU()
     self.output = nn.Linear(512, num_classes)
     self.output_act = nn.Sigmoid()
    def __init__(self, nInputChannels, n_classes, os, backbone_type):
        super(ResNet_ASPP, self).__init__()

        self.os = os
        self.backbone_type = backbone_type
        
        if os == 16:
            rates = [1, 6, 12, 18]
        elif os == 8 or os == 32:
            rates = [1, 12, 24, 36]
        else:
            raise NotImplementedError

        if backbone_type == 'resnet18':
            self.backbone_features = ResNet18(nInputChannels, os, pretrained=False)
        elif backbone_type == 'resnet34':
            self.backbone_features = ResNet34(nInputChannels, os, pretrained=False)
        elif backbone_type == 'resnet50':
            self.backbone_features = ResNet50(nInputChannels, os, pretrained=True)
        else:
            raise NotImplementedError

        asppInputChannels = 512
        asppOutputChannels = 256
        if backbone_type == 'resnet50': asppInputChannels = 2048
        
        self.aspp = ASPP(asppInputChannels, asppOutputChannels, rates)
        self.last_conv = nn.Sequential(
                nn.Conv2d(asppOutputChannels, 256, kernel_size=3, stride=1, padding=1, bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                nn.Conv2d(256, n_classes, kernel_size=1, stride=1)
            ) 
Esempio n. 4
0
                    help="specific run train or evaluate")

if __name__ == "__main__":
    # load params
    args = parser.parse_args()
    json_path = os.path.join(args.model_dir, 'params.json')
    assert os.path.isfile(
        json_path), "No json configuration file found at {}".format(json_path)
    params = utils.Params(json_path)

    params.cuda = torch.cuda.is_available()
    seed_everything(params.seed)
    utils.set_logger(
        os.path.join(args.model_dir, "log/" + args.mode + VERSION + ".log"))

    model = ResNet50((224, 224), 5)
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=params.lr, eps=1e-4, amsgrad=True)
    dataloaders = fetch_dataloader(args.data_dir, [0.8, 0.1, 0.1], params)

    if (args.restore_file):
        model.load_state_dict(torch.load(args.restore_file))
    if (torch.cuda.is_available()):
        model = model.cuda()
    if (args.mode == 'train'):
        train_losses, train_accs, val_losses, val_accs = train_and_eval(
            model, loss_fn, dataloaders['train'], dataloaders['val'],
            optimizer, params.epoch, accuracy,
            os.path.join(args.model_dir, "model"))
        plot_result(
            train_losses, val_losses, "loss",
Esempio n. 5
0
import io
from torchvision import transforms
import torch
from PIL import Image
import os
from model.multi_model import ConcatModel
from model.vgg16 import Vgg16
from model.ResNet import ResNet50
from flask import Flask, jsonify, request

app = Flask(__name__)
concat = ConcatModel(2)
vgg = Vgg16(2)
resNet = ResNet50(2)


def transform_image(image_data):
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Resize((128, 128))])

    image = Image.open(io.BytesIO(image_data))

    return transform(image).unsqueeze(0)


def get_prediction(model, image_data):
    tensor = transform_image(image_data)
    _, res = model.forward(tensor).max(1)
    return res.numpy()[0]
Esempio n. 6
0
def train(model_name, train_dataset, test_dataset, lr=2e-3, epoches=200, batch_size=64, cuda=True):
    # device = torch.device('cuda:1') if cuda else torch.device('cpu')
    # distributed.init_process_group(backend='nccl', init_method='tcp://*****:*****@ep{}, loss:{}\nsaving model to {}".format(model_name, ep + 1, test_loss / test_count, model_name))
            writer.add_scalar('{}_test_loss'.format(model_name), test_loss / test_count, ep)
            # pickle.dump(model.state_dict(), open(model_name, 'wb'))
            torch.save(model.module.state_dict(), model_path)
    writer.close()
Esempio n. 7
0
def get_network(args):
    """ return given network
    """
    if args.class_num:
        class_num = args.class_num

    if args.net == 'vgg16':
        from model.VGGNet import VGG16
        net = VGG16(class_num=class_num)
    elif args.net == 'vgg19':
        from model.VGGNet import VGG19
        net = VGG19(class_num=class_num)
    elif args.net == 'alexnet':
        from model.AlexNet import alexnet
        net = alexnet(class_num=class_num)
    elif args.net == 'densenet121':
        from model.DenseNet import DenseNet121
        net = DenseNet121(class_num=class_num)
    elif args.net == 'densenet169':
        from model.DenseNet import DenseNet169
        net = DenseNet169(class_num=class_num)
    elif args.net == 'densenet201':
        from model.DenseNet import DenseNet201
        net = DenseNet201(class_num=class_num)
    elif args.net == 'densenet264':
        from model.DenseNet import DenseNet264
        net = DenseNet264(class_num=class_num)
    elif args.net == 'googlenet':
        from model.GoogleNet import googlenet
        net = googlenet(class_num=class_num)
    elif args.net == 'inceptionv1':
        from model.InceptionV1 import inceptionv1
        net = inceptionv1(class_num=class_num)
    elif args.net == 'inceptionv2':
        from model.InceptionV2 import inceptionv2
        net = inceptionv2(class_num=class_num)
    elif args.net == 'inceptionv3':
        from model.InceptionV3 import inceptionv3
        net = inceptionv3(class_num=class_num)
    elif args.net == 'inceptionv4':
        from model.InceptionV4 import inceptionv4
        net = inceptionv4(class_num=class_num)
    elif args.net == 'resnet50':
        from model.ResNet import ResNet50
        net = ResNet50(class_num=class_num)
    elif args.net == 'resnet101':
        from model.ResNet import ResNet101
        net = ResNet101(class_num=class_num)
    elif args.net == 'resnet152':
        from model.ResNet import ResNet152
        net = resnet50(class_num=class_num)
    elif args.net == 'preactresnet18':
        from model.preactresnet import preactresnet18
        net = preactresnet18(class_num=class_num)
    elif args.net == 'preactresnet34':
        from model.preactresnet import preactresnet34
    elif args.net == 'preactresnet50':
        net = preactresnet34(class_num=class_num)
        from model.preactresnet import preactresnet50
        net = preactresnet50(class_num=class_num)
    elif args.net == 'preactresnet101':
        from model.preactresnet import preactresnet101
        net = preactresnet101(class_num=class_num)
    elif args.net == 'preactresnet152':
        from model.preactresnet import preactresnet152
        net = preactresnet152(class_num=class_num)
    elif args.net == 'resnext50':
        from model.ResNeXt import ResNeXt50
        net = ResNeXt50(class_num=class_num)
    elif args.net == 'resnext101':
        from model.ResNeXt import ResNeXt101
        net = ResNeXt101(class_num=class_num)
    elif args.net == 'resnext152':
        from model.ResNeXt import ResNeXt152
        net = ResNeXt152(class_num=class_num)
    elif args.net == 'mobilenet':
        from model.mobilNet import mobilenet
        net = mobilenet(class_num=class_num)
    elif args.net == 'mobilenetv2':
        from model.mobileNetv2 import mobilenetv2
        net = mobilenetv2(class_num=class_num)
    elif args.net == 'nasnet':
        from model.NasNet import nasnet
        net = nasnet(class_num=class_num)
    elif args.net == 'attention56':
        from model.attention import attention56
        net = attention56(class_num=class_num)
    elif args.net == 'attention92':
        from model.attention import attention92
        net = attention92(class_num=class_num)
    elif args.net == 'seresnet18':
        from model.SeNet import seresnet18
        net = seresnet18(class_num=class_num)
    elif args.net == 'seresnet34':
        from model.SeNet import seresnet34
        net = seresnet34(class_num=class_num)
    elif args.net == 'seresnet50':
        from model.SeNet import seresnet50
        net = seresnet50(class_num=class_num)
    elif args.net == 'seresnet101':
        from model.SeNet import seresnet101
        net = seresnet101(class_num=class_num)
    elif args.net == 'seresnet152':
        from model.SeNet import seresnet152
        net = seresnet152(class_num=class_num)
    elif args.net == 'rirnet':
        from model.RiR import resnet_in_resnet
        net = resnet_in_resnet(class_num=class_num)
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu: #use_gpu
        net = net.cuda()

    return net