Exemplo n.º 1
0
def copy_layers(teacher_layers: nn.ModuleList, student_layers: nn.ModuleList,
                layers_to_copy: List) -> None:
    layers_to_copy = nn.ModuleList(
        [l for i, l in enumerate(teacher_layers) if i in layers_to_copy])
    assert len(student_layers) == len(
        layers_to_copy), f"{len(student_layers)} != {len(layers_to_copy)}"
    student_layers.load_state_dict(layers_to_copy.state_dict())
Exemplo n.º 2
0
def copy_layers(src_layers: nn.ModuleList, dest_layers: nn.ModuleList,
                layers_to_copy: List[int]) -> None:
    layers_to_copy = nn.ModuleList(
        [l for i, l in enumerate(src_layers) if i in layers_to_copy])
    assert len(dest_layers) == len(
        layers_to_copy), f"{len(dest_layers)} != {len(layers_to_copy)}"
    dest_layers.load_state_dict(layers_to_copy.state_dict())
Exemplo n.º 3
0
class VGGSSD(nn.Module):
    def __init__(self, num_classes, device, is_test=False, config=None):
        super(VGGSSD, self).__init__()

        vgg_config = [
            64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M',
            512, 512, 512
        ]

        self.num_classes = num_classes
        self.device = device

        self.base_net = ModuleList(self.vgg(vgg_config))

        self.source_layer_indexes = [
            (23, BatchNorm2d(512)),
            len(self.base_net),
        ]
        self.extras = ModuleList([
            Sequential(
                Conv2d(in_channels=1024, out_channels=256, kernel_size=1),
                ReLU(),
                Conv2d(in_channels=256,
                       out_channels=512,
                       kernel_size=3,
                       stride=2,
                       padding=1), ReLU()),
            Sequential(
                Conv2d(in_channels=512, out_channels=128, kernel_size=1),
                ReLU(),
                Conv2d(in_channels=128,
                       out_channels=256,
                       kernel_size=3,
                       stride=2,
                       padding=1), ReLU()),
            Sequential(
                Conv2d(in_channels=256, out_channels=128, kernel_size=1),
                ReLU(), Conv2d(in_channels=128,
                               out_channels=256,
                               kernel_size=3), ReLU()),
            Sequential(
                Conv2d(in_channels=256, out_channels=128, kernel_size=1),
                ReLU(), Conv2d(in_channels=128,
                               out_channels=256,
                               kernel_size=3), ReLU())
        ])

        self.classification_headers = ModuleList([
            Conv2d(in_channels=512,
                   out_channels=4 * num_classes,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=1024,
                   out_channels=6 * num_classes,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=512,
                   out_channels=6 * num_classes,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=256,
                   out_channels=6 * num_classes,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=256,
                   out_channels=4 * num_classes,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=256,
                   out_channels=4 * num_classes,
                   kernel_size=1,
                   padding=0),
        ])

        self.regression_headers = ModuleList([
            Conv2d(in_channels=512,
                   out_channels=4 * 4,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=1024,
                   out_channels=6 * 4,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=512,
                   out_channels=6 * 4,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=256,
                   out_channels=6 * 4,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=256,
                   out_channels=4 * 4,
                   kernel_size=3,
                   padding=1),
            Conv2d(in_channels=256,
                   out_channels=4 * 4,
                   kernel_size=1,
                   padding=0),
        ])

        self.is_test = is_test
        self.config = config
        self.source_layer_add_ons = nn.ModuleList([
            t[1] for t in self.source_layer_indexes
            if isinstance(t, tuple) and not isinstance(t, GraphPath)
        ])
        self.priors = config.priors.to(self.device)

    def forward(self, x):
        confidences = []
        locations = []
        start_layer_index = 0
        header_index = 0
        for end_layer_index in self.source_layer_indexes:
            if isinstance(end_layer_index, GraphPath):
                path = end_layer_index
                end_layer_index = end_layer_index.s0
                added_layer = None
            elif isinstance(end_layer_index, tuple):
                added_layer = end_layer_index[1]
                end_layer_index = end_layer_index[0]
                path = None
            else:
                added_layer = None
                path = None
            for layer in self.base_net[start_layer_index:end_layer_index]:
                x = layer(x)
            if added_layer:
                y = added_layer(x)
            else:
                y = x
            if path:
                sub = getattr(self.base_net[end_layer_index], path.name)
                for layer in sub[:path.s1]:
                    x = layer(x)
                y = x
                for layer in sub[path.s1:]:
                    x = layer(x)
                end_layer_index += 1
            start_layer_index = end_layer_index
            confidence, location = self.compute_header(header_index, y)
            header_index += 1
            confidences.append(confidence)
            locations.append(location)

        for layer in self.base_net[end_layer_index:]:
            x = layer(x)

        for layer in self.extras:
            x = layer(x)
            confidence, location = self.compute_header(header_index, x)
            header_index += 1
            confidences.append(confidence)
            locations.append(location)

        confidences = torch.cat(confidences, 1)
        locations = torch.cat(locations, 1)

        if self.is_test:
            confidences = F.softmax(confidences, dim=2)
            boxes = box_utils.convert_locations_to_boxes(
                locations, self.priors, self.config.center_variance,
                self.config.size_variance)
            boxes = box_utils.center_form_to_corner_form(boxes)
            return confidences, boxes
        else:
            return confidences, locations

    def compute_header(self, i, x):
        confidence = self.classification_headers[i](x)
        confidence = confidence.permute(0, 2, 3, 1).contiguous()
        confidence = confidence.view(confidence.size(0), -1, self.num_classes)

        location = self.regression_headers[i](x)
        location = location.permute(0, 2, 3, 1).contiguous()
        location = location.view(location.size(0), -1, 4)

        return confidence, location

    def init_from_base_net(self, model):
        self.base_net.load_state_dict(torch.load(
            model, map_location=lambda storage, loc: storage),
                                      strict=True)
        self.source_layer_add_ons.apply(_xavier_init_)
        self.extras.apply(_xavier_init_)
        self.classification_headers.apply(_xavier_init_)
        self.regression_headers.apply(_xavier_init_)

    def init_from_pretrained_ssd(self, model):
        state_dict = torch.load(model,
                                map_location=lambda storage, loc: storage)
        state_dict = {
            k: v
            for k, v in state_dict.items()
            if not (k.startswith("classification_headers")
                    or k.startswith("regression_headers"))
        }
        model_dict = self.state_dict()
        model_dict.update(state_dict)
        self.load_state_dict(model_dict)
        self.classification_headers.apply(_xavier_init_)
        self.regression_headers.apply(_xavier_init_)

    def init(self):
        self.base_net.apply(_xavier_init_)
        self.source_layer_add_ons.apply(_xavier_init_)
        self.extras.apply(_xavier_init_)
        self.classification_headers.apply(_xavier_init_)
        self.regression_headers.apply(_xavier_init_)

    def load(self, model):
        self.load_state_dict(
            torch.load(model, map_location=lambda storage, loc: storage))

    def save(self, model_path):
        torch.save(self.state_dict(), model_path)

    # referenced https://github.com/amdegroot/ssd.pytorch/blob/master/ssd.py
    def vgg(self, cfg, batch_norm=False):
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            elif v == 'C':
                layers += [
                    nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
                ]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                if batch_norm:
                    layers += [
                        conv2d,
                        nn.BatchNorm2d(v),
                        nn.ReLU(inplace=True)
                    ]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
        pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6)
        conv7 = nn.Conv2d(1024, 1024, kernel_size=1)
        layers += [
            pool5, conv6,
            nn.ReLU(inplace=True), conv7,
            nn.ReLU(inplace=True)
        ]
        return layers