def __build_model(self):
     self.transformer = Transformer(
         num_time_mask=self.hparams.num_time_mask,
         num_freq_mask=self.hparams.num_freq_mask,
         freq_mask_length=self.hparams.freq_mask_length,
         time_mask_length=self.hparams.time_mask_length,
         feature_dim=self.hparams.feature_dim,
         model_size=self.hparams.model_size,
         feed_forward_size=self.hparams.feed_forward_size,
         hidden_size=self.hparams.hidden_size,
         dropout=self.hparams.dropout,
         num_head=self.hparams.num_head,
         num_encoder_layer=self.hparams.num_encoder_layer,
         num_decoder_layer=self.hparams.num_decoder_layer,
         vocab_path=self.hparams.vocab_path,
         max_feature_length=self.hparams.max_feature_length,
         max_token_length=self.hparams.max_token_length,
         enable_spec_augment=self.hparams.enable_spec_augment,
         share_weight=self.hparams.share_weight,
         smoothing=self.hparams.smoothing,
     )
Ejemplo n.º 2
0
 def __build_model(self):
     self.transformer = Transformer(
         num_time_mask=self.hparams.num_time_mask,
         num_freq_mask=self.hparams.num_freq_mask,
         freq_mask_length=self.hparams.freq_mask_length,
         time_mask_length=self.hparams.time_mask_length,
         feature_dim=self.hparams.feature_dim,
         model_size=self.hparams.model_size,
         feed_forward_size=self.hparams.feed_forward_size,
         hidden_size=self.hparams.hidden_size,
         dropout=self.hparams.dropout,
         num_head=self.hparams.num_head,
         num_encoder_layer=self.hparams.num_encoder_layer,
         num_decoder_layer=self.hparams.num_decoder_layer,
         vocab_path=self.hparams.vocab_path,
         max_feature_length=self.hparams.max_feature_length,
         max_token_length=self.hparams.max_token_length,
         enable_spec_augment=self.hparams.enable_spec_augment,
         share_weight=self.hparams.share_weight,
         smoothing=self.hparams.smoothing,
     )
     x = t.load('exp/lightning_logs/version_1011/checkpoints/_ckpt_epoch_66.ckpt')
     self.load_state_dict(x['state_dict'])
Ejemplo n.º 3
0
def make_transformer(config):
    src_embedding = nn.Embedding(num_embeddings=config['src_vocab_size'],
                                 embedding_dim=config['d_model'])
    trg_embedding = nn.Embedding(num_embeddings=config['trg_vocab_size'],
                                 embedding_dim=config['d_model'])
    positional_embedding = PositionalEmbedding(
        num_embeddings=config['num_positions'],
        embedding_dim=config['d_model'],
        learnable=False)
    scaled_dot_attention = ScaledDotAttention(dropout=0)
    multi_head_attention = MultiHeadAttention(
        attention=scaled_dot_attention,
        num_heads=config['num_heads'],
        hidden_size=config['d_model'],
        key_size=config['d_model'] // config['num_heads'],
        value_size=config['d_model'] // config['num_heads'])
    feed_forward = FeedForward(input_size=config['d_model'],
                               feed_forward_size=4 * config['d_model'],
                               output_size=config['d_model'])
    encoder_layer = TransformerEncoderLayer(
        hidden_size=config['d_model'],
        attention=deepcopy(multi_head_attention),
        feed_forward=deepcopy(feed_forward),
        dropout=config['dropout'])
    encoder = TransformerEncoder(embedding=src_embedding,
                                 positional_embedding=positional_embedding,
                                 layer=encoder_layer,
                                 num_layers=config['num_layers'],
                                 dropout=config['dropout'])
    decoder_layer = TransformerDecoderLayer(
        hidden_size=config['d_model'],
        self_attention=deepcopy(multi_head_attention),
        src_attention=deepcopy(multi_head_attention),
        feed_forward=deepcopy(feed_forward),
        dropout=config['dropout'])
    decoder = TransformerDecoder(embedding=trg_embedding,
                                 positional_embedding=positional_embedding,
                                 layer=decoder_layer,
                                 num_layers=config['num_layers'],
                                 dropout=config['dropout'])
    transformer = Transformer(encoder=encoder, decoder=decoder)
    return transformer
Ejemplo n.º 4
0
class LightningModel(pl.LightningModule):
    def __init__(self, hparams):
        super(LightningModel, self).__init__()
        self.hparams = hparams
        self.__build_model()
        self.lr = 0

    def __build_model(self):
        self.transformer = Transformer(
            num_time_mask=self.hparams.num_time_mask,
            num_freq_mask=self.hparams.num_freq_mask,
            freq_mask_length=self.hparams.freq_mask_length,
            time_mask_length=self.hparams.time_mask_length,
            feature_dim=self.hparams.feature_dim,
            model_size=self.hparams.model_size,
            feed_forward_size=self.hparams.feed_forward_size,
            hidden_size=self.hparams.hidden_size,
            dropout=self.hparams.dropout,
            num_head=self.hparams.num_head,
            num_encoder_layer=self.hparams.num_encoder_layer,
            num_decoder_layer=self.hparams.num_decoder_layer,
            vocab_path=self.hparams.vocab_path,
            max_feature_length=self.hparams.max_feature_length,
            max_token_length=self.hparams.max_token_length,
            enable_spec_augment=self.hparams.enable_spec_augment,
            share_weight=self.hparams.share_weight,
            smoothing=self.hparams.smoothing,
        )
        x = t.load(
            'exp/lightning_logs/version_1020/checkpoints/_ckpt_epoch_37.ckpt')
        self.load_state_dict(x['state_dict'])

    def forward(self,
                feature,
                feature_length,
                target,
                target_length,
                cal_ce_loss=True):
        output, output_token, spec_output, feature_length, ori_token, ori_token_length, ce_loss, switch_loss = self.transformer.forward(
            feature, feature_length, target, target_length, cal_ce_loss)

        return output, output_token, spec_output, feature_length, ori_token, ori_token_length, ce_loss, switch_loss

    def decode(self, feature, feature_length, decode_type='greedy'):
        assert decode_type in ['greedy', 'beam']
        output = self.transformer.inference(feature,
                                            feature_length,
                                            decode_type=decode_type)
        return output

    def training_step(self, batch, batch_nb):
        feature, feature_length, target, target_length = batch[0], batch[
            1], batch[2], batch[3]
        model_output, output_token, spec_output, feature_length, ori_token, ori_token_length, ce_loss, switch_loss = self.forward(
            feature, feature_length, target, target_length, True)
        ctc_loss = self.transformer.cal_ctc_loss(spec_output, feature_length,
                                                 ori_token, ori_token_length)
        loss = self.hparams.loss_lambda * ce_loss + (
            1 - self.hparams.loss_lambda) * ctc_loss + switch_loss / 2
        tqdm_dict = {
            'loss': loss,
            'ce': ce_loss,
            'switch': switch_loss,
            'lr': self.lr
        }
        output = OrderedDict({
            'loss': loss,
            'ce': ce_loss,
            # 'ctc_loss': ctc_loss,
            'switch': switch_loss,
            'progress_bar': tqdm_dict,
            'log': tqdm_dict
        })
        return output

    def validation_step(self, batch, batch_nb):
        unvalid = ['[B]', '[S]', '[N]', '[T]', '[P]']
        feature, feature_length, target, target_length = batch[0], batch[
            1], batch[2], batch[3]
        model_output, output_token, spec_output, feature_length, ori_token, ori_token_length, ce_loss, switch_loss = self.forward(
            feature, feature_length, target, target_length, True)
        result_string_list = [
            ' '.join([j for j in tokenize(i) if j not in unvalid])
            for i in self.transformer.inference(feature, feature_length)
        ]
        target_string_list = [
            ' '.join([
                j
                for j in tokenize(self.transformer.vocab.id2string(i.tolist()))
                if j not in unvalid
            ]) for i in output_token
        ]
        print(result_string_list[0])
        print(target_string_list[0])
        mers = [
            cal_wer(i[0], i[1])
            for i in zip(target_string_list, result_string_list)
        ]
        mer = np.mean(mers)
        ctc_loss = self.transformer.cal_ctc_loss(spec_output, feature_length,
                                                 ori_token, ori_token_length)
        loss = self.hparams.loss_lambda * ce_loss + (
            1 - self.hparams.loss_lambda) * ctc_loss + switch_loss / 2
        tqdm_dict = {
            'loss': loss,
            'ce': ce_loss,
            'switch': switch_loss,
            'mer': mer,
            'lr': self.lr
        }
        output = OrderedDict({
            'loss': loss,
            'ce': ce_loss,
            # 'ctc_loss': ctc_loss,
            'switch': switch_loss,
            'mer': mer,
            'progress_bar': tqdm_dict,
            'log': tqdm_dict
        })
        return output

    def validation_end(self, outputs):
        val_loss = t.stack([i['loss'] for i in outputs]).mean()
        ce_loss = t.stack([i['ce'] for i in outputs]).mean()
        # ctc_loss = t.stack([i['ctc_loss'] for i in outputs]).mean()
        switch_loss = t.stack([i['switch'] for i in outputs]).mean()
        mer = np.mean([i['mer'] for i in outputs])
        print('val_loss', val_loss.item())
        print('ce', ce_loss.item())
        # print('ctc', ctc_loss.item())
        print('switch_loss', switch_loss.item())
        print('mer', mer)
        return {
            'val_loss': val_loss,
            'val_ce_loss': ce_loss,
            'val_mer': mer,
            'log': {
                'val_loss': val_loss,
                'val_ce_loss': ce_loss,
                'val_mer': mer
            }
        }

    @pl.data_loader
    def train_dataloader(self):
        # dataloader = build_multi_dataloader(
        #     record_root='data/tfrecords/{}.tfrecord',
        #     index_root='data/tfrecord_index/{}.index',
        #     data_name_list=[
        #         # 'magic_data_train_562694',
        #         'data_aishell_train_117346',
        #         # 'c_500_train_549679',
        #         # 'ce_200_161001'
        #     ],
        #     batch_size=self.hparams.train_batch_size,
        #     num_workers=self.hparams.train_loader_num_workers
        # )
        dataloader = build_raw_data_loader(
            [
                'data/filterd_manifest/ce_200.csv',
                'data/manifest/libri_train_short.csv',
                'data/filterd_manifest/c_500_train.csv',
                # 'data/filterd_manifest/aidatatang_200zh_train.csv',
                # 'data/filterd_manifest/data_aishell_train.csv',
                # 'data/filterd_manifest/AISHELL-2.csv',
                # 'data/filterd_manifest/magic_data_train.csv',
                # 'data/manifest/libri_100.csv',
                # 'data/manifest/libri_360.csv',
                # 'data/manifest/libri_500.csv'
            ],
            vocab_path=self.hparams.vocab_path,
            batch_size=self.hparams.train_batch_size,
            num_workers=self.hparams.train_loader_num_workers,
            speed_perturb=True,
            max_duration=10)
        return dataloader

    @pl.data_loader
    def val_dataloader(self):
        # dataloader = build_multi_dataloader(
        #     record_root='data/tfrecords/{}.tfrecord',
        #     index_root='data/tfrecord_index/{}.index',
        #     data_name_list=[
        #         'magic_data_test_small_2852',
        #         'data_aishell_test_small_589',
        #         # 'c_500_test_small_2245',
        #         'ce_20_dev_small_814'
        #     ],
        #     batch_size=self.hparams.train_batch_size,
        #     num_workers=self.hparams.train_loader_num_workers
        #
        # )
        dataloader = build_raw_data_loader(
            [
                'data/manifest/libri_test_short.csv',
                'data/manifest/ce_test.csv',
                # 'data/manifest/ce_20_dev.csv',
                'data/filterd_manifest/c_500_test.csv',
                # 'data/manifest/ce_20_dev_small.csv',
                # 'aishell2_testing/manifest1.csv',
                # 'data/filterd_manifest/data_aishell_test.csv'
            ],
            vocab_path=self.hparams.vocab_path,
            batch_size=self.hparams.train_batch_size,
            num_workers=self.hparams.train_loader_num_workers,
            speed_perturb=False,
            max_duration=10)
        return dataloader

    def optimizer_step(self,
                       epoch_nb,
                       batch_nb,
                       optimizer,
                       optimizer_i,
                       second_order_closure=None):
        lr = self.hparams.factor * ((self.hparams.model_size**-0.5) * min(
            (self.global_step + 1)**-0.5,
            (self.global_step + 1) * (self.hparams.warm_up_step**-1.5)))
        self.lr = lr
        for pg in optimizer.param_groups:
            pg['lr'] = lr
        optimizer.step()
        optimizer.zero_grad()

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(),
                          lr=self.hparams.lr,
                          betas=(0.9, 0.997))
        optimizer = Lookahead(optimizer)
        return optimizer

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = HyperOptArgumentParser(parents=[parent_parser])
        parser.add_argument('--num_freq_mask', default=2, type=int)
        parser.add_argument('--num_time_mask', default=2, type=int)
        parser.add_argument('--freq_mask_length', default=30, type=int)
        parser.add_argument('--time_mask_length', default=20, type=int)
        parser.add_argument('--feature_dim', default=400, type=int)
        parser.add_argument('--model_size', default=512, type=int)
        parser.add_argument('--feed_forward_size', default=2048, type=int)
        parser.add_argument('--hidden_size', default=64, type=int)
        parser.add_argument('--dropout', default=0.1, type=float)
        parser.add_argument('--num_head', default=8, type=int)
        parser.add_argument('--num_encoder_layer', default=6, type=int)
        parser.add_argument('--num_decoder_layer', default=6, type=int)
        parser.add_argument('--vocab_path',
                            default='testing_vocab_2.model',
                            type=str)
        parser.add_argument('--max_feature_length', default=1024, type=int)
        parser.add_argument('--max_token_length', default=50, type=int)
        parser.add_argument('--share_weight', default=True, type=bool)
        parser.add_argument('--loss_lambda', default=0.9, type=float)
        parser.add_argument('--smoothing', default=0.1, type=float)

        parser.add_argument('--lr', default=3e-4, type=float)
        parser.add_argument('--warm_up_step', default=16000, type=int)
        parser.add_argument('--factor', default=1, type=int)
        parser.add_argument('--enable_spec_augment', default=True, type=bool)

        parser.add_argument('--train_batch_size', default=64, type=int)
        parser.add_argument('--train_loader_num_workers', default=16, type=int)
        parser.add_argument('--val_batch_size', default=64, type=int)
        parser.add_argument('--val_loader_num_workers', default=16, type=int)

        return parser
Ejemplo n.º 5
0
def make_transformer(config):
    path = parse_path(config['data_process']['base_path'])
    data_log = yaml.load(open(path['log']['data_log']))
    share_src_trg_vocab = config['model']['share_src_trg_vocab']
    config = config['model'][config['model']['type']]
    if share_src_trg_vocab:
        src_embedding = nn.Embedding(
            num_embeddings=data_log['vocab_size'],
            embedding_dim=config['d_model']
        )
        trg_embedding = src_embedding
    else:
        src_embedding = nn.Embedding(
            num_embeddings=data_log['src_vocab_size'],
            embedding_dim=config['d_model']
        )
        trg_embedding = nn.Embedding(
            num_embeddings=data_log['trg_vocab_size'],
            embedding_dim=config['d_model']
        )
    positional_embedding = PositionalEmbedding(
        num_embeddings=config['num_positions'],
        embedding_dim=config['d_model'],
        learnable=False
    )
    scaled_dot_attention = ScaledDotAttention(dropout=0)
    multi_head_attention = MultiHeadAttention(
        attention=scaled_dot_attention,
        num_heads=config['num_heads'],
        hidden_size=config['d_model'],
        key_size=config['d_model'] // config['num_heads'],
        value_size=config['d_model'] // config['num_heads']
    )
    feed_forward = FeedForward(
        input_size=config['d_model'],
        feed_forward_size=4 * config['d_model'],
        output_size=config['d_model']
    )
    encoder_layer = TransformerEncoderLayer(
        hidden_size=config['d_model'],
        attention=deepcopy(multi_head_attention),
        feed_forward=deepcopy(feed_forward),
        dropout=config['dropout']
    )
    encoder = TransformerEncoder(
        embedding=src_embedding,
        positional_embedding=positional_embedding,
        layer=encoder_layer,
        num_layers=config['num_layers'],
        dropout=config['dropout']
    )
    decoder_layer = TransformerDecoderLayer(
        hidden_size=config['d_model'],
        self_attention=deepcopy(multi_head_attention),
        src_attention=deepcopy(multi_head_attention),
        feed_forward=deepcopy(feed_forward),
        dropout=config['dropout']
    )
    decoder = TransformerDecoder(
        embedding=trg_embedding,
        positional_embedding=positional_embedding,
        layer=decoder_layer,
        num_layers=config['num_layers'],
        dropout=config['dropout']
    )
    transformer = Transformer(
        encoder=encoder,
        decoder=decoder
    )
    return transformer