Exemple #1
0
class ResNetLSTM1(nn.Module):
    def __init__(self, num_classes, is_test=False, config=None, num_lstm=5):
        """Compose a SSD model using the given components.
		"""
        super(ResNetLSTM1, self).__init__()

        # alpha = 1
        # alpha_base = alpha
        # alpha_ssd = 0.5 * alpha
        # alpha_lstm = 0.25 * alpha

        resnet = resnet101(pretrained=True)
        all_modules = list(resnet.children())
        modules = all_modules[:-4]
        self.base_net = nn.Sequential(*modules)

        modules = all_modules[6:7]
        self.conv_final = nn.Sequential(*modules)

        self.num_classes = num_classes
        self.is_test = is_test
        self.config = config

        # lstm_layers = [BottleNeckLSTM(1024, 256),
        # 			   BottleNeckLSTM(256, 64),
        # 			   BottleNeckLSTM(64, 16),
        # 			   ConvLSTMCell(16, 16),
        # 			   ConvLSTMCell(16, 16)]

        lstm_layers = [
            BottleNeckLSTM(1024, 256),
            BottleNeckLSTM(256, 64),
            BottleNeckLSTM(64, 16),
            ConvLSTMCell(16, 16),
            ConvLSTMCell(16, 16)
        ]

        self.lstm_layers = nn.ModuleList(
            [lstm_layers[i] for i in range(num_lstm)])

        self.extras = ModuleList([
            Sequential(
                conv_depthwise_seperable(in_channels=256,
                                         out_channels=128,
                                         kernel_size=1), ReLU(),
                conv_depthwise_seperable(in_channels=128,
                                         out_channels=256,
                                         kernel_size=3,
                                         stride=2,
                                         padding=1), ReLU()),
            Sequential(
                conv_depthwise_seperable(in_channels=64,
                                         out_channels=32,
                                         kernel_size=1), ReLU(),
                conv_depthwise_seperable(in_channels=32,
                                         out_channels=64,
                                         kernel_size=3,
                                         stride=2,
                                         padding=1), ReLU()),
            Sequential(
                conv_depthwise_seperable(in_channels=16,
                                         out_channels=8,
                                         kernel_size=1), ReLU(),
                conv_depthwise_seperable(in_channels=8,
                                         out_channels=16,
                                         kernel_size=3), ReLU()),
            Sequential(
                conv_depthwise_seperable(in_channels=16,
                                         out_channels=8,
                                         kernel_size=1), ReLU(),
                conv_depthwise_seperable(in_channels=8,
                                         out_channels=16,
                                         kernel_size=3), ReLU())
        ])

        self.regression_headers = ModuleList([
            conv_depthwise_seperable(in_channels=512,
                                     out_channels=4 * 4,
                                     kernel_size=3,
                                     padding=1),
            conv_depthwise_seperable(in_channels=256,
                                     out_channels=6 * 4,
                                     kernel_size=3,
                                     padding=1),
            conv_depthwise_seperable(in_channels=64,
                                     out_channels=6 * 4,
                                     kernel_size=3,
                                     padding=1),
            conv_depthwise_seperable(in_channels=16,
                                     out_channels=6 * 4,
                                     kernel_size=3,
                                     padding=1),
            conv_depthwise_seperable(in_channels=16,
                                     out_channels=4 * 4,
                                     kernel_size=3,
                                     padding=1),
            conv_depthwise_seperable(
                in_channels=16, out_channels=4 * 4, kernel_size=3,
                padding=1),  # TODO: change to kernel_size=1, padding=0?
        ])

        self.classification_headers = ModuleList([
            conv_depthwise_seperable(in_channels=512,
                                     out_channels=4 * num_classes,
                                     kernel_size=3,
                                     padding=1),
            conv_depthwise_seperable(in_channels=256,
                                     out_channels=6 * num_classes,
                                     kernel_size=3,
                                     padding=1),
            conv_depthwise_seperable(in_channels=64,
                                     out_channels=6 * num_classes,
                                     kernel_size=3,
                                     padding=1),
            conv_depthwise_seperable(in_channels=16,
                                     out_channels=6 * num_classes,
                                     kernel_size=3,
                                     padding=1),
            conv_depthwise_seperable(in_channels=16,
                                     out_channels=4 * num_classes,
                                     kernel_size=3,
                                     padding=1),
            conv_depthwise_seperable(
                in_channels=16,
                out_channels=4 * num_classes,
                kernel_size=3,
                padding=1),  # TODO: change to kernel_size=1, padding=0?
        ])

        self.device = torch.device(
            f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
        if is_test:
            self.config = config
            self.priors = config.priors.to(self.device)

    def forward(self, x):
        confidences = []
        locations = []
        header_index = 0

        x = self.base_net(x)
        confidence, location = self.compute_header(header_index, x)
        header_index += 1
        confidences.append(confidence)
        locations.append(location)

        x = self.conv_final(x)
        x, _ = self.lstm_layers[0](x)
        confidence, location = self.compute_header(header_index, x)
        header_index += 1
        confidences.append(confidence)
        locations.append(location)

        for i in range(len(self.extras)):
            if (i < len(self.lstm_layers) - 1):
                x = self.extras[i](x)
                x, _ = self.lstm_layers[i + 1](x)
                confidence, location = self.compute_header(header_index, x)
                header_index += 1
                confidences.append(confidence)
                locations.append(location)
            else:
                x = self.extras[i](x)
                confidence, location = self.compute_header(header_index, x)
                header_index += 1
                confidences.append(confidence)
                locations.append(location)

        confidences = torch.cat(confidences, 1)
        locations = torch.cat(locations, 1)

        if self.is_test:
            confidences = F.softmax(confidences, dim=2)
            boxes = box_utils.convert_locations_to_boxes(
                locations, self.priors, self.config.center_variance,
                self.config.size_variance)
            boxes = box_utils.center_form_to_corner_form(boxes)
            return confidences, boxes
        else:
            # print(locations.size())
            # print(confidence.size())
            return confidences, locations

    def compute_header(self, i, x):
        confidence = self.classification_headers[i](x)
        confidence = confidence.permute(0, 2, 3, 1).contiguous()
        confidence = confidence.view(confidence.size(0), -1, self.num_classes)

        location = self.regression_headers[i](x)
        location = location.permute(0, 2, 3, 1).contiguous()
        location = location.view(location.size(0), -1, 4)
        return confidence, location

    def init_from_base_net(self, model):
        self.base_net.load_state_dict(torch.load(
            model, map_location=lambda storage, loc: storage),
                                      strict=True)
        self.source_layer_add_ons.apply(_xavier_init_)
        self.extras.apply(_xavier_init_)
        self.classification_headers.apply(_xavier_init_)
        self.regression_headers.apply(_xavier_init_)

    def detach_all(self):
        for layer in self.lstm_layers:
            layer.hidden_state.detach_()
            layer.cell_state.detach_()

    def init_from_pretrained_ssd(self, model):
        state_dict = torch.load(model,
                                map_location=lambda storage, loc: storage)
        state_dict = {
            k: v
            for k, v in state_dict.items()
            if not (k.startswith("classification_headers")
                    or k.startswith("regression_headers"))
        }
        model_dict = self.state_dict()
        model_dict.update(state_dict)
        self.load_state_dict(model_dict)
        self.classification_headers.apply(_xavier_init_)
        self.regression_headers.apply(_xavier_init_)

    def init(self):
        self.base_net.apply(_xavier_init_)
        self.source_layer_add_ons.apply(_xavier_init_)
        self.extras.apply(_xavier_init_)
        self.classification_headers.apply(_xavier_init_)
        self.regression_headers.apply(_xavier_init_)

    def load(self, model):
        self.load_state_dict(
            torch.load(model, map_location=lambda storage, loc: storage))

    def save(self, model_path):
        torch.save(self.state_dict(), model_path)
Exemple #2
0
class SSD(nn.Module):
    def __init__(self, num_classes, is_test=False, config=None, device=None):
        """Compose a SSD model using the given components.
		"""
        super(SSD, self).__init__()

        # alpha = 1
        # alpha_base = alpha
        # alpha_ssd = 0.5 * alpha
        # alpha_lstm = 0.25 * alpha

        self.num_classes = num_classes
        self.base_net = MobileNetV1()
        self.is_test = is_test
        self.config = config

        self.BottleneckLSTM_1 = ConvLSTMCell(1024, 256)
        self.BottleneckLSTM_2 = ConvLSTMCell(256, 64)
        self.BottleneckLSTM_3 = ConvLSTMCell(64, 16)
        self.BottleneckLSTM_4 = ConvLSTMCell(16, 4)
        self.BottleneckLSTM_5 = ConvLSTMCell(4, 1)

        self.extras = ModuleList([
            Sequential(
                Conv2d(in_channels=256, out_channels=128, kernel_size=1),
                ReLU(),
                conv_dw_1(inp=128, oup=256, kernel_size=3, stride=2,
                          padding=1), ReLU()),
            Sequential(
                Conv2d(in_channels=64, out_channels=32, kernel_size=1), ReLU(),
                conv_dw_1(inp=32, oup=64, kernel_size=3, stride=2, padding=1),
                ReLU()),
            Sequential(
                Conv2d(in_channels=16, out_channels=8, kernel_size=1), ReLU(),
                conv_dw_1(inp=8, oup=16, kernel_size=3, stride=2, padding=1),
                ReLU()),
            Sequential(
                Conv2d(in_channels=4, out_channels=2, kernel_size=1), ReLU(),
                conv_dw_1(inp=2, oup=4, kernel_size=3, stride=2, padding=1),
                ReLU())
        ])

        self.regression_headers = ModuleList([
            conv_dw_1(inp=512, oup=6 * 4, kernel_size=3, padding=1),
            conv_dw_1(inp=256, oup=6 * 4, kernel_size=3, padding=1),
            conv_dw_1(inp=64, oup=6 * 4, kernel_size=3, padding=1),
            conv_dw_1(inp=16, oup=6 * 4, kernel_size=3, padding=1),
            conv_dw_1(inp=4, oup=6 * 4, kernel_size=3, padding=1),
            conv_dw_1(inp=1, oup=6 * 4, kernel_size=3, padding=1),
        ])

        self.classification_headers = ModuleList([
            conv_dw_1(inp=512, oup=6 * num_classes, kernel_size=3, padding=1),
            conv_dw_1(inp=256, oup=6 * num_classes, kernel_size=3, padding=1),
            conv_dw_1(inp=64, oup=6 * num_classes, kernel_size=3, padding=1),
            conv_dw_1(inp=16, oup=6 * num_classes, kernel_size=3, padding=1),
            conv_dw_1(inp=4, oup=6 * num_classes, kernel_size=3, padding=1),
            conv_dw_1(inp=1, oup=6 * num_classes, kernel_size=3, padding=1),
        ])

        self.conv_13 = conv_dw(512, 1024, 2)

        if device:
            self.device = device
        else:
            self.device = torch.device(
                "cuda:1" if torch.cuda.is_available() else "cpu")
        if is_test:
            self.config = config
            self.priors = config.priors.to(self.device)

    def forward(self, x):
        confidences = []
        locations = []
        header_index = 0

        # 12 conv features
        x = self.base_net(x)
        confidence, location = self.compute_header(header_index, x)
        header_index += 1
        confidences.append(confidence)
        locations.append(location)

        x = self.conv_13(x)
        state = self.BottleneckLSTM_1(x)
        x = state[0]
        confidence, location = self.compute_header(header_index, x)
        header_index += 1
        confidences.append(confidence)
        locations.append(location)

        x = self.extras[0](x)
        state = self.BottleneckLSTM_2(x)
        x = state[0]
        confidence, location = self.compute_header(header_index, x)
        header_index += 1
        confidences.append(confidence)
        locations.append(location)

        x = self.extras[1](x)
        state = self.BottleneckLSTM_3(x)
        x = state[0]
        confidence, location = self.compute_header(header_index, x)
        header_index += 1
        confidences.append(confidence)
        locations.append(location)

        x = self.extras[2](x)
        state = self.BottleneckLSTM_4(x)
        x = state[0]
        confidence, location = self.compute_header(header_index, x)
        header_index += 1
        confidences.append(confidence)
        locations.append(location)

        x = self.extras[3](x)
        state = self.BottleneckLSTM_5(x)
        x = state[0]
        confidence, location = self.compute_header(header_index, x)
        header_index += 1
        confidences.append(confidence)
        locations.append(location)

        confidences = torch.cat(confidences, 1)
        locations = torch.cat(locations, 1)

        if self.is_test:
            confidences = F.softmax(confidences, dim=2)
            # boxes = box_utils.convert_locations_to_boxes(
            #     locations, self.priors, self.config.center_variance, self.config.size_variance
            # )
            # boxes = box_utils.center_form_to_corner_form(boxes)
            # return all_hidden_states, confidences, boxes
        else:
            # print(locations.size())
            # print(confidence.size())
            return confidences, locations

    def compute_header(self, i, x):
        confidence = self.classification_headers[i](x)
        confidence = confidence.permute(0, 2, 3, 1).contiguous()
        confidence = confidence.view(confidence.size(0), -1, self.num_classes)

        location = self.regression_headers[i](x)
        location = location.permute(0, 2, 3, 1).contiguous()
        location = location.view(location.size(0), -1, 4)
        return confidence, location

    def init_from_base_net(self, model):
        self.base_net.load_state_dict(torch.load(
            model, map_location=lambda storage, loc: storage),
                                      strict=True)
        self.source_layer_add_ons.apply(_xavier_init_)
        self.extras.apply(_xavier_init_)
        self.classification_headers.apply(_xavier_init_)
        self.regression_headers.apply(_xavier_init_)

    def init_from_pretrained_ssd(self, model):
        state_dict = torch.load(model,
                                map_location=lambda storage, loc: storage)
        state_dict = {
            k: v
            for k, v in state_dict.items()
            if not (k.startswith("classification_headers")
                    or k.startswith("regression_headers"))
        }
        model_dict = self.state_dict()
        model_dict.update(state_dict)
        self.load_state_dict(model_dict)
        self.classification_headers.apply(_xavier_init_)
        self.regression_headers.apply(_xavier_init_)

    def init(self):
        self.base_net.apply(_xavier_init_)
        self.source_layer_add_ons.apply(_xavier_init_)
        self.extras.apply(_xavier_init_)
        self.classification_headers.apply(_xavier_init_)
        self.regression_headers.apply(_xavier_init_)

    def load(self, model):
        self.load_state_dict(
            torch.load(model, map_location=lambda storage, loc: storage))

    def save(self, model_path):
        torch.save(self.state_dict(), model_path)
Exemple #3
0
class SSD(nn.Module):
    def __init__(self,
                 num_classes: int,
                 is_test=False,
                 config=None,
                 device=None):
        """ Create default SSD model.
        """
        super(SSD, self).__init__()

        self.num_classes = num_classes
        self.base_net = MobileNetV1(self.num_classes).model
        self.source_layer_indexes = [
            12,
            14,
        ]
        self.extras = ModuleList([
            Sequential(
                Conv2d(in_channels=1024, out_channels=256, kernel_size=1),
                ReLU(),
                Conv2d(in_channels=256,
                       out_channels=512,
                       kernel_size=3,
                       stride=2,
                       padding=1), ReLU()),
            Sequential(
                Conv2d(in_channels=512, out_channels=128, kernel_size=1),
                ReLU(),
                Conv2d(in_channels=128,
                       out_channels=256,
                       kernel_size=3,
                       stride=2,
                       padding=1), ReLU()),
            Sequential(
                Conv2d(in_channels=256, out_channels=128, kernel_size=1),
                ReLU(),
                Conv2d(in_channels=128,
                       out_channels=256,
                       kernel_size=3,
                       stride=2,
                       padding=1), ReLU()),
            Sequential(
                Conv2d(in_channels=256, out_channels=128, kernel_size=1),
                ReLU(),
                Conv2d(in_channels=128,
                       out_channels=256,
                       kernel_size=3,
                       stride=2,
                       padding=1), ReLU())
        ])

        self.regression_headers = ModuleList([
            Conv2d(in_channels=512,
                   out_channels=6 * 4,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=1024,
                   out_channels=6 * 4,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=512,
                   out_channels=6 * 4,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=256,
                   out_channels=6 * 4,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=256,
                   out_channels=6 * 4,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=256,
                   out_channels=6 * 4,
                   kernel_size=3,
                   padding=1),
            # TODO: change to kernel_size=1, padding=0?
        ])

        self.classification_headers = ModuleList([
            Conv2d(in_channels=512,
                   out_channels=6 * num_classes,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=1024,
                   out_channels=6 * num_classes,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=512,
                   out_channels=6 * num_classes,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=256,
                   out_channels=6 * num_classes,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=256,
                   out_channels=6 * num_classes,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=256,
                   out_channels=6 * num_classes,
                   kernel_size=3,
                   padding=1),
            # TODO: change to kernel_size=1, padding=0?
        ])

        self.is_test = is_test
        self.config = config

        # register layers in source_layer_indexes by adding them to a module list
        self.source_layer_add_ons = nn.ModuleList(
            [t[1] for t in self.source_layer_indexes if isinstance(t, tuple)])
        if device:
            self.device = device
        else:
            self.device = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")
        if is_test:
            self.config = config
            self.priors = config.priors.to(self.device)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        confidences = []
        locations = []
        start_layer_index = 0
        header_index = 0
        for end_layer_index in self.source_layer_indexes:

            if isinstance(end_layer_index, tuple):
                added_layer = end_layer_index[1]
                end_layer_index = end_layer_index[0]
            else:
                added_layer = None
            for layer in self.base_net[start_layer_index:end_layer_index]:
                x = layer(x)
            start_layer_index = end_layer_index
            if added_layer:
                y = added_layer(x)
            else:
                y = x
            confidence, location = self.compute_header(header_index, y)
            header_index += 1
            confidences.append(confidence)
            locations.append(location)

        for layer in self.base_net[end_layer_index:]:
            x = layer(x)

        for layer in self.extras:
            x = layer(x)
            confidence, location = self.compute_header(header_index, x)
            header_index += 1
            confidences.append(confidence)
            locations.append(location)

        confidences = torch.cat(confidences, 1)
        locations = torch.cat(locations, 1)

        if self.is_test:
            confidences = F.softmax(confidences, dim=2)
            boxes = box_utils.convert_locations_to_boxes(
                locations, self.priors, self.config.center_variance,
                self.config.size_variance)
            boxes = box_utils.center_form_to_corner_form(boxes)
            return confidences, boxes
        else:
            return confidences, locations

    def compute_header(self, i, x):
        confidence = self.classification_headers[i](x)
        confidence = confidence.permute(0, 2, 3, 1).contiguous()
        confidence = confidence.view(confidence.size(0), -1, self.num_classes)

        location = self.regression_headers[i](x)
        location = location.permute(0, 2, 3, 1).contiguous()
        location = location.view(location.size(0), -1, 4)

        return confidence, location

    def init_from_base_net(self, model):
        self.base_net.load_state_dict(torch.load(
            model, map_location=lambda storage, loc: storage),
                                      strict=True)
        self.source_layer_add_ons.apply(_xavier_init_)
        self.extras.apply(_xavier_init_)
        self.classification_headers.apply(_xavier_init_)
        self.regression_headers.apply(_xavier_init_)

    def init_from_pretrained_ssd(self, model):
        state_dict = torch.load(model,
                                map_location=lambda storage, loc: storage)
        state_dict = {
            k: v
            for k, v in state_dict.items()
            if not (k.startswith("classification_headers")
                    or k.startswith("regression_headers"))
        }
        model_dict = self.state_dict()
        model_dict.update(state_dict)
        self.load_state_dict(model_dict)
        self.classification_headers.apply(_xavier_init_)
        self.regression_headers.apply(_xavier_init_)

    def init(self):
        self.base_net.apply(_xavier_init_)
        self.source_layer_add_ons.apply(_xavier_init_)
        self.extras.apply(_xavier_init_)
        self.classification_headers.apply(_xavier_init_)
        self.regression_headers.apply(_xavier_init_)

    def load(self, model):
        self.load_state_dict(
            torch.load(model, map_location=lambda storage, loc: storage))

    def save(self, model_path):
        torch.save(self.state_dict(), model_path)
Exemple #4
0
class VGGSSD(nn.Module):
    def __init__(self, num_classes, device, is_test=False, config=None):
        super(VGGSSD, self).__init__()

        vgg_config = [
            64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M',
            512, 512, 512
        ]

        self.num_classes = num_classes
        self.device = device

        self.base_net = ModuleList(self.vgg(vgg_config))

        self.source_layer_indexes = [
            (23, BatchNorm2d(512)),
            len(self.base_net),
        ]
        self.extras = ModuleList([
            Sequential(
                Conv2d(in_channels=1024, out_channels=256, kernel_size=1),
                ReLU(),
                Conv2d(in_channels=256,
                       out_channels=512,
                       kernel_size=3,
                       stride=2,
                       padding=1), ReLU()),
            Sequential(
                Conv2d(in_channels=512, out_channels=128, kernel_size=1),
                ReLU(),
                Conv2d(in_channels=128,
                       out_channels=256,
                       kernel_size=3,
                       stride=2,
                       padding=1), ReLU()),
            Sequential(
                Conv2d(in_channels=256, out_channels=128, kernel_size=1),
                ReLU(), Conv2d(in_channels=128,
                               out_channels=256,
                               kernel_size=3), ReLU()),
            Sequential(
                Conv2d(in_channels=256, out_channels=128, kernel_size=1),
                ReLU(), Conv2d(in_channels=128,
                               out_channels=256,
                               kernel_size=3), ReLU())
        ])

        self.classification_headers = ModuleList([
            Conv2d(in_channels=512,
                   out_channels=4 * num_classes,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=1024,
                   out_channels=6 * num_classes,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=512,
                   out_channels=6 * num_classes,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=256,
                   out_channels=6 * num_classes,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=256,
                   out_channels=4 * num_classes,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=256,
                   out_channels=4 * num_classes,
                   kernel_size=1,
                   padding=0),
        ])

        self.regression_headers = ModuleList([
            Conv2d(in_channels=512,
                   out_channels=4 * 4,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=1024,
                   out_channels=6 * 4,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=512,
                   out_channels=6 * 4,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=256,
                   out_channels=6 * 4,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=256,
                   out_channels=4 * 4,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=256,
                   out_channels=4 * 4,
                   kernel_size=1,
                   padding=0),
        ])

        self.is_test = is_test
        self.config = config
        self.source_layer_add_ons = nn.ModuleList([
            t[1] for t in self.source_layer_indexes
            if isinstance(t, tuple) and not isinstance(t, GraphPath)
        ])
        self.priors = config.priors.to(self.device)

    def forward(self, x):
        confidences = []
        locations = []
        start_layer_index = 0
        header_index = 0
        for end_layer_index in self.source_layer_indexes:
            if isinstance(end_layer_index, GraphPath):
                path = end_layer_index
                end_layer_index = end_layer_index.s0
                added_layer = None
            elif isinstance(end_layer_index, tuple):
                added_layer = end_layer_index[1]
                end_layer_index = end_layer_index[0]
                path = None
            else:
                added_layer = None
                path = None
            for layer in self.base_net[start_layer_index:end_layer_index]:
                x = layer(x)
            if added_layer:
                y = added_layer(x)
            else:
                y = x
            if path:
                sub = getattr(self.base_net[end_layer_index], path.name)
                for layer in sub[:path.s1]:
                    x = layer(x)
                y = x
                for layer in sub[path.s1:]:
                    x = layer(x)
                end_layer_index += 1
            start_layer_index = end_layer_index
            confidence, location = self.compute_header(header_index, y)
            header_index += 1
            confidences.append(confidence)
            locations.append(location)

        for layer in self.base_net[end_layer_index:]:
            x = layer(x)

        for layer in self.extras:
            x = layer(x)
            confidence, location = self.compute_header(header_index, x)
            header_index += 1
            confidences.append(confidence)
            locations.append(location)

        confidences = torch.cat(confidences, 1)
        locations = torch.cat(locations, 1)

        if self.is_test:
            confidences = F.softmax(confidences, dim=2)
            boxes = box_utils.convert_locations_to_boxes(
                locations, self.priors, self.config.center_variance,
                self.config.size_variance)
            boxes = box_utils.center_form_to_corner_form(boxes)
            return confidences, boxes
        else:
            return confidences, locations

    def compute_header(self, i, x):
        confidence = self.classification_headers[i](x)
        confidence = confidence.permute(0, 2, 3, 1).contiguous()
        confidence = confidence.view(confidence.size(0), -1, self.num_classes)

        location = self.regression_headers[i](x)
        location = location.permute(0, 2, 3, 1).contiguous()
        location = location.view(location.size(0), -1, 4)

        return confidence, location

    def init_from_base_net(self, model):
        self.base_net.load_state_dict(torch.load(
            model, map_location=lambda storage, loc: storage),
                                      strict=True)
        self.source_layer_add_ons.apply(_xavier_init_)
        self.extras.apply(_xavier_init_)
        self.classification_headers.apply(_xavier_init_)
        self.regression_headers.apply(_xavier_init_)

    def init_from_pretrained_ssd(self, model):
        state_dict = torch.load(model,
                                map_location=lambda storage, loc: storage)
        state_dict = {
            k: v
            for k, v in state_dict.items()
            if not (k.startswith("classification_headers")
                    or k.startswith("regression_headers"))
        }
        model_dict = self.state_dict()
        model_dict.update(state_dict)
        self.load_state_dict(model_dict)
        self.classification_headers.apply(_xavier_init_)
        self.regression_headers.apply(_xavier_init_)

    def init(self):
        self.base_net.apply(_xavier_init_)
        self.source_layer_add_ons.apply(_xavier_init_)
        self.extras.apply(_xavier_init_)
        self.classification_headers.apply(_xavier_init_)
        self.regression_headers.apply(_xavier_init_)

    def load(self, model):
        self.load_state_dict(
            torch.load(model, map_location=lambda storage, loc: storage))

    def save(self, model_path):
        torch.save(self.state_dict(), model_path)

    # referenced https://github.com/amdegroot/ssd.pytorch/blob/master/ssd.py
    def vgg(self, cfg, batch_norm=False):
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            elif v == 'C':
                layers += [
                    nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
                ]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                if batch_norm:
                    layers += [
                        conv2d,
                        nn.BatchNorm2d(v),
                        nn.ReLU(inplace=True)
                    ]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
        pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6)
        conv7 = nn.Conv2d(1024, 1024, kernel_size=1)
        layers += [
            pool5, conv6,
            nn.ReLU(inplace=True), conv7,
            nn.ReLU(inplace=True)
        ]
        return layers