def forward(self, inputs): ob, info, (convhx, convcx, hx, cx, frames) = inputs # Get the grid state from vectorized input x = self.senc_nngrid((ob, info)) # Stack it x, frames = self.frame_stack((x, frames)) # Resize to correct dims for convnet batch_size = x.size(0) x = x.view(batch_size, self.frame_stack.n_frames*self.input_size[0], self.input_size[1], self.input_size[2]) convhx, convcx = self._convlstmforward(x, convhx, convcx) x = convhx[-1] x = x.view(1, -1) hx, cx = self.lstm(x, (hx, cx)) x = hx critic_out = self.critic_linear(x) actor_out = F.softsign(self.actor_linear(x)) actor_out2 = self.actor_linear2(x) return self.critic_linear(x), F.softsign(self.actor_linear(x)), self.actor_linear2(x), (convhx, convcx, hx, cx, frames)
def forward(self, x, state): u = th.cat((x, state), 1) # Concatenation of input & previous state za = F.softsign(self.weight_zx(u)) z = (za + 1) / 2 g = F.softsign(self.weight_hx(u)) # candidate cell state h = (1 - z) * state + z * g return h
def forward(self, audio, aux, do=False, last=False): # Input Features x = self.upsampling(self.conv_aux(self.scale_in(aux)))[:,:,1:] # B x C x T if self.do_prob > 0 and do: x = self.aux_drop(x) if self.audio_in_flag: x = torch.cat((x,audio),1) # B x C x T # Initial Hidden Units if not self.wav_conv_flag: h = F.softsign(self.causal(audio)) # B x C x T else: h = F.softsign(self.causal(self.wav_conv(audio))) # B x C x T # DCRNN blocks sum_out, h = self._dcrnn_forward(x, h, self.in_x[0], self.dil_h[0], self.out_skip[0]) if self.do_prob > 0 and do: for l in range(1,len(self.dil_facts)): if (l+1)%self.dilation_depth == 0: out, h = self._dcrnn_forward_drop(x, h, self.in_x[l], self.dil_h[l], self.out_skip[l]) else: out, h = self._dcrnn_forward(x, h, self.in_x[l], self.dil_h[l], self.out_skip[l]) sum_out += out else: for l in range(1,len(self.dil_facts)): out, h = self._dcrnn_forward(x, h, self.in_x[l], self.dil_h[l], self.out_skip[l]) sum_out += out # Output return self.out_2(F.relu(self.out_1(F.relu(sum_out)))).transpose(1,2)
def _forward(self, x, style_embed, speaker_embed, is_incremental): residual = x x = F.dropout(x, p=self.dropout, training=self.training) if is_incremental: splitdim = -1 x = self.conv.incremental_forward(x) else: splitdim = 1 x = self.conv(x) # remove future time steps x = x[:, :, :residual.size(-1)] if self.causal else x a, b = x.split(x.size(splitdim) // 2, dim=splitdim) if self.style_proj is not None: softsign = F.softsign(self.style_proj(style_embed)) # Since conv layer assumes BCT, we need to transpose softsign = softsign if is_incremental else softsign.transpose(1, 2) a = a + softsign if self.speaker_proj is not None: softsign = F.softsign(self.speaker_proj(speaker_embed)) # Since conv layer assumes BCT, we need to transpose softsign = softsign if is_incremental else softsign.transpose(1, 2) a = a + softsign x = a * F.sigmoid(b) return (x + residual) * math.sqrt(0.5) if self.residual else x
def forward(self, x): x = F.softsign(self.fc1(x)) x = F.softsign(self.fc2(x)) x = F.softsign(self.fc3(x)) x = F.relu(self.fc4(x)) x = F.relu(self.fc5(x)) x = torch.tanh(self.out(x)) return x
def forward(self, x): #out = F.softsign(self.actib1.hermite(self.bn1(x), self.actib1_wts, num_pol = num_pol)) out = F.softsign(self.bn1(x)) shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x out = self.conv1(out) #out = F.softsign(self.conv2(self.actib2.hermite(self.bn2(out), self.actib2_wts, num_pol = num_pol))) out = F.softsign(self.conv2(self.bn2(out))) out += shortcut return out
def forward(self, x): x = F.softsign(self.conv1(x)) x = F.max_pool2d(x, 2, 2) x = F.softsign(self.conv2(x)) x = F.max_pool2d(x, 2, 2) x = x.view(-1, 4 * 4 * 32) x = F.softsign(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=1)
def forward(self, inputs): x = inputs x = self.lrelu1(self.fc1(x)) x = self.lrelu2(self.fc2(x)) x = self.lrelu3(self.fc3(x)) x = self.lrelu4(self.fc4(x)) return self.critic_linear(x), torch.Tensor([self.action_space.high[0]]) * F.softsign(self.actor_linear(x)), \ 0.5 * (F.softsign(self.actor_linear2(x)) + 1.0) + 1e-5
def forward(self, x): out = F.softsign( self.actib1.hermite(self.bn1(x), self.actib1_wts, num_pol=NUM_POL)) shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x out = self.conv1(out) # V2L Architecture: Pull Softsign out here out = F.softsign( self.conv2( self.actib2.hermite( self.bn2(out), self.actib2_wts, num_pol=NUM_POL))) out += shortcut return out
def initialize_decoder_states(self, memory, spk_embeds, mask): """ Initializes attention rnn states, decoder rnn states, attention weights, attention cumulative weights, attention context, stores memory and stores processed memory PARAMS ------ memory: Encoder outputs mask: Mask for padded data if training, expects None for inference """ B = memory.size(0) MAX_TIME = memory.size(1) attention_init_state = F.softsign( self.dense_init_attention_lstm(spk_embeds)) # (B, 1024*2) attention_init_state = attention_init_state.view( -1, attention_init_state.size(1) // 2, 2) # (B, 1024, 2) attention_init_state = attention_init_state.permute(2, 0, 1) # (2,B, 1024) self.attention_hidden = attention_init_state[0] self.attention_cell = attention_init_state[1] # self.attention_hidden = Variable(memory.data.new( # B, self.attention_rnn_dim).zero_()) # self.attention_cell = Variable(memory.data.new( # B, self.attention_rnn_dim).zero_()) decoder_init_state = F.softsign( self.dense_init_decoder_lstm(spk_embeds)) # (B, 1024*2) decoder_init_state = decoder_init_state.view( -1, decoder_init_state.size(1) // 2, 2) # (B, 1024, 2) decoder_init_state = decoder_init_state.permute(2, 0, 1) # (2,B, 1024) self.decoder_hidden = decoder_init_state[0] self.decoder_cell = decoder_init_state[1] # self.decoder_hidden = Variable(memory.data.new( # B, self.decoder_rnn_dim).zero_()) # self.decoder_cell = Variable(memory.data.new( # B, self.decoder_rnn_dim).zero_()) self.attention_weights = Variable(memory.data.new(B, MAX_TIME).zero_()) self.attention_weights_cum = Variable( memory.data.new(B, MAX_TIME).zero_()) self.attention_context = Variable( memory.data.new(B, self.encoder_embedding_dim).zero_()) self.memory = memory self.processed_memory = self.attention_layer.memory_layer(memory) self.mask = mask
def adpW(self, x): # x = F.normalize(x) x = x.detach() x = self.adp_metric_embedding1(x) # x = self.adp_metric_embedding1_bn(x) x = F.softsign(x) x = self.adp_metric_embedding2(x) x = F.softsign(x) # x = self.adp_metric_embedding2_bn(x) diag_matrix = [] for i in range(x.size(0)): diag_matrix.append(torch.diag(x[i, :])) x = torch.stack(diag_matrix) # W = torch.matmul(self.transform_matrix, torch.matmul(x, self.transform_matrix)) return x
def generate_lesion(preds, interval, num_classes=1, coef=None): # b, c, h, w pi = torch.tensor(np.pi) output = [] for t in range(1, len(preds) + 1): b, c, h, w = preds[-t].size() cur_pred = preds[-t].view(b, num_classes, -1, h, w) x = (interval - interval[:, -t - 1:-t]).view(b, 1, -1, 1, 1) mu = cur_pred[:, :, 0:1] logvar = cur_pred[:, :, 1:2] if coef is None: phi = 1 elif coef == "tanh": phi = (torch.tanh(torch.exp(-logvar) * x + cur_pred[:, :, 2:3]) + 1) / 2 elif coef == "softsign": phi = (F.softsign(torch.exp(-logvar) * x + cur_pred[:, :, 2:3]) + 1) / 2 elif coef == "sigmoid": phi = torch.sigmoid(torch.exp(-logvar) * x + cur_pred[:, :, 2:3]) else: raise ValueError(f"Unsupported coeffcient type {coef}") p = torch.exp(-torch.exp(-2 * logvar) * (x - mu)**2) * phi p = torch.cat((1 - p.sum((1), keepdim=True), p), dim=1).clamp_min(0.001) output.append(p) # b, k, t, h, w, from back to front return output
def forward(self, inputs): x, (hx, cx) = inputs x = self.lrelu1(self.conv1(x)) x = self.lrelu2(self.conv2(x)) x = self.lrelu3(self.conv3(x)) x = self.lrelu4(self.conv4(x)) x = x.view(x.size(0), -1) hx, cx = self.lstm(x, (hx, cx)) x = hx if self.terminal_aux_head is None: terminal_prediction = None else: terminal_prediction = self.terminal_aux_head(x) if self.reward_aux_head is None: reward_prediction = None else: reward_prediction = self.reward_aux_head(x) return self.critic_linear(x), F.softsign( self.actor_linear(x) ), self.actor_linear2(x), ( hx, cx ), terminal_prediction, reward_prediction # last two outputs are auxiliary tasks
def forward(self, inputs): ob, info, frames = inputs # Get the grid state from vectorized input x, anchor = self.senc_nngrid((ob, info)) self.input_size = self.senc_nngrid.observation_space.shape # Stack it x, frames = self.frame_stack((x, frames, anchor)) # Resize to correct dims for convnet batch_size = x.size(0) x = x.view(batch_size, self.frame_stack.n_frames * self.input_size[0], self.input_size[1], self.input_size[2]) x = self._convforward(x) # Compute action mean, action var and value grid critic_out = self.critic_linear(x) actor_out = F.softsign(self.actor_linear(x)) actor_out2 = self.actor_linear2(x) # Extract motor-specific values from action grid critic_out = self.adec_nngrid((critic_out, info)).mean(-1, keepdim=True) actor_out = self.adec_nngrid((actor_out, info)) actor_out2 = self.adec_nngrid((actor_out2, info)) return critic_out, actor_out, actor_out2, frames
def forward(self, x1): x1 = F.relu( self.bn1_2(self.conv1_2(F.relu(self.bn1_1(self.conv1_1(x1)))))) x2 = F.relu( self.bn2_2( self.conv2_2(F.relu(self.bn2_1(self.conv2_1( self.maxpool(x1))))))) xup = F.relu( self.bn4_2( self.conv4_2(F.relu(self.bn4_1(self.conv4_1( self.maxpool(x2))))))) xup = self.bn4(self.upconv4(self.upsample(xup))) xup = self.bn4_out(torch.cat((x2, xup), 1)) xup = F.relu( self.bn7_2(self.conv7_2(F.relu(self.bn7_1(self.conv7_1(xup)))))) xup = self.bn7(self.upconv7(self.upsample(xup))) xup = self.bn7_out(torch.cat((x1, xup), 1)) xup = F.relu( self.bn9_3( self.conv9_3( F.relu( self.bn9_2( self.conv9_2(F.relu(self.bn9_1( self.conv9_1(xup))))))))) return F.softsign(self.bn9(xup))
def forward(self, inp): x, layeracts = inp out = F.softsign( self.actib1.hermite(self.bn1(x), self.actib1_wts, num_pol=num_pol)) layeracts.append(out.clone().view(out.size(0), -1)) shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x out = self.conv1(out) out = F.softsign( self.conv2( self.actib2.hermite(self.bn2(out), self.actib2_wts, num_pol=num_pol))) layeracts.append(out.clone().view(out.size(0), -1)) out += shortcut return (out, layeracts)
def forward(self, x): #print(x) x = self.prenet(x) x = x.view(batch_size * N_samples, x.size(2), x.size(3)).transpose(1, 2) x = self.conv(x) x = x.transpose(1, 2) x.contiguous() x = x.view(batch_size, N_samples, x.size(1), x.size(2)) #x = librosa.decompose.hpss(x)[0] x = x.mean(dim=2) conv_out = x conv_out = self.residual_conv(conv_out) x.contiguous() #print(x) x = self.attention(x) #print(x) x = self.prohead(x) x = torch.squeeze(x) x = F.softsign(x) x = self.bn(x) x = torch.unsqueeze(x, dim=2) x = torch.bmm(x.transpose(1, 2), conv_out) x = torch.squeeze(x) return x
def forward(self, x): x = F.leaky_relu(self.bn1(self.conv1(x))) x = F.leaky_relu(self.bn2(self.conv2(x))) x = F.leaky_relu(self.bn3(self.conv3(x))) x = F.leaky_relu(self.fc1(x.view(x.size(0), -1))) x = F.softsign(self.head(x)) * 15. return x
def forward(self, src, trg, src_mask, trg_mask, real_flag=None): # NOTE: real_flag is dev_mode if self.hp.dev_mode: real_tts_emb = F.softsign(self.emb_real_tts(real_flag)) src = src + real_tts_emb.unsqueeze(1) if not self.frame_stacking: src, src_mask = self.cnn_encoder(src, src_mask) else: src = self.embedder(src) e_outputs, attn_enc_enc = self.encoder(src, src_mask) if self.decoder_type.lower() == 'transformer': d_output, attn_dec_dec, attn_dec_enc = self.decoder( trg, e_outputs, src_mask, trg_mask) outputs = self.out(d_output) elif self.decoder_type.lower() == 'ctc': ctc_outputs = self.out(e_outputs) outputs, attn_dec_dec, attn_dec_enc = None, None, None elif self.decoder_type.lower() == 'transducer': outputs, attn_dec_dec, attn_dec_enc = self.decoder(trg, e_outputs) else: d_output, attn_dec_dec, attn_dec_enc = self.decoder( trg, e_outputs, src_mask, trg_mask) outputs = d_output if self.use_ctc: ctc_outputs = self.out_ctc(e_outputs) else: ctc_outputs = None return outputs, ctc_outputs, attn_enc_enc, attn_dec_dec, attn_dec_enc
def forward(self, x): x = F.leaky_relu(self.fc1(x)) x = F.leaky_relu(self.fc2(x)) mu = F.softsign(self.mu_head(x)) sigma = self.action_std_init * torch.sigmoid(self.sigma_head(x)) return mu, sigma
def forward(self, psi_x, psi_y): # reshape to broadcast in matmul x, y = psi_x.view(self.n_par, self.dim, 1), psi_y.view(self.n_par, self.dim, 1) alpha = torch.zeros(self.n_par, 1, 2 * self.nsite, device=device) loss = torch.zeros(self.n_par, device=device) fidelity_store = torch.zeros(self.n_steps, self.n_par, device=device) last_action_store = torch.zeros(2, self.n_steps, self.n_par, self.nsite, device=device) for j in range(self.n_steps): input = torch.cat((x, y), 1).transpose(1, 2) dalpha1 = self.net_state(input) dalpha2 = self.net_action(alpha / self.force_mag) #+ alpha/self.force_mag dalpha = self.net_combine(dalpha1 + dalpha2) alpha = self.force_mag * F.softsign(dalpha) alpha = torch.clamp(alpha, min=-self.force_mag, max=self.force_mag) alphax = alpha[:, 0, :self.nsite] # Dimension (batchsize, nsite) alphay = alpha[:, 0, self.nsite:] # Dimension (batchsize, nsite) for _ in range(self.n_substeps): # H_Re and H_Im have dimensions (n_par, dim, dim) H_Re = self.H_0_dt + generate_drive(alphax, self.H_1_dt) H_Im = generate_drive(alphay, self.H_2_dt) x, y = self.Heun_complex(x, y, H_Re, H_Im) fidelity = (torch.matmul(self.target_x, x)**2 + torch.matmul(self.target_x, y)**2).squeeze() loss += self.C1 * self.gamma**j * (1 - fidelity ) # add state infidelity # punish large actions abs_alpha = torch.mean(alpha**2, dim=2).squeeze() #print(abs_alpha.shape) loss += self.C2 * abs_alpha # feed storage fidelity_store[j] = fidelity last_action_store[0, j] = alphax last_action_store[1, j] = alphay psi_x, psi_y = x.view(self.n_par, self.dim), y.view(self.n_par, self.dim) loss += self.C3 * (1 - fidelity_store[-1]) loss = loss.mean() #/self.n_steps return psi_x, psi_y, loss, fidelity_store, last_action_store
def forward(self, inputs): x, (hx, cx) = inputs x = self.fc(x) x = x.view(1, 128) hx, cx = self.lstm(x, (hx, cx)) x = hx return self.critic_linear(x), F.softsign( self.actor_linear(x)), self.actor_linear2(x), (hx, cx)
def forward(self, x, input_lengths, spk_embeds): for conv in self.convolutions: x = F.dropout(F.relu(conv(x)), 0.5, self.training) x = x.transpose(1, 2) # (B, max_len, D) # pass spk_embeds to dense net and expand spk_embeds to 3 dims # before_lstm = F.softsign(self.dense_spk_embeds( # spk_embeds)) # (B, 16) -> (B,512) # before_lstm = before_lstm.unsqueeze(1) # (B, 512) -> (B, 1, 512) # before_lstm = before_lstm.repeat( # 1, x.size(1), 1) # (B, max_len, 512) # # # add after conv input and spk_embeds # x = before_lstm + x # (B, max_len, 512) # pytorch tensor are not reversible, hence the conversion input_lengths = input_lengths.cpu().numpy() x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True) self.lstm.flatten_parameters() # prepare initial state for lstm cell encoder_init_state = F.softsign( self.dense_init_lstm(spk_embeds)) # (B, 512*2) encoder_init_state = encoder_init_state.view( -1, encoder_init_state.size(1) // 4, 2, 2) # (B, 256, 2, 2) encoder_init_state = encoder_init_state.permute(3, 2, 0, 1) # (2,2, B, 256) outputs, _ = self.lstm(x, (encoder_init_state[0].contiguous(), encoder_init_state[1].contiguous())) outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) after_lstm = F.softsign( self.dense_spk_embeds(spk_embeds)) # (B, 16) -> (B,512) after_lstm = after_lstm.unsqueeze(1) # (B, 512) -> (B, 1, 512) after_lstm = after_lstm.repeat(1, outputs.size(1), 1) # (B, max_len, 512) # add after conv input and spk_embeds outputs = after_lstm + outputs # (B, max_len, 512) return outputs
def forward(self, text_sequences, text_positions=None, lengths=None, speaker_embed=None): assert self.n_speakers == 1 or speaker_embed is not None # embed text_sequences x = self.embed_tokens(text_sequences.long()) x = F.dropout(x, p=self.dropout, training=self.training) # expand speaker embedding for all time steps speaker_embed_btc = expand_speaker_embed(x, speaker_embed) if speaker_embed_btc is not None: speaker_embed_btc = F.dropout(speaker_embed_btc, p=self.dropout, training=self.training) x = x + F.softsign(self.speaker_fc1(speaker_embed_btc)) input_embedding = x # B x T x C -> B x C x T x = x.transpose(1, 2) # 1D conv blocks for f in self.convolutions: x = f(x, speaker_embed_btc) if isinstance(f, Conv1dGLU) else f(x) # Back to B x T x C keys = x.transpose(1, 2) if speaker_embed_btc is not None: keys = keys + F.softsign(self.speaker_fc2(speaker_embed_btc)) # scale gradients (this only affects backward, not forward) if self.apply_grad_scaling and self.num_attention_layers is not None: keys = GradMultiply.apply(keys, 1.0 / (2.0 * self.num_attention_layers)) # add output to input embedding for attention values = (keys + input_embedding) * math.sqrt(0.5) return keys, values
def forward(self, feature): value = self.critic_linear(feature) if 'discrete' in self.head_name: mu = self.actor_linear(feature) else: mu = F.softsign(self.actor_linear(feature)) sigma = self.actor_linear2(feature) return value, mu, sigma
def forward(self, x, test=False): if self.continuous: mu = F.softsign(self.actor_linear(x)) sigma = self.actor_linear2(x) else: mu = self.actor_linear(x) sigma = 0 action, entropy, log_prob = sample_action(self.head_name, mu, sigma, test) return action, entropy, log_prob
def forward(self, x): x = F.relu(getattr(self, self.name + "_l1")(x)) x = F.relu(getattr(self, self.name + "_l2")(x)) mu = F.softsign(getattr(self, self.name + "_l3_mu")(x)) mu = torch.clamp(mu, -self.max_action, +self.max_action) std = getattr(self, self.name + "_l3_std")(x) std = F.softplus(std) + 1e-5 return mu, std
def forward(self, x, test=False): if self.continuous: mu = F.softsign(self.actor_linear(x)) sigma = self.actor_linear2(x) else: mu = self.actor_linear(x) sigma = torch.ones_like(mu) action, entropy, log_prob = sample_action(self.continuous, mu, sigma, self.device, test) return action, entropy, log_prob
def forward(self, x): x = self.conv(x) x = torch.flatten(x, start_dim=1) x = x.reshape(16, 8192, 1) x = torch.transpose(x, 0, 2) x = torch.transpose(x, 1, 2) out, (h_n, h_c) = self.lstm(x, None) out = F.softsign(out) out = out[:, -1, :] out = self.fc(out) return out
def forward(self, inputs): x, (hx, cx) = inputs x = self.lrelu1(self.fc1(x)) x = self.lrelu2(self.fc2(x)) x = self.lrelu3(self.fc3(x)) x = self.lrelu4(self.fc4(x)) x = x.view(1, self.m1) return self.critic_linear(x), F.softsign(self.actor_linear(x)), self.actor_linear2(x), (hx, cx)
def forward(self, text_sequences, speaker_embed=None): """Forward pass """ assert self.n_speakers == 1 or speaker_embed is not None # Text embedding x = self.text_embed(text_sequences.long()) x = F.dropout(x, p=self.dropout, training=self.training) # expand speaker embedding for all time steps speaker_embed_btc = expand_speaker_embed(x, speaker_embed) if speaker_embed_btc is not None: speaker_embed_btc = F.dropout(speaker_embed_btc, p=self.dropout, training=self.training) x = x + F.softsign(self.speaker_fc1(speaker_embed_btc)) input_embedding = x # [B, T_max, channels] -> [B, channels, T_max] x = x.transpose(1, 2).contiguous() # 1D conv blocks for f in self.convolutions: x = f(x, speaker_embed_btc) if isinstance(f, Conv1DGLU) else f(x) # [B, channels, T_max] -> [B, T_max, channels] keys = x.transpose(1, 2).contiguous() if speaker_embed_btc is not None: keys = keys + F.softsign(self.speaker_fc2(speaker_embed_btc)) # scale gradients (this only affects backward, not forward) if self.apply_grad_scaling and self.num_attention_layers is not None: keys = GradMultiply.apply(keys, 1.0 / (2.0 * self.num_attention_layers)) # add output to input embedding for attention values = (keys + input_embedding) * math.sqrt(0.5) return keys, values
def forward(self, x1): x1 = F.relu(self.bn1(self.conv1_2(F.relu(self.conv1_1(x1))))) # print('x1 size: %d'%(x1.size(2))) x2 = F.relu(self.bn2(self.conv2_2(F.relu(self.conv2_1(self.maxpool(x1)))))) # print('x2 size: %d'%(x2.size(2))) x3 = F.relu(self.bn3(self.conv3_2(F.relu(self.conv3_1(self.maxpool(x2)))))) # print('x3 size: %d'%(x3.size(2))) x4 = F.relu(self.bn4(self.conv4_2(F.relu(self.conv4_1(self.maxpool(x3)))))) # print('x4 size: %d'%(x4.size(2))) xup = F.relu(self.conv5_2(F.relu(self.conv5_1(self.maxpool(x4))))) # x5 # print('x5 size: %d'%(xup.size(2))) xup = self.bn5(self.upconv5(self.upsample(xup))) # x6in cropidx = (x4.size(2) - xup.size(2)) // 2 x4 = x4[:, :, cropidx:cropidx + xup.size(2), cropidx:cropidx + xup.size(2)] # print('crop1 size: %d, x9 size: %d'%(x4crop.size(2),xup.size(2))) xup = self.bn5_out(torch.cat((x4, xup), 1)) # x6 cat x4 xup = F.relu(self.conv6_2(F.relu(self.conv6_1(xup)))) # x6out xup = self.bn6(self.upconv6(self.upsample(xup))) # x7in cropidx = (x3.size(2) - xup.size(2)) // 2 x3 = x3[:, :, cropidx:cropidx + xup.size(2), cropidx:cropidx + xup.size(2)] # print('crop1 size: %d, x9 size: %d'%(x3crop.size(2),xup.size(2))) xup = self.bn6_out(torch.cat((x3, xup), 1) ) # x7 cat x3 xup = F.relu(self.conv7_2(F.relu(self.conv7_1(xup)))) # x7out xup = self.bn7(self.upconv7(self.upsample(xup)) ) # x8in cropidx = (x2.size(2) - xup.size(2)) // 2 x2 = x2[:, :, cropidx:cropidx + xup.size(2), cropidx:cropidx + xup.size(2)] # print('crop1 size: %d, x9 size: %d'%(x2crop.size(2),xup.size(2))) xup = self.bn7_out(torch.cat((x2, xup), 1)) # x8 cat x2 xup = F.relu(self.conv8_2(F.relu(self.conv8_1(xup)))) # x8out xup = self.bn8(self.upconv8(self.upsample(xup)) ) # x9in cropidx = (x1.size(2) - xup.size(2)) // 2 x1 = x1[:, :, cropidx:cropidx + xup.size(2), cropidx:cropidx + xup.size(2)] # print('crop1 size: %d, x9 size: %d'%(x1crop.size(2),xup.size(2))) xup = self.bn8_out(torch.cat((x1, xup), 1)) # x9 cat x1 xup = F.relu(self.conv9_3(F.relu(self.conv9_2(F.relu(self.conv9_1(xup)))))) # x9out return F.softsign(self.bn9(xup))
def fit(self, model, feature_extraction, protocol, log_dir, subset='train', epochs=1000, restart=0, gpu=False): import tensorboardX writer = tensorboardX.SummaryWriter(log_dir=log_dir) checkpoint = Checkpoint(log_dir=log_dir, restart=restart > 0) batch_generator = SpeechSegmentGenerator( feature_extraction, per_label=self.per_label, per_fold=self.per_fold, duration=self.duration, parallel=self.parallel) batches = batch_generator(protocol, subset=subset) batch = next(batches) batches_per_epoch = batch_generator.batches_per_epoch if restart > 0: weights_pt = checkpoint.WEIGHTS_PT.format( log_dir=log_dir, epoch=restart) model.load_state_dict(torch.load(weights_pt)) if gpu: model = model.cuda() model.internal = False parameters = list(model.parameters()) if self.variant in [2, 3, 4, 5, 6, 7, 8]: # norm batch-normalization self.norm_bn = nn.BatchNorm1d( 1, eps=1e-5, momentum=0.1, affine=True) if gpu: self.norm_bn = self.norm_bn.cuda() parameters += list(self.norm_bn.parameters()) if self.variant in [9]: # norm batch-normalization self.norm_bn = nn.BatchNorm1d( 1, eps=1e-5, momentum=0.1, affine=False) if gpu: self.norm_bn = self.norm_bn.cuda() parameters += list(self.norm_bn.parameters()) if self.variant in [5, 6, 7]: self.positive_bn = nn.BatchNorm1d( 1, eps=1e-5, momentum=0.1, affine=False) self.negative_bn = nn.BatchNorm1d( 1, eps=1e-5, momentum=0.1, affine=False) if gpu: self.positive_bn = self.positive_bn.cuda() self.negative_bn = self.negative_bn.cuda() parameters += list(self.positive_bn.parameters()) parameters += list(self.negative_bn.parameters()) if self.variant in [8, 9]: self.delta_bn = nn.BatchNorm1d( 1, eps=1e-5, momentum=0.1, affine=False) if gpu: self.delta_bn = self.delta_bn.cuda() parameters += list(self.delta_bn.parameters()) optimizer = Adam(parameters) if restart > 0: optimizer_pt = checkpoint.OPTIMIZER_PT.format( log_dir=log_dir, epoch=restart) optimizer.load_state_dict(torch.load(optimizer_pt)) if gpu: for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.cuda() epoch = restart if restart > 0 else -1 while True: epoch += 1 if epoch > epochs: break loss_avg, tloss_avg, closs_avg = 0., 0., 0. if epoch % 5 == 0: log_positive = [] log_negative = [] log_delta = [] log_norm = [] desc = 'Epoch #{0}'.format(epoch) for i in tqdm(range(batches_per_epoch), desc=desc): model.zero_grad() batch = next(batches) X = batch['X'] if not getattr(model, 'batch_first', True): X = np.rollaxis(X, 0, 2) X = np.array(X, dtype=np.float32) X = Variable(torch.from_numpy(X)) if gpu: X = X.cuda() fX = model(X) # pre-compute pairwise distances distances = self.pdist(fX) # sample triplets triplets = getattr(self, 'batch_{0}'.format(self.sampling)) anchors, positives, negatives = triplets(batch['y'], distances) # compute triplet loss tlosses, deltas, pos_index, neg_index = self.triplet_loss( distances, anchors, positives, negatives, return_delta=True) tloss = torch.mean(tlosses) if self.variant == 1: closses = F.sigmoid( F.softsign(deltas) * torch.norm(fX[anchors], 2, 1, keepdim=True)) # if d(a, p) < d(a, n) (i.e. good case) # --> sign(delta) < 0 # --> loss decreases when norm increases. # i.e. encourages longer anchor # if d(a, p) > d(a, n) (i.e. bad case) # --> sign(delta) > 0 # --> loss increases when norm increases # i.e. encourages shorter anchor elif self.variant == 2: norms_ = torch.norm(fX, 2, 1, keepdim=True) norms_ = F.sigmoid(self.norm_bn(norms_)) confidence = (norms_[anchors] + norms_[positives] + norms_[negatives]) / 3 # if |x| is average # --> normalized |x| = 0 # --> confidence = 0.5 # if |x| is bigger than average # --> normalized |x| >> 0 # --> confidence = 1 # if |x| is smaller than average # --> normalized |x| << 0 # --> confidence = 0 correctness = F.sigmoid(-deltas / np.pi * 6) # if d(a, p) = d(a, n) (i.e. uncertain case) # --> correctness = 0.5 # if d(a, p) - d(a, n) = -𝛑 (i.e. best possible case) # --> correctness = 1 # if d(a, p) - d(a, n) = +𝛑 (i.e. worst possible case) # --> correctness = 0 closses = torch.abs(confidence - correctness) # small if (and only if) confidence & correctness agree elif self.variant == 3: norms_ = torch.norm(fX, 2, 1, keepdim=True) norms_ = F.sigmoid(self.norm_bn(norms_)) confidence = (norms_[anchors] * norms_[positives] * norms_[negatives]) / 3 correctness = F.sigmoid(-(deltas + np.pi / 4) / np.pi * 6) # correctness = 0.5 at delta == -pi/4 # correctness = 1 for delta == -pi # correctness = 0 for delta < 0 closses = torch.abs(confidence - correctness) elif self.variant == 4: norms_ = torch.norm(fX, 2, 1, keepdim=True) norms_ = F.sigmoid(self.norm_bn(norms_)) confidence = (norms_[anchors] * norms_[positives] * norms_[negatives]) ** 1/3 correctness = F.sigmoid(-(deltas + np.pi / 4) / np.pi * 6) # correctness = 0.5 at delta == -pi/4 # correctness = 1 for delta == -pi # correctness = 0 for delta < 0 # delta = pos - neg ... should be < 0 closses = torch.abs(confidence - correctness) elif self.variant == 5: norms_ = torch.norm(fX, 2, 1, keepdim=True) confidence = F.sigmoid(self.norm_bn(norms_)) confidence_pos = .5 * (confidence[anchors] + confidence[positives]) # low positive distance == high correctness correctness_pos = F.sigmoid( -self.positive_bn(distances[pos_index].view(-1, 1))) confidence_neg = .5 * (confidence[anchors] + confidence[negatives]) # high negative distance == high correctness correctness_neg = F.sigmoid( self.negative_bn(distances[neg_index].view(-1, 1))) closses = .5 * (torch.abs(confidence_pos - correctness_pos) \ + torch.abs(confidence_neg - correctness_neg)) elif self.variant == 6: norms_ = torch.norm(fX, 2, 1, keepdim=True) confidence = F.sigmoid(self.norm_bn(norms_)) confidence_pos = .5 * (confidence[anchors] + confidence[positives]) # low positive distance == high correctness correctness_pos = F.sigmoid( -self.positive_bn(distances[pos_index].view(-1, 1))) closses = torch.abs(confidence_pos - correctness_pos) elif self.variant == 7: norms_ = torch.norm(fX, 2, 1, keepdim=True) confidence = F.sigmoid(self.norm_bn(norms_)) confidence_neg = .5 * (confidence[anchors] + confidence[negatives]) # high negative distance == high correctness correctness_neg = F.sigmoid( self.negative_bn(distances[neg_index].view(-1, 1))) closses = torch.abs(confidence_neg - correctness_neg) elif self.variant in [8, 9]: norms_ = torch.norm(fX, 2, 1, keepdim=True) norms_ = F.sigmoid(self.norm_bn(norms_)) confidence = (norms_[anchors] * norms_[positives] * norms_[negatives]) / 3 correctness = F.sigmoid(-self.delta_bn(deltas)) closses = torch.abs(confidence - correctness) closs = torch.mean(closses) if epoch % 5 == 0: if gpu: fX_npy = fX.data.cpu().numpy() pdist_npy = distances.data.cpu().numpy() delta_npy = deltas.data.cpu().numpy() else: fX_npy = fX.data.numpy() pdist_npy = distances.data.numpy() delta_npy = deltas.data.numpy() log_norm.append(np.linalg.norm(fX_npy, axis=1)) same_speaker = pdist(batch['y'].reshape((-1, 1)), metric='chebyshev') < 1 log_positive.append(pdist_npy[np.where(same_speaker)]) log_negative.append(pdist_npy[np.where(~same_speaker)]) log_delta.append(delta_npy) # log loss if gpu: tloss_ = float(tloss.data.cpu().numpy()) closs_ = float(closs.data.cpu().numpy()) else: tloss_ = float(tloss.data.numpy()) closs_ = float(closs.data.numpy()) tloss_avg += tloss_ closs_avg += closs_ loss_avg += tloss_ + closs_ loss = tloss + closs loss.backward() optimizer.step() tloss_avg /= batches_per_epoch writer.add_scalar('tloss', tloss_avg, global_step=epoch) closs_avg /= batches_per_epoch writer.add_scalar('closs', closs_avg, global_step=epoch) loss_avg /= batches_per_epoch writer.add_scalar('loss', loss_avg, global_step=epoch) if epoch % 5 == 0: log_positive = np.hstack(log_positive) writer.add_histogram( 'embedding/pairwise_distance/positive', log_positive, global_step=epoch, bins=np.linspace(0, np.pi, 50)) log_negative = np.hstack(log_negative) writer.add_histogram( 'embedding/pairwise_distance/negative', log_negative, global_step=epoch, bins=np.linspace(0, np.pi, 50)) _, _, _, eer = det_curve( np.hstack([np.ones(len(log_positive)), np.zeros(len(log_negative))]), np.hstack([log_positive, log_negative]), distances=True) writer.add_scalar('eer', eer, global_step=epoch) log_norm = np.hstack(log_norm) writer.add_histogram( 'norm', log_norm, global_step=epoch, bins='doane') log_delta = np.vstack(log_delta) writer.add_histogram( 'delta', log_delta, global_step=epoch, bins='doane') checkpoint.on_epoch_end(epoch, model, optimizer) if hasattr(self, 'norm_bn'): confidence_pt = self.CONFIDENCE_PT.format( log_dir=log_dir, epoch=epoch) torch.save(self.norm_bn.state_dict(), confidence_pt)