示例#1
0
def load_test_train_data(audio_conf,
                         train_manifest,
                         val_manifest,
                         labels,
                         batch_size=20,
                         augment=False,
                         normalize=True,
                         num_workers=4):

    train_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                       manifest_filepath=train_manifest,
                                       labels=labels,
                                       normalize=normalize,
                                       augment=augment)

    test_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                      manifest_filepath=val_manifest,
                                      labels=labels,
                                      normalize=normalize,
                                      augment=False)

    train_sampler = BucketingSampler(train_dataset, batch_size=batch_size)

    train_loader = AudioDataLoader(train_dataset,
                                   num_workers=num_workers,
                                   batch_sampler=train_sampler)
    test_loader = AudioDataLoader(test_dataset,
                                  batch_size=batch_size,
                                  num_workers=num_workers)

    return (train_loader, test_loader, train_sampler)
示例#2
0
def init_train_set(epoch, from_iter):
    #train_dataset.set_curriculum_epoch(epoch, sample=True)
    train_dataset.set_curriculum_epoch(epoch, sample=False)
    global train_loader, train_sampler
    if not args.distributed:
        train_sampler = BucketingSampler(train_dataset,
                                         batch_size=args.batch_size)
        train_sampler.bins = train_sampler.bins[from_iter:]
    else:
        train_sampler = DistributedBucketingSampler(
            train_dataset,
            batch_size=args.batch_size,
            num_replicas=args.world_size,
            rank=args.rank)
    train_loader = AudioDataLoader(train_dataset,
                                   num_workers=args.num_workers,
                                   batch_sampler=train_sampler)
    if (not args.no_shuffle and epoch != 0) or args.no_sorta_grad:
        print("Shuffling batches for the following epochs")
        train_sampler.shuffle(epoch)
                           audio_conf=audio_conf,
                           bidirectional=args.bidirectional)

    decoder = GreedyDecoder(labels)
    train_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                       manifest_filepath=args.train_manifest,
                                       labels=labels,
                                       normalize=True,
                                       augment=args.augment)
    test_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                      manifest_filepath=args.val_manifest,
                                      labels=labels,
                                      normalize=True,
                                      augment=False)
    if not args.distributed:
        train_sampler = BucketingSampler(train_dataset,
                                         batch_size=args.batch_size)
    else:
        train_sampler = DistributedBucketingSampler(
            train_dataset,
            batch_size=args.batch_size,
            num_replicas=args.world_size,
            rank=args.rank)
    train_loader = AudioDataLoader(train_dataset,
                                   num_workers=args.num_workers,
                                   batch_sampler=train_sampler)
    test_loader = AudioDataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers)

    if (not args.no_shuffle and start_epoch != 0) or args.no_sorta_grad:
        print("Shuffling batches for the following epochs")
# import torch.utils.data.distributed
import torch  # 20200113    torch.nn.CTCLoss

from data.data_loader import AudioDataLoader, SpectrogramDataset, BucketingSampler, DistributedBucketingSampler
from decoder import GreedyDecoder
# from logger import VisdomLogger, TensorBoardLogger
# from model import DeepSpeech, supported_rnns
# from test import evaluate
# from utils import reduce_tensor, check_loss

if __name__ == '__main__':
    labels_path = "labels.json"
    with open(labels_path) as label_file:
        labels = str(''.join(json.load(label_file)))
    print("labels : ", labels)
    decoder = GreedyDecoder(labels)
    print("decoder : ", decoder)
    audio_conf = dict(sample_rate=args.sample_rate,
                      window_size=args.window_size,
                      window_stride=args.window_stride,
                      window=args.window,
                      noise_dir=args.noise_dir,
                      noise_prob=args.noise_prob,
                      noise_levels=(args.noise_min, args.noise_max))
    train_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                       manifest_filepath=args.train_manifest,
                                       labels=labels,
                                       normalize=True,
                                       augment=args.augment)
    train_sampler = BucketingSampler(train_dataset, batch_size=args.batch_size)
示例#5
0
parser.add_argument('--num-workers', default=4, type=int, help='Number of workers used in dataloading')
parser.add_argument('--verbose', action="store_true", help="print out decoded output and error of each sample")
parser.add_argument('--output-path', default=None, type=str, help="Where to save raw acoustic output")
args = parser.parse_args()

if __name__ == '__main__':
    torch.set_grad_enabled(False)
    model, _ = load_model(args.model_path)
    device = torch.device("cuda" if args.cuda else "cpu")
    label_decoder = LabelDecoder(model.labels)
    model.eval()
    model = model.to(device)

    test_dataset = SpectrogramDataset(audio_conf=model.audio_conf, manifest_filepath=args.test_manifest,
                                      labels=model.labels)
    test_sampler = BucketingSampler(test_dataset, batch_size=args.batch_size)
    test_loader = AudioDataLoader(test_dataset, batch_sampler=test_sampler,
                                  num_workers=args.num_workers)
    test_sampler.shuffle(1)

    total_wer, total_cer, total_ler, num_words, num_chars, num_labels = 0, 0, 0, 0, 0, 0
    output_data = []

    for i, (data) in tqdm(enumerate(test_loader), total=len(test_loader), ascii=True):
        inputs, targets, input_sizes, target_sizes, filenames = data
        inputs = inputs.to(device)
        input_sizes = input_sizes.to(device)
        outputs = model.transcribe(inputs, input_sizes)

        for i, target in enumerate(targets):
            reference = label_decoder.decode(target[:target_sizes[i]].tolist())
示例#6
0
                          window=args.window,
                          noise_dir=None,
                          noise_prob=1,
                          noise_levels=(0.5, 0.8))
    else:
        print("must load model!!!")
        exit()

    # decoder = GreedyDecoder(labels)
    original_dataset = SpectrogramDataset(
        audio_conf=audio_conf,
        manifest_filepath=args.train_manifest,
        labels=labels,
        normalize=False,
        augment=args.augment)
    original_train_sampler = BucketingSampler(original_dataset,
                                              batch_size=args.batch_size)

    original_train_loader = AudioDataLoader(
        original_dataset,
        num_workers=args.num_workers,
        batch_sampler=original_train_sampler)
    for i, (data) in enumerate(original_train_loader, start=start_iter):
        original_inputs, _, _, _ = data

    train_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                       manifest_filepath=args.train_manifest,
                                       labels=labels,
                                       normalize=True,
                                       augment=args.augment)

    train_sampler = BucketingSampler(train_dataset, batch_size=args.batch_size)
def main():

    opt = TrainOptions().parse()
    device = torch.device("cuda:{}".format(opt.gpu_ids[0]) if len(opt.gpu_ids)
                          > 0 and torch.cuda.is_available() else "cpu")

    #import fake_opt
    #opt = fake_opt.Asr_train()

    visualizer = Visualizer(opt)
    logging = visualizer.get_logger()
    acc_report = visualizer.add_plot_report(['train/acc', 'val/acc'],
                                            'acc.png')
    loss_report = visualizer.add_plot_report(['train/loss', 'val/loss'],
                                             'loss.png')

    # data
    logging.info("Building dataset.")
    # train目录 和 dict目录,作为输入
    train_dataset = SequentialDataset(
        opt,
        os.path.join(opt.dataroot, 'train'),
        os.path.join(opt.dict_dir, 'train_units.txt'),
    )
    val_dataset = SequentialDataset(
        opt,
        os.path.join(opt.dataroot, 'dev'),
        os.path.join(opt.dict_dir, 'train_units.txt'),
    )

    train_sampler = BucketingSampler(train_dataset, batch_size=opt.batch_size)

    train_loader = SequentialDataLoader(train_dataset,
                                        num_workers=opt.num_workers,
                                        batch_sampler=train_sampler)
    val_loader = SequentialDataLoader(val_dataset,
                                      batch_size=int(opt.batch_size / 2),
                                      num_workers=opt.num_workers,
                                      shuffle=False)

    opt.idim = train_dataset.get_feat_size()
    opt.odim = train_dataset.get_num_classes()
    opt.char_list = train_dataset.get_char_list()
    opt.train_dataset_len = len(train_dataset)

    logging.info('#input dims : ' + str(opt.idim))
    logging.info('#output dims: ' + str(opt.odim))
    logging.info("Dataset ready!")

    # Setup a model
    asr_model = E2E(opt)
    fbank_model = FbankModel(opt)
    lr = opt.lr  # default=0.005
    eps = opt.eps  # default=1e-8
    iters = opt.iters  # default=0
    start_epoch = opt.start_epoch  # default=0
    best_loss = opt.best_loss  # default=float('inf')
    best_acc = opt.best_acc  # default=0

    # 如果有中继点
    if opt.resume:
        model_path = os.path.join(opt.works_dir, opt.resume)
        if os.path.isfile(model_path):
            package = torch.load(model_path,
                                 map_location=lambda storage, loc: storage)
            lr = package.get('lr', opt.lr)
            eps = package.get('eps', opt.eps)
            best_loss = package.get('best_loss', float('inf'))
            best_acc = package.get('best_acc', 0)
            start_epoch = int(package.get('epoch', 0))
            iters = int(package.get('iters', 0))

            acc_report = package.get('acc_report', acc_report)
            loss_report = package.get('loss_report', loss_report)
            visualizer.set_plot_report(acc_report, 'acc.png')
            visualizer.set_plot_report(loss_report, 'loss.png')

            asr_model = E2E.load_model(model_path, 'asr_state_dict')
            fbank_model = FbankModel.load_model(model_path, 'fbank_state_dict')
            logging.info('Loading model {} and iters {}'.format(
                model_path, iters))
        else:
            print("no checkpoint found at {}".format(model_path))
    # convert to cuda
    asr_model.cuda()
    fbank_model.cuda()
    print(asr_model)
    print(fbank_model)

    # Setup an optimizer
    parameters = filter(
        lambda p: p.requires_grad,
        itertools.chain(asr_model.parameters(), fbank_model.parameters()))
    #parameters = filter(lambda p: p.requires_grad, itertools.chain(asr_model.parameters()))
    if opt.opt_type == 'adadelta':
        optimizer = torch.optim.Adadelta(parameters, rho=0.95, eps=eps)
    elif opt.opt_type == 'adam':
        optimizer = torch.optim.Adam(parameters,
                                     lr=lr,
                                     betas=(opt.beta1, 0.999))

    asr_model.train()
    fbank_model.train()
    #NOTE sample_rampup = utils.ScheSampleRampup(opt.sche_samp_start_iter, opt.sche_samp_final_iter, opt.sche_samp_final_rate)
    sample_rampup = utils.ScheSampleRampup(opt.sche_samp_start_epoch,
                                           opt.sche_samp_final_epoch,
                                           opt.sche_samp_final_rate)
    sche_samp_rate = sample_rampup.update(iters)

    # 计算fbank的cmvn输入
    fbank_cmvn_file = os.path.join(opt.exp_path, 'fbank_cmvn.npy')
    if os.path.exists(fbank_cmvn_file):
        # 如果有fbank_cmvn
        fbank_cmvn = np.load(fbank_cmvn_file)
    else:
        # 否则自己生成
        for i, (data) in enumerate(train_loader, start=0):
            utt_ids, spk_ids, inputs, log_inputs, targets, input_sizes, target_sizes = data
            fbank_cmvn = fbank_model.compute_cmvn(inputs, input_sizes)
            # 下面这个if 是原code,通过fbank-cmvn是否为none判断是否break是十分愚蠢的
            if fbank_cmvn is not None:
                np.save(fbank_cmvn_file, fbank_cmvn)
                print('save fbank_cmvn to {}'.format(fbank_cmvn_file))
                break
            # 因此需要通过 cmvn_processed_num 和 cmvn_num 来判断
            if fbank_model.cmvn_processed_num >= fbank_model.cmvn_num:
                # 运行最后一次compute_cmvn
                fbank_cmvn = fbank_model.compute_cmvn(inputs, input_sizes)
                np.save(fbank_cmvn_file, fbank_cmvn)
                print('save fbank_cmvn to {}'.format(fbank_cmvn_file))
                break
    print(fbank_model.cmvn_processed_num)  # 3944
    fbank_cmvn = torch.FloatTensor(fbank_cmvn)

    # 开始训练
    for epoch in range(start_epoch, opt.epochs):
        if epoch > opt.shuffle_epoch:
            print("Shuffling batches for the following epochs")
            train_sampler.shuffle(epoch)
        for i, (data) in enumerate(train_loader,
                                   start=(iters * opt.batch_size) %
                                   len(train_dataset)):
            utt_ids, spk_ids, inputs, log_inputs, targets, input_sizes, target_sizes = data
            fbank_features = fbank_model(inputs, fbank_cmvn)
            # NOTE 下面这个原来的语句,是和 变量data 不匹配的
            # utt_ids, spk_ids, fbank_features, targets, input_sizes, target_sizes = data

            # asr_model 输出的数量是3,而这里却有4个变量
            # 去查以下e2e_model
            # 实际在forward中的输出根本没有context
            # 另外,下面另外一个asr_model 同理
            loss_ctc, loss_att, acc = asr_model(fbank_features, targets,
                                                input_sizes, target_sizes,
                                                sche_samp_rate)
            loss = opt.mtlalpha * loss_ctc + (1 - opt.mtlalpha) * loss_att

            optimizer.zero_grad()  # Clear the parameter gradients
            loss.backward()  # compute backwards

            # compute the gradient norm to check if it is normal or not 'fbank_state_dict': fbank_model.state_dict(),
            grad_norm = torch.nn.utils.clip_grad_norm_(asr_model.parameters(),
                                                       opt.grad_clip)
            if math.isnan(grad_norm):
                logging.warning('grad norm is nan. Do not update model.')
            else:
                optimizer.step()

            iters += 1
            errors = {
                'train/loss': loss.item(),
                'train/loss_ctc': loss_ctc.item(),
                'train/acc': acc,
                'train/loss_att': loss_att.item()
            }
            visualizer.set_current_errors(errors)
            if iters % opt.print_freq == 0:
                visualizer.print_current_errors(epoch, iters)
                state = {
                    'asr_state_dict': asr_model.state_dict(),
                    'opt': opt,
                    'epoch': epoch,
                    'iters': iters,
                    'eps': opt.eps,
                    'lr': opt.lr,
                    'best_loss': best_loss,
                    'best_acc': best_acc,
                    'acc_report': acc_report,
                    'loss_report': loss_report
                }
                filename = 'latest'
                utils.save_checkpoint(state, opt.exp_path, filename=filename)

            if iters % opt.validate_freq == 0:
                sche_samp_rate = sample_rampup.update(iters)
                print("iters {} sche_samp_rate {}".format(
                    iters, sche_samp_rate))
                asr_model.eval()
                fbank_model.eval()
                torch.set_grad_enabled(False)
                num_saved_attention = 0
                for i, (data) in tqdm(enumerate(val_loader, start=0)):
                    utt_ids, spk_ids, inputs, log_inputs, targets, input_sizes, target_sizes = data
                    fbank_features = fbank_model(inputs, fbank_cmvn)
                    # utt_ids, spk_ids, fbank_features, targets, input_sizes, target_sizes = data
                    loss_ctc, loss_att, acc = asr_model(
                        fbank_features, targets, input_sizes, target_sizes,
                        0.0)
                    loss = opt.mtlalpha * loss_ctc + (1 -
                                                      opt.mtlalpha) * loss_att
                    errors = {
                        'val/loss': loss.item(),
                        'val/loss_ctc': loss_ctc.item(),
                        'val/acc': acc,
                        'val/loss_att': loss_att.item()
                    }
                    visualizer.set_current_errors(errors)

                    if opt.num_save_attention > 0 and opt.mtlalpha != 1.0:
                        if num_saved_attention < opt.num_save_attention:
                            att_ws = asr_model.calculate_all_attentions(
                                fbank_features, targets, input_sizes,
                                target_sizes)
                            for x in range(len(utt_ids)):
                                att_w = att_ws[x]
                                utt_id = utt_ids[x]
                                file_name = "{}_ep{}_it{}.png".format(
                                    utt_id, epoch, iters)
                                dec_len = int(target_sizes[x])
                                enc_len = int(input_sizes[x])
                                visualizer.plot_attention(
                                    att_w, dec_len, enc_len, file_name)
                                num_saved_attention += 1
                                if num_saved_attention >= opt.num_save_attention:
                                    break
                asr_model.train()
                fbank_model.train()
                torch.set_grad_enabled(True)

                visualizer.print_epoch_errors(epoch, iters)
                acc_report = visualizer.plot_epoch_errors(
                    epoch, iters, 'acc.png')
                loss_report = visualizer.plot_epoch_errors(
                    epoch, iters, 'loss.png')
                val_loss = visualizer.get_current_errors('val/loss')
                val_acc = visualizer.get_current_errors('val/acc')
                filename = None
                if opt.criterion == 'acc' and opt.mtl_mode is not 'ctc':
                    if val_acc < best_acc:
                        logging.info('val_acc {} > best_acc {}'.format(
                            val_acc, best_acc))
                        opt.eps = utils.adadelta_eps_decay(
                            optimizer, opt.eps_decay)
                    else:
                        filename = 'model.acc.best'
                    best_acc = max(best_acc, val_acc)
                    logging.info('best_acc {}'.format(best_acc))
                elif args.criterion == 'loss':
                    if val_loss > best_loss:
                        logging.info('val_loss {} > best_loss {}'.format(
                            val_loss, best_loss))
                        opt.eps = utils.adadelta_eps_decay(
                            optimizer, opt.eps_decay)
                    else:
                        filename = 'model.loss.best'
                    best_loss = min(val_loss, best_loss)
                    logging.info('best_loss {}'.format(best_loss))
                state = {
                    'asr_state_dict': asr_model.state_dict(),
                    'opt': opt,
                    'epoch': epoch,
                    'iters': iters,
                    'eps': opt.eps,
                    'lr': opt.lr,
                    'best_loss': best_loss,
                    'best_acc': best_acc,
                    'acc_report': acc_report,
                    'loss_report': loss_report
                }
                utils.save_checkpoint(state, opt.exp_path, filename=filename)
                ##filename='epoch-{}_iters-{}_loss-{:.4f}_acc-{:.4f}.pth'.format(epoch, iters, val_loss, val_acc)
                ##utils.save_checkpoint(state, opt.exp_path, filename=filename)
                visualizer.reset()
示例#8
0
def main(config: DictConfig) -> None:
    warnings.filterwarnings('ignore')
    print(OmegaConf.to_yaml(config))

    torch.manual_seed(config.train.seed)
    torch.cuda.manual_seed_all(config.train.seed)
    np.random.seed(config.train.seed)
    random.seed(config.train.seed)

    use_cuda = config.train.cuda and torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')

    char2id, id2char = load_label(config.train.label_path,
                                  config.train.blank_id)
    train_audio_paths, train_transcripts, valid_audio_paths, valid_transcripts = load_dataset(
        config.train.dataset_path, config.train.mode)

    train_dataset = SpectrogramDataset(
        config.train.audio_path,
        train_audio_paths,
        train_transcripts,
        config.audio.sampling_rate,
        config.audio.n_mfcc
        if config.audio.feature_extraction == 'mfcc' else config.audio.n_mel,
        config.audio.frame_length,
        config.audio.frame_stride,
        config.audio.extension,
        config.audio.feature_extraction,
        config.audio.normalize,
        config.audio.spec_augment,
        config.audio.freq_mask_parameter,
        config.audio.num_time_mask,
        config.audio.num_freq_mask,
        config.train.sos_id,
        config.train.eos_id,
    )

    train_sampler = BucketingSampler(train_dataset,
                                     batch_size=config.train.batch_size)
    train_loader = AudioDataLoader(
        train_dataset,
        batch_sampler=train_sampler,
        num_workers=config.train.num_workers,
    )

    valid_dataset = SpectrogramDataset(
        config.train.audio_path,
        valid_audio_paths,
        valid_transcripts,
        config.audio.sampling_rate,
        config.audio.n_mfcc
        if config.audio.feature_extraction == 'mfcc' else config.audio.n_mel,
        config.audio.frame_length,
        config.audio.frame_stride,
        config.audio.extension,
        config.audio.feature_extraction,
        config.audio.normalize,
        config.audio.spec_augment,
        config.audio.freq_mask_parameter,
        config.audio.num_time_mask,
        config.audio.num_freq_mask,
        config.train.sos_id,
        config.train.eos_id,
    )
    valid_sampler = BucketingSampler(valid_dataset,
                                     batch_size=config.train.batch_size)
    valid_loader = AudioDataLoader(
        valid_dataset,
        batch_sampler=valid_sampler,
        num_workers=config.train.num_workers,
    )

    model = build_model(config, device)

    optimizer = optim.Adam(model.parameters(), lr=config.train.lr)

    print('Start Train !!!')
    for epoch in range(0, config.train.epochs):
        train(config, model, device, train_loader, valid_loader, train_sampler,
              optimizer, epoch, id2char, epoch)

        if epoch % 2 == 0:
            torch.save(
                model,
                os.path.join(os.getcwd(), config.train.model_save_path +
                             str(epoch) + '.pt'))
示例#9
0
def test_dataloader():
    args = parser.parse_args()

    # Set seeds for determinism
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    device = torch.device("cuda" if args.cuda else "cpu")
    if args.mixed_precision and not args.cuda:
        raise ValueError(
            'If using mixed precision training, CUDA must be enabled!')
    args.distributed = args.world_size > 1
    main_proc = True
    device = torch.device("cuda" if args.cuda else "cpu")
    save_folder = args.save_folder
    os.makedirs(save_folder, exist_ok=True)  # Ensure save folder exists

    loss_results, cer_results, wer_results = torch.Tensor(
        args.epochs), torch.Tensor(args.epochs), torch.Tensor(args.epochs)
    best_wer = None
    if main_proc and args.visdom:
        visdom_logger = VisdomLogger(args.id, args.epochs)
    if main_proc and args.tensorboard:
        tensorboard_logger = TensorBoardLogger(args.id, args.log_dir,
                                               args.log_params)

    avg_loss, start_epoch, start_iter, optim_state = 0, 0, 0, None

    if args.continue_from:  # Starting from previous model
        print("Loading checkpoint model %s" % args.continue_from)
        package = torch.load(args.continue_from,
                             map_location=lambda storage, loc: storage)

        if not args.finetune:  # Don't want to restart training
            optim_state = package['optim_dict']
            start_epoch = int(package.get(
                'epoch', 1)) - 1  # Index start at 0 for training
            start_iter = package.get('iteration', None)
            if start_iter is None:
                start_epoch += 1  # We saved model after epoch finished, start at the next epoch.
                start_iter = 0
            else:
                start_iter += 1
            avg_loss = int(package.get('avg_loss', 0))
            loss_results, cer_results, wer_results = package['loss_results'], package['cer_results'], \
                                                     package['wer_results']
            if main_proc and args.visdom:  # Add previous scores to visdom graph
                visdom_logger.load_previous_values(start_epoch, package)
            if main_proc and args.tensorboard:  # Previous scores to tensorboard logs
                tensorboard_logger.load_previous_values(start_epoch, package)

        print("Loading label from %s" % args.labels_path)
        with open(args.labels_path) as label_file:
            labels = str(''.join(json.load(label_file)))

        audio_conf = dict(sample_rate=args.sample_rate,
                          window_size=args.window_size,
                          window_stride=args.window_stride,
                          window=args.window,
                          noise_dir=args.noise_dir,
                          noise_prob=args.noise_prob,
                          noise_levels=(args.noise_min, args.noise_max))
    else:
        print("must load model!!!")
        exit()

    # decoder = GreedyDecoder(labels)
    train_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                       manifest_filepath=args.train_manifest,
                                       labels=labels,
                                       normalize=True,
                                       augment=args.augment)

    train_sampler = BucketingSampler(train_dataset, batch_size=args.batch_size)
    train_loader = AudioDataLoader(train_dataset,
                                   num_workers=args.num_workers,
                                   batch_sampler=train_sampler)

    for i, (data) in enumerate(train_loader, start=start_iter):
        # 获取初始输入
        inputs, targets, input_percentages, target_sizes = data
        input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
        inputs = inputs.to(device)
        size = inputs.size()
        print(size)
        # 初始化模型
        model = M_Noise_Deepspeech(package, size)
        for para in model.deepspeech_net.parameters():
            para.requires_grad = False
        model = model.to(device)

        # 获取初始输出
        out_star = model.deepspeech_net(inputs, input_sizes)[0]
        out_star = out_star.transpose(0, 1)  # TxNxH
        float_out_star = out_star.float()
        break

    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.SGD(parameters,
                                lr=args.lr,
                                momentum=args.momentum,
                                nesterov=True,
                                weight_decay=1e-5)
    print(model)
示例#10
0
def train_main(args):
    args.distributed = args.world_size > 1
    main_proc = True
    if args.distributed:
        if args.gpu_rank:
            torch.cuda.set_device(int(args.gpu_rank))
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)
        main_proc = args.rank == 0  # Only the first proc should save models
    save_folder = args.save_folder

    loss_results, cer_results, wer_results = torch.Tensor(args.epochs), torch.Tensor(args.epochs), torch.Tensor(
        args.epochs)
    best_wer = None
    if args.visdom and main_proc:
        from visdom import Visdom

        viz = Visdom()
        opts = dict(title=args.id, ylabel='', xlabel='Epoch', legend=['Loss', 'WER', 'CER'])
        viz_window = None
        epochs = torch.arange(1, args.epochs + 1)
    if args.tensorboard and main_proc:
        os.makedirs(args.log_dir, exist_ok=True)
        from tensorboardX import SummaryWriter

        tensorboard_writer = SummaryWriter(args.log_dir)
    os.makedirs(save_folder, exist_ok=True)

    avg_loss, start_epoch, start_iter = 0, 0, 0
    if args.continue_from:  # Starting from previous model
        print("Loading checkpoint model %s" % args.continue_from)
        package = torch.load(args.continue_from, map_location=lambda storage, loc: storage)
        model = DeepSpeech.load_model_package(package)
        labels = DeepSpeech.get_labels(model)
        audio_conf = DeepSpeech.get_audio_conf(model)
        parameters = model.parameters()
        optimizer = torch.optim.SGD(parameters, lr=args.lr,
                                    momentum=args.momentum, nesterov=True)
        if not args.finetune:  # Don't want to restart training
            if args.cuda:
                model.cuda()
            optimizer.load_state_dict(package['optim_dict'])
            start_epoch = int(package.get('epoch', 1)) - 1  # Index start at 0 for training
            start_iter = package.get('iteration', None)
            if start_iter is None:
                start_epoch += 1  # We saved model after epoch finished, start at the next epoch.
                start_iter = 0
            else:
                start_iter += 1
            avg_loss = int(package.get('avg_loss', 0))
            loss_results, cer_results, wer_results = package['loss_results'], package[
                'cer_results'], package['wer_results']
            if main_proc and args.visdom and \
                            package[
                                'loss_results'] is not None and start_epoch > 0:  # Add previous scores to visdom graph
                x_axis = epochs[0:start_epoch]
                y_axis = torch.stack(
                    (loss_results[0:start_epoch], wer_results[0:start_epoch], cer_results[0:start_epoch]),
                    dim=1)
                viz_window = viz.line(
                    X=x_axis,
                    Y=y_axis,
                    opts=opts,
                )
            if main_proc and args.tensorboard and \
                            package[
                                'loss_results'] is not None and start_epoch > 0:  # Previous scores to tensorboard logs
                for i in range(start_epoch):
                    values = {
                        'Avg Train Loss': loss_results[i],
                        'Avg WER': wer_results[i],
                        'Avg CER': cer_results[i]
                    }
                    tensorboard_writer.add_scalars(args.id, values, i + 1)
    else:
        with open(args.labels_path) as label_file:
            labels = str(''.join(json.load(label_file)))

        audio_conf = dict(sample_rate=args.sample_rate,
                          window_size=args.window_size,
                          window_stride=args.window_stride,
                          window=args.window,
                          noise_dir=args.noise_dir,
                          noise_prob=args.noise_prob,
                          noise_levels=(args.noise_min, args.noise_max))

        rnn_type = args.rnn_type.lower()
        assert rnn_type in supported_rnns, "rnn_type should be either lstm, rnn or gru"
        model = DeepSpeech(rnn_hidden_size=args.hidden_size,
                           nb_layers=args.hidden_layers,
                           labels=labels,
                           rnn_type=supported_rnns[rnn_type],
                           audio_conf=audio_conf,
                           bidirectional=args.bidirectional)
        parameters = model.parameters()
        optimizer = torch.optim.SGD(parameters, lr=args.lr,
                                    momentum=args.momentum, nesterov=True)
    criterion = CTCLoss()
    decoder = GreedyDecoder(labels)
    train_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.train_manifest, labels=labels,
                                       normalize=True, augment=args.augment)
    test_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.val_manifest, labels=labels,
                                      normalize=True, augment=False)
    if not args.distributed:
        train_sampler = BucketingSampler(train_dataset, batch_size=args.batch_size)
    else:
        train_sampler = DistributedBucketingSampler(train_dataset, batch_size=args.batch_size,
                                                    num_replicas=args.world_size, rank=args.rank)
    train_loader = AudioDataLoader(train_dataset,
                                   num_workers=args.num_workers, batch_sampler=train_sampler)
    test_loader = AudioDataLoader(test_dataset, batch_size=args.batch_size,
                                  num_workers=args.num_workers)

    if (not args.no_shuffle and start_epoch != 0) or args.no_sorta_grad:
        print("Shuffling batches for the following epochs")
        train_sampler.shuffle(start_epoch)

    if args.cuda:
        model.cuda()
        if args.distributed:
            model = torch.nn.parallel.DistributedDataParallel(model,
                                                              device_ids=(int(args.gpu_rank),) if args.rank else None)

    print(model)
    print("Number of parameters: %d" % DeepSpeech.get_param_size(model))

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    for epoch in range(start_epoch, args.epochs):
        model.train()
        end = time.time()
        start_epoch_time = time.time()
        for i, (data) in enumerate(train_loader, start=start_iter):
            if i == len(train_sampler):
                break
            inputs, targets, input_percentages, target_sizes = data
            input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
            # measure data loading time
            data_time.update(time.time() - end)

            if args.cuda:
                inputs = inputs.cuda()

            out, output_sizes = model(inputs, input_sizes)
            out = out.transpose(0, 1)  # TxNxH

            loss = criterion(out, targets, output_sizes, target_sizes)
            loss = loss / inputs.size(0)  # average the loss by minibatch

            inf = float("inf")
            if args.distributed:
                loss_value = reduce_tensor(loss, args.world_size)[0]
            else:
                loss_value = loss.item()
            if loss_value == inf or loss_value == -inf:
                print("WARNING: received an inf loss, setting loss value to 0")
                loss_value = 0

            avg_loss += loss_value
            losses.update(loss_value, inputs.size(0))

            # compute gradient
            optimizer.zero_grad()
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
            # SGD step
            optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            if not args.silent:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                    (epoch + 1), (i + 1), len(train_sampler), batch_time=batch_time, data_time=data_time, loss=losses))
            if args.checkpoint_per_batch > 0 and i > 0 and (i + 1) % args.checkpoint_per_batch == 0 and main_proc:
                file_path = '%s/deepspeech_checkpoint_epoch_%d_iter_%d.pth' % (save_folder, epoch + 1, i + 1)
                print("Saving checkpoint model to %s" % file_path)
                torch.save(DeepSpeech.serialize(model, optimizer=optimizer, epoch=epoch, iteration=i,
                                                loss_results=loss_results,
                                                wer_results=wer_results, cer_results=cer_results, avg_loss=avg_loss),
                           file_path)
            del loss
            del out
        avg_loss /= len(train_sampler)

        epoch_time = time.time() - start_epoch_time
        print('Training Summary Epoch: [{0}]\t'
              'Time taken (s): {epoch_time:.0f}\t'
              'Average Loss {loss:.3f}\t'.format(epoch + 1, epoch_time=epoch_time, loss=avg_loss))

        start_iter = 0  # Reset start iteration for next epoch
        total_cer, total_wer = 0, 0
        model.eval()
        with torch.no_grad():
            for i, (data) in tqdm(enumerate(test_loader), total=len(test_loader)):
                inputs, targets, input_percentages, target_sizes = data
                input_sizes = input_percentages.mul_(int(inputs.size(3))).int()

                # unflatten targets
                split_targets = []
                offset = 0
                for size in target_sizes:
                    split_targets.append(targets[offset:offset + size])
                    offset += size

                if args.cuda:
                    inputs = inputs.cuda()

                out, output_sizes = model(inputs, input_sizes)

                decoded_output, _ = decoder.decode(out.data, output_sizes)
                target_strings = decoder.convert_to_strings(split_targets)
                wer, cer = 0, 0
                for x in range(len(target_strings)):
                    transcript, reference = decoded_output[x][0], target_strings[x][0]
                    wer += decoder.wer(transcript, reference) / float(len(reference.split()))
                    cer += decoder.cer(transcript, reference) / float(len(reference))
                total_cer += cer
                total_wer += wer
                del out
            wer = total_wer / len(test_loader.dataset)
            cer = total_cer / len(test_loader.dataset)
            wer *= 100
            cer *= 100
            loss_results[epoch] = avg_loss
            wer_results[epoch] = wer
            cer_results[epoch] = cer
            print('Validation Summary Epoch: [{0}]\t'
                  'Average WER {wer:.3f}\t'
                  'Average CER {cer:.3f}\t'.format(epoch + 1, wer=wer, cer=cer))

            if args.visdom and main_proc:
                x_axis = epochs[0:epoch + 1]
                y_axis = torch.stack(
                    (loss_results[0:epoch + 1], wer_results[0:epoch + 1], cer_results[0:epoch + 1]), dim=1)
                if viz_window is None:
                    viz_window = viz.line(
                        X=x_axis,
                        Y=y_axis,
                        opts=opts,
                    )
                else:
                    viz.line(
                        X=x_axis.unsqueeze(0).expand(y_axis.size(1), x_axis.size(0)).transpose(0, 1),  # Visdom fix
                        Y=y_axis,
                        win=viz_window,
                        update='replace',
                    )
            if args.tensorboard and main_proc:
                values = {
                    'Avg Train Loss': avg_loss,
                    'Avg WER': wer,
                    'Avg CER': cer
                }
                tensorboard_writer.add_scalars(args.id, values, epoch + 1)
                if args.log_params:
                    for tag, value in model.named_parameters():
                        tag = tag.replace('.', '/')
                        tensorboard_writer.add_histogram(tag, to_np(value), epoch + 1)
                        tensorboard_writer.add_histogram(tag + '/grad', to_np(value.grad), epoch + 1)
            if args.checkpoint and main_proc:
                file_path = '%s/deepspeech_%d.pth' % (save_folder, epoch + 1)
                torch.save(DeepSpeech.serialize(model, optimizer=optimizer, epoch=epoch, loss_results=loss_results,
                                                wer_results=wer_results, cer_results=cer_results),
                           file_path)
                # anneal lr
                optim_state = optimizer.state_dict()
                optim_state['param_groups'][0]['lr'] = optim_state['param_groups'][0]['lr'] / args.learning_anneal
                optimizer.load_state_dict(optim_state)
                print('Learning rate annealed to: {lr:.6f}'.format(lr=optim_state['param_groups'][0]['lr']))

            if (best_wer is None or best_wer > wer) and main_proc:
                print("Found better validated model, saving to %s" % args.model_path)
                torch.save(DeepSpeech.serialize(model, optimizer=optimizer, epoch=epoch, loss_results=loss_results,
                                                wer_results=wer_results, cer_results=cer_results), args.model_path)
                best_wer = wer

                avg_loss = 0
            if not args.no_shuffle:
                print("Shuffling batches...")
                train_sampler.shuffle(epoch)
示例#11
0
        model = Model(model_conf, audio_conf, labels)

    model = model.to(device)

    # Data inputs configuration
    train_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                       manifest_filepath=args.train_manifest,
                                       labels=labels)
    val_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                     manifest_filepath=args.val_manifest,
                                     labels=labels)
    label_decoder = LabelDecoder(labels)

    if not args.distributed:
        train_sampler = BucketingSampler(train_dataset, batch_size=batch_size)
        val_sampler = BucketingSampler(val_dataset, batch_size=batch_size)
    else:
        train_sampler = DistributedBucketingSampler(
            train_dataset,
            batch_size=batch_size,
            num_replicas=args.world_size,
            rank=args.rank)
        val_sampler = DistributedBucketingSampler(val_dataset,
                                                  batch_size=batch_size,
                                                  num_replicas=args.world_size,
                                                  rank=args.rank)

    train_loader = AudioDataLoader(train_dataset,
                                   num_workers=args.num_workers,
                                   batch_sampler=train_sampler)
示例#12
0
    decoder = MyDecoder(labels)

    train_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.train_manifest, metadata_file_path=metadata_path, labels=labels,
                                       normalize=True, speed_volume_perturb=args.speed_volume_perturb, spec_augment=args.spec_augment)

    test_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.val_manifest,  metadata_file_path=metadata_path, labels=labels,
                                      normalize=True, speed_volume_perturb=False, spec_augment=False)

    print('\n Num samples train dataset: ', len(train_dataset))
    print('\n Num samples val   dataset: ', len(test_dataset))

    # for it in range(num_iter_cv):

    if not distributed:
        if sampler == 'bucketing':
            train_sampler = BucketingSampler(train_dataset, batch_size=args.batch_size, shuffle=True)
        if sampler == 'random':
            train_sampler = RandomBucketingSampler(train_dataset, batch_size=args.batch_size)
    else:
        train_sampler = DistributedBucketingSampler(train_dataset, batch_size=args.batch_size, num_replicas=args.world_size, rank=args.rank)
    train_loader = AudioDataLoader(train_dataset,
                                   num_workers=args.num_workers, batch_sampler=train_sampler)
    test_loader = AudioDataLoader(test_dataset, batch_size=args.batch_size,
                                  num_workers=args.num_workers, shuffle=True)



    if (not args.no_shuffle and start_epoch != 0) or args.no_sorta_grad:
        print("Shuffling batches for the following epochs")
        train_sampler.shuffle(start_epoch)
示例#13
0
                           nb_layers=args.hidden_layers,
                           labels=labels,
                           rnn_type=supported_rnns[rnn_type],
                           audio_conf=audio_conf,
                           bidirectional=args.bidirectional)
        parameters = model.parameters()
        optimizer = torch.optim.SGD(parameters, lr=args.lr,
                                    momentum=args.momentum, nesterov=True)
    criterion = CTCLoss()
    decoder = GreedyDecoder(labels)
    train_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.train_manifest, labels=labels,
                                       normalize=True, augment=args.augment)
    test_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.val_manifest, labels=labels,
                                      normalize=True, augment=False)
    if not args.distributed:
        train_sampler = BucketingSampler(train_dataset, batch_size=args.batch_size)
    else:
        train_sampler = DistributedBucketingSampler(train_dataset, batch_size=args.batch_size,
                                                    num_replicas=args.world_size, rank=args.rank)
    train_loader = AudioDataLoader(train_dataset,
                                   num_workers=args.num_workers, batch_sampler=train_sampler)
    test_loader = AudioDataLoader(test_dataset, batch_size=args.batch_size,
                                  num_workers=args.num_workers)

    if (not args.no_shuffle and start_epoch != 0) or args.no_sorta_grad:
        print("Shuffling batches for the following epochs")
        train_sampler.shuffle(start_epoch)

    if args.cuda:
        model.cuda()
        if args.distributed:
     labels=labels,
     normalize=True,
     augment=args.augment)
 train_dataset_adv = SpectrogramDataset(
     audio_conf=audio_conf,
     manifest_filepath=args.train_manifest_adv,
     labels=labels,
     normalize=True,
     augment=args.augment)
 test_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                   manifest_filepath=args.val_manifest,
                                   labels=labels,
                                   normalize=True,
                                   augment=False)
 if not args.distributed:
     train_sampler_clean = BucketingSampler(train_dataset_clean,
                                            batch_size=args.batch_size)
     train_sampler_adv = BucketingSampler(train_dataset_adv,
                                          batch_size=args.batch_size)
 else:
     train_sampler_clean = DistributedBucketingSampler(
         train_dataset_clean,
         batch_size=args.batch_size,
         num_replicas=args.world_size,
         rank=args.rank)
     train_sampler_adv = DistributedBucketingSampler(
         train_dataset_adv,
         batch_size=args.batch_size,
         num_replicas=args.world_size,
         rank=args.rank)
 train_loader_clean = AudioDataLoader(train_dataset_clean,
                                      num_workers=args.num_workers,
示例#15
0
    def train(self, **kwargs):
        """
        Run optimization to train the model.

        Parameters
        ----------


        """
        world_size = kwargs.pop('world_size', 1)
        gpu_rank = kwargs.pop('gpu_rank', 0)
        rank = kwargs.pop('rank', 0)
        dist_backend = kwargs.pop('dist_backend', 'nccl')
        dist_url = kwargs.pop('dist_url', None)

        os.environ['MASTER_ADDR'] = '127.0.0.1'
        os.environ['MASTER_PORT'] = '1234'

        main_proc = True
        self.distributed = world_size > 1

        if self.distributed:
            if self.gpu_rank:
                torch.cuda.set_device(int(gpu_rank))
            dist.init_process_group(backend=dist_backend,
                                    init_method=dist_url,
                                    world_size=world_size,
                                    rank=rank)
            print('Initiated process group')
            main_proc = rank == 0  # Only the first proc should save models

        if main_proc and self.tensorboard:
            tensorboard_logger = TensorBoardLogger(self.id,
                                                   self.log_dir,
                                                   self.log_params,
                                                   comment=self.sufix)

        if self.distributed:
            train_sampler = DistributedBucketingSampler(
                self.data_train,
                batch_size=self.batch_size,
                num_replicas=world_size,
                rank=rank)
        else:
            if self.sampler_type == 'bucketing':
                train_sampler = BucketingSampler(self.data_train,
                                                 batch_size=self.batch_size,
                                                 shuffle=True)
            if self.sampler_type == 'random':
                train_sampler = RandomBucketingSampler(
                    self.data_train, batch_size=self.batch_size)

        print("Shuffling batches for the following epochs..")
        train_sampler.shuffle(self.start_epoch)

        train_loader = AudioDataLoader(self.data_train,
                                       num_workers=self.num_workers,
                                       batch_sampler=train_sampler)
        val_loader = AudioDataLoader(self.data_val,
                                     batch_size=self.batch_size_val,
                                     num_workers=self.num_workers,
                                     shuffle=True)

        if self.tensorboard and self.generate_graph:  # TO DO get some audios also
            with torch.no_grad():
                inputs, targets, input_percentages, target_sizes = next(
                    iter(train_loader))
                input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
                tensorboard_logger.add_image(inputs,
                                             input_sizes,
                                             targets,
                                             network=self.model)

        self.model = self.model.to(self.device)
        parameters = self.model.parameters()

        if self.update_rule == 'adam':
            optimizer = torch.optim.Adam(parameters,
                                         lr=self.lr,
                                         weight_decay=self.reg)
        if self.update_rule == 'sgd':
            optimizer = torch.optim.SGD(parameters,
                                        lr=self.lr,
                                        weight_decay=self.reg)

        self.model, self.optimizer = amp.initialize(
            self.model,
            optimizer,
            opt_level=self.opt_level,
            keep_batchnorm_fp32=self.keep_batchnorm_fp32,
            loss_scale=self.loss_scale)

        if self.optim_state is not None:
            self.optimizer.load_state_dict(self.optim_state)

        if self.amp_state is not None:
            amp.load_state_dict(self.amp_state)

        if self.distributed:
            self.model = DistributedDataParallel(self.model)

        print(self.model)

        if self.criterion_type == 'cross_entropy_loss':
            self.criterion = torch.nn.CrossEntropyLoss()

        #  Useless for now because I don't save.
        accuracies_train_iters = []
        losses_iters = []

        avg_loss = 0
        batch_time = AverageMeter()
        epoch_time = AverageMeter()
        losses = AverageMeter()

        start_training = time.time()
        for epoch in range(self.start_epoch, self.num_epochs):
            print("Start epoch..")

            # Put model in train mode
            self.model.train()

            y_true_train_epoch = np.array([])
            y_pred_train_epoch = np.array([])

            start_epoch = time.time()
            for i, (data) in enumerate(train_loader, start=0):
                start_batch = time.time()

                print('Start batch..')

                if i == len(train_sampler):  # QUE pq isso deus
                    break

                inputs, targets, input_percentages, _ = data

                input_sizes = input_percentages.mul_(int(inputs.size(3))).int()

                inputs = inputs.to(self.device)
                targets = targets.to(self.device)

                output, loss_value = self._step(inputs, input_sizes, targets)

                print('Step finished.')

                avg_loss += loss_value

                with torch.no_grad():
                    y_pred = self.decoder.decode(output.detach()).cpu().numpy()

                    # import pdb; pdb.set_trace()

                    y_true_train_epoch = np.concatenate(
                        (y_true_train_epoch, targets.cpu().numpy()
                         ))  # maybe I should do it with tensors?
                    y_pred_train_epoch = np.concatenate(
                        (y_pred_train_epoch, y_pred))

                inputs_size = inputs.size(0)
                del output, inputs, input_percentages

                if self.intra_epoch_sanity_check:
                    with torch.no_grad():
                        acc, _ = self.check_accuracy(targets.cpu().numpy(),
                                                     y_pred=y_pred)
                        accuracies_train_iters.append(acc)
                        losses_iters.append(loss_value)

                        cm = confusion_matrix(targets.cpu().numpy(),
                                              y_pred,
                                              labels=self.labels)
                        print('[it %i/%i] Confusion matrix train step:' %
                              ((i + 1, len(train_sampler))))
                        print(pd.DataFrame(cm))

                        if self.tensorboard:
                            tensorboard_logger.update(
                                len(train_loader) * epoch + i + 1, {
                                    'Loss/through_iterations': loss_value,
                                    'Accuracy/train_through_iterations': acc
                                })

                del targets

                batch_time.update(time.time() - start_batch)

            epoch_time.update(time.time() - start_epoch)
            losses.update(loss_value, inputs_size)

            # Write elapsed time (and loss) to terminal
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Batch {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Epoch {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                      (epoch + 1), (i + 1),
                      len(train_sampler),
                      batch_time=batch_time,
                      data_time=epoch_time,
                      loss=losses))

            # Loss log
            avg_loss /= len(train_sampler)
            self.loss_epochs.append(avg_loss)

            # Accuracy train log
            acc_train, _ = self.check_accuracy(y_true_train_epoch,
                                               y_pred=y_pred_train_epoch)
            self.accuracy_train_epochs.append(acc_train)

            # Accuracy val log
            with torch.no_grad():
                y_pred_val = np.array([])
                targets_val = np.array([])
                for data in val_loader:
                    inputs, targets, input_percentages, _ = data
                    input_sizes = input_percentages.mul_(int(
                        inputs.size(3))).int()
                    _, y_pred_val_batch = self.check_accuracy(
                        targets.cpu().numpy(),
                        inputs=inputs,
                        input_sizes=input_sizes)
                    y_pred_val = np.concatenate((y_pred_val, y_pred_val_batch))
                    targets_val = np.concatenate(
                        (targets_val, targets.cpu().numpy()
                         ))  # TO DO: think of a smarter way to do this later
                    del inputs, targets, input_percentages

            # import pdb; pdb.set_trace()
            acc_val, y_pred_val = self.check_accuracy(targets_val,
                                                      y_pred=y_pred_val)
            self.accuracy_val_epochs.append(acc_val)
            cm = confusion_matrix(targets_val, y_pred_val, labels=self.labels)
            print('Confusion matrix validation:')
            print(pd.DataFrame(cm))

            # Write epoch stuff to tensorboard
            if self.tensorboard:
                tensorboard_logger.update(
                    epoch + 1, {'Loss/through_epochs': avg_loss},
                    parameters=self.model.named_parameters)

                tensorboard_logger.update(epoch + 1, {
                    'train': acc_train,
                    'validation': acc_val
                },
                                          together=True,
                                          name='Accuracy/through_epochs')

            # Keep track of the best model
            if acc_val > self.best_acc_val:
                self.best_acc_val = acc_val
                self.best_params = {}
                for k, v in self.model.named_parameters(
                ):  # TO DO: actually copy model and save later? idk..
                    self.best_params[k] = v.clone()

            # Anneal learning rate. TO DO: find better way to this this specific to every parameter as cs231n does.
            for g in self.optimizer.param_groups:
                g['lr'] = g['lr'] / self.learning_anneal
            print('Learning rate annealed to: {lr:.6f}'.format(lr=g['lr']))

            # Shuffle batches order
            print("Shuffling batches...")
            train_sampler.shuffle(epoch)

            # Rechoose batches elements
            if self.sampler_type == 'random':
                train_sampler.recompute_bins()

        end_training = time.time()

        if self.tensorboard:
            tensorboard_logger.close()

        print('Elapsed time in training: %.02f ' %
              ((end_training - start_training) / 60.0))