Ejemplo n.º 1
0
 def __init__(self):
     self.inplanes = 64
     super(DRNRes, self).__init__()
     drn = drn_d_54(pretrained=True)
     self.layer0 = drn.layer0
     self.layer1 = drn.layer1
     self.layer2 = drn.layer2
     self.layer3 = drn.layer3
     self.layer4 = drn.layer4
     self.layer5 = drn.layer5
     self.layer6 = drn.layer6
     self.layer7 = drn.layer7
     self.layer8 = drn.layer8
Ejemplo n.º 2
0
 def __init__(self, options):
     super(PlaneNet, self).__init__()
     
     self.options = options        
     self.drn = drn_d_54(pretrained=True, out_map=32, num_classes=-1, out_middle=False)
     self.pool = torch.nn.AvgPool2d((32 * options.height / options.width, 32))
     self.plane_pred = nn.Linear(512, options.numOutputPlanes * 3)
     self.pyramid = PyramidModule(options, 512, 128)
     self.feature_conv = ConvBlock(1024, 512)
     self.segmentation_pred = nn.Conv2d(512, options.numOutputPlanes + 1, kernel_size=1)
     self.depth_pred = nn.Conv2d(512, 1, kernel_size=1)
     self.upsample = torch.nn.Upsample(size=(options.outputHeight, options.outputWidth), mode='bilinear')
     return
Ejemplo n.º 3
0
    def __init__(self, options):
        super(Model, self).__init__()

        self.options = options
        self.drn = drn_d_54(pretrained=True,
                            out_map=32,
                            num_classes=-1,
                            out_middle=False)
        self.pyramid = PyramidModule(options, 512, 128)
        self.feature_conv = ConvBlock(1024, 512)
        self.segmentation_pred = nn.Conv2d(512,
                                           NUM_CORNERS + NUM_ICONS + 2 +
                                           NUM_ROOMS + 2,
                                           kernel_size=1)
        self.upsample = torch.nn.Upsample(size=(options.height, options.width),
                                          mode='bilinear')
        return
Ejemplo n.º 4
0
    def __init__(self, type, num_classes):
        super().__init__()

        if type == "drn_d_38":
            self.drn = drn_d_38(pretrained=True)
        elif type == "drn_d_54":
            self.drn = drn_d_54(pretrained=True)
        elif type == "drn_d_105":
            self.drn = drn_d_105(pretrained=True)
        else:
            raise Exception("Unsupported drn model type: '{}".format(type))

        self.expand_channels = ExpandChannels2d(3)
        self.bn = nn.BatchNorm2d(3)

        self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
        self.fc = nn.Conv2d(self.drn.out_dim,
                            num_classes,
                            kernel_size=1,
                            stride=1,
                            padding=0,
                            bias=True)
Ejemplo n.º 5
0
                                              batch_size=test_batch_size,
                                              sampler=vali_sampler)

    test_dataset = image_list_folder.ImageListFolder(root='data/dev_data/',
                                                     transform=test_transform)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              num_workers=16,
                                              batch_size=test_batch_size,
                                              shuffle=True,
                                              drop_last=False)

    if model_name in pretrainedmodels.model_names:
        backbone = pretrainedmodels.__dict__[model_name]()
        backbone = nn.Sequential(*list(backbone.children())[:-2])
    elif model_name == 'drn_d_54':
        backbone = drn.drn_d_54(True, out_feat=True)
    else:
        raise Exception('\nModel {} not exist'.format(model_name))
#     net = GAIN(backbone, num_classes, in_channels=in_channels)
    solver = GAINSolver(backbone,
                        num_classes,
                        in_channels,
                        train_loader,
                        test_loader,
                        test_batch_size,
                        lr=lr,
                        loss_weights=loss_weights,
                        checkpoint_name=checkpoint_name,
                        devices=devices,
                        area_threshold=area_threshold,
                        optimizer=optimizer,