def define_model(is_resnet, is_densenet, is_senet):
    use18 = True # True
    if is_resnet:
        if not use18:
            original_model = resnet.resnet18(pretrained = True)
            Encoder = modules.E_resnet(original_model) 
            model = net.model(Encoder, num_features=512, block_channel = [64, 128, 256, 512])
        else:
            stereoModel = Resnet18Encoder(3)  
            model_dict = stereoModel.state_dict()
            encoder_dict = torch.load('./models/monodepth_resnet18_001.pth',map_location='cpu' )
            new_dict = {}
            for key in encoder_dict:
                if key in model_dict:
                    new_dict[key] = encoder_dict[key]

            stereoModel.load_state_dict(new_dict )
            Encoder = stereoModel
            model = net.model(Encoder, num_features=512, block_channel = [64, 128, 256, 512])      
          
    if is_densenet:
        original_model = densenet.densenet161(pretrained=True)
        Encoder = modules.E_densenet(original_model)
        model = net.model(Encoder, num_features=2208, block_channel = [192, 384, 1056, 2208])
    if is_senet:
        original_model = senet.senet154(pretrained='imagenet')
        Encoder = modules.E_senet(original_model)
        model = net.model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048])

    return model
Ejemplo n.º 2
0
def define_model(encoder='resnet'):
    if encoder is 'resnet':
        original_model = resnet.resnet50(pretrained=True)
        Encoder = modules.E_resnet(original_model)
        model = net.model(Encoder,
                          num_features=2048,
                          block_channel=[256, 512, 1024, 2048])
    if encoder is 'densenet':
        original_model = densenet.densenet161(pretrained=True)
        Encoder = modules.E_densenet(original_model)
        model = net.model(Encoder,
                          num_features=2208,
                          block_channel=[192, 384, 1056, 2208])
    if encoder is 'senet':
        original_model = senet.senet154(pretrained='imagenet')
        Encoder = modules.E_senet(original_model)
        model = net.model(Encoder,
                          num_features=2048,
                          block_channel=[256, 512, 1024, 2048])
    if encoder is 'resnet4':
        original_model = resnet4.resnet50(pretrained=True)
        Encoder = modules.E_resnet(original_model)
        model = net.model(Encoder,
                          num_features=2048,
                          block_channel=[256, 512, 1024, 2048])

    return model
Ejemplo n.º 3
0
def define_test_model():
    #archs = {"Resnet", "Densenet", "SEnet", "Custom"}
    is_resnet = args.arch == "Resnet"  #True #False #True
    is_densenet = args.arch == "Densenet"  # #False #True #False # False
    is_senet = args.arch == "SEnet"  # True #False #True #False
    is_custom = args.arch == "Custom"

    if is_resnet:
        #original_model = resnet.resnet18(pretrained = pretrain_logical)
        #Encoder = modules.E_resnet(original_model)
        #model = net.model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048])

        stereoModel = Resnet18Encoder(3)
        model_dict = stereoModel.state_dict()
        encoder_dict = torch.load('./models/monodepth_resnet18_001.pth',
                                  map_location='cpu')
        new_dict = {}
        for key in encoder_dict:
            # print(key)
            if key in model_dict:
                new_dict[key] = encoder_dict[key]

        stereoModel.load_state_dict(new_dict)
        Encoder = stereoModel
        model = net.model(Encoder,
                          num_features=512,
                          block_channel=[64, 128, 256, 512])
        print("Loading a model...")
        print("/model_epoch_{}.pth".format(str(args.load_epoch)))
        model = model.cuda().float()
        #print(stereoModel)
        #print(model)

        model_dict = torch.load(
            args.load_dir +
            "/original_model_epoch_{}.pth".format(str(args.load_epoch)))
        new_dict = model_dict
        #new_dict = {}
        #for key in model_dict:
        #	new_dict[key[7:]] = model_dict[key]
        model.load_state_dict(new_dict)

    if is_densenet:
        # TODO: no dot bug
        original_model = densenet.densenet161(pretrained=True)
        Encoder = modules.E_densenet(original_model)
        model = net.model(Encoder,
                          num_features=2208,
                          block_channel=[192, 384, 1056, 2208])

    if is_senet:
        original_model = senet.senet154(pretrained='imagenet')
        Encoder = modules.E_senet(original_model)
        model = net.model(Encoder,
                          num_features=2048,
                          block_channel=[256, 512, 1024, 2048])

    return model
Ejemplo n.º 4
0
def define_model(is_resnet, is_densenet, is_senet):
    if is_resnet:
        original_model = resnet.resnet50(pretrained=True)
        Encoder = modules.E_resnet(original_model)
        model = net.model(Encoder, num_features=2048, block_channel=[256, 512, 1024, 2048])
    if is_densenet:
        original_model = densenet.densenet161(pretrained=True)
        Encoder = modules.E_densenet(original_model)
        model = net.model(Encoder, num_features=2208, block_channel=[192, 384, 1056, 2208])
    if is_senet:
        original_model = senet.senet154(pretrained=None)
        Encoder = modules.E_senet(original_model)
        model = net.model(Encoder, num_features=2048, block_channel=[256, 512, 1024, 2048])

    return model
Ejemplo n.º 5
0
def define_model(is_resnet,
                 is_densenet,
                 is_senet,
                 model='tbdp',
                 parallel=False,
                 semff=False,
                 pcamff=False):
    if is_resnet:
        original_model = resnet.resnet50(pretrained=True)
        Encoder = modules.E_resnet(original_model)
        if model == 'tbdp':
            model = net.TBDPNet(Encoder,
                                num_features=2048,
                                block_channel=[256, 512, 1024, 2048],
                                parallel=parallel,
                                pcamff=pcamff)
        elif model == 'hu':
            model = net.Hu(Encoder,
                           num_features=2048,
                           block_channel=[256, 512, 1024, 2048],
                           semff=semff,
                           pcamff=pcamff)
        else:
            raise NotImplementedError(
                "Select model type in [\'tbdp\', \'hu\']")
    if is_densenet:
        original_model = densenet.densenet161(pretrained=True)
        Encoder = modules.E_densenet(original_model)
        if model == 'tbdp':
            model = net.TBDPNet(Encoder,
                                num_features=2208,
                                block_channel=[192, 384, 1056, 2208],
                                parallel=parallel,
                                pcamff=pcamff)
        elif model == 'hu':
            model = net.Hu(Encoder,
                           num_features=2208,
                           block_channel=[192, 384, 1056, 2208],
                           semff=semff,
                           pcamff=pcamff)
        else:
            raise NotImplementedError(
                "Select model type in [\'tbdp\', \'hu\']")
    if is_senet:
        original_model = senet.senet154(pretrained='imagenet')
        Encoder = modules.E_senet(original_model)
        if model == 'tbdp':
            model = net.TBDPNet(Encoder,
                                num_features=2048,
                                block_channel=[256, 512, 1024, 2048],
                                parallel=parallel,
                                pcamff=pcamff)
        elif model == 'hu':
            model = net.Hu(Encoder,
                           num_features=2048,
                           block_channel=[256, 512, 1024, 2048],
                           semff=semff,
                           pcamff=pcamff)
        else:
            raise NotImplementedError(
                "Select model type in [\'tbdp\', \'hu\']")

    return model