コード例 #1
0
def main():
    _, tt_set, feat_dim, vocab_size, tokenizer, msg = load_dataset(
        args.njobs, args.gpu, args.pin_memory,
        config['hparas']['curriculum'] > 0, **config['data'])
    verbose(msg)

    # model select
    print('Model initializing\n')
    net = torch.nn.DataParallel(
        AttentionModel(120,
                       hidden_size=args.hidden_size,
                       dropout_p=args.dropout_p,
                       use_attn=args.attn_use,
                       stacked_encoder=args.stacked_encoder,
                       attn_len=args.attn_len))
    # net = AttentionModel(257, 112, dropout_p = args.dropout_p, use_attn = args.attn_use)
    net = net.cuda()
    print(net)

    optimizer = optim.Adam(net.parameters(), lr=args.learning_rate)

    # Check point load
    print('Trying Checkpoint Load\n')
    ckpt_dir = args.ck_dir
    ckpt_path = os.path.join(ckpt_dir, args.ck_name)

    if os.path.exists(ckpt_path):
        ckpt = torch.load(ckpt_path)
        try:
            net.load_state_dict(ckpt['model'])
            optimizer.load_state_dict(ckpt['optimizer'])
            print('checkpoint is loaded !')
        except RuntimeError as e:
            print('wrong checkpoint\n')
    else:
        print('checkpoint not exist!')

    # test phase
    net.eval()
    with torch.no_grad():
        sr1 = 16000
        window_size = 25  # int, window size for FFT (ms)
        stride = 10
        ws = int(sr1 * 0.001 * window_size)
        st = int(sr1 * 0.001 * stride)
        name_sum = 0
        for input in tqdm(tt_set):
            tt_noisy_set, feat_dim, vocab_size, tokenizer = load_noisy_dataset(
                "test", input[0], args.njobs, args.gpu, args.pin_memory,
                config['hparas']['curriculum'] > 0, **config['data_noisy'])
            for input_noisy in tt_noisy_set:
                # test_clean_feat = input[1].to(device='cuda')
                test_noisy_feat = input_noisy[1].to(device='cuda')
                # feed data
                test_mixed_feat, attn_weight = net(test_noisy_feat)

                for i in range(len(test_mixed_feat)):
                    name = args.out_path + input_noisy[0][i] + '.mat'
                    feat = test_mixed_feat[i].to(device='cpu').numpy()
                    scio.savemat(name, {'feat': feat})
コード例 #2
0
    def load_data(self):
        ''' Load data for training/validation, store tokenizer and input/output shape'''
        self.tr_set, self.dv_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \
                         load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                                      self.curriculum>0,
                                      **self.config['data'])
        self.verbose(msg)

        # Dev set sames
        self.dv_names = []
        if type(self.dv_set) is list:
            for ds in self.config['data']['corpus']['dev_split']:
                self.dv_names.append(ds[0])
        else:
            self.dv_names = self.config['data']['corpus']['dev_split'][0]

        # Logger settings
        if type(self.dv_names) is str:
            self.best_wer = {
                'att': {
                    self.dv_names: 3.0
                },
                'ctc': {
                    self.dv_names: 3.0
                }
            }
        else:
            self.best_wer = {'att': {}, 'ctc': {}}
            for name in self.dv_names:
                self.best_wer['att'][name] = 3.0
                self.best_wer['ctc'][name] = 3.0
コード例 #3
0
 def load_data(self):
     ''' Load data for training/validation, store tokenizer and input/output shape'''
     self.tr_set, self.dv_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \
         load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                      self.curriculum > 0, **self.config['data'])
     ## from src.data import load_dataset
     ## curriculum refers to ascending
     self.verbose(msg)
コード例 #4
0
 def load_data(self, batch_size=7):
     ''' Load data for training/validation, store tokenizer and input/output shape'''
     prev_batch_size = self.config['data']['corpus']['batch_size']
     self.config['data']['corpus']['batch_size'] = batch_size
     self.tr_set, self.dv_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \
         load_dataset(self.paras.njobs, self.paras.gpu,
                      self.paras.pin_memory, False, **self.config['data'])
     self.config['data']['corpus']['batch_size'] = prev_batch_size
     self.verbose(msg)
コード例 #5
0
    def test_load_dataset(self):
        features, labels = load_dataset(MAT_FILE)

        # check features
        assert isinstance(features, np.ndarray)
        assert features.shape == (2, 2, 40)

        # check labels
        assert isinstance(labels, np.ndarray)
        assert labels.shape == (2, 2)
コード例 #6
0
 def load_data(self):
     ''' Load data for training/validation, store tokenizer and input/output shape'''
     self.verbose(['Loading data... large corpus may took a while.'])
     self.unpair_set, self.pair_set, self.dev_set, self.test_set, self.audio_converter, self.tokenizer, data_msg = \
             load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, **self.config['data'])
     self.pair_iter = iter(self.pair_set)
     self.unpair_iter = iter(self.unpair_set)
     self.dev_iter = iter(self.dev_set)
     # Feature statics
     self.n_mels, self.linear_dim = self.audio_converter.feat_dim
     self.vocab_size = self.tokenizer.vocab_size
     self.n_spkr = len(
         json.load(open(self.config['data']['corpus']['spkr_map'])))
     self.verbose(data_msg)
コード例 #7
0
    def load_data(self):
        ''' Load data for training/validation, store tokenizer and input/output shape'''
        if self.paras.upstream is not None:
            print(f'[Solver] - using S3PRL {self.paras.upstream}')
            self.tr_set, self.dv_set, self.vocab_size, self.tokenizer, msg = \
                            load_wav_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                                        self.curriculum>0,
                                        **self.config['data'])
            self.upstream = torch.hub.load(
                's3prl/s3prl',
                self.paras.upstream,
                feature_selection=self.paras.upstream_feature_selection,
                refresh=self.paras.upstream_refresh,
                ckpt=self.paras.upstream_ckpt,
                force_reload=True,
            )
            self.feat_dim = self.upstream.get_output_dim()
            self.specaug = Augment()
        else:
            self.tr_set, self.dv_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \
                         load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                                      self.curriculum>0,
                                      **self.config['data'])
        self.verbose(msg)

        # Dev set sames
        self.dv_names = []
        if type(self.dv_set) is list:
            for ds in self.config['data']['corpus']['dev_split']:
                self.dv_names.append(ds[0])
        else:
            self.dv_names = self.config['data']['corpus']['dev_split'][0]

        # Logger settings
        if type(self.dv_names) is str:
            self.best_wer = {
                'att': {
                    self.dv_names: 3.0
                },
                'ctc': {
                    self.dv_names: 3.0
                }
            }
        else:
            self.best_wer = {'att': {}, 'ctc': {}}
            for name in self.dv_names:
                self.best_wer['att'][name] = 3.0
                self.best_wer['ctc'][name] = 3.0
コード例 #8
0
    def load_data(self):
        ''' Load data for training/validation, store tokenizer and input/output shape'''
        self.unpair_set, self.pair_set, self.dev_set, self.test_set, self.audio_converter, self.tokenizer, data_msg = \
                load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, inference_stage=True,
                             **self.config['data'])
        self.pair_iter = iter(self.pair_set)
        self.unpair_iter = iter(self.unpair_set)
        self.dev_iter = iter(self.dev_set)
        self.test_iter = iter(self.test_set)
        # Feature statics
        self.n_mels, self.linear_dim = self.audio_converter.feat_dim
        self.vocab_size = self.tokenizer.vocab_size
        self.n_spkr = len(
            json.load(open(self.config['data']['corpus']['spkr_map'])))

        self.filelist = {'pair': [], 'unpair': [], 'dev': [], 'test': []}
        self.filelist['pair'] = self.pair_set.dataset.table.index.tolist()
        self.filelist['unpair'] = self.unpair_set.dataset.table.index.tolist()
        self.filelist['dev'] = self.dev_set.dataset.table.index.tolist()
        self.filelist['test'] = self.test_set.dataset.table.index.tolist()
コード例 #9
0
 def load_data(self):
     ''' Load data for training/validation, store tokenizer and input/output shape'''
     if self.paras.upstream is not None:
         print(f'[Solver] - using S3PRL {self.paras.upstream}')
         self.dv_set, self.tt_set, self.vocab_size, self.tokenizer, msg = \
                         load_wav_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, 
                                          False, **self.config['data'])
         self.upstream = torch.hub.load(
             's3prl/s3prl',
             args.upstream,
             feature_selection = args.upstream_feature_selection,
             refresh = args.upstream_refresh,
             ckpt = args.upstream_ckpt,
             force_reload = True,
         )
         self.feat_dim = self.upstream.get_output_dim()
     else:
         self.dv_set, self.tt_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \
             load_dataset(self.paras.njobs, self.paras.gpu,
                          self.paras.pin_memory, False, **self.config['data'])
     self.verbose(msg)
コード例 #10
0
    def exec(self):
        ''' Training End-to-end ASR system '''
        self.verbose('Total training steps {}.'.format(
            human_format(self.max_step)))
        ctc_loss, att_loss, emb_loss = None, None, None
        n_epochs = 0
        self.timer.set()

        while self.step < self.max_step:
            # Renew dataloader to enable random sampling
            if self.curriculum > 0 and n_epochs == self.curriculum:
                self.verbose(
                    'Curriculum learning ends after {} epochs, starting random sampling.'
                    .format(n_epochs))
                self.tr_set, _, _, _, _, _ = \
                    load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                                 False, **self.config['data'])
            for data in self.tr_set:
                # Pre-step : update tf_rate/lr_rate and do zero_grad
                tf_rate = self.optimizer.pre_step(self.step)
                total_loss = 0

                # Fetch data
                feat, feat_len, txt, txt_len = self.fetch_data(data)
                self.timer.cnt('rd')

                # Forward model
                # Note: txt should NOT start w/ <sos>
                ctc_output, encode_len, att_output, att_align, dec_state = \
                    self.model(feat, feat_len, max(txt_len), tf_rate=tf_rate,
                               teacher=txt, get_dec_state=self.emb_reg)

                # Plugins
                if self.emb_reg:
                    emb_loss, fuse_output = self.emb_decoder(dec_state,
                                                             att_output,
                                                             label=txt)
                    total_loss += self.emb_decoder.weight * emb_loss

                # Compute all objectives
                if ctc_output is not None:
                    if self.paras.cudnn_ctc:
                        ctc_loss = self.ctc_loss(
                            ctc_output.transpose(0, 1),
                            txt.to_sparse().values().to(device='cpu',
                                                        dtype=torch.int32),
                            [ctc_output.shape[1]] * len(ctc_output),
                            txt_len.cpu().tolist())
                    else:
                        ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1),
                                                 txt, encode_len, txt_len)
                    total_loss += ctc_loss * self.model.ctc_weight

                if att_output is not None:
                    b, t, _ = att_output.shape
                    att_output = fuse_output if self.emb_fuse else att_output
                    att_loss = self.seq_loss(
                        att_output.contiguous().view(b * t, -1),
                        txt.contiguous().view(-1))
                    total_loss += att_loss * (1 - self.model.ctc_weight)

                self.timer.cnt('fw')

                # Backprop
                grad_norm = self.backward(total_loss)
                self.step += 1

                # Logger
                if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0):
                    self.progress(
                        'Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'.
                        format(total_loss.cpu().item(), grad_norm,
                               self.timer.show()))
                    self.write_log('loss', {
                        'tr_ctc': ctc_loss,
                        'tr_att': att_loss
                    })
                    self.write_log('emb_loss', {'tr': emb_loss})
                    self.write_log(
                        'wer', {
                            'tr_att':
                            cal_er(self.tokenizer, att_output, txt),
                            'tr_ctc':
                            cal_er(self.tokenizer, ctc_output, txt, ctc=True)
                        })
                    if self.emb_fuse:
                        if self.emb_decoder.fuse_learnable:
                            self.write_log(
                                'fuse_lambda',
                                {'emb': self.emb_decoder.get_weight()})
                        self.write_log('fuse_temp',
                                       {'temp': self.emb_decoder.get_temp()})

                # Validation
                if (self.step == 1) or (self.step % self.valid_step == 0):
                    self.validate()

                # End of step
                # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
                torch.cuda.empty_cache()
                self.timer.set()
                if self.step > self.max_step:
                    break
            n_epochs += 1
        self.log.close()
コード例 #11
0
 def load_data(self):
     ''' Load data for training/validation, store tokenizer and input/output shape'''
     self.dv_set, self.tt_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \
         load_dataset(self.paras.njobs, self.paras.gpu,
                      self.paras.pin_memory, False, **self.config['data'])
     self.verbose(msg)
コード例 #12
0
ファイル: train_asr.py プロジェクト: ttaoREtw/Multi-CTC
    def exec(self):
        ''' Training End-to-end ASR system '''
        self.verbose('Total training steps {}.'.format(
            human_format(self.max_step)))
        ctc_loss = None
        n_epochs = 0
        self.timer.set()

        while self.step < self.max_step:
            # Renew dataloader to enable random sampling
            if self.curriculum > 0 and n_epochs == self.curriculum:
                self.verbose(
                    'Curriculum learning ends after {} epochs, starting random sampling.'
                    .format(n_epochs))
                self.tr_set, _, _, _, _, _ = \
                    load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                                 False, **self.config['data'])
            for data in self.tr_set:
                # Pre-step : update tf_rate/lr_rate and do zero_grad
                # zero grad here
                tf_rate = self.optimizer.pre_step(self.step)
                total_loss = 0

                # Fetch data
                feat, feat_len, txt, txt_len = self.fetch_data(data)
                self.timer.cnt('rd')

                # Forward model
                # Note: txt should NOT start w/ <sos>
                ctc_output, encode_len = self.model(feat, feat_len)

                # Compute all objectives
                if self.paras.cudnn_ctc:
                    ctc_loss = self.ctc_loss(
                        ctc_output.transpose(0, 1),
                        txt.to_sparse().values().to(device='cpu',
                                                    dtype=torch.int32),
                        [ctc_output.shape[1]] * len(ctc_output),
                        txt_len.cpu().tolist())
                else:
                    ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1), txt,
                                             encode_len, txt_len)

                total_loss = ctc_loss

                self.timer.cnt('fw')

                # Backprop
                grad_norm = self.backward(total_loss)

                self.step += 1
                # Logger

                if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0):
                    self.progress(
                        'Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'.
                        format(total_loss.cpu().item(), grad_norm,
                               self.timer.show()))
                    #self.write_log('wer', {'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, ctc=True)})
                    ctc_output = [
                        x[:length].argmax(dim=-1)
                        for x, length in zip(ctc_output, encode_len)
                    ]
                    self.write_log(
                        'per', {
                            'tr_ctc':
                            cal_er(self.tokenizer,
                                   ctc_output,
                                   txt,
                                   mode='per',
                                   ctc=True)
                        })
                    self.write_log(
                        'wer', {
                            'tr_ctc':
                            cal_er(self.tokenizer,
                                   ctc_output,
                                   txt,
                                   mode='wer',
                                   ctc=True)
                        })
                    self.write_log('loss', {'tr_ctc': ctc_loss.cpu().item()})

                # Validation
                if (self.step == 1) or (self.step % self.valid_step == 0):
                    self.validate()

                # End of step
                # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
                torch.cuda.empty_cache()
                self.timer.set()
                if self.step > self.max_step:
                    break
            n_epochs += 1
コード例 #13
0
def main():
    summary = SummaryWriter('./log')
    tr_set, dv_set, feat_dim, vocab_size, tokenizer, msg = load_dataset(args.njobs, args.gpu, args.pin_memory,
                                                                        config['hparas']['curriculum'] > 0,
                                                                        **config['data'])

    verbose(msg)
    # model select
    print('Model initializing\n')

    net = torch.nn.DataParallel(
        make_long_model(n=args.ENCODER_NUM, d_model=120, d_ff=2048, h=config.HEAD, dropout=0.1))
    net = net.cuda()
    print(net)

    optimizer = optim.Adam(net.parameters(), lr=args.learning_rate)

    scheduler = ExponentialLR(optimizer, 0.5)

    # check point load
    # Check point load

    print('Trying Checkpoint Load\n')
    ckpt_dir = 'ckpt_dir'
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    best_loss = 200000.
    ckpt_path = os.path.join(ckpt_dir, args.ck_name)
    if os.path.exists(ckpt_path):
        ckpt = torch.load(ckpt_path)
        try:
            net.load_state_dict(ckpt['model'])
            optimizer.load_state_dict(ckpt['optimizer'])
            best_loss = ckpt['best_loss']

            print('checkpoint is loaded !')
            print('current best loss : %.4f' % best_loss)
        except RuntimeError as e:
            print('wrong checkpoint\n')
    else:
        print('checkpoint not exist!')
        print('current best loss : %.4f' % best_loss)

    print('Training Start!')
    # train
    iteration = 0
    train_losses = []
    test_losses = []
    for epoch in range(args.num_epochs):
        n = 0
        avg_loss = 0
        net.train()
        for input in tqdm(tr_set):
            tr_noisy_set, feat_dim, vocab_size, tokenizer = load_noisy_dataset("train", input[0], args.njobs,
                                                                               args.gpu,
                                                                               args.pin_memory,
                                                                               config['hparas']['curriculum'] > 0,
                                                                               **config['data_noisy'])
            for input_noisy in tr_noisy_set:
                train_clean_feat = input[1].to(device='cuda')
                train_noisy_feat = input_noisy[1].to(device='cuda')

                iteration += 1

                # feed data
                train_mixed_feat, attn_weight = net(train_noisy_feat)
                if train_mixed_feat.shape == train_clean_feat.shape:
                    loss = F.mse_loss(train_mixed_feat, train_clean_feat, True)

                    if torch.any(torch.isnan(loss)):
                        torch.save(
                            {'clean_mag': train_clean_feat, 'noisy_mag': train_noisy_feat, 'out_mag': train_mixed_feat},
                            'nan_mag')
                        raise ('loss is NaN')
                    avg_loss += loss.item()

                    n += 1
                    # gradient optimizer
                    optimizer.zero_grad()

                    loss.backward()

                    # update weight
                    optimizer.step()

        avg_loss /= n
        print('result:')
        print('[epoch: {}, iteration: {}] avg_loss : {:.4f}'.format(epoch, iteration, avg_loss))

        summary.add_scalar('Train Loss', avg_loss, iteration)

        train_losses.append(avg_loss)
        if (len(train_losses) > 2) and (train_losses[-2] < avg_loss):
            print("Learning rate Decay")
            scheduler.step()

        # test phase
        n = 0
        avg_test_loss = 0
        net.eval()
        with torch.no_grad():
            for input in tqdm(dv_set):
                dv_noisy_set, feat_dim, vocab_size, tokenizer = load_noisy_dataset("dev", input[0], args.njobs,
                                                                                   args.gpu,
                                                                                   args.pin_memory,
                                                                                   config['hparas']['curriculum'] > 0,
                                                                                   **config['data_noisy'])
                for input_noisy in dv_noisy_set:
                    test_clean_feat = input[1].to(device='cuda')
                    test_noisy_feat = input_noisy[1].to(device='cuda')

                    test_mixed_feat, logits_attn_weight = net(test_noisy_feat)
                    if test_mixed_feat.shape == test_clean_feat.shape:
                        test_loss = F.mse_loss(test_mixed_feat, test_clean_feat, True)

                        avg_test_loss += test_loss.item()
                        n += 1

            avg_test_loss /= n

            test_losses.append(avg_test_loss)
            summary.add_scalar('Test Loss', avg_test_loss, iteration)

            print('[epoch: {}, iteration: {}] test loss : {:.4f} '.format(epoch, iteration, avg_test_loss))
            if avg_test_loss < best_loss:
                best_loss = avg_test_loss
                # Note: optimizer also has states ! don't forget to save them as well.
                ckpt = {'model': net.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'best_loss': best_loss}
                torch.save(ckpt, ckpt_path)
                print('checkpoint is saved !')
コード例 #14
0
    def exec(self):
        ''' Training End-to-end ASR system '''
        self.verbose('Total training steps {}.'.format(
            human_format(self.max_step)))
        if self.transfer_learning:
            self.model.encoder.fix_layers(self.fix_enc)
            if self.fix_dec and self.model.enable_att:
                self.model.decoder.fix_layers()
            if self.fix_dec and self.model.enable_ctc:
                self.model.fix_ctc_layer()

        self.n_epochs = 0
        self.timer.set()
        '''early stopping for ctc '''
        self.early_stoping = self.config['hparas']['early_stopping']
        stop_epoch = 10
        batch_size = self.config['data']['corpus']['batch_size']
        stop_step = len(self.tr_set) * stop_epoch // batch_size

        while self.step < self.max_step:
            ctc_loss, att_loss, emb_loss = None, None, None
            # Renew dataloader to enable random sampling

            if self.curriculum > 0 and n_epochs == self.curriculum:
                self.verbose(
                    'Curriculum learning ends after {} epochs, starting random sampling.'
                    .format(n_epochs))
                self.tr_set, _, _, _, _, _ = \
                         load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                                      False, **self.config['data'])

            for data in self.tr_set:
                # Pre-step : update tf_rate/lr_rate and do zero_grad
                tf_rate = self.optimizer.pre_step(self.step)
                total_loss = 0

                # Fetch data
                feat, feat_len, txt, txt_len = self.fetch_data(data,
                                                               train=True)

                self.timer.cnt('rd')
                # Forward model
                # Note: txt should NOT start w/ <sos>
                ctc_output, encode_len, att_output, att_align, dec_state = \
                    self.model( feat, feat_len, max(txt_len), tf_rate=tf_rate,
                                    teacher=txt, get_dec_state=self.emb_reg)
                # Clear not used objects
                del att_align

                # Plugins
                if self.emb_reg:
                    emb_loss, fuse_output = self.emb_decoder(dec_state,
                                                             att_output,
                                                             label=txt)
                    total_loss += self.emb_decoder.weight * emb_loss
                else:
                    del dec_state
                ''' early stopping ctc'''
                if self.early_stoping:
                    if self.step > stop_step:
                        ctc_output = None
                        self.model.ctc_weight = 0
                #print(ctc_output.shape)
                # Compute all objectives
                if ctc_output is not None:
                    if self.paras.cudnn_ctc:
                        ctc_loss = self.ctc_loss(
                            ctc_output.transpose(0, 1),
                            txt.to_sparse().values().to(device='cpu',
                                                        dtype=torch.int32),
                            [ctc_output.shape[1]] * len(ctc_output),
                            #[int(encode_len.max()) for _ in encode_len],
                            txt_len.cpu().tolist())
                    else:
                        ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1),
                                                 txt, encode_len, txt_len)
                    total_loss += ctc_loss * self.model.ctc_weight
                    del encode_len

                if att_output is not None:
                    #print(att_output.shape)
                    b, t, _ = att_output.shape
                    att_output = fuse_output if self.emb_fuse else att_output
                    att_loss = self.seq_loss(att_output.view(b * t, -1),
                                             txt.view(-1))
                    # Sum each uttr and devide by length then mean over batch
                    # att_loss = torch.mean(torch.sum(att_loss.view(b,t),dim=-1)/torch.sum(txt!=0,dim=-1).float())
                    total_loss += att_loss * (1 - self.model.ctc_weight)

                self.timer.cnt('fw')

                # Backprop
                grad_norm = self.backward(total_loss)

                self.step += 1

                # Logger
                if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0):
                    self.progress('Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'\
                            .format(total_loss.cpu().item(),grad_norm,self.timer.show()))
                    self.write_log('emb_loss', {'tr': emb_loss})
                    if att_output is not None:
                        self.write_log('loss', {'tr_att': att_loss})
                        self.write_log(self.WER, {
                            'tr_att':
                            cal_er(self.tokenizer, att_output, txt)
                        })
                        self.write_log(
                            'cer', {
                                'tr_att':
                                cal_er(self.tokenizer,
                                       att_output,
                                       txt,
                                       mode='cer')
                            })
                    if ctc_output is not None:
                        self.write_log('loss', {'tr_ctc': ctc_loss})
                        self.write_log(
                            self.WER, {
                                'tr_ctc':
                                cal_er(
                                    self.tokenizer, ctc_output, txt, ctc=True)
                            })
                        self.write_log(
                            'cer', {
                                'tr_ctc':
                                cal_er(self.tokenizer,
                                       ctc_output,
                                       txt,
                                       mode='cer',
                                       ctc=True)
                            })
                        self.write_log(
                            'ctc_text_train',
                            self.tokenizer.decode(
                                ctc_output[0].argmax(dim=-1).tolist(),
                                ignore_repeat=True))
                    # if self.step==1 or self.step % (self.PROGRESS_STEP * 5) == 0:
                    #     self.write_log('spec_train',feat_to_fig(feat[0].transpose(0,1).cpu().detach(), spec=True))
                    #del total_loss

                    if self.emb_fuse:
                        if self.emb_decoder.fuse_learnable:
                            self.write_log(
                                'fuse_lambda',
                                {'emb': self.emb_decoder.get_weight()})
                        self.write_log('fuse_temp',
                                       {'temp': self.emb_decoder.get_temp()})

                # Validation
                if (self.step == 1) or (self.step % self.valid_step == 0):
                    if type(self.dv_set) is list:
                        for dv_id in range(len(self.dv_set)):
                            self.validate(self.dv_set[dv_id],
                                          self.dv_names[dv_id])
                    else:
                        self.validate(self.dv_set, self.dv_names)
                if self.step % (len(self.tr_set) //
                                batch_size) == 0:  # one epoch
                    print('Have finished epoch: ', self.n_epochs)
                    self.n_epochs += 1

                if self.lr_scheduler == None:
                    lr = self.optimizer.opt.param_groups[0]['lr']

                    if self.step == 1:
                        print(
                            '[INFO]    using lr schedular defined by Daniel, init lr = ',
                            lr)

                    if self.step > 99999 and self.step % 2000 == 0:
                        lr = lr * 0.85
                        for param_group in self.optimizer.opt.param_groups:
                            param_group['lr'] = lr
                        print('[INFO]     at step:', self.step)
                        print('[INFO]   lr reduce to', lr)

                    #self.lr_scheduler.step(total_loss)
                # End of step
                # if self.step % EMPTY_CACHE_STEP == 0:
                # Empty cuda cache after every fixed amount of steps
                torch.cuda.empty_cache(
                )  # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
                self.timer.set()
                if self.step > self.max_step: break

            #update lr_scheduler

        self.log.close()
        print('[INFO] Finished training after', human_format(self.max_step),
              'steps.')
コード例 #15
0
from src.data import load_dataset
from src.model import BruiseDetector

# Load the data
train_data, test_data = load_dataset(batch_size=1)

# Load Model
model = BruiseDetector()
# Train Model
model.fit(train_data, batch_size=1, epochs=1)

# Save Model
# model.save()

# Evaluate The model
# model.evaluate()