def forward(self, inputs):
        ob, info, (convhx, convcx, hx, cx, frames) = inputs

        # Get the grid state from vectorized input
        x = self.senc_nngrid((ob, info))

        # Stack it
        x, frames = self.frame_stack((x, frames))

        # Resize to correct dims for convnet
        batch_size = x.size(0)
        x = x.view(batch_size,
                   self.frame_stack.n_frames*self.input_size[0],
                   self.input_size[1], self.input_size[2])
        convhx, convcx = self._convlstmforward(x, convhx, convcx)
        x = convhx[-1]

        x = x.view(1, -1)
        hx, cx = self.lstm(x, (hx, cx))
        x = hx

        critic_out = self.critic_linear(x)
        actor_out = F.softsign(self.actor_linear(x))
        actor_out2 = self.actor_linear2(x)

        return self.critic_linear(x), F.softsign(self.actor_linear(x)), self.actor_linear2(x), (convhx, convcx, hx, cx, frames)
Exemple #2
0
 def forward(self, x, state):
     u = th.cat((x, state), 1)  # Concatenation of input & previous state
     za = F.softsign(self.weight_zx(u))
     z = (za + 1) / 2
     g = F.softsign(self.weight_hx(u))  # candidate cell state
     h = (1 - z) * state + z * g
     return h
 def forward(self, audio, aux, do=False, last=False):
     # Input	Features
     x = self.upsampling(self.conv_aux(self.scale_in(aux)))[:,:,1:] # B x C x T
     if self.do_prob > 0 and do:
         x = self.aux_drop(x)
     if self.audio_in_flag:
         x = torch.cat((x,audio),1) # B x C x T
     # Initial Hidden Units
     if not self.wav_conv_flag:
         h = F.softsign(self.causal(audio)) # B x C x T
     else:
         h = F.softsign(self.causal(self.wav_conv(audio))) # B x C x T
     # DCRNN blocks
     sum_out, h = self._dcrnn_forward(x, h, self.in_x[0], self.dil_h[0], self.out_skip[0])
     if self.do_prob > 0 and do:
         for l in range(1,len(self.dil_facts)):
             if (l+1)%self.dilation_depth == 0:
                 out, h = self._dcrnn_forward_drop(x, h, self.in_x[l], self.dil_h[l], self.out_skip[l])
             else:
                 out, h = self._dcrnn_forward(x, h, self.in_x[l], self.dil_h[l], self.out_skip[l])
             sum_out += out
     else:
         for l in range(1,len(self.dil_facts)):
             out, h = self._dcrnn_forward(x, h, self.in_x[l], self.dil_h[l], self.out_skip[l])
             sum_out += out
     # Output
     return self.out_2(F.relu(self.out_1(F.relu(sum_out)))).transpose(1,2)
Exemple #4
0
    def _forward(self, x, style_embed, speaker_embed, is_incremental):
        residual = x
        x = F.dropout(x, p=self.dropout, training=self.training)
        if is_incremental:
            splitdim = -1
            x = self.conv.incremental_forward(x)
        else:
            splitdim = 1
            x = self.conv(x)
            # remove future time steps
            x = x[:, :, :residual.size(-1)] if self.causal else x

        a, b = x.split(x.size(splitdim) // 2, dim=splitdim)
        if self.style_proj is not None:
            softsign = F.softsign(self.style_proj(style_embed))
            # Since conv layer assumes BCT, we need to transpose
            softsign = softsign if is_incremental else softsign.transpose(1, 2)
            a = a + softsign

        if self.speaker_proj is not None:
            softsign = F.softsign(self.speaker_proj(speaker_embed))
            # Since conv layer assumes BCT, we need to transpose
            softsign = softsign if is_incremental else softsign.transpose(1, 2)
            a = a + softsign

        x = a * F.sigmoid(b)
        return (x + residual) * math.sqrt(0.5) if self.residual else x
Exemple #5
0
 def forward(self, x):
     x = F.softsign(self.fc1(x))
     x = F.softsign(self.fc2(x))
     x = F.softsign(self.fc3(x))
     x = F.relu(self.fc4(x))
     x = F.relu(self.fc5(x))
     x = torch.tanh(self.out(x))
     return x
 def forward(self, x):
     #out = F.softsign(self.actib1.hermite(self.bn1(x), self.actib1_wts, num_pol = num_pol))
     out = F.softsign(self.bn1(x))
     shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
     out = self.conv1(out)
     #out = F.softsign(self.conv2(self.actib2.hermite(self.bn2(out), self.actib2_wts, num_pol = num_pol)))
     out = F.softsign(self.conv2(self.bn2(out)))
     out += shortcut
     return out
Exemple #7
0
 def forward(self, x):
     x = F.softsign(self.conv1(x))
     x = F.max_pool2d(x, 2, 2)
     x = F.softsign(self.conv2(x))
     x = F.max_pool2d(x, 2, 2)
     x = x.view(-1, 4 * 4 * 32)
     x = F.softsign(self.fc1(x))
     x = self.fc2(x)
     return F.log_softmax(x, dim=1)
Exemple #8
0
    def forward(self, inputs):
        x = inputs

        x = self.lrelu1(self.fc1(x))
        x = self.lrelu2(self.fc2(x))
        x = self.lrelu3(self.fc3(x))
        x = self.lrelu4(self.fc4(x))
        return self.critic_linear(x), torch.Tensor([self.action_space.high[0]]) * F.softsign(self.actor_linear(x)), \
               0.5 * (F.softsign(self.actor_linear2(x)) + 1.0) + 1e-5
Exemple #9
0
 def forward(self, x):
     out = F.softsign(
         self.actib1.hermite(self.bn1(x), self.actib1_wts, num_pol=NUM_POL))
     shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
     out = self.conv1(out)
     # V2L Architecture: Pull Softsign out here
     out = F.softsign(
         self.conv2(
             self.actib2.hermite(
                 self.bn2(out), self.actib2_wts, num_pol=NUM_POL)))
     out += shortcut
     return out
Exemple #10
0
    def initialize_decoder_states(self, memory, spk_embeds, mask):
        """ Initializes attention rnn states, decoder rnn states, attention
        weights, attention cumulative weights, attention context, stores memory
        and stores processed memory
        PARAMS
        ------
        memory: Encoder outputs
        mask: Mask for padded data if training, expects None for inference
        """
        B = memory.size(0)
        MAX_TIME = memory.size(1)

        attention_init_state = F.softsign(
            self.dense_init_attention_lstm(spk_embeds))  # (B, 1024*2)
        attention_init_state = attention_init_state.view(
            -1,
            attention_init_state.size(1) // 2, 2)  # (B, 1024, 2)
        attention_init_state = attention_init_state.permute(2, 0,
                                                            1)  # (2,B, 1024)
        self.attention_hidden = attention_init_state[0]
        self.attention_cell = attention_init_state[1]

        # self.attention_hidden = Variable(memory.data.new(
        # B, self.attention_rnn_dim).zero_())
        # self.attention_cell = Variable(memory.data.new(
        # B, self.attention_rnn_dim).zero_())

        decoder_init_state = F.softsign(
            self.dense_init_decoder_lstm(spk_embeds))  # (B, 1024*2)
        decoder_init_state = decoder_init_state.view(
            -1,
            decoder_init_state.size(1) // 2, 2)  # (B, 1024, 2)
        decoder_init_state = decoder_init_state.permute(2, 0, 1)  # (2,B, 1024)
        self.decoder_hidden = decoder_init_state[0]
        self.decoder_cell = decoder_init_state[1]

        # self.decoder_hidden = Variable(memory.data.new(
        # B, self.decoder_rnn_dim).zero_())
        # self.decoder_cell = Variable(memory.data.new(
        # B, self.decoder_rnn_dim).zero_())

        self.attention_weights = Variable(memory.data.new(B, MAX_TIME).zero_())
        self.attention_weights_cum = Variable(
            memory.data.new(B, MAX_TIME).zero_())
        self.attention_context = Variable(
            memory.data.new(B, self.encoder_embedding_dim).zero_())

        self.memory = memory
        self.processed_memory = self.attention_layer.memory_layer(memory)
        self.mask = mask
 def adpW(self, x):
     # x = F.normalize(x)
     x = x.detach()
     x = self.adp_metric_embedding1(x)
     # x = self.adp_metric_embedding1_bn(x)
     x = F.softsign(x)
     x = self.adp_metric_embedding2(x)
     x = F.softsign(x)
     # x = self.adp_metric_embedding2_bn(x)
     diag_matrix = []
     for i in range(x.size(0)):
         diag_matrix.append(torch.diag(x[i, :]))
     x = torch.stack(diag_matrix)
     # W = torch.matmul(self.transform_matrix, torch.matmul(x, self.transform_matrix))
     return x
Exemple #12
0
def generate_lesion(preds, interval, num_classes=1, coef=None):
    # b, c, h, w
    pi = torch.tensor(np.pi)
    output = []
    for t in range(1, len(preds) + 1):
        b, c, h, w = preds[-t].size()
        cur_pred = preds[-t].view(b, num_classes, -1, h, w)
        x = (interval - interval[:, -t - 1:-t]).view(b, 1, -1, 1, 1)

        mu = cur_pred[:, :, 0:1]
        logvar = cur_pred[:, :, 1:2]
        if coef is None:
            phi = 1
        elif coef == "tanh":
            phi = (torch.tanh(torch.exp(-logvar) * x + cur_pred[:, :, 2:3]) +
                   1) / 2
        elif coef == "softsign":
            phi = (F.softsign(torch.exp(-logvar) * x + cur_pred[:, :, 2:3]) +
                   1) / 2
        elif coef == "sigmoid":
            phi = torch.sigmoid(torch.exp(-logvar) * x + cur_pred[:, :, 2:3])
        else:
            raise ValueError(f"Unsupported coeffcient type {coef}")
        p = torch.exp(-torch.exp(-2 * logvar) * (x - mu)**2) * phi
        p = torch.cat((1 - p.sum((1), keepdim=True), p),
                      dim=1).clamp_min(0.001)
        output.append(p)
    # b, k, t, h, w, from back to front
    return output
Exemple #13
0
    def forward(self, inputs):
        x, (hx, cx) = inputs

        x = self.lrelu1(self.conv1(x))
        x = self.lrelu2(self.conv2(x))
        x = self.lrelu3(self.conv3(x))
        x = self.lrelu4(self.conv4(x))

        x = x.view(x.size(0), -1)
        hx, cx = self.lstm(x, (hx, cx))
        x = hx

        if self.terminal_aux_head is None:
            terminal_prediction = None
        else:
            terminal_prediction = self.terminal_aux_head(x)

        if self.reward_aux_head is None:
            reward_prediction = None
        else:
            reward_prediction = self.reward_aux_head(x)

        return self.critic_linear(x), F.softsign(
            self.actor_linear(x)
        ), self.actor_linear2(x), (
            hx, cx
        ), terminal_prediction, reward_prediction  # last two outputs are auxiliary tasks
Exemple #14
0
    def forward(self, inputs):
        ob, info, frames = inputs

        # Get the grid state from vectorized input
        x, anchor = self.senc_nngrid((ob, info))
        self.input_size = self.senc_nngrid.observation_space.shape

        # Stack it
        x, frames = self.frame_stack((x, frames, anchor))

        # Resize to correct dims for convnet
        batch_size = x.size(0)
        x = x.view(batch_size, self.frame_stack.n_frames * self.input_size[0],
                   self.input_size[1], self.input_size[2])
        x = self._convforward(x)

        # Compute action mean, action var and value grid
        critic_out = self.critic_linear(x)
        actor_out = F.softsign(self.actor_linear(x))
        actor_out2 = self.actor_linear2(x)

        # Extract motor-specific values from action grid
        critic_out = self.adec_nngrid((critic_out, info)).mean(-1,
                                                               keepdim=True)
        actor_out = self.adec_nngrid((actor_out, info))
        actor_out2 = self.adec_nngrid((actor_out2, info))
        return critic_out, actor_out, actor_out2, frames
Exemple #15
0
    def forward(self, x1):
        x1 = F.relu(
            self.bn1_2(self.conv1_2(F.relu(self.bn1_1(self.conv1_1(x1))))))

        x2 = F.relu(
            self.bn2_2(
                self.conv2_2(F.relu(self.bn2_1(self.conv2_1(
                    self.maxpool(x1)))))))
        xup = F.relu(
            self.bn4_2(
                self.conv4_2(F.relu(self.bn4_1(self.conv4_1(
                    self.maxpool(x2)))))))
        xup = self.bn4(self.upconv4(self.upsample(xup)))
        xup = self.bn4_out(torch.cat((x2, xup), 1))

        xup = F.relu(
            self.bn7_2(self.conv7_2(F.relu(self.bn7_1(self.conv7_1(xup))))))

        xup = self.bn7(self.upconv7(self.upsample(xup)))
        xup = self.bn7_out(torch.cat((x1, xup), 1))
        xup = F.relu(
            self.bn9_3(
                self.conv9_3(
                    F.relu(
                        self.bn9_2(
                            self.conv9_2(F.relu(self.bn9_1(
                                self.conv9_1(xup)))))))))

        return F.softsign(self.bn9(xup))
Exemple #16
0
    def forward(self, inp):
        x, layeracts = inp

        out = F.softsign(
            self.actib1.hermite(self.bn1(x), self.actib1_wts, num_pol=num_pol))
        layeracts.append(out.clone().view(out.size(0), -1))
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = F.softsign(
            self.conv2(
                self.actib2.hermite(self.bn2(out),
                                    self.actib2_wts,
                                    num_pol=num_pol)))
        layeracts.append(out.clone().view(out.size(0), -1))
        out += shortcut
        return (out, layeracts)
 def forward(self, x):
     #print(x)
     x = self.prenet(x)
     x = x.view(batch_size * N_samples, x.size(2),
                x.size(3)).transpose(1, 2)
     x = self.conv(x)
     x = x.transpose(1, 2)
     x.contiguous()
     x = x.view(batch_size, N_samples, x.size(1), x.size(2))
     #x = librosa.decompose.hpss(x)[0]
     x = x.mean(dim=2)
     conv_out = x
     conv_out = self.residual_conv(conv_out)
     x.contiguous()
     #print(x)
     x = self.attention(x)
     #print(x)
     x = self.prohead(x)
     x = torch.squeeze(x)
     x = F.softsign(x)
     x = self.bn(x)
     x = torch.unsqueeze(x, dim=2)
     x = torch.bmm(x.transpose(1, 2), conv_out)
     x = torch.squeeze(x)
     return x
Exemple #18
0
 def forward(self, x):
     x = F.leaky_relu(self.bn1(self.conv1(x)))
     x = F.leaky_relu(self.bn2(self.conv2(x)))
     x = F.leaky_relu(self.bn3(self.conv3(x)))
     x = F.leaky_relu(self.fc1(x.view(x.size(0), -1)))
     x = F.softsign(self.head(x)) * 15.
     return x
    def forward(self, src, trg, src_mask, trg_mask, real_flag=None):
        # NOTE: real_flag is dev_mode

        if self.hp.dev_mode:
            real_tts_emb = F.softsign(self.emb_real_tts(real_flag))
            src = src + real_tts_emb.unsqueeze(1)

        if not self.frame_stacking:
            src, src_mask = self.cnn_encoder(src, src_mask)
        else:
            src = self.embedder(src)

        e_outputs, attn_enc_enc = self.encoder(src, src_mask)
        if self.decoder_type.lower() == 'transformer':
            d_output, attn_dec_dec, attn_dec_enc = self.decoder(
                trg, e_outputs, src_mask, trg_mask)
            outputs = self.out(d_output)
        elif self.decoder_type.lower() == 'ctc':
            ctc_outputs = self.out(e_outputs)
            outputs, attn_dec_dec, attn_dec_enc = None, None, None
        elif self.decoder_type.lower() == 'transducer':
            outputs, attn_dec_dec, attn_dec_enc = self.decoder(trg, e_outputs)
        else:
            d_output, attn_dec_dec, attn_dec_enc = self.decoder(
                trg, e_outputs, src_mask, trg_mask)
            outputs = d_output

        if self.use_ctc:
            ctc_outputs = self.out_ctc(e_outputs)
        else:
            ctc_outputs = None
        return outputs, ctc_outputs, attn_enc_enc, attn_dec_dec, attn_dec_enc
Exemple #20
0
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x))
        x = F.leaky_relu(self.fc2(x))

        mu = F.softsign(self.mu_head(x))
        sigma = self.action_std_init * torch.sigmoid(self.sigma_head(x))

        return mu, sigma
Exemple #21
0
    def forward(self, psi_x, psi_y):
        # reshape to broadcast in matmul
        x, y = psi_x.view(self.n_par, self.dim,
                          1), psi_y.view(self.n_par, self.dim, 1)
        alpha = torch.zeros(self.n_par, 1, 2 * self.nsite, device=device)

        loss = torch.zeros(self.n_par, device=device)
        fidelity_store = torch.zeros(self.n_steps, self.n_par, device=device)
        last_action_store = torch.zeros(2,
                                        self.n_steps,
                                        self.n_par,
                                        self.nsite,
                                        device=device)

        for j in range(self.n_steps):
            input = torch.cat((x, y), 1).transpose(1, 2)
            dalpha1 = self.net_state(input)

            dalpha2 = self.net_action(alpha /
                                      self.force_mag)  #+ alpha/self.force_mag

            dalpha = self.net_combine(dalpha1 + dalpha2)
            alpha = self.force_mag * F.softsign(dalpha)

            alpha = torch.clamp(alpha, min=-self.force_mag, max=self.force_mag)

            alphax = alpha[:, 0, :self.nsite]  # Dimension (batchsize, nsite)
            alphay = alpha[:, 0, self.nsite:]  # Dimension (batchsize, nsite)

            for _ in range(self.n_substeps):
                # H_Re and H_Im have dimensions (n_par, dim, dim)
                H_Re = self.H_0_dt + generate_drive(alphax, self.H_1_dt)
                H_Im = generate_drive(alphay, self.H_2_dt)

                x, y = self.Heun_complex(x, y, H_Re, H_Im)

            fidelity = (torch.matmul(self.target_x, x)**2 +
                        torch.matmul(self.target_x, y)**2).squeeze()

            loss += self.C1 * self.gamma**j * (1 - fidelity
                                               )  # add state infidelity
            # punish large actions
            abs_alpha = torch.mean(alpha**2, dim=2).squeeze()
            #print(abs_alpha.shape)
            loss += self.C2 * abs_alpha

            # feed storage
            fidelity_store[j] = fidelity
            last_action_store[0, j] = alphax
            last_action_store[1, j] = alphay

        psi_x, psi_y = x.view(self.n_par,
                              self.dim), y.view(self.n_par, self.dim)

        loss += self.C3 * (1 - fidelity_store[-1])
        loss = loss.mean()  #/self.n_steps
        return psi_x, psi_y, loss, fidelity_store, last_action_store
Exemple #22
0
    def forward(self, inputs):
        x, (hx, cx) = inputs
        x = self.fc(x)
        x = x.view(1, 128)
        hx, cx = self.lstm(x, (hx, cx))
        x = hx

        return self.critic_linear(x), F.softsign(
            self.actor_linear(x)), self.actor_linear2(x), (hx, cx)
Exemple #23
0
    def forward(self, x, input_lengths, spk_embeds):
        for conv in self.convolutions:
            x = F.dropout(F.relu(conv(x)), 0.5, self.training)

        x = x.transpose(1, 2)  # (B, max_len, D)
        # pass spk_embeds to dense net and expand spk_embeds to 3 dims
        # before_lstm = F.softsign(self.dense_spk_embeds(
        #     spk_embeds))  # (B, 16) -> (B,512)
        # before_lstm = before_lstm.unsqueeze(1)  # (B, 512) -> (B, 1, 512)
        # before_lstm = before_lstm.repeat(
        #     1, x.size(1), 1)  # (B, max_len, 512)

        # # # add after conv input and spk_embeds
        # x = before_lstm + x  # (B, max_len, 512)

        # pytorch tensor are not reversible, hence the conversion
        input_lengths = input_lengths.cpu().numpy()
        x = nn.utils.rnn.pack_padded_sequence(x,
                                              input_lengths,
                                              batch_first=True)

        self.lstm.flatten_parameters()
        # prepare initial state for lstm cell
        encoder_init_state = F.softsign(
            self.dense_init_lstm(spk_embeds))  # (B, 512*2)
        encoder_init_state = encoder_init_state.view(
            -1,
            encoder_init_state.size(1) // 4, 2, 2)  # (B, 256, 2, 2)
        encoder_init_state = encoder_init_state.permute(3, 2, 0,
                                                        1)  # (2,2, B, 256)
        outputs, _ = self.lstm(x, (encoder_init_state[0].contiguous(),
                                   encoder_init_state[1].contiguous()))

        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs,
                                                      batch_first=True)
        after_lstm = F.softsign(
            self.dense_spk_embeds(spk_embeds))  # (B, 16) -> (B,512)
        after_lstm = after_lstm.unsqueeze(1)  # (B, 512) -> (B, 1, 512)
        after_lstm = after_lstm.repeat(1, outputs.size(1),
                                       1)  # (B, max_len, 512)
        # add after conv input and spk_embeds
        outputs = after_lstm + outputs  # (B, max_len, 512)
        return outputs
    def forward(self,
                text_sequences,
                text_positions=None,
                lengths=None,
                speaker_embed=None):
        assert self.n_speakers == 1 or speaker_embed is not None

        # embed text_sequences
        x = self.embed_tokens(text_sequences.long())
        x = F.dropout(x, p=self.dropout, training=self.training)

        # expand speaker embedding for all time steps
        speaker_embed_btc = expand_speaker_embed(x, speaker_embed)
        if speaker_embed_btc is not None:
            speaker_embed_btc = F.dropout(speaker_embed_btc,
                                          p=self.dropout,
                                          training=self.training)
            x = x + F.softsign(self.speaker_fc1(speaker_embed_btc))

        input_embedding = x

        # B x T x C -> B x C x T
        x = x.transpose(1, 2)

        # 1D conv blocks
        for f in self.convolutions:
            x = f(x, speaker_embed_btc) if isinstance(f, Conv1dGLU) else f(x)

        # Back to B x T x C
        keys = x.transpose(1, 2)

        if speaker_embed_btc is not None:
            keys = keys + F.softsign(self.speaker_fc2(speaker_embed_btc))

        # scale gradients (this only affects backward, not forward)
        if self.apply_grad_scaling and self.num_attention_layers is not None:
            keys = GradMultiply.apply(keys,
                                      1.0 / (2.0 * self.num_attention_layers))

        # add output to input embedding for attention
        values = (keys + input_embedding) * math.sqrt(0.5)

        return keys, values
    def forward(self, feature):

        value = self.critic_linear(feature)
        if 'discrete' in self.head_name:
            mu = self.actor_linear(feature)
        else:
            mu = F.softsign(self.actor_linear(feature))
        sigma = self.actor_linear2(feature)

        return value, mu, sigma
 def forward(self, x, test=False):
     if self.continuous:
         mu = F.softsign(self.actor_linear(x))
         sigma = self.actor_linear2(x)
     else:
         mu = self.actor_linear(x)
         sigma = 0
     action, entropy, log_prob = sample_action(self.head_name, mu, sigma,
                                               test)
     return action, entropy, log_prob
    def forward(self, x):
        x = F.relu(getattr(self, self.name + "_l1")(x))
        x = F.relu(getattr(self, self.name + "_l2")(x))

        mu = F.softsign(getattr(self, self.name + "_l3_mu")(x))
        mu = torch.clamp(mu, -self.max_action, +self.max_action)

        std = getattr(self, self.name + "_l3_std")(x)
        std = F.softplus(std) + 1e-5

        return mu, std
    def forward(self, x, test=False):
        if self.continuous:
            mu = F.softsign(self.actor_linear(x))
            sigma = self.actor_linear2(x)
        else:
            mu = self.actor_linear(x)
            sigma = torch.ones_like(mu)

        action, entropy, log_prob = sample_action(self.continuous, mu, sigma,
                                                  self.device, test)
        return action, entropy, log_prob
Exemple #29
0
 def forward(self, x):
     x = self.conv(x)
     x = torch.flatten(x, start_dim=1)
     x = x.reshape(16, 8192, 1)
     x = torch.transpose(x, 0, 2)
     x = torch.transpose(x, 1, 2)
     out, (h_n, h_c) = self.lstm(x, None)
     out = F.softsign(out)
     out = out[:, -1, :]
     out = self.fc(out)
     return out
Exemple #30
0
    def forward(self, inputs):
        x, (hx, cx) = inputs

        x = self.lrelu1(self.fc1(x))
        x = self.lrelu2(self.fc2(x))
        x = self.lrelu3(self.fc3(x))
        x = self.lrelu4(self.fc4(x))

        x = x.view(1, self.m1)

        return self.critic_linear(x), F.softsign(self.actor_linear(x)), self.actor_linear2(x), (hx, cx)
Exemple #31
0
    def forward(self, text_sequences, speaker_embed=None):
        """Forward pass
        """
        assert self.n_speakers == 1 or speaker_embed is not None

        # Text embedding
        x = self.text_embed(text_sequences.long())
        x = F.dropout(x, p=self.dropout, training=self.training)

        # expand speaker embedding for all time steps
        speaker_embed_btc = expand_speaker_embed(x, speaker_embed)
        if speaker_embed_btc is not None:
            speaker_embed_btc = F.dropout(speaker_embed_btc,
                                          p=self.dropout,
                                          training=self.training)
            x = x + F.softsign(self.speaker_fc1(speaker_embed_btc))

        input_embedding = x

        # [B, T_max, channels] -> [B, channels, T_max]
        x = x.transpose(1, 2).contiguous()

        # 1D conv blocks
        for f in self.convolutions:
            x = f(x, speaker_embed_btc) if isinstance(f, Conv1DGLU) else f(x)

        # [B, channels, T_max] -> [B, T_max, channels]
        keys = x.transpose(1, 2).contiguous()

        if speaker_embed_btc is not None:
            keys = keys + F.softsign(self.speaker_fc2(speaker_embed_btc))

        # scale gradients (this only affects backward, not forward)
        if self.apply_grad_scaling and self.num_attention_layers is not None:
            keys = GradMultiply.apply(keys,
                                      1.0 / (2.0 * self.num_attention_layers))

        # add output to input embedding for attention
        values = (keys + input_embedding) * math.sqrt(0.5)

        return keys, values
    def forward(self, x1):
        x1 = F.relu(self.bn1(self.conv1_2(F.relu(self.conv1_1(x1)))))
        # print('x1 size: %d'%(x1.size(2)))
        x2 = F.relu(self.bn2(self.conv2_2(F.relu(self.conv2_1(self.maxpool(x1))))))
        # print('x2 size: %d'%(x2.size(2)))
        x3 = F.relu(self.bn3(self.conv3_2(F.relu(self.conv3_1(self.maxpool(x2))))))
        # print('x3 size: %d'%(x3.size(2)))
        x4 = F.relu(self.bn4(self.conv4_2(F.relu(self.conv4_1(self.maxpool(x3))))))
        # print('x4 size: %d'%(x4.size(2)))
        xup = F.relu(self.conv5_2(F.relu(self.conv5_1(self.maxpool(x4)))))  # x5
        # print('x5 size: %d'%(xup.size(2)))

        xup = self.bn5(self.upconv5(self.upsample(xup)))  # x6in
        cropidx = (x4.size(2) - xup.size(2)) // 2
        x4 = x4[:, :, cropidx:cropidx + xup.size(2), cropidx:cropidx + xup.size(2)]
        # print('crop1 size: %d, x9 size: %d'%(x4crop.size(2),xup.size(2)))
        xup = self.bn5_out(torch.cat((x4, xup), 1))  # x6 cat x4
        xup = F.relu(self.conv6_2(F.relu(self.conv6_1(xup))))  # x6out

        xup = self.bn6(self.upconv6(self.upsample(xup)))  # x7in
        cropidx = (x3.size(2) - xup.size(2)) // 2
        x3 = x3[:, :, cropidx:cropidx + xup.size(2), cropidx:cropidx + xup.size(2)]
        # print('crop1 size: %d, x9 size: %d'%(x3crop.size(2),xup.size(2)))
        xup = self.bn6_out(torch.cat((x3, xup), 1) ) # x7 cat x3
        xup = F.relu(self.conv7_2(F.relu(self.conv7_1(xup))))  # x7out

        xup = self.bn7(self.upconv7(self.upsample(xup)) ) # x8in
        cropidx = (x2.size(2) - xup.size(2)) // 2
        x2 = x2[:, :, cropidx:cropidx + xup.size(2), cropidx:cropidx + xup.size(2)]
        # print('crop1 size: %d, x9 size: %d'%(x2crop.size(2),xup.size(2)))
        xup = self.bn7_out(torch.cat((x2, xup), 1))  # x8 cat x2
        xup = F.relu(self.conv8_2(F.relu(self.conv8_1(xup))))  # x8out

        xup = self.bn8(self.upconv8(self.upsample(xup)) ) # x9in
        cropidx = (x1.size(2) - xup.size(2)) // 2
        x1 = x1[:, :, cropidx:cropidx + xup.size(2), cropidx:cropidx + xup.size(2)]
        # print('crop1 size: %d, x9 size: %d'%(x1crop.size(2),xup.size(2)))
        xup = self.bn8_out(torch.cat((x1, xup), 1))  # x9 cat x1
        xup = F.relu(self.conv9_3(F.relu(self.conv9_2(F.relu(self.conv9_1(xup))))))  # x9out

        return F.softsign(self.bn9(xup))
    def fit(self, model, feature_extraction, protocol, log_dir, subset='train',
            epochs=1000, restart=0, gpu=False):

        import tensorboardX
        writer = tensorboardX.SummaryWriter(log_dir=log_dir)

        checkpoint = Checkpoint(log_dir=log_dir,
                                      restart=restart > 0)

        batch_generator = SpeechSegmentGenerator(
            feature_extraction,
            per_label=self.per_label, per_fold=self.per_fold,
            duration=self.duration, parallel=self.parallel)
        batches = batch_generator(protocol, subset=subset)
        batch = next(batches)

        batches_per_epoch = batch_generator.batches_per_epoch

        if restart > 0:
            weights_pt = checkpoint.WEIGHTS_PT.format(
                log_dir=log_dir, epoch=restart)
            model.load_state_dict(torch.load(weights_pt))

        if gpu:
            model = model.cuda()

        model.internal = False

        parameters = list(model.parameters())

        if self.variant in [2, 3, 4, 5, 6, 7, 8]:

            # norm batch-normalization
            self.norm_bn = nn.BatchNorm1d(
                1, eps=1e-5, momentum=0.1, affine=True)
            if gpu:
                self.norm_bn = self.norm_bn.cuda()
            parameters += list(self.norm_bn.parameters())

        if self.variant in [9]:
            # norm batch-normalization
            self.norm_bn = nn.BatchNorm1d(
                1, eps=1e-5, momentum=0.1, affine=False)
            if gpu:
                self.norm_bn = self.norm_bn.cuda()
            parameters += list(self.norm_bn.parameters())

        if self.variant in [5, 6, 7]:
            self.positive_bn = nn.BatchNorm1d(
                1, eps=1e-5, momentum=0.1, affine=False)
            self.negative_bn = nn.BatchNorm1d(
                1, eps=1e-5, momentum=0.1, affine=False)
            if gpu:
                self.positive_bn = self.positive_bn.cuda()
                self.negative_bn = self.negative_bn.cuda()
            parameters += list(self.positive_bn.parameters())
            parameters += list(self.negative_bn.parameters())

        if self.variant in [8, 9]:

            self.delta_bn = nn.BatchNorm1d(
                1, eps=1e-5, momentum=0.1, affine=False)
            if gpu:
                self.delta_bn = self.delta_bn.cuda()
            parameters += list(self.delta_bn.parameters())

        optimizer = Adam(parameters)
        if restart > 0:
            optimizer_pt = checkpoint.OPTIMIZER_PT.format(
                log_dir=log_dir, epoch=restart)
            optimizer.load_state_dict(torch.load(optimizer_pt))
            if gpu:
                for state in optimizer.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.cuda()

        epoch = restart if restart > 0 else -1
        while True:
            epoch += 1
            if epoch > epochs:
                break

            loss_avg, tloss_avg, closs_avg = 0., 0., 0.

            if epoch % 5 == 0:
                log_positive = []
                log_negative = []
                log_delta = []
                log_norm = []

            desc = 'Epoch #{0}'.format(epoch)
            for i in tqdm(range(batches_per_epoch), desc=desc):

                model.zero_grad()

                batch = next(batches)

                X = batch['X']
                if not getattr(model, 'batch_first', True):
                    X = np.rollaxis(X, 0, 2)
                X = np.array(X, dtype=np.float32)
                X = Variable(torch.from_numpy(X))

                if gpu:
                    X = X.cuda()

                fX = model(X)

                # pre-compute pairwise distances
                distances = self.pdist(fX)

                # sample triplets
                triplets = getattr(self, 'batch_{0}'.format(self.sampling))
                anchors, positives, negatives = triplets(batch['y'], distances)

                # compute triplet loss
                tlosses, deltas, pos_index, neg_index  = self.triplet_loss(
                    distances, anchors, positives, negatives,
                    return_delta=True)

                tloss = torch.mean(tlosses)

                if self.variant == 1:

                    closses = F.sigmoid(
                        F.softsign(deltas) * torch.norm(fX[anchors], 2, 1, keepdim=True))

                    # if d(a, p) < d(a, n) (i.e. good case)
                    #   --> sign(delta) < 0
                    #   --> loss decreases when norm increases.
                    #       i.e. encourages longer anchor

                    # if d(a, p) > d(a, n) (i.e. bad case)
                    #   --> sign(delta) > 0
                    #   --> loss increases when norm increases
                    #       i.e. encourages shorter anchor

                elif self.variant == 2:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    norms_ = F.sigmoid(self.norm_bn(norms_))

                    confidence = (norms_[anchors] + norms_[positives] + norms_[negatives]) / 3
                    # if |x| is average
                    #    --> normalized |x| = 0
                    #    --> confidence = 0.5

                    # if |x| is bigger than average
                    #    --> normalized |x| >> 0
                    #    --> confidence = 1

                    # if |x| is smaller than average
                    #    --> normalized |x| << 0
                    #    --> confidence = 0

                    correctness = F.sigmoid(-deltas / np.pi * 6)
                    # if d(a, p) = d(a, n) (i.e. uncertain case)
                    #    --> correctness = 0.5

                    # if d(a, p) - d(a, n) = -𝛑 (i.e. best possible case)
                    #    --> correctness = 1

                    # if d(a, p) - d(a, n) = +𝛑 (i.e. worst possible case)
                    #    --> correctness = 0

                    closses = torch.abs(confidence - correctness)
                    # small if (and only if) confidence & correctness agree

                elif self.variant == 3:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    norms_ = F.sigmoid(self.norm_bn(norms_))
                    confidence = (norms_[anchors] * norms_[positives] * norms_[negatives]) / 3

                    correctness = F.sigmoid(-(deltas + np.pi / 4) / np.pi * 6)
                    # correctness = 0.5 at delta == -pi/4
                    # correctness = 1 for delta == -pi
                    # correctness = 0 for delta < 0

                    closses = torch.abs(confidence - correctness)

                elif self.variant == 4:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    norms_ = F.sigmoid(self.norm_bn(norms_))
                    confidence = (norms_[anchors] * norms_[positives] * norms_[negatives]) ** 1/3

                    correctness = F.sigmoid(-(deltas + np.pi / 4) / np.pi * 6)
                    # correctness = 0.5 at delta == -pi/4
                    # correctness = 1 for delta == -pi
                    # correctness = 0 for delta < 0

                    # delta = pos - neg ... should be < 0

                    closses = torch.abs(confidence - correctness)

                elif self.variant == 5:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    confidence = F.sigmoid(self.norm_bn(norms_))

                    confidence_pos = .5 * (confidence[anchors] + confidence[positives])
                    # low positive distance == high correctness
                    correctness_pos = F.sigmoid(
                        -self.positive_bn(distances[pos_index].view(-1, 1)))

                    confidence_neg = .5 * (confidence[anchors] + confidence[negatives])
                    # high negative distance == high correctness
                    correctness_neg = F.sigmoid(
                        self.negative_bn(distances[neg_index].view(-1, 1)))

                    closses = .5 * (torch.abs(confidence_pos - correctness_pos) \
                                  + torch.abs(confidence_neg - correctness_neg))

                elif self.variant == 6:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    confidence = F.sigmoid(self.norm_bn(norms_))

                    confidence_pos = .5 * (confidence[anchors] + confidence[positives])
                    # low positive distance == high correctness
                    correctness_pos = F.sigmoid(
                        -self.positive_bn(distances[pos_index].view(-1, 1)))

                    closses = torch.abs(confidence_pos - correctness_pos)

                elif self.variant == 7:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    confidence = F.sigmoid(self.norm_bn(norms_))

                    confidence_neg = .5 * (confidence[anchors] + confidence[negatives])
                    # high negative distance == high correctness
                    correctness_neg = F.sigmoid(
                        self.negative_bn(distances[neg_index].view(-1, 1)))

                    closses = torch.abs(confidence_neg - correctness_neg)

                elif self.variant in [8, 9]:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    norms_ = F.sigmoid(self.norm_bn(norms_))
                    confidence = (norms_[anchors] * norms_[positives] * norms_[negatives]) / 3

                    correctness = F.sigmoid(-self.delta_bn(deltas))
                    closses = torch.abs(confidence - correctness)

                closs = torch.mean(closses)

                if epoch % 5 == 0:

                    if gpu:
                        fX_npy = fX.data.cpu().numpy()
                        pdist_npy = distances.data.cpu().numpy()
                        delta_npy = deltas.data.cpu().numpy()
                    else:
                        fX_npy = fX.data.numpy()
                        pdist_npy = distances.data.numpy()
                        delta_npy = deltas.data.numpy()

                    log_norm.append(np.linalg.norm(fX_npy, axis=1))

                    same_speaker = pdist(batch['y'].reshape((-1, 1)), metric='chebyshev') < 1
                    log_positive.append(pdist_npy[np.where(same_speaker)])
                    log_negative.append(pdist_npy[np.where(~same_speaker)])

                    log_delta.append(delta_npy)

                # log loss
                if gpu:
                    tloss_ = float(tloss.data.cpu().numpy())
                    closs_ = float(closs.data.cpu().numpy())
                else:
                    tloss_ = float(tloss.data.numpy())
                    closs_ = float(closs.data.numpy())
                tloss_avg += tloss_
                closs_avg += closs_
                loss_avg += tloss_ + closs_

                loss = tloss + closs
                loss.backward()
                optimizer.step()

            tloss_avg /= batches_per_epoch
            writer.add_scalar('tloss', tloss_avg, global_step=epoch)

            closs_avg /= batches_per_epoch
            writer.add_scalar('closs', closs_avg, global_step=epoch)

            loss_avg /= batches_per_epoch
            writer.add_scalar('loss', loss_avg, global_step=epoch)

            if epoch % 5 == 0:

                log_positive = np.hstack(log_positive)
                writer.add_histogram(
                    'embedding/pairwise_distance/positive', log_positive,
                    global_step=epoch, bins=np.linspace(0, np.pi, 50))
                log_negative = np.hstack(log_negative)

                writer.add_histogram(
                    'embedding/pairwise_distance/negative', log_negative,
                    global_step=epoch, bins=np.linspace(0, np.pi, 50))

                _, _, _, eer = det_curve(
                    np.hstack([np.ones(len(log_positive)), np.zeros(len(log_negative))]),
                    np.hstack([log_positive, log_negative]), distances=True)
                writer.add_scalar('eer', eer, global_step=epoch)

                log_norm = np.hstack(log_norm)
                writer.add_histogram(
                    'norm', log_norm,
                    global_step=epoch, bins='doane')

                log_delta = np.vstack(log_delta)
                writer.add_histogram(
                    'delta', log_delta,
                    global_step=epoch, bins='doane')

            checkpoint.on_epoch_end(epoch, model, optimizer)

            if hasattr(self, 'norm_bn'):
                confidence_pt = self.CONFIDENCE_PT.format(
                    log_dir=log_dir, epoch=epoch)
                torch.save(self.norm_bn.state_dict(), confidence_pt)