Example #1
0
    def train(self):
        self.model.train()
        dataset = SpeechDataset(N_CLASS, SLICE_LENGTH, FRAME_LENGTH,
                                FRAME_STRIDE, TEST_SIZE, self.device)
        dataset.init_dataset(test_mode=False)
        data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
        optimizer = torch.optim.Adam(self.model.parameters(), lr=LEARNING_RATE)
        for epoch in range(MAX_EPOCHS):
            for i, data in enumerate(data_loader):
                x, y, cond = data
                pred_y = self.model(x, cond)
                loss = F.cross_entropy(pred_y, y)
                optimizer.zero_grad()
                loss.backward()
                clip_grad_norm_(self.model.parameters(), MAX_NORM)
                optimizer.step()

                if i % PRINT_FREQ == 0:
                    self.logger.info(
                        'epoch: %d, step:%d, tot_step:%d, loss: %f' %
                        (epoch, i, self.tot_steps, loss.item()))
                if i % VALID_FREQ == 0:
                    self.validate()
                    self.model.eval()

                self.tot_steps += 1
                if self.tot_steps % 100 == 0:
                    self.save_model()
Example #2
0
def main(DEVICE):

    # define model, optimizer, scheduler
    model = VQVC().to(DEVICE)

    recon_loss = nn.L1Loss().to(DEVICE)
    vocoder = get_vocgan(
        ckpt_path=args.vocoder_pretrained_model_path).to(DEVICE)

    mel_stat = torch.tensor(np.load(args.mel_stat_path)).to(DEVICE)

    optimizer = Adam(model.parameters(), lr=args.init_lr)
    scheduler = WarmupScheduler(optimizer,
                                warmup_epochs=args.warmup_steps,
                                initial_lr=args.init_lr,
                                max_lr=args.max_lr,
                                milestones=args.milestones,
                                gamma=args.gamma)

    global_step = load_checkpoint(checkpoint_path=args.model_checkpoint_path,
                                  model=model,
                                  optimizer=optimizer,
                                  scheduler=scheduler)

    # load dataset & dataloader
    train_dataset = SpeechDataset(mem_mode=args.mem_mode,
                                  meta_dir=args.prepro_meta_train,
                                  dataset_name=args.dataset_name,
                                  mel_stat_path=args.mel_stat_path,
                                  max_frame_length=args.max_frame_length)
    eval_dataset = SpeechDataset(mem_mode=args.mem_mode,
                                 meta_dir=args.prepro_meta_eval,
                                 dataset_name=args.dataset_name,
                                 mel_stat_path=args.mel_stat_path,
                                 max_frame_length=args.max_frame_length)

    train_data_loader = DataLoader(dataset=train_dataset,
                                   batch_size=args.train_batch_size,
                                   shuffle=True,
                                   drop_last=True,
                                   pin_memory=True,
                                   num_workers=args.n_workers)
    eval_data_loader = DataLoader(dataset=eval_dataset,
                                  batch_size=args.train_batch_size,
                                  shuffle=False,
                                  pin_memory=True,
                                  drop_last=True)

    # tensorboard
    writer = Writer(args.model_log_path) if args.log_tensorboard else None

    # train the model!
    train(train_data_loader, eval_data_loader, model, recon_loss, vocoder,
          mel_stat, optimizer, scheduler, global_step, writer, DEVICE)
Example #3
0
 def validate(self):
     self.model.eval()
     dataset = SpeechDataset(N_CLASS, SLICE_LENGTH, FRAME_LENGTH,
                             FRAME_STRIDE, TEST_SIZE, self.device)
     dataset.init_dataset(test_mode=True)
     data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
     res = []
     for i, data in enumerate(data_loader):
         if i == MAX_VALID:
             break
         x, y, cond = data
         pred_y = self.model(x, cond)
         loss = F.cross_entropy(pred_y.squeeze(), y.squeeze())
         res.append(loss.item())
     self.logger.info('valid loss: ' + str(sum(res) / len(res)))
Example #4
0
 def generate(self):
     self.model.eval()
     dataset = SpeechDataset(N_CLASS, SLICE_LENGTH, FRAME_LENGTH,
                             FRAME_STRIDE, TEST_SIZE, self.device)
     dataset.init_dataset(test_mode=True)
     data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
     for i, data in enumerate(data_loader):
         if i == MAX_GENERATE:
             break
         _, _, cond = data
         res = self.model.generate(cond, MAX_GENERATE_LENGTH)
         res = dequantize_signal(res, N_CLASS)
         for j in range(res.shape[0]):
             librosa.output.write_wav(
                 './samples/sample%d.wav' % (self.sample_count), res[j],
                 SAMPLE_RATE)
             self.sample_count += 1
Example #5
0
def get_dataloader(n_jobs, noisy_list, clean_list, batch_size, shuffle=False):
    def collate_fn(samples):
        niy_samples = [s[0] for s in samples]
        cln_samples = [s[1] for s in samples]
        lengths = torch.LongTensor([len(s[0]) for s in samples])

        niy_samples = torch.nn.utils.rnn.pad_sequence(
            niy_samples, batch_first=True)
        cln_samples = torch.nn.utils.rnn.pad_sequence(
            cln_samples, batch_first=True)
        return lengths, niy_samples.transpose(-1, -2).contiguous(), cln_samples.transpose(-1, -2).contiguous()

    dataloader = torch.utils.data.DataLoader(SpeechDataset(
        noisy_list, clean_list), batch_size, collate_fn=collate_fn, num_workers=n_jobs, shuffle=shuffle)

    return dataloader
Example #6
0
def train_model(cfg):
    tensorboard_path = Path(utils.to_absolute_path("tensorboard")) / cfg.checkpoint_dir
    checkpoint_dir = Path(utils.to_absolute_path(cfg.checkpoint_dir))
    writer = SummaryWriter(tensorboard_path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    encoder = Encoder(**cfg.model.encoder)
    decoder = Decoder(**cfg.model.decoder)
    encoder.to(device)
    decoder.to(device)

    optimizer = optim.Adam(
        chain(encoder.parameters(), decoder.parameters()),
        lr=cfg.training.optimizer.lr)
    [encoder, decoder], optimizer = amp.initialize([encoder, decoder], optimizer, opt_level="O1")
    scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=cfg.training.scheduler.milestones,
        gamma=cfg.training.scheduler.gamma)

    if cfg.resume:
        print("Resume checkpoint from: {}:".format(cfg.resume))
        resume_path = utils.to_absolute_path(cfg.resume)
        checkpoint = torch.load(resume_path, map_location=lambda storage, loc: storage)
        encoder.load_state_dict(checkpoint["encoder"])
        decoder.load_state_dict(checkpoint["decoder"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        amp.load_state_dict(checkpoint["amp"])
        scheduler.load_state_dict(checkpoint["scheduler"])
        global_step = checkpoint["step"]
    else:
        global_step = 0

    root_path = Path(utils.to_absolute_path("datasets")) / cfg.dataset.path
    dataset = SpeechDataset(
        root=root_path,
        hop_length=cfg.preprocessing.hop_length,
        sr=cfg.preprocessing.sr,
        sample_frames=cfg.training.sample_frames)

    dataloader = DataLoader(
        dataset,
        batch_size=cfg.training.batch_size,
        shuffle=True,
        num_workers=cfg.training.n_workers,
        pin_memory=True,
        drop_last=True)

    n_epochs = cfg.training.n_steps // len(dataloader) + 1
    start_epoch = global_step // len(dataloader) + 1

    for epoch in range(start_epoch, n_epochs + 1):
        average_recon_loss = average_vq_loss = average_perplexity = 0

        for i, (audio, mels, speakers) in enumerate(tqdm(dataloader), 1):
            audio, mels, speakers = audio.to(device), mels.to(device), speakers.to(device)

            optimizer.zero_grad()

            z, vq_loss, perplexity = encoder(mels)
            output = decoder(audio[:, :-1], z, speakers)
            recon_loss = F.cross_entropy(output.transpose(1, 2), audio[:, 1:])
            loss = recon_loss + vq_loss

            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()

            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1)
            optimizer.step()
            scheduler.step()

            average_recon_loss += (recon_loss.item() - average_recon_loss) / i
            average_vq_loss += (vq_loss.item() - average_vq_loss) / i
            average_perplexity += (perplexity.item() - average_perplexity) / i

            global_step += 1

            if global_step % cfg.training.checkpoint_interval == 0:
                save_checkpoint(
                    encoder, decoder, optimizer, amp,
                    scheduler, global_step, checkpoint_dir)

        writer.add_scalar("recon_loss/train", average_recon_loss, global_step)
        writer.add_scalar("vq_loss/train", average_vq_loss, global_step)
        writer.add_scalar("average_perplexity", average_perplexity, global_step)

        print("epoch:{}, recon loss:{:.2E}, vq loss:{:.2E}, perpexlity:{:.3f}"
              .format(epoch, average_recon_loss, average_vq_loss, average_perplexity))
Example #7
0
    args = parse_args()
    print('=' * 20)
    print('Input arguments:\n%s' % (args))

    # Validate arguments.
    if args.mode == 'test' and (not args.reload_model
                                or args.model_path == None):
        raise ValueError("Input Argument Error: Test mode specified but reload_model is %s and model_path is %s." \
                        % (args.reload_model, args.model_path))

    if (args.reload_model and (args.model_path == None)):
        raise ValueError("Input Argument Error: Reload model specified true but model_path is %s." \
                        % (args.model_path))

    # Instantiate speech dataset.
    speechTrainDataset = SpeechDataset(FRAME_CONTEXT_RANGE, mode='train')
    speechTestDataset = SpeechDataset(
        FRAME_CONTEXT_RANGE,
        mode='test',
    )
    speechValDataset = SpeechDataset(FRAME_CONTEXT_RANGE, mode='dev')
    train_loader = DataLoader(speechTrainDataset,
                              batch_size=args.train_batch_size,
                              shuffle=True,
                              num_workers=8)
    test_loader = DataLoader(speechTestDataset,
                             batch_size=args.test_batch_size,
                             shuffle=False,
                             num_workers=8)
    val_loader = DataLoader(speechValDataset,
                            batch_size=args.train_batch_size,
Example #8
0
def train_model(resume):

    with open(Path("./cfg/cfg.json").absolute()) as file:
        para = json.load(file)
    tensorboard_path = Path("./tensorboard/writer").absolute()
    checkpoint_dir = Path("./checkpoint").absolute()
    writer = SummaryWriter(tensorboard_path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    encoder = Encoder(in_channels=para['encoder']['in_channels'], channels=para['encoder']['channels'],
                      n_embeddings=para['encoder']['n_embeddings'], embedding_dim=para['encoder']['embedding_dim'], jitter=para['encoder']['jitter'])
    decoder = Decoder(in_channels=para['decoder']['in_channels'], conditioning_channels=para['decoder']['conditioning_channels'],
                      n_speakers = para['decoder']['n_speakers'], speaker_embedding_dim=para['decoder']['speaker_embedding_dim'],
                      mu_embedding_dim=para['decoder']['mu_embedding_dim'], rnn_channels=para['decoder']['rnn_channels'], fc_channels=para['decoder']['fc_channels'],
                      bits=para['decoder']['bits'], hop_length=para['decoder']['hop_length'])


    encoder.to(device)
    decoder.to(device)



    if resume:

        resume_path = Path("./checkpoint/model.pt").absolute()
        print("Resume checkpoint from: {}:".format(str(resume_path)))
        checkpoint = torch.load(resume_path, map_location=lambda storage, loc: storage)
        print(checkpoint.keys())
        encoder.load_state_dict(checkpoint["encoder"])
        decoder.load_state_dict(checkpoint["decoder"])
        optimizer = optim.Adam(
            chain(encoder.parameters(), decoder.parameters()),
            lr=1e-5)

        # [encoder, decoder], optimizer = amp.initialize([encoder, decoder], optimizer, opt_level="O1")
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[300000, 400000],
            gamma=0.5)
        optimizer.load_state_dict(checkpoint["optimizer"])
        #amp.load_state_dict(checkpoint["amp"])
        scheduler.load_state_dict(checkpoint["scheduler"])
        global_step = checkpoint["step"]


    else:
        global_step = 0
        optimizer = optim.Adam(
            chain(encoder.parameters(), decoder.parameters()),
            lr=1e-5)

        # [encoder, decoder], optimizer = amp.initialize([encoder, decoder], optimizer, opt_level="O1")
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[300000, 400000],
            gamma=0.5)


    sdataset = SpeechDataset(
        root='./preprocessed_file/train',
        hop_length=para['preprocess']['hop_length'],
        sr=para['preprocess']['sr'],
        sample_frames=para['preprocess']['sample_frames'])
    print(len(sdataset))
    dataloader = DataLoader(
        dataset=sdataset,
        batch_size=16,
        shuffle=True,
        num_workers=1,
        pin_memory=True,
        drop_last=True)

    print(len(dataloader))
    n_epochs = 1
#    start_epoch = global_step // len(dataloader) + 1

    for epoch in range(global_step, global_step+n_epochs):
        average_recon_loss = average_vq_loss = average_perplexity = 0

        for i, (audio, mels, speakers) in enumerate(tqdm(dataloader), 1):
            #audio, mels, speakers = audio.to(device), mels.to(device), speakers.to(device)
            #print(speakers)
            optimizer.zero_grad()
            z, vq_loss, perplexity = encoder(mels)
            output = decoder(audio[:, :-1], z, speakers)
            recon_loss = F.cross_entropy(output.transpose(1, 2), audio[:, 1:])
            loss = recon_loss + vq_loss

            loss.backward()

            #with amp.scale_loss(loss, optimizer) as scaled_loss:
            #    scaled_loss.backward()

            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1)
            optimizer.step()
            scheduler.step()

            average_recon_loss += (recon_loss.item() - average_recon_loss) / i
            average_vq_loss += (vq_loss.item() - average_vq_loss) / i
            average_perplexity += (perplexity.item() - average_perplexity) / i

            global_step += 1


        save_checkpoint(
                encoder, decoder, optimizer, amp,
                scheduler, global_step, checkpoint_dir)

        writer.add_scalar("recon_loss/train", average_recon_loss, global_step)
        writer.add_scalar("vq_loss/train", average_vq_loss, global_step)
        writer.add_scalar("average_perplexity", average_perplexity, global_step)

        print("epoch:{}, recon loss:{:.2E}, vq loss:{:.2E}, perpexlity:{:.3f}"
              .format(epoch, average_recon_loss, average_vq_loss, average_perplexity))
def train():
    parser = argparse.ArgumentParser()
    parser.add_argument("-conf", type=str)
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()

    debug = args.debug
    config = configparser.ConfigParser()
    config.read(args.conf)

    log_path = config["log"]["log_path"]
    log_step = int(config["log"]["log_step"])
    log_dir = os.path.dirname(log_path)
    os.makedirs(log_dir, exist_ok=True)

    save_prefix = config["save"]["save_prefix"]
    save_format = save_prefix + ".network.epoch{}"
    optimizer_save_format = save_prefix + ".optimizer.epoch{}"
    save_step = int(config["save"]["save_step"])
    save_dir = os.path.dirname(save_prefix)
    os.makedirs(save_dir, exist_ok=True)

    num_epochs = int(config["train"]["num_epochs"])
    batch_size = int(config["train"]["batch_size"])
    decay_start_epoch = int(config["train"]["decay_start_epoch"])
    decay_rate = float(config["train"]["decay_rate"])
    vocab_size = int(config["vocab"]["vocab_size"])
    ls_prob = float(config["train"]["ls_prob"])
    distill_weight = float(config["distill"]["distill_weight"])

    if debug:
        logging.basicConfig(format="%(asctime)s %(message)s",
                            level=logging.INFO)  # to stdout
    else:
        logging.basicConfig(filename=log_path,
                            format="%(asctime)s %(message)s",
                            level=logging.DEBUG)

    model = AttnModel(args.conf)
    model.apply(init_weight)
    model.to(device)
    optimizer = optim.Adam(model.parameters(), weight_decay=1e-5)

    dataset = SpeechDataset(args.conf)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            collate_fn=collate_fn,
                            num_workers=2,
                            pin_memory=True)
    num_steps = len(dataloader)

    for epoch in range(num_epochs):
        loss_sum = 0

        for step, data in enumerate(dataloader):
            loss_step = train_step(model, optimizer, data, vocab_size, ls_prob,
                                   distill_weight)
            loss_sum += loss_step

            if (step + 1) % log_step == 0:
                logging.info(
                    "epoch = {:>2} step = {:>6} / {:>6} loss = {:.3f}".format(
                        epoch + 1, step + 1, num_steps, loss_sum / log_step))
                loss_sum = 0

        if epoch == 0 or (epoch + 1) % save_step == 0:
            save_path = save_format.format(epoch + 1)
            torch.save(model.state_dict(), save_path)
            optimizer_save_path = optimizer_save_format.format(epoch + 1)
            torch.save(optimizer.state_dict(), optimizer_save_path)
            logging.info("model saved to: {}".format(save_path))
            logging.info("optimizer saved to: {}".format(optimizer_save_path))
        update_epoch(model, epoch + 1)
        decay_lr(optimizer, epoch + 1, decay_start_epoch, decay_rate)
    # Parse args.
    args = parse_args()
    print('=' * 20)
    print('Input arguments:\n%s' % (args))

    # Validate arguments.
    if args.mode == 'test' and args.model_path == None and not args.model_ensemble:
        raise ValueError(
            "Input Argument Error: Test mode specified but model_path is %s." %
            (args.model_path))

    # Check for CUDA.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create datasets and dataloaders.
    speechTrainDataset = SpeechDataset(mode='train')
    speechValDataset = SpeechDataset(mode='dev')
    speechTestDataset = SpeechDataset(mode='test')

    train_loader = DataLoader(speechTrainDataset,
                              batch_size=args.train_batch_size,
                              shuffle=True,
                              num_workers=4,
                              collate_fn=SpeechCollateFn)
    val_loader = DataLoader(speechValDataset,
                            batch_size=args.train_batch_size,
                            shuffle=False,
                            num_workers=4,
                            collate_fn=SpeechCollateFn)
    test_loader = DataLoader(speechTestDataset,
                             batch_size=args.test_batch_size,
Example #11
0
def main(args):

    # Set seeds for determinism
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    labels = LABELS

    audio_conf = dict(sample_rate=args.sample_rate,
                      window_size=args.window_size,
                      window_stride=args.window_stride,
                      window=args.window)

    rnn_type = args.rnn_type.lower()
    assert rnn_type in supported_rnns, "rnn_type should be either lstm, rnn or gru"
    model = DeepSpeech(rnn_hidden_size=args.rnn_hidden_size,
                       nb_layers=args.hidden_layers,
                       labels=labels,
                       rnn_type=supported_rnns[rnn_type],
                       audio_conf=audio_conf,
                       bidirectional=args.bidirectional)

    # Data setup
    evaluation_decoder = GreedyDecoder(
        model.labels)  # Decoder used for validation

    train_df = pd.read_csv(args.train_path)
    train_dataset = SpeechDataset(args=args, df=train_df)

    test_df = pd.read_csv(args.test_path)
    test_dataset = SpeechDataset(args=args, df=test_df)

    train_loader = AudioDataLoader(dataset=train_dataset,
                                   num_workers=args.num_workers,
                                   batch_size=args.batch_size)

    test_loader = AudioDataLoader(dataset=test_dataset,
                                  num_workers=args.num_workers,
                                  batch_size=args.batch_size)

    model = model.to(args.device)
    parameters = model.parameters()

    optimizer = torch.optim.AdomW(parameters,
                                  lr=args.learning_rate,
                                  betas=args.betas,
                                  eps=args.eps,
                                  weight_decay=args.weight_decay)

    criterion = CTCLoss()

    best_score = 99999

    for epoch in range(args.epochs):
        train_loss = train_fn(args, train_loader, model, optimizer, criterion)
        wer, cer, output_data = test_fn(args=args,
                                        model=model,
                                        decoder=evaluation_decoder,
                                        target_decoder=evaluation_decoder)

        print('Validation Summary Epoch: [{0}]\t'
              'Average WER {wer:.3f}\t'
              'Average CER {cer:.3f}\t'.format(epoch + 1, wer=wer, cer=cer))

        if (wer + cer) / 2 < best_score:
            print("**** Model Improved !!!! Saving Model")
            torch.save(model.state_dict(), f"best_model.bin")
            best_score = (wer + cer) / 2
Example #12
0
dataset = {}
loader = {}

for type_ in types:
    print('Loading %s dataset ... ' % (type_), end='')
    x_paths[type_] = glob(join(data_dir, 'x_' + type_, '*.bin'))
    t_paths[type_] = glob(join(data_dir, 't_' + type_, '*.bin'))

    x_dim = config.get_feature_config().get_linguistic_dim(type_)
    t_dim = config.get_feature_config().get_parm_dim(type_)
    max_len = get_max_length(x_paths[type_], x_dim)
    pad_value = config.get_feature_config().pad_value

    dataset[type_] = SpeechDataset(x_paths[type_],
                                   t_paths[type_],
                                   x_dim=x_dim,
                                   t_dim=t_dim,
                                   max_len=max_len,
                                   pad_value=pad_value)
    batch_size = config.get_train_config().batch_size
    loader[type_] = DataLoader(dataset[type_],
                               batch_size=batch_size,
                               shuffle=True)
    print('done!')
    print('\tDataset Size\t%d' % (len(x_paths[type_])))
    print('\tInput Dim\t%d' % (x_dim))
    print('\tTarget Dim\t%d\n' % (t_dim))

# ---------- Make Model ----------

print('Creating models ...')
model = {}
Example #13
0
from torch.utils.data import DataLoader
from dataset import SpeechDataset, collate_fn_train, collate_fn_eval

"""
dataset = SpeechDataset(config_path="params.conf")

dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=False, collate_fn=collate_fn_train,
                        num_workers=2, pin_memory=True)

for data in dataloader:
    print(data)
    print(data["x_batch"].shape, data["seq_lens"].shape, data["labels"].shape, data["lab_lens"].shape)

    break

"""
dataset_eval = SpeechDataset(config_path="params.conf", no_label=True)

dataloader = DataLoader(dataset=dataset_eval, batch_size=1, shuffle=False, collate_fn=collate_fn_eval, num_workers=2)

for data in dataloader:
    print(data)
    print(data["x_batch"].shape, data["seq_lens"].shape)

    break
Example #14
0
 def create_dataset(self):
     dataset = SpeechDataset(N_CLASS, SLICE_LENGTH, FRAME_LENGTH,
                             FRAME_STRIDE, TEST_SIZE, self.device)
     dataset.create_dataset(MAX_FILES, FILE_PREFIX)
Example #15
0
def train():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config_path", type=str, default="params.conf")
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()

    config_path = args.config_path
    debug = args.debug

    # load configs
    config = configparser.ConfigParser()
    config.read(config_path)
    log_dir = config["log"]["log_path"]
    log_step = int(config["log"]["log_step"])
    save_dir = config["save"]["save_path"]
    save_step = int(config["save"]["save_step"])
    num_epochs = int(config["train"]["num_epochs"])
    batch_size = int(config["train"]["batch_size"])
    vocab_size = int(config["vocab"]["vocab_size"])

    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(save_dir, exist_ok=True)

    dt_now = datetime.now()
    dt_str = dt_now.strftime("%m%d%H%M%S")

    if debug:
        logging.basicConfig(format="%(asctime)s %(message)s",
                            level=logging.INFO)  # to stdout
    else:
        log_path = log_dir + "train_attn_{}.log".format(dt_str)

        logging.basicConfig(filename=log_path,
                            format="%(asctime)s %(message)s",
                            level=logging.DEBUG)

    logging.info("process id: {:d} is allocated".format(os.getpid()))

    model = AttnModel(config_path)
    model.to(device)

    optimizer = optim.Adam(model.parameters(), weight_decay=1e-5)

    dataset = SpeechDataset(config_path)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            collate_fn=collate_fn_train,
                            num_workers=2,
                            pin_memory=True)

    num_steps = len(dataset)

    for epoch in range(num_epochs):
        loss_sum = 0

        for step, data in enumerate(dataloader):
            loss_step = train_step(model, optimizer, data, vocab_size)
            loss_sum += loss_step

            if (step + 1) % log_step == 0:
                logging.info(
                    "epoch = {:>2} step = {:>6} / {:>6} loss = {:.3f}".format(
                        epoch + 1, (step + 1) * batch_size, num_steps,
                        loss_sum / log_step))
                loss_sum = 0

        if epoch == 0 or (epoch + 1) % save_step == 0:
            save_path = save_dir + "attention{}.network.epoch{}".format(
                dt_str, epoch + 1)
            torch.save(model.state_dict(), save_path)
            optimizer_save_path = save_dir + "attention{}.optimizer.epoch{}".format(
                dt_str, epoch + 1)
            torch.save(optimizer.state_dict(), optimizer_save_path)
from os.path import join
from glob import glob

from dataset import SpeechDataset

work_dir = './'
data_dir = join(work_dir, 'data')

x_paths = glob(join(data_dir, 'x_acoustic', '*.bin'))
t_paths = glob(join(data_dir, 't_acoustic', '*.bin'))

dataset = SpeechDataset(x_paths,
                        t_paths,
                        x_dim=610,
                        t_dim=127,
                        max_len=5000,
                        pad_value=9999)

x, t, x_l, t_l = dataset[0]

x = x[:x_l]
t = t[:t_l]
print(t)
print(t.shape)
print(type(t))