Exemplo n.º 1
0
def finetune(args):
    # Construct Solver
    # data
    token2idx_src, idx2token_src = load_vocab(args.vocab_src)
    token2idx_tgt, idx2token_tgt = load_vocab(args.vocab_tgt)
    args.n_src = len(idx2token_src)
    args.n_tgt = len(idx2token_tgt)

    tr_dataset = VQ_Pred_Dataset(args.train_src,
                                 args.train_tgt,
                                 token2idx_src,
                                 token2idx_tgt,
                                 args.batch_size,
                                 args.maxlen_in,
                                 args.maxlen_out,
                                 down_sample_rate=args.down_sample_rate)
    cv_dataset = VQ_Pred_Dataset(args.valid_src,
                                 args.valid_tgt,
                                 token2idx_src,
                                 token2idx_tgt,
                                 args.batch_size,
                                 args.maxlen_in,
                                 args.maxlen_out,
                                 down_sample_rate=args.down_sample_rate)
    tr_loader = DataLoader(tr_dataset,
                           batch_size=1,
                           collate_fn=f_xy_pad,
                           num_workers=args.num_workers,
                           shuffle=args.shuffle)
    cv_loader = DataLoader(cv_dataset,
                           batch_size=1,
                           collate_fn=f_xy_pad,
                           num_workers=args.num_workers)

    # load dictionary and generate char_list, sos_id, eos_id
    data = {'tr_loader': tr_loader, 'cv_loader': cv_loader}

    if args.structure == 'BERT':
        from mask_lm.Mask_LM import Mask_LM as Model
        from mask_lm.solver import Mask_LM_Solver as Solver

        model = Model.create_model(args)

    print(model)
    model.cuda()

    # optimizer
    optimizier = TransformerOptimizer(
        torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09),
        args.k, args.d_model, args.warmup_steps)

    # solver
    solver = Solver(data, model, optimizier, args)
    solver.train()
Exemplo n.º 2
0
def train_net(args):
    torch.manual_seed(7)
    np.random.seed(7)
    checkpoint = args.checkpoint
    start_epoch = 0
    best_loss = float('inf')
    writer = SummaryWriter()
    epochs_since_improvement = 0

    # Initialize / load checkpoint
    if checkpoint is None:
        # model
        encoder = Encoder(n_src_vocab,
                          args.n_layers_enc,
                          args.n_head,
                          args.d_k,
                          args.d_v,
                          args.d_model,
                          args.d_inner,
                          dropout=args.dropout,
                          pe_maxlen=args.pe_maxlen)
        decoder = Decoder(
            sos_id,
            eos_id,
            n_tgt_vocab,
            args.d_word_vec,
            args.n_layers_dec,
            args.n_head,
            args.d_k,
            args.d_v,
            args.d_model,
            args.d_inner,
            dropout=args.dropout,
            tgt_emb_prj_weight_sharing=args.tgt_emb_prj_weight_sharing,
            pe_maxlen=args.pe_maxlen)
        model = Transformer(encoder, decoder)
        # print(model)
        # model = nn.DataParallel(model)

        # optimizer
        optimizer = TransformerOptimizer(
            torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09))

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']

    # Move to GPU, if available
    model = model.to(device)

    # Custom dataloaders
    train_dataset = AiChallenger2017Dataset('train')
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               collate_fn=pad_collate,
                                               shuffle=True,
                                               num_workers=args.num_workers)
    valid_dataset = AiChallenger2017Dataset('valid')
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               collate_fn=pad_collate,
                                               shuffle=False,
                                               num_workers=args.num_workers)

    # Epochs
    for epoch in range(start_epoch, args.epochs):
        # One epoch's training
        train_loss = train(train_loader=train_loader,
                           model=model,
                           optimizer=optimizer,
                           epoch=epoch,
                           logger=logger,
                           writer=writer)

        writer.add_scalar('epoch/train_loss', train_loss, epoch)
        writer.add_scalar('epoch/learning_rate', optimizer.lr, epoch)

        print('\nLearning rate: {}'.format(optimizer.lr))
        print('Step num: {}\n'.format(optimizer.step_num))

        # One epoch's validation
        valid_loss = valid(valid_loader=valid_loader,
                           model=model,
                           logger=logger)
        writer.add_scalar('epoch/valid_loss', valid_loss, epoch)

        # Check if there was an improvement
        is_best = valid_loss < best_loss
        best_loss = min(valid_loss, best_loss)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(epoch, epochs_since_improvement, model, optimizer,
                        best_loss, is_best)
Exemplo n.º 3
0
def main():
    global char2index
    global index2char
    global SOS_token
    global EOS_token
    global PAD_token

    parser = argparse.ArgumentParser(description='Speech hackathon Baseline')

    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='batch size in training (default: 32)')
    parser.add_argument(
        '--workers',
        type=int,
        default=4,
        help='number of workers in dataset loader (default: 4)')
    parser.add_argument('--max_epochs',
                        type=int,
                        default=10,
                        help='number of max epochs in training (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0001,
                        help='learning rate (default: 0.0001)')
    parser.add_argument('--teacher_forcing',
                        type=float,
                        default=0.5,
                        help='teacher forcing ratio in decoder (default: 0.5)')
    parser.add_argument('--max_len',
                        type=int,
                        default=WORD_MAXLEN,
                        help='maximum characters of sentence (default: 80)')
    parser.add_argument('--no_cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        help='random seed (default: 1)')
    parser.add_argument('--save_name',
                        type=str,
                        default='model',
                        help='the name of model in nsml or local')
    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument("--pause", type=int, default=0)
    parser.add_argument(
        '--word',
        action='store_true',
        help='Train/Predict model using word based label (default: False)')
    parser.add_argument('--gen_label_index',
                        action='store_true',
                        help='Generate word label index map(default: False)')
    parser.add_argument('--iteration', type=str, help='Iteratiom')
    parser.add_argument('--premodel_session',
                        type=str,
                        help='Session name of premodel')

    # transformer model parameter
    parser.add_argument('--d_model',
                        type=int,
                        default=128,
                        help='transformer_d_model')
    parser.add_argument('--n_head',
                        type=int,
                        default=8,
                        help='transformer_n_head')
    parser.add_argument('--num_encoder_layers',
                        type=int,
                        default=4,
                        help='num_encoder_layers')
    parser.add_argument('--num_decoder_layers',
                        type=int,
                        default=4,
                        help='transformer_num_decoder_layers')
    parser.add_argument('--dim_feedforward',
                        type=int,
                        default=2048,
                        help='transformer_d_model')
    parser.add_argument('--dropout',
                        type=float,
                        default=0.1,
                        help='transformer_dropout')

    # transformer warmup parameter
    parser.add_argument('--warmup_multiplier',
                        type=int,
                        default=3,
                        help='transformer_warmup_multiplier')
    parser.add_argument('--warmup_epoch',
                        type=int,
                        default=10,
                        help='transformer_warmup_epoch')

    args = parser.parse_args()
    char_loader = CharLabelLoader()
    char_loader.load_char2index('./hackathon.labels')
    label_loader = char_loader
    if args.word:
        if args.gen_label_index:
            generate_word_label_index_file(char_loader, TRAIN_LABEL_CHAR_PATH)
            from subprocess import call
            call(f'cat {TRAIN_LABEL_CHAR_PATH}', shell=True)
        # ??? ??? ??? ??
        word_loader = CharLabelLoader()
        word_loader.load_char2index('./hackathon.pos.labels')
        label_loader = word_loader
        if os.path.exists(TRAIN_LABEL_CHAR_PATH):
            generate_word_label_file(char_loader, word_loader,
                                     TRAIN_LABEL_POS_PATH,
                                     TRAIN_LABEL_CHAR_PATH)
    char2index = label_loader.char2index
    index2char = label_loader.index2char
    SOS_token = char2index['<s>']
    EOS_token = char2index['</s>']
    PAD_token = char2index['_']
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device('cuda' if args.cuda else 'cpu')

    ############ model
    print("model: transformer")
    # model = Transformer(d_model= args.d_model, n_head= args.n_head, num_encoder_layers= args.num_encoder_layers, num_decoder_layers= args.num_decoder_layers,
    #                     dim_feedforward= args.dim_feedforward, dropout= args.dropout, vocab_size= len(char2index), sound_maxlen= SOUND_MAXLEN, word_maxlen= WORD_MAXLEN)

    encoder = Encoder(d_input=128,
                      n_layers=6,
                      n_head=4,
                      d_k=128,
                      d_v=128,
                      d_model=128,
                      d_inner=2048,
                      dropout=0.1,
                      pe_maxlen=SOUND_MAXLEN)
    decoder = Decoder(sos_id=SOS_token,
                      eos_id=EOS_token,
                      n_tgt_vocab=len(char2index),
                      d_word_vec=128,
                      n_layers=6,
                      n_head=4,
                      d_k=128,
                      d_v=128,
                      d_model=128,
                      d_inner=2048,
                      dropout=0.1,
                      tgt_emb_prj_weight_sharing=True,
                      pe_maxlen=SOUND_MAXLEN)
    model = Transformer(encoder, decoder)

    optimizer = TransformerOptimizer(
        torch.optim.Adam(model.parameters(),
                         lr=0.0004,
                         betas=(0.9, 0.98),
                         eps=1e-09))

    ############/

    for param in model.parameters():
        param.data.uniform_(-0.08, 0.08)

    model = nn.DataParallel(model).to(device)
    """
    optimizer = optim.Adam(model.module.parameters(), lr=args.lr)

    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.max_epochs)
    scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=args.warmup_multiplier, total_epoch=args.warmup_epoch, after_scheduler=scheduler_cosine)
    
    
    criterion = nn.CrossEntropyLoss(reduction='sum', ignore_index=PAD_token).to(device)
    """

    bind_model(model, optimizer)

    if args.pause == 1:
        nsml.paused(scope=locals())

    if args.mode != "train":
        return

    data_list = os.path.join(DATASET_PATH, 'train_data', 'data_list.csv')
    wav_paths = list()
    script_paths = list()

    with open(data_list, 'r') as f:
        for line in f:
            # line: "aaa.wav,aaa.label"

            wav_path, script_path = line.strip().split(',')
            wav_paths.append(os.path.join(DATASET_PATH, 'train_data',
                                          wav_path))
            script_paths.append(
                os.path.join(DATASET_PATH, 'train_data', script_path))

    best_loss = 1e10
    begin_epoch = 0

    # load all target scripts for reducing disk i/o
    # target_path = os.path.join(DATASET_PATH, 'train_label')
    target_path = TRAIN_LABEL_CHAR_PATH
    if args.word:
        target_path = TRAIN_LABEL_POS_PATH
    load_targets(target_path)

    train_batch_num, train_dataset_list, valid_dataset = split_dataset(
        args, wav_paths, script_paths, valid_ratio=0.05)

    if args.iteration:
        if args.premodel_session:
            nsml.load(args.iteration, session=args.premodel_session)
            logger.info(f'Load {args.premodel_session} {args.iteration}')
        else:
            nsml.load(args.iteration)
            logger.info(f'Load {args.iteration}')
    logger.info('start')

    train_begin = time.time()

    for epoch in range(begin_epoch, args.max_epochs):
        # learning rate scheduler

        train_queue = queue.Queue(args.workers * 2)

        train_loader = MultiLoader(train_dataset_list, train_queue,
                                   args.batch_size, args.workers)
        train_loader.start()

        train_loss, train_cer = train(model, train_batch_num, train_queue,
                                      optimizer, device, train_begin,
                                      args.workers, 10, args.teacher_forcing)
        logger.info('Epoch %d (Training) Loss %0.4f CER %0.4f' %
                    (epoch, train_loss, train_cer))

        train_loader.join()

        print("~~~~~~~~~~~~")

        if epoch == 10 or (epoch > 48 and epoch % 10 == 9):
            valid_queue = queue.Queue(args.workers * 2)
            valid_loader = BaseDataLoader(valid_dataset, valid_queue,
                                          args.batch_size, 0)
            valid_loader.start()

            eval_loss, eval_cer = evaluate(model, valid_loader, valid_queue,
                                           device, args.max_len,
                                           args.batch_size)
            logger.info('Epoch %d (Evaluate) Loss %0.4f CER %0.4f' %
                        (epoch, eval_loss, eval_cer))

            valid_loader.join()

            nsml.report(False,
                        step=epoch,
                        train_epoch__loss=train_loss,
                        train_epoch__cer=train_cer,
                        eval__loss=eval_loss,
                        eval__cer=eval_cer)

            best_model = (eval_loss < best_loss)
            nsml.save(args.save_name)

            if best_model:
                nsml.save('best')
                best_loss = eval_loss
Exemplo n.º 4
0
def train_net(args):
    # 为了保证程序执行结果一致, 给随机化设定种子
    torch.manual_seed(7)
    np.random.seed(7)
    checkpoint = args.checkpoint

    start_epoch = 0
    writer = SummaryWriter()

    if checkpoint is None:
        # model
        encoder = Encoder(Config.vocab_size, args.n_layers_enc, args.n_head,
                          args.d_k, args.d_v, args.d_model, args.d_inner,
                          dropout=args.dropout, pe_maxlen=args.pe_maxlen)

        decoder = Decoder(Config.sos_id, Config.eos_id, Config.vocab_size,
                          args.d_word_vec, args.n_layers_dec, args.n_head,
                          args.d_k, args.d_v, args.d_model, args.d_inner,
                          dropout=args.dropout,
                          tgt_emb_prj_weight_sharing=args.tgt_emb_prj_weight_sharing,
                          pe_maxlen=args.pe_maxlen)

        model = Transformer(encoder, decoder)

        # optimizer
        optimizer = TransformerOptimizer(
            torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09))

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']

    # Move to GPU, if available
    model = model.to(Config.device)

    # Custom dataloaders  数据的加载 注意这里指定了一个参数collate_fn代表的数据需要padding
    train_dataset = TranslateDataset()

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=pad_collate,
                                               shuffle=True, num_workers=args.num_workers)

    # Epochs
    Loss_list = []
    for epoch in range(start_epoch, args.epochs):
        # One epoch's training
        train_loss = train(train_loader=train_loader,
                           model=model,
                           optimizer=optimizer,
                           epoch=epoch,
                           logger=logger,
                           writer=writer)

        l = str(train_loss)
        Loss_list.append(l)

        l_temp = l + '\n'
        with open('loss_epoch.txt', 'a+') as f:
            f.write(l_temp)

        writer.add_scalar('epoch/train_loss', train_loss, epoch)
        writer.add_scalar('epoch/learning_rate', optimizer.lr, epoch)

        print('\nLearning rate: {}'.format(optimizer.lr))
        print('Step num: {}\n'.format(optimizer.step_num))

        # Save checkpoint
        save_checkpoint(epoch, model, optimizer, train_loss)
    with open('loss.txt', 'w') as f:
        f.write('\n'.join(Loss_list))
Exemplo n.º 5
0
def main(args):
    # Construct Solver
    # data
    token2idx, idx2token = load_vocab(args.vocab)
    args.vocab_size = len(token2idx)
    args.sos_id = token2idx['<sos>']
    args.eos_id = token2idx['<eos>']

    tr_dataset = AudioDataset(args.train_json,
                              args.batch_size,
                              args.maxlen_in,
                              args.maxlen_out,
                              batch_frames=args.batch_frames)
    cv_dataset = AudioDataset(args.valid_json,
                              args.batch_size,
                              args.maxlen_in,
                              args.maxlen_out,
                              batch_frames=args.batch_frames)
    tr_loader = AudioDataLoader(tr_dataset,
                                batch_size=1,
                                token2idx=token2idx,
                                label_type=args.label_type,
                                num_workers=args.num_workers,
                                shuffle=args.shuffle,
                                LFR_m=args.LFR_m,
                                LFR_n=args.LFR_n)
    cv_loader = AudioDataLoader(cv_dataset,
                                batch_size=1,
                                token2idx=token2idx,
                                label_type=args.label_type,
                                num_workers=args.num_workers,
                                LFR_m=args.LFR_m,
                                LFR_n=args.LFR_n)
    # load dictionary and generate char_list, sos_id, eos_id
    data = {'tr_loader': tr_loader, 'cv_loader': cv_loader}

    if args.structure == 'transformer':
        from transformer.Transformer import Transformer
        from transformer.solver import Transformer_Solver as Solver

        model = Transformer.create_model(args)

    elif args.structure == 'transformer-ctc':
        from transformer.Transformer import CTC_Transformer as Transformer
        from transformer.solver import Transformer_CTC_Solver as Solver

        model = Transformer.create_model(args)

    elif args.structure == 'conv-transformer-ctc':
        from transformer.Transformer import Conv_CTC_Transformer as Transformer
        from transformer.solver import Transformer_CTC_Solver as Solver

        model = Transformer.create_model(args)

    elif args.structure == 'cif':
        from transformer.CIF_Model import CIF_Model
        from transformer.solver import CIF_Solver as Solver

        model = CIF_Model.create_model(args)

    print(model)
    model.cuda()

    # optimizer
    optimizier = TransformerOptimizer(
        torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09),
        args.k, args.d_model, args.warmup_steps)

    # solver
    solver = Solver(data, model, optimizier, args)
    solver.train()
Exemplo n.º 6
0
        eos_id,
        vocab_size,
        args.d_word_vec,
        args.n_layers_dec,
        args.n_head,
        args.d_k,
        args.d_v,
        args.d_model,
        args.d_inner,
        dropout=args.dropout,
        tgt_emb_prj_weight_sharing=args.tgt_emb_prj_weight_sharing,
        pe_maxlen=args.pe_maxlen)
    model = Transformer(encoder, decoder)

    optimizer = TransformerOptimizer(
        torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09),
        args.k, args.d_model, args.warmup_steps)

    print(args.k)
    print(args.d_model)
    print(args.warmup_steps)

    lr_list = []
    for step_num in range(1, 50000):
        # print(step_num)
        lr_1 = k * init_lr * min(step_num**(-0.5),
                                 step_num * (warmup_steps**(-1.5)))
        optimizer.step()
        lr_2 = optimizer.lr
        # print(lr_1)
        # print(lr_2)
Exemplo n.º 7
0
 def __init__(self):
     dir_path = os.path.dirname(os.path.realpath(__file__))
     self.train_json = os.path.join(dir_path, self.train_json)
     self.valid_json = os.path.join(dir_path, self.valid_json)
     self.dict_txt = os.path.join(dir_path, self.dict_txt)
     self.char_list, self.sos_id, self.eos_id = process_dict(self.dict_txt)
     self.vocab_size = len(self.char_list)
     self.tr_dataset = AudioDataset(self.train_json,
                                    self.batch_size,
                                    self.maxlen_in,
                                    self.maxlen_out,
                                    batch_frames=self.batch_frames)
     self.cv_dataset = AudioDataset(self.valid_json,
                                    self.batch_size,
                                    self.maxlen_in,
                                    self.maxlen_out,
                                    batch_frames=self.batch_frames)
     self.tr_loader = AudioDataLoader(self.tr_dataset,
                                      batch_size=1,
                                      num_workers=self.num_workers,
                                      shuffle=self.shuffle,
                                      LFR_m=self.LFR_m,
                                      LFR_n=self.LFR_n)
     self.cv_loader = AudioDataLoader(self.cv_dataset,
                                      batch_size=1,
                                      num_workers=self.num_workers,
                                      LFR_m=self.LFR_m,
                                      LFR_n=self.LFR_n)
     self.data = {'tr_loader': self.tr_loader, 'cv_loader': self.cv_loader}
     self.encoder = Encoder(self.d_input * self.LFR_m,
                            self.n_layers_enc,
                            self.n_head,
                            self.d_k,
                            self.d_v,
                            self.d_model,
                            self.d_inner,
                            dropout=self.dropout,
                            pe_maxlen=self.pe_maxlen)
     self.decoder = Decoder(
         self.sos_id,
         self.eos_id,
         self.vocab_size,
         self.d_word_vec,
         self.n_layers_dec,
         self.n_head,
         self.d_k,
         self.d_v,
         self.d_model,
         self.d_inner,
         dropout=self.dropout,
         tgt_emb_prj_weight_sharing=self.tgt_emb_prj_weight_sharing,
         pe_maxlen=self.pe_maxlen)
     self.tr_loss = torch.Tensor(self.epochs)
     self.cv_loss = torch.Tensor(self.epochs)
     self.model = Transformer(self.encoder, self.decoder)
     self.optimizer = TransformerOptimizer(
         torch.optim.Adam(self.model.parameters(),
                          betas=(0.9, 0.98),
                          eps=1e-09), self.k, self.d_model,
         self.warmup_steps)
     self._reset()
Exemplo n.º 8
0
class SpeechTransformerTrainConfig:
    # Low Frame Rate
    LFR_m = 4
    LFR_n = 3
    # Network Architecture - Encoder
    d_input = 80
    n_layers_enc = 6
    n_head = 8
    d_k = 64
    d_v = 64
    d_model = 512
    d_inner = 2048
    dropout = 0.1
    pe_maxlen = 5000
    d_word_vec = 512
    n_layers_dec = 6
    tgt_emb_prj_weight_sharing = 1
    label_smoothing = 0.1
    # minibatch
    shuffle = 1
    batch_size = 16
    batch_frames = 15000
    maxlen_in = 800
    maxlen_out = 150
    num_workers = 4
    # optimizer
    k = 0.2
    warmup_steps = 1
    # solver configs
    epochs = 5
    save_folder = "output_data"
    checkpoint = False
    continue_from = False
    model_path = 'final.pth.tar'
    print_freq = 10
    visdom = 0
    visdom_lr = 0
    visdom_epoch = 0
    visdom_id = 0
    # The input files. Their paths are relative to the directory of __file__
    train_json = "input_data/train/data.json"
    valid_json = "input_data/dev/data.json"
    dict_txt = "input_data/lang_1char/train_chars.txt"

    def __init__(self):
        dir_path = os.path.dirname(os.path.realpath(__file__))
        self.train_json = os.path.join(dir_path, self.train_json)
        self.valid_json = os.path.join(dir_path, self.valid_json)
        self.dict_txt = os.path.join(dir_path, self.dict_txt)
        self.char_list, self.sos_id, self.eos_id = process_dict(self.dict_txt)
        self.vocab_size = len(self.char_list)
        self.tr_dataset = AudioDataset(self.train_json,
                                       self.batch_size,
                                       self.maxlen_in,
                                       self.maxlen_out,
                                       batch_frames=self.batch_frames)
        self.cv_dataset = AudioDataset(self.valid_json,
                                       self.batch_size,
                                       self.maxlen_in,
                                       self.maxlen_out,
                                       batch_frames=self.batch_frames)
        self.tr_loader = AudioDataLoader(self.tr_dataset,
                                         batch_size=1,
                                         num_workers=self.num_workers,
                                         shuffle=self.shuffle,
                                         LFR_m=self.LFR_m,
                                         LFR_n=self.LFR_n)
        self.cv_loader = AudioDataLoader(self.cv_dataset,
                                         batch_size=1,
                                         num_workers=self.num_workers,
                                         LFR_m=self.LFR_m,
                                         LFR_n=self.LFR_n)
        self.data = {'tr_loader': self.tr_loader, 'cv_loader': self.cv_loader}
        self.encoder = Encoder(self.d_input * self.LFR_m,
                               self.n_layers_enc,
                               self.n_head,
                               self.d_k,
                               self.d_v,
                               self.d_model,
                               self.d_inner,
                               dropout=self.dropout,
                               pe_maxlen=self.pe_maxlen)
        self.decoder = Decoder(
            self.sos_id,
            self.eos_id,
            self.vocab_size,
            self.d_word_vec,
            self.n_layers_dec,
            self.n_head,
            self.d_k,
            self.d_v,
            self.d_model,
            self.d_inner,
            dropout=self.dropout,
            tgt_emb_prj_weight_sharing=self.tgt_emb_prj_weight_sharing,
            pe_maxlen=self.pe_maxlen)
        self.tr_loss = torch.Tensor(self.epochs)
        self.cv_loss = torch.Tensor(self.epochs)
        self.model = Transformer(self.encoder, self.decoder)
        self.optimizer = TransformerOptimizer(
            torch.optim.Adam(self.model.parameters(),
                             betas=(0.9, 0.98),
                             eps=1e-09), self.k, self.d_model,
            self.warmup_steps)
        self._reset()

    def _reset(self):
        self.prev_val_loss = float("inf")
        self.best_val_loss = float("inf")
        self.halving = False

    def _run_one_epoch(self, cross_valid=False):
        total_loss = 0
        data_loader = self.tr_loader if not cross_valid else self.cv_loader
        for i, (data) in enumerate(data_loader):
            padded_input, input_lengths, padded_target = data
            padded_input = padded_input.cuda()
            input_lengths = input_lengths.cuda()
            padded_target = padded_target.cuda()
            pred, gold = self.model(padded_input, input_lengths, padded_target)
            loss, n_correct = cal_performance(pred,
                                              gold,
                                              smoothing=self.label_smoothing)
            if not cross_valid:
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            total_loss += loss.item()
            non_pad_mask = gold.ne(IGNORE_ID)
            n_word = non_pad_mask.sum().item()
            return total_loss / (i + 1)

    def train(self, epoch=1):
        self.model.train()
        tr_avg_loss = self._run_one_epoch()
        # Cross validation
        self.model.eval()
        val_loss = self._run_one_epoch(cross_valid=True)
        self.tr_loss[epoch] = tr_avg_loss
        self.cv_loss[epoch] = val_loss
        if val_loss < self.best_val_loss:
            self.best_val_loss = val_loss