Beispiel #1
0
def batch_select(obj, idx):
    if isinstance(obj, torch.autograd.Variable):
        res = obj.index_select(0,
                               Variable(torchauto(obj.data).LongTensor(idx)))
    elif isinstance(obj, torch.tensor._TensorBase):
        res = obj.index_select(0, torchauto(obj).LongTensor(idx))
    elif isinstance(obj, list):
        res = [obj[ii] for ii in idx]
    else:
        raise ValueError("obj type is not supported")
    return res
    pass
    def fn_batch_ce(feat_mat, feat_len, speaker_list, train_step=True):
        feat_mat = Variable(feat_mat)
        feat_mask = Variable(
            generate_seq_mask([x for x in feat_len], opts['gpu']))
        speaker_list_id = [map_spk2id[x] for x in speaker_list]
        speaker_list_id = Variable(
            torchauto(model).LongTensor(speaker_list_id))

        model.reset()
        model.train(train_step)
        batch, dec_len, _ = feat_mat.size()

        pred_emb = model(feat_mat, feat_len)
        pred_softmax = model.forward_softmax(pred_emb)

        loss = criterion_ce(pred_softmax, speaker_list_id) * opts['coeff_ce']
        acc = torch.max(pred_softmax, 1)[1].data.eq(
            speaker_list_id.data).sum() / batch

        if train_step:
            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(model.parameters(),
                                          opts['grad_clip'])
            opt.step()
        return loss.data.sum(), acc
    def fn_batch(feat_mat, feat_len, text_mat, text_len, train_step=True):
        # TODO
        feat_mat = Variable(feat_mat)
        text_input = Variable(text_mat[:, 0:-1])
        text_output = Variable(text_mat[:, 1:])
        model.reset()
        model.train(train_step)
        model.encode(feat_mat, feat_len)
        batch, dec_len = text_input.size()
        list_pre_softmax = []
        for ii in range(dec_len):
            _pre_softmax_ii, _ = model.decode(text_input[:, ii])
            list_pre_softmax.append(_pre_softmax_ii)
            pass
        pre_softmax = torch.stack(list_pre_softmax, 1)
        denominator = Variable(torchauto(model).FloatTensor(text_len) - 1)
        # average loss based on individual length #
        loss = criterion(pre_softmax.contiguous().view(batch * dec_len, -1),
                         text_output.contiguous().view(batch * dec_len)).view(
                             batch, dec_len).sum(dim=1) / denominator
        loss = loss.mean()

        acc = torch.max(
            pre_softmax, 2)[1].data.eq(text_output.data).masked_select(
                text_output.ne(constant.PAD).data).sum() / denominator.sum()
        if train_step:
            model.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(model.parameters(),
                                          opts['grad_clip'])
            opt.step()

        return loss.data.sum(), acc.data.sum()
Beispiel #4
0
    def encode(self, input, src_len=None):
        """
        input : (batch x max_src_len x in_size)
        mask : (batch x max_src_len)
        """
        batch, max_src_len, in_size = input.size()

        # apply raw -> feat #
        res = self.encode_raw(input, src_len)

        for ii in range(len(self.enc_rnn)):
            res = pack(res, src_len, batch_first=True)
            res = self.enc_rnn[ii](res)[0]  # get h only #
            res, _ = unpack(res, batch_first=True)
            res = F.dropout(res, self.enc_rnn_do[ii], self.training)
            if self.downsampling[ii] == True:
                res = res[:, 1::2]
                src_len = [x // 2 for x in src_len]
                pass
        ctx = res
        # create mask if required #
        if src_len is not None:
            ctx_mask = torchauto(self).FloatTensor(batch, ctx.size(1)).zero_()
            for ii in range(batch):
                ctx_mask[ii, 0:src_len[ii]] = 1.0
            ctx_mask = Variable(ctx_mask)
        else:
            ctx_mask = None
        self.dec.set_ctx(ctx, ctx_mask)
Beispiel #5
0
    def forward(self, input, seq_len=None):
        batch, max_seq_len, _ = input.size()
        if seq_len is None:
            seq_len = [max_seq_len] * batch
        res = input
        res = pack(res, seq_len, batch_first=True)
        res = self.layers(res)[0]  # get h only #
        res, _ = unpack(res, batch_first=True)

        seq_len_var = Variable(
            torchauto(self).FloatTensor(seq_len).unsqueeze(1).expand(
                batch, res.size(2)))

        if self.summary == 'mean':
            res = torch.sum(res, 1).squeeze(1) / seq_len_var
        elif self.summary == 'last':
            _res = []
            for ii in range(batch):
                if self.layers.bidirectional:
                    _last_fwd = res[ii, seq_len[ii] - 1, 0:self.hid_size]
                    _last_bwd = res[ii, 0, self.hid_size * 2:]
                    _res.append(torch.cat([_last_fwd, _last_bwd]))
                else:
                    _res.append(res[ii, seq_len[ii] - 1, 0:self.hid_size])

            res = torch.stack(_res)
        return self.regression(res)
Beispiel #6
0
 def get_speaker_emb(self, speaker_list):
     speaker_id_list = [self.map_spk2id[x] for x in speaker_list]
     speaker_list_var = Variable(
         torchauto(self).LongTensor(speaker_id_list))
     # get embedding for each speaker #
     speaker_emb_var = self.spk_module_lyr.emb_lyr(speaker_list_var)
     return speaker_emb_var
 def tts_loss_spk_emb(feat_mat,
                      feat_len,
                      target_emb,
                      size_average=True,
                      loss_cfg=opts['tts_loss_spk_emb']):
     assert isinstance(
         feat_mat,
         Variable), "feat must be variable generate from TTS model"
     if loss_cfg is None:
         return Variable(torchauto(opts['gpu']).FloatTensor([0]))
     assert loss_cfg['type'] in ['huber', 'L1', 'L2', 'cosine']
     model_spkrec.reset()
     model_spkrec.eval(
     )  # set eval mode, no gradient update for model_spkrec #
     pred_emb = model_spkrec(feat_mat, feat_len)
     if loss_cfg['type'] == 'huber':
         loss = nn.SmoothL1Loss(size_average=size_average)(pred_emb,
                                                           target_emb)
     elif loss_cfg['type'] == 'L2':
         loss = nn.MSELoss(size_average=size_average)(pred_emb, target_emb)
     elif loss_cfg['type'] == 'L1':
         loss = nn.L1Loss(size_average=size_average)(pred_emb, target_emb)
     elif loss_cfg['type'] == 'cosine':
         loss = (1 - nn.CosineSimilarity()(pred_emb, target_emb)).sum() / (
             feat_mat.shape[0] if size_average else 1)
     else:
         raise NotImplementedError()
     return loss_cfg['coeff'] * loss
     pass
Beispiel #8
0
 def combine_state_and_input(device, beams, prev_state):
     state = vars_index_select(prev_state, 0,
                               Variable(
                                   torchauto(device).LongTensor([
                                       x.state[-1] for x in beams
                                   ])))  # take latest state_idx #
     input = torch.cat([x.output[-1] for x in beams])
     return state, input
     pass
 def criterion_gan(input, target, mask, size_average=True) :
     # convert to 2d #
     batch, seq_len = input.size()
     input_1d = input.contiguous().view(-1)
     # TODO : use mask ? #
     if opts['gan_type'] == 'none' :
         return None
     elif opts['gan_type'] == 'gan' :
         # normal gan loss #
         target_1d = Variable(torchauto(opts['gpu']).FloatTensor(input_1d.size()).fill_(target))
         return F.binary_cross_entropy_with_logits(input_1d, target_1d)
     elif opts['gan_type'] == 'wgan' :
         return torch.mean(input_1d * (-1 if target else 1))
     elif opts['gan_type'] == 'lsgan' :
         target_1d = Variable(torchauto(opts['gpu']).FloatTensor(input_1d.size()).fill_(target))
         return torch.mean((input_1d - target_1d)**2)
     else :
         raise NotImplementedError
     pass
    def encode(self, input, src_len=None):
        """
        input : (batch x max_src_len x in_size)
        mask : (batch x max_src_len)
        """
        batch, max_src_len, in_size = input.size()

        if src_len is None:
            src_len = [max_src_len] * batch
        res = input.view(batch * max_src_len, in_size)
        enc_fnn_act = getattr(F, self.enc_fnn_act)
        for ii in range(len(self.enc_fnn)):
            res = F.dropout(enc_fnn_act(self.enc_fnn[ii](res)),
                            self.enc_fnn_do[ii], self.training)
            pass
        # res = batch * max_src_len x ndim #
        res = res.view(batch, max_src_len,
                       res.size(1)).transpose(1, 2).unsqueeze(3)
        # res = batch x ndim x src_len x 1 #
        enc_cnn_act = getattr(F, self.enc_cnn_act)
        for ii in range(len(self.enc_cnn)):
            if self.use_pad1[ii]:
                res = F.pad(res, (0, 0, 0, 1))
            res = self.enc_cnn[ii](res)
            res = enc_cnn_act(res)
            src_len = [x // self.enc_cnn_strides[ii] for x in src_len]
            pass
        res = res.squeeze(3).transpose(1, 2)  # batch x src_len x ndim #
        # add position embedding #
        _pos_arr = np.arange(0, res.size(1)).astype('float32')  # src_len #
        _pos_arr = np.repeat(_pos_arr[np.newaxis, :], batch,
                             0)  # batch x src_len #
        _pos_arr /= np.array(
            src_len)[:, np.newaxis]  # divide for relative position #
        _pos_arr = tensorauto(self, torch.from_numpy(_pos_arr))
        _pos_var = Variable(_pos_arr.view(batch * _pos_arr.size(1), 1))
        # TODO : absolute or relative position #
        res_pos = self.pos_emb(_pos_var)
        res_pos = res_pos.view(batch, _pos_arr.size(1), -1)
        ctx = res + res_pos  # TODO : sum or concat ? #
        # create mask if required #
        if src_len is not None:
            ctx_mask = torchauto(self).FloatTensor(batch, ctx.size(1)).zero_()
            for ii in range(batch):
                ctx_mask[ii, 0:src_len[ii]] = 1.0
            ctx_mask = Variable(ctx_mask)
        else:
            ctx_mask = None
        self.dec.set_ctx(ctx, ctx_mask)
    def fn_batch_asr(model,
                     feat_mat,
                     feat_len,
                     text_mat,
                     text_len,
                     train_step=True,
                     coeff_loss=1):
        # refit data #
        if max(feat_len) != feat_mat.shape[1]:
            feat_mat = feat_mat[:, 0:max(feat_len)]
        if max(text_len) != text_mat.shape[1]:
            text_mat = text_mat[:, 0:max(text_len)]

        if not isinstance(text_mat, Variable):
            text_mat = Variable(text_mat)
        if not isinstance(feat_mat, Variable):
            feat_mat = Variable(feat_mat)
        text_input = text_mat[:, 0:-1]
        text_output = text_mat[:, 1:]
        model.reset()
        model.train(train_step)
        model.encode(feat_mat, feat_len)
        batch, dec_len = text_input.size()
        list_pre_softmax = []
        for ii in range(dec_len):
            _pre_softmax_ii, _ = model.decode(text_input[:, ii])
            list_pre_softmax.append(_pre_softmax_ii)
            pass
        pre_softmax = torch.stack(list_pre_softmax, 1)
        denominator = Variable(torchauto(model).FloatTensor(text_len) - 1)
        # average loss based on individual length #
        loss = asr_loss_ce(
            pre_softmax.contiguous().view(batch * dec_len, -1),
            text_output.contiguous().view(batch * dec_len)).view(
                batch, dec_len).sum(dim=1) / denominator
        loss = loss.mean() * coeff_loss

        acc = torch.max(
            pre_softmax, 2)[1].data.eq(text_output.data).masked_select(
                text_output.ne(constant.PAD).data).sum() / denominator.sum()
        if train_step == True:
            model.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(model.parameters(),
                                          opts['asr_grad_clip'])
            asr_opt.step()
        return loss.data.sum(), acc.data.sum()
 def tts_loss_freq(input,
                   target,
                   mask,
                   size_average=True,
                   loss_cfg=opts['tts_loss_freq_cfg']):
     """
     aux loss for prioritize optimizing loss on lower frequency
     """
     if loss_cfg is None:
         return Variable(torchauto(opts['gpu']).FloatTensor([0]))
     assert 0 < loss_cfg['topn'] <= 1
     assert loss_cfg['coeff'] > 0
     ndim = int(input.size(-1) * loss_cfg['topn'])
     loss = tts_loss(input[:, :, 0:ndim],
                     target[:, :, 0:ndim],
                     mask,
                     size_average=True)
     return loss_cfg['coeff'] * loss
Beispiel #13
0
def greedy_search_torch(model, src_mat, src_len, map_text2idx, max_target):
    if not isinstance(src_mat, Variable):
        src_mat = Variable(src_mat)

    batch = src_mat.size(0)
    model.eval()
    model.reset()
    model.encode(src_mat, src_len)

    prev_label = Variable(
        torchauto(model).LongTensor(
            [map_text2idx[constant.BOS_WORD] for _ in range(batch)]))
    transcription = []
    transcription_len = [-1 for _ in range(batch)]
    att_mat = []
    for tt in range(max_target):
        pre_softmax, dec_res = model.decode(prev_label)
        max_id = pre_softmax.max(1)[1]
        transcription.append(max_id)
        att_mat.append(dec_res['att_output']['p_ctx'])
        for ii in range(batch):
            if transcription_len[ii] == -1:
                if max_id.data[ii] == map_text2idx[constant.EOS_WORD]:
                    transcription_len[ii] = tt + 1
        if all([ii != -1 for ii in transcription_len]):
            # finish #
            break
        prev_label = max_id
        pass

    # concat across all timestep #
    transcription = torch.stack(transcription, 1)  # batch x seq_len #
    att_mat = torch.stack(att_mat, 1)  # batch x seq_len x enc_len #

    return transcription, transcription_len, att_mat
    pass
Beispiel #14
0
def beam_search_torch(model,
                      src_mat,
                      src_len,
                      map_text2idx,
                      max_target,
                      kbeam=5,
                      coeff_lp=1.0,
                      nbest=None):
    """
    http://opennmt.net/OpenNMT/translation/beam_search/

    nbest : 
        if None -> took only top-1
            return transcription (batch x words) (list of words)
        if n >= 1 -> took top-n
            return transcription (batch x nbest x words) (list of list of words)
    """
    if not isinstance(src_mat, Variable):
        src_mat = Variable(src_mat)
    batch, max_src_len = src_mat.size()[0:2]
    # run encoder to get context and mask #
    model.eval()
    model.reset()
    model.encode(src_mat, src_len)
    enc_ctx = model.dec_att_lyr.ctx
    enc_ctx_mask = model.dec_att_lyr.ctx_mask
    # expand encoder & encoder mask for each beam #
    max_src_len = enc_ctx.size(1)

    pool_beams = []
    curr_input = Variable(
        torchauto(model).LongTensor([constant.BOS for _ in range(batch)]))

    # start by adding 1 beam per batch #
    for bb in range(batch):
        if nbest is None:
            pool_beam = PoolBeam(topk=kbeam, nbest=1)
        else:
            pool_beam = PoolBeam(topk=kbeam, nbest=nbest)

        pool_beam.add_beam(
            Beam([],
                 output=[],
                 log_prob=Variable(torchauto(model).FloatTensor([0])),
                 score=Variable(torchauto(model).FloatTensor([0])),
                 coeff_lp=coeff_lp))
        pool_beams.append(pool_beam)

    active_batch = [1 for x in range(batch)]

    def local2global_encoder(active_batch):
        """
        function to convert number of active beam beam for each batch into list of encoder side index
        """
        return sum([[bb for _ in range(active_batch[bb])]
                    for bb in range(batch)], [])

    global_state_index = Variable(
        torchauto(model).LongTensor(local2global_encoder(active_batch)))
    list_pre_softmax = []
    list_att_mat = []
    for tt in range(max_target):
        pre_softmax, dec_output = model.decode(curr_input)
        list_pre_softmax.append(pre_softmax)
        list_att_mat.append(dec_output['att_output']['p_ctx'])
        log_prob = F.log_softmax(pre_softmax, dim=-1)
        new_beams = []
        start, end = [], []
        for bb in range(batch):
            start.append(0 if bb == 0 else end[bb - 1])
            end.append(start[-1] + active_batch[bb])
        for bb in range(batch):
            if pool_beams[bb].is_finished() or active_batch[bb] == 0:
                continue
            # distribute softmax calculation to each beam #
            log_prob_bb = log_prob[start[bb]:end[bb]]
            pool_beams[bb].step(list(range(start[bb], end[bb])), log_prob_bb)
            # renew active_batch and finished_batch information #
            _tmp_active_beam = pool_beams[bb].get_active_beam()
            active_batch[bb] = len(
                _tmp_active_beam) if not pool_beams[bb].is_finished() else 0
            if not pool_beams[bb].is_finished():
                # if this pool has not finish, continue put new beam #
                new_beams.extend(_tmp_active_beam)

        if all(x.is_finished() for x in pool_beams):
            # all batch finished #
            break
        # clear up and get new input + state #
        curr_state, curr_input = PoolBeam.combine_state_and_input(
            model, new_beams, model.state)

        global_state_index = Variable(
            torchauto(model).LongTensor(local2global_encoder(active_batch)))
        # model ctx & ctx mask #
        model.dec_att_lyr.ctx = enc_ctx.index_select(0, global_state_index)
        model.dec_att_lyr.ctx_mask = enc_ctx_mask.index_select(
            0, global_state_index)
        # model state #
        model.state = curr_state
        pass
    # gather all top of stack #
    if nbest is None:
        transcription = []
        transcription_len = []
        att_mat = []
        map_idx2text = dict([y, x] for (x, y) in map_text2idx.items())
        for bb in range(batch):
            best_beam = pool_beams[bb].stack[0]
            # gather transcription #
            transcription.append(torch.cat(best_beam.output))
            transcription_len.append(
                len(best_beam.output) if pool_beams[bb].is_finished() else -1)
            # gather attention matrix #
            att_mat_bb = []
            for ii in range(len(best_beam.state)):
                att_mat_bb.append(list_att_mat[ii][best_beam.state[ii]])
            att_mat_bb = torch.stack(att_mat_bb, 0)
            att_mat.append(att_mat_bb)
        att_mat = pad_sequence(att_mat, batch_first=True)
        return transcription, transcription_len, att_mat
    else:
        transcription = []
        transcription_len = []
        att_mat = []
        map_idx2text = dict([y, x] for (x, y) in map_text2idx.items())
        for bb in range(batch):
            transcription.append([])
            transcription_len.append([])
            att_mat.append([])
            for nn in range(nbest):
                best_beam = pool_beams[bb].stack[nn]
                # gather transcription #
                transcription[bb].append(torch.cat(best_beam.output))
                transcription_len[bb].append(
                    len(best_beam.output) if pool_beams[bb].is_finished(
                    ) else -1)
                # gather attention matrix #
                att_mat_bb_nn = []
                for ii in range(len(best_beam.state)):
                    att_mat_bb_nn.append(list_att_mat[ii][best_beam.state[ii]])
                att_mat_bb_nn = torch.stack(att_mat_bb_nn, 0)
                att_mat[bb].append(att_mat_bb_nn)
        return transcription, transcription_len, att_mat
    pass
    def fn_batch_tts(model,
                     text_mat,
                     text_len,
                     feat_mat,
                     feat_len,
                     aux_info=None,
                     train_step=True,
                     coeff_loss=1):
        # refit data #
        if max(feat_len) != feat_mat.shape[1]:
            feat_mat = feat_mat[:, 0:max(feat_len)]
        if max(text_len) != text_mat.shape[1]:
            text_mat = text_mat[:, 0:max(text_len)]
        batch_size = text_mat.shape[0]
        if not isinstance(text_mat, Variable):
            text_mat = Variable(text_mat)
        if not isinstance(feat_mat, Variable):
            feat_mat = Variable(feat_mat)
        feat_mat_input = feat_mat[:, 0:-1]
        feat_mat_output = feat_mat[:, 1:]

        feat_mask = Variable(
            generate_seq_mask([x - 1 for x in feat_len], opts['gpu']))

        feat_label_end = Variable(
            1. -
            generate_seq_mask([x - 1 - opts['tts_pad_sil'] for x in feat_len],
                              opts['gpu'],
                              max_len=feat_mask.size(1)))
        model.reset()
        model.train(train_step)
        model.encode(text_mat, text_len)

        # additional input condition
        if model.TYPE == TacotronType.MULTI_SPEAKER:
            aux_info['speaker_vector'] = Variable(
                tensorauto(
                    opts['gpu'],
                    torch.from_numpy(
                        np.stack(
                            aux_info['speaker_vector']).astype('float32'))))
            model.set_aux_info(aux_info)

        batch, dec_len, _ = feat_mat_input.size()
        list_dec_core = []
        list_dec_core_bernoulli_end = []
        list_dec_att = []
        for ii in range(dec_len):
            _dec_core_ii, _dec_att_ii, _dec_core_bernoulli_end = model.decode(
                feat_mat_input[:, ii],
                feat_mask[:, ii] if opts['tts_mask_dec'] else None)
            list_dec_core.append(_dec_core_ii)
            list_dec_core_bernoulli_end.append(_dec_core_bernoulli_end)
            list_dec_att.append(_dec_att_ii['att_output']['p_ctx'])
            pass

        dec_core = torch.stack(list_dec_core, 1)
        dec_core_bernoulli_end = torch.cat(list_dec_core_bernoulli_end, 1)
        dec_att = torch.stack(list_dec_att, dim=1)

        # main : loss mel spectrogram #
        loss_core = tts_loss(dec_core, feat_mat_output, feat_mask)

        # optional : aux loss for lower frequency #
        loss_core_freq = 1 * tts_loss_freq(dec_core, feat_mat_output,
                                           feat_mask)

        loss_feat = loss_core + loss_core_freq

        # optional : aux loss for speaker embedding reconstruction #
        if model_tts.TYPE == TacotronType.MULTI_SPEAKER:
            loss_spk_emb = tts_loss_spk_emb(
                dec_core.view(batch_size, -1, NDIM_FEAT),
                [x * opts['tts_group'] for x in feat_len],
                aux_info['speaker_vector'])
        else:
            loss_spk_emb = Variable(torchauto(opts['gpu']).FloatTensor([0.0]))

        # main : frame ending prediction #
        loss_core_bernoulli_end = F.binary_cross_entropy_with_logits(
            dec_core_bernoulli_end, feat_label_end) * opts['tts_coeff_bern']
        acc_core_bernoulli_end = ((dec_core_bernoulli_end > 0.0) == (
            feat_label_end > 0.5)).float().mean()

        # combine all loss #
        loss = loss_feat + loss_core_bernoulli_end + loss_spk_emb
        loss = loss * coeff_loss

        # if train_step :
        if train_step == True:
            model.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(model.parameters(),
                                          opts['tts_grad_clip'])
            tts_opt.step()

        return loss.data.sum(), loss_feat.data.sum(), loss_core_bernoulli_end.data.sum(), \
                loss_spk_emb.data.sum(), acc_core_bernoulli_end.data.sum()