def forward(self, signal):

        if is_stft(self.config.data.features) or is_mel(self.config.data.features):
            signal = compute_torch_stft(
                signal.squeeze(-1),
                self.config.data.features
            )

            if is_stft(self.config.data.features):
                signal = torch.log(signal + 1e-4)

        if is_mel(self.config.data.features):
            signal = nn.functional.conv1d(
                signal,
                self.filterbanks.unsqueeze(-1)
            )
            signal = torch.log(signal + 1e-4)

        signal = signal.unsqueeze(1)
        signal = signal.repeat(1, 3, 1, 1)
        signal = self.input_norm(signal)

        h = self.backbone.features(signal)

        features = self.global_maxpool(h).squeeze(-1).squeeze(-1)

        class_logits = self.output_transform(features)

        r = dict(
            class_logits=class_logits
        )

        return r
    def __init__(self, experiment, device="cuda"):
        super().__init__()

        self.device = device

        self.experiment = experiment
        self.config = experiment.config

        if is_mel(self.config.data.features):
            self.filterbanks = torch.from_numpy(
                make_mel_filterbanks(self.config.data.features)).to(self.device)

        self.input_norm = nn.BatchNorm2d(3)

        if self.config.network.backbone == "resnet18":
            self.backbone = resnet18(pretrained=None)
        elif self.config.network.backbone == "resnet34":
            self.backbone = resnet34(pretrained=None)

        self.global_maxpool = nn.AdaptiveMaxPool2d(1)

        total_depth = self.backbone.last_linear.in_features

        self.output_transform = nn.Sequential(
            nn.BatchNorm1d(total_depth),
            nn.Linear(total_depth, total_depth),
            nn.BatchNorm1d(total_depth),
            nn.PReLU(total_depth),
            nn.Dropout(p=self.config.network.output_dropout),
            nn.Linear(total_depth, self.config.data._n_classes)
        )

        self.to(self.device)
Ejemplo n.º 3
0
    def __init__(self, experiment, device="cuda"):
        super().__init__()

        self.device = device

        self.experiment = experiment
        self.config = experiment.config

        if is_mel(self.config.data.features):
            self.filterbanks = torch.from_numpy(
                make_mel_filterbanks(self.config.data.features)).to(
                    self.device)

        self.conv_modules = torch.nn.ModuleList()
        self.rnns = torch.nn.ModuleList()

        total_depth = 0

        for k in range(self.config.network.num_conv_blocks):

            input_size = 2 if not k else depth
            depth = int(self.config.network.growth_rate**k *
                        self.config.network.conv_base_depth)

            rnn_size = 128

            if k >= self.config.network.start_deep_supervision_on:
                if self.config.network.aggregation_type == "max":
                    total_depth += depth
                elif self.config.network.aggregation_type == "rnn":
                    total_depth += rnn_size * 2
                    self.rnns.append(
                        nn.Sequential(
                            nn.LayerNorm((depth, )),
                            nn.GRU(depth,
                                   rnn_size,
                                   batch_first=True,
                                   bidirectional=True)))

            modules = [nn.BatchNorm2d(input_size)]
            modules.extend([
                nn.Conv2d(input_size, depth, kernel_size=3, padding=1),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.BatchNorm2d(depth),
                nn.PReLU(depth),
                ResnetBlock2d(depth)
            ])

            self.conv_modules.append(nn.Sequential(*modules))

        self.global_maxpool = nn.AdaptiveMaxPool2d(1)

        self.output_transform = nn.Sequential(
            nn.BatchNorm1d(total_depth), nn.Linear(total_depth, total_depth),
            nn.BatchNorm1d(total_depth), nn.PReLU(total_depth),
            nn.Dropout(p=self.config.network.output_dropout),
            nn.Linear(total_depth, 2))

        self.to(self.device)
    def forward(self, signal):

        if is_stft(self.config.data.features) or is_mel(self.config.data.features):
            signal = compute_torch_stft(
                signal.squeeze(-1),
                self.config.data.features
            )

            if is_stft(self.config.data.features):
                signal = torch.log(signal + 1e-4)

        if is_mel(self.config.data.features):
            signal = nn.functional.conv1d(
                signal,
                self.filterbanks.unsqueeze(-1)
            )
            signal = torch.log(signal + 1e-4)

        signal = signal.unsqueeze(1)
        signal = self._add_frequency_encoding(signal)

        features = []

        h = signal
        for k, module in enumerate(self.conv_modules):
            h = module(h)
            if k >= self.config.network.start_deep_supervision_on:
                if self.config.network.aggregation_type == "max":
                    features.append(self.global_maxpool(h).squeeze(-1).squeeze(-1))
                elif self.config.network.aggregation_type == "rnn":
                    rnn_input = torch.mean(h, 2).permute(0, 2, 1)
                    outputs, state = self.rnns[
                        k - self.config.network.start_deep_supervision_on](rnn_input)
                    features.append(
                        state.permute(1, 0, 2).contiguous().view(rnn_input.size(0), -1))

        features = torch.cat(features, -1)

        class_logits = self.output_transform(features)

        r = dict(
            class_logits=class_logits
        )

        return r