Exemple #1
0
def test_sortagrad(swap_io):
    dummy_json = make_dummy_json(128, [1, 700], [1, 700])
    if swap_io:
        batchset = make_batchset(
            dummy_json,
            16,
            2**10,
            2**10,
            batch_sort_key="input",
            shortest_first=True,
            swap_io=True,
        )
        key = "output"
    else:
        batchset = make_batchset(dummy_json,
                                 16,
                                 2**10,
                                 2**10,
                                 shortest_first=True)
        key = "input"
    prev_start_ilen = batchset[0][0][1][key][0]["shape"][0]
    for batch in batchset:
        cur_start_ilen = batch[0][1][key][0]["shape"][0]
        assert cur_start_ilen >= prev_start_ilen
        prev_ilen = cur_start_ilen
        for sample in batch:
            cur_ilen = sample[1][key][0]["shape"][0]
            assert cur_ilen <= prev_ilen
            prev_ilen = cur_ilen
        prev_start_ilen = cur_start_ilen
Exemple #2
0
def get_iter(args):

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # make minibatch list (variable length)
    train = make_batchset(train_json, args.batch_size,
                          args.maxlen_in, args.maxlen_out, args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          shortest_first=use_sortagrad,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout,
                          iaxis=0, oaxis=0,
                          perturb_sampling=args.perturb_sampling)
    valid = make_batchset(valid_json, args.batch_size,
                          args.maxlen_in, args.maxlen_out, args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout,
                          iaxis=0, oaxis=0)

    load_tr = LoadInputsAndTargets(
        mode='asr', load_output=True, preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': True}, train=True  # Switch the mode of preprocessing
    )
    load_cv = LoadInputsAndTargets(
        mode='asr', load_output=True, preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': False}, train=False  # Switch the mode of preprocessing
    )
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    # default collate function converts numpy array to pytorch tensor
    # we used an empty collate function instead which returns list
    if args.train_dtype in ("float16", "float32", "float64"):
        dtype = getattr(torch, args.train_dtype)
    else:
        dtype = torch.float32
    converter = CustomConverter(subsampling_factor=1, dtype=dtype)
    train_iter = {'main': ChainerDataLoader(
        dataset=TransformDataset(train, lambda data: converter([load_tr(data)])),
        batch_size=1, num_workers=args.n_iter_processes,
        shuffle=not use_sortagrad, collate_fn=lambda x: x[0])}
    valid_iter = {'main': ChainerDataLoader(
        dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])),
        batch_size=1, shuffle=False, collate_fn=lambda x: x[0],
        num_workers=args.n_iter_processes)}
    
    return train_iter, valid_iter
Exemple #3
0
def test_model_trainable_and_decodable(module, num_encs, model_dict):
    args = make_arg(num_encs=num_encs, **model_dict)
    batch = prepare_inputs("pytorch", num_encs)

    # test trainable
    m = importlib.import_module(module)
    model = m.E2E([40 for _ in range(num_encs)], 5, args)
    loss = model(*batch)
    loss.backward()  # trainable

    # test attention plot
    dummy_json = make_dummy_json(num_encs, [10, 20], [10, 20], idim=40, odim=5, num_inputs=num_encs)
    batchset = make_batchset(dummy_json, 2, 2 ** 10, 2 ** 10, shortest_first=True)
    att_ws = model.calculate_all_attentions(*convert_batch(
        batchset[0], "pytorch", idim=40, odim=5, num_inputs=num_encs))
    from espnet.asr.asr_utils import PlotAttentionReport
    tmpdir = tempfile.mkdtemp()
    plot = PlotAttentionReport(model.calculate_all_attentions, batchset[0], tmpdir, None, None, None)
    for i in range(num_encs):
        # att-encoder
        att_w = plot.get_attention_weight(0, att_ws[i][0])
        plot._plot_and_save_attention(att_w, '{}/att{}.png'.format(tmpdir, i))
    # han
    att_w = plot.get_attention_weight(0, att_ws[num_encs][0])
    plot._plot_and_save_attention(att_w, '{}/han.png'.format(tmpdir), han_mode=True)

    # test decodable
    with torch.no_grad(), chainer.no_backprop_mode():
        in_data = [np.random.randn(10, 40) for _ in range(num_encs)]
        model.recognize(in_data, args, args.char_list)  # decodable
        if "pytorch" in module:
            batch_in_data = [[np.random.randn(10, 40), np.random.randn(5, 40)] for _ in range(num_encs)]
            model.recognize_batch(batch_in_data, args, args.char_list)  # batch decodable
Exemple #4
0
def test_sortagrad_trainable(module):
    args = make_arg(sortagrad=1)
    dummy_json = make_dummy_json_mt(4, [10, 20], [10, 20], idim=6, odim=5)
    if module == "pytorch":
        import espnet.nets.pytorch_backend.e2e_mt as m
    else:
        import espnet.nets.chainer_backend.e2e_mt as m
    batchset = make_batchset(dummy_json,
                             2,
                             2**10,
                             2**10,
                             shortest_first=True,
                             mt=True,
                             iaxis=1,
                             oaxis=0)
    model = m.E2E(6, 5, args)
    for batch in batchset:
        loss = model(*convert_batch(batch, module, idim=6, odim=5))
        if isinstance(loss, tuple):
            # chainer return several values as tuple
            loss[0].backward()  # trainable
        else:
            loss.backward()  # trainable
    with torch.no_grad(), chainer.no_backprop_mode():
        in_data = np.random.randint(0, 5, (1, 100))
        model.translate(in_data, args, args.char_list)
def test_sortagrad_trainable_with_batch_bins(module):
    args = make_arg(sortagrad=1)
    idim = 6
    odim = 5
    dummy_json = make_dummy_json_mt(4, [10, 20], [10, 20], idim=idim, odim=odim)
    if module == "pytorch":
        import espnet.nets.pytorch_backend.e2e_mt as m
    else:
        import espnet.nets.chainer_backend.e2e_mt as m
    batch_elems = 2000
    batchset = make_batchset(dummy_json, batch_bins=batch_elems, shortest_first=True, mt=True, iaxis=1, oaxis=0)
    for batch in batchset:
        n = 0
        for uttid, info in batch:
            ilen = int(info['output'][1]['shape'][0])
            olen = int(info['output'][0]['shape'][0])
            n += ilen * idim + olen * odim
        assert olen < batch_elems

    model = m.E2E(6, 5, args)
    for batch in batchset:
        attn_loss = model(*convert_batch(batch, module, idim=6, odim=5))
        attn_loss.backward()
    with torch.no_grad(), chainer.no_backprop_mode():
        in_data = np.random.randint(0, 5, (1, 100))
        model.translate(in_data, args, args.char_list)
Exemple #6
0
    def __call__(self, trainer):
        """Calls the enabler on the given iterator

        :param trainer: The iterator
        """
        args = self.args
        use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
        # make new batch set for perturb_sampling mode
        # the following are imported from espnet.asr.pytorch_backend.asr:train
        train = make_batchset(self.train_json,
                              args.batch_size,
                              args.maxlen_in,
                              args.maxlen_out,
                              args.minibatches,
                              min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                              shortest_first=use_sortagrad,
                              count=args.batch_count,
                              batch_bins=args.batch_bins,
                              batch_frames_in=args.batch_frames_in,
                              batch_frames_out=args.batch_frames_out,
                              batch_frames_inout=args.batch_frames_inout,
                              iaxis=0,
                              oaxis=0,
                              perturb_sampling=args.perturb_sampling,
                              rank=args.rank,
                              world_size=args.world_size)
        dataset = TransformDataset(
            train, lambda data: self.converter([self.load_tr(data)]))
        self.train_iter['main'].perturb_sampling_shuffle(dataset)
        logging.warning("Doing Perturb-Sampling shuffling")
def test_sortagrad_trainable_with_batch_frames(module, num_encs):
    args = make_arg(num_encs=num_encs, sortagrad=1)
    idim = 2
    odim = 2
    dummy_json = make_dummy_json(4, [2, 3], [2, 3],
                                 idim=idim,
                                 odim=odim,
                                 num_inputs=num_encs)
    import espnet.nets.pytorch_backend.e2e_asr_mulenc as m

    batch_frames_in = 50
    batch_frames_out = 50
    batchset = make_batchset(
        dummy_json,
        batch_frames_in=batch_frames_in,
        batch_frames_out=batch_frames_out,
        shortest_first=True,
    )
    for batch in batchset:
        i = 0
        o = 0
        for uttid, info in batch:
            i += int(info["input"][0]["shape"][0])  # based on the first input
            o += int(info["output"][0]["shape"][0])
        assert i <= batch_frames_in
        assert o <= batch_frames_out

    model = m.E2E([2 for _ in range(num_encs)], 2, args)
    for batch in batchset:
        loss = model(
            *convert_batch(batch, module, idim=2, odim=2, num_inputs=num_encs))
        loss.backward()  # trainable
    with torch.no_grad():
        in_data = [np.random.randn(100, 2) for _ in range(num_encs)]
        model.recognize(in_data, args, args.char_list)
def test_sortagrad_trainable_with_batch_bins(module, num_encs):
    args = make_arg(num_encs=num_encs, sortagrad=1)
    idim = 20
    odim = 5
    dummy_json = make_dummy_json(4, [10, 20], [10, 20],
                                 idim=idim,
                                 odim=odim,
                                 num_inputs=num_encs)
    import espnet.nets.pytorch_backend.e2e_asr_mulenc as m
    batch_elems = 2000
    batchset = make_batchset(dummy_json,
                             batch_bins=batch_elems,
                             shortest_first=True)
    for batch in batchset:
        n = 0
        for uttid, info in batch:
            ilen = int(
                info['input'][0]['shape'][0])  # based on the first input
            olen = int(info['output'][0]['shape'][0])
            n += ilen * idim + olen * odim
        assert olen < batch_elems

    model = m.E2E([20 for _ in range(num_encs)], 5, args)
    for batch in batchset:
        loss = model(*convert_batch(
            batch, module, idim=20, odim=5, num_inputs=num_encs))
        loss.backward()  # trainable
    with torch.no_grad(), chainer.no_backprop_mode():
        in_data = [np.random.randn(100, 20) for _ in range(num_encs)]
        model.recognize(in_data, args, args.char_list)
Exemple #9
0
def test_sortagrad_trainable_with_batch_bins(module):
    args = make_arg(sortagrad=1)
    idim = 20
    odim = 5
    dummy_json = make_dummy_json_st(4, [10, 20], [10, 20], [10, 20],
                                    idim=idim,
                                    odim=odim)
    if module == "pytorch":
        import espnet.nets.pytorch_backend.e2e_st as m
    else:
        raise NotImplementedError
    batch_elems = 2000
    batchset = make_batchset(dummy_json,
                             batch_bins=batch_elems,
                             shortest_first=True)
    for batch in batchset:
        n = 0
        for uttid, info in batch:
            ilen = int(info['input'][0]['shape'][0])
            olen = int(info['output'][0]['shape'][0])
            n += ilen * idim + olen * odim
        assert olen < batch_elems

    model = m.E2E(20, 5, args)
    for batch in batchset:
        loss = model(*convert_batch(batch, module, idim=20, odim=5))
        if isinstance(loss, tuple):
            # chainer return several values as tuple
            loss[0].backward()  # trainable
        else:
            loss.backward()  # trainable
    with torch.no_grad(), chainer.no_backprop_mode():
        in_data = np.random.randn(100, 20)
        model.translate(in_data, args, args.char_list)
Exemple #10
0
def test_sortagrad_trainable_with_batch_bins(module):
    args = make_arg(sortagrad=1)
    idim = 20
    odim = 5
    dummy_json = make_dummy_json(8, [100, 200], [100, 200],
                                 idim=idim,
                                 odim=odim)
    if module == "pytorch":
        import espnet.nets.pytorch_backend.e2e_asr as m
    else:
        import espnet.nets.chainer_backend.e2e_asr as m
    batch_elems = 20000
    batchset = make_batchset(dummy_json,
                             batch_bins=batch_elems,
                             shortest_first=True)
    for batch in batchset:
        n = 0
        for uttid, info in batch:
            ilen = int(info['input'][0]['shape'][0])
            olen = int(info['output'][0]['shape'][0])
            n += ilen * idim + olen * odim
        assert olen < batch_elems

    model = m.E2E(20, 5, args)
    for batch in batchset:
        attn_loss = model(*convert_batch(batch, module, idim=20, odim=5))[0]
        attn_loss.backward()
    with torch.no_grad(), chainer.no_backprop_mode():
        in_data = np.random.randn(100, 20)
        model.recognize(in_data, args, args.char_list)
Exemple #11
0
def test_sortagrad_trainable_with_batch_bins(module):
    args = make_arg(sortagrad=1)
    idim = 10
    odim = 5
    dummy_json = make_dummy_json(2, [3, 5], [3, 5], idim=idim, odim=odim)
    if module == "pytorch":
        import espnet.nets.pytorch_backend.e2e_asr as m
    else:
        import espnet.nets.chainer_backend.e2e_asr as m
    batch_elems = 2000
    batchset = make_batchset(dummy_json,
                             batch_bins=batch_elems,
                             shortest_first=True)
    for batch in batchset:
        n = 0
        for uttid, info in batch:
            ilen = int(info["input"][0]["shape"][0])
            olen = int(info["output"][0]["shape"][0])
            n += ilen * idim + olen * odim
        assert olen < batch_elems

    model = m.E2E(idim, odim, args)
    for batch in batchset:
        loss = model(*convert_batch(batch, module, idim=idim, odim=odim))
        if isinstance(loss, tuple):
            # chainer return several values as tuple
            loss[0].backward()  # trainable
        else:
            loss.backward()  # trainable
    with torch.no_grad(), chainer.no_backprop_mode():
        in_data = np.random.randn(10, idim)
        model.recognize(in_data, args, args.char_list)
Exemple #12
0
def test_sortagrad_trainable_with_batch_frames(module):
    args = make_arg(sortagrad=1)
    idim = 20
    odim = 5
    dummy_json = make_dummy_json(4, [10, 20], [10, 20], idim=idim, odim=odim)
    if module == "pytorch":
        import espnet.nets.pytorch_backend.e2e_asr as m
    else:
        import espnet.nets.chainer_backend.e2e_asr as m
    batch_frames_in = 50
    batch_frames_out = 50
    batchset = make_batchset(dummy_json,
                             batch_frames_in=batch_frames_in,
                             batch_frames_out=batch_frames_out,
                             shortest_first=True)
    for batch in batchset:
        i = 0
        o = 0
        for uttid, info in batch:
            i += int(info['input'][0]['shape'][0])
            o += int(info['output'][0]['shape'][0])
        assert i <= batch_frames_in
        assert o <= batch_frames_out

    model = m.E2E(20, 5, args)
    for batch in batchset:
        loss = model(*convert_batch(batch, module, idim=20, odim=5))
        if isinstance(loss, tuple):
            # chainer return several values as tuple
            loss[0].backward()  # trainable
        else:
            loss.backward()  # trainable
    with torch.no_grad(), chainer.no_backprop_mode():
        in_data = np.random.randn(100, 20)
        model.recognize(in_data, args, args.char_list)
Exemple #13
0
def test_sortagrad_trainable_with_batch_frames(module):
    args = make_arg(sortagrad=1)
    idim = 6
    odim = 5
    dummy_json = make_dummy_json_mt(8, [100, 200], [100, 200],
                                    idim=idim,
                                    odim=odim)
    if module == "pytorch":
        import espnet.nets.pytorch_backend.e2e_mt as m
    else:
        import espnet.nets.chainer_backend.e2e_mt as m
    batch_frames_in = 200
    batch_frames_out = 200
    batchset = make_batchset(dummy_json,
                             batch_frames_in=batch_frames_in,
                             batch_frames_out=batch_frames_out,
                             shortest_first=True,
                             mt=True)
    for batch in batchset:
        i = 0
        o = 0
        for uttid, info in batch:
            i += int(info['output'][1]['shape'][0])
            o += int(info['output'][0]['shape'][0])
        assert i <= batch_frames_in
        assert o <= batch_frames_out

    model = m.E2E(6, 5, args)
    for batch in batchset:
        attn_loss = model(*convert_batch(batch, module, idim=6, odim=5))
        attn_loss.backward()
    with torch.no_grad(), chainer.no_backprop_mode():
        in_data = np.random.randint(0, 5, (1, 100))
        model.translate(in_data, args, args.char_list)
Exemple #14
0
def test_make_batchset(swap_io):
    dummy_json = make_dummy_json(128, [128, 512], [16, 128])
    # check w/o adaptive batch size
    batchset = make_batchset(dummy_json, 24, 2 ** 10, 2 ** 10,
                             min_batch_size=1, swap_io=swap_io)
    assert sum([len(batch) >= 1 for batch in batchset]) == len(batchset)
    print([len(batch) for batch in batchset])
    batchset = make_batchset(dummy_json, 24, 2 ** 10, 2 ** 10,
                             min_batch_size=10, swap_io=swap_io)
    assert sum([len(batch) >= 10 for batch in batchset]) == len(batchset)
    print([len(batch) for batch in batchset])

    # check w/ adaptive batch size
    batchset = make_batchset(dummy_json, 24, 256, 64,
                             min_batch_size=10, swap_io=swap_io)
    assert sum([len(batch) >= 10 for batch in batchset]) == len(batchset)
    print([len(batch) for batch in batchset])
    batchset = make_batchset(dummy_json, 24, 256, 64,
                             min_batch_size=10, swap_io=swap_io)
    assert sum([len(batch) >= 10 for batch in batchset]) == len(batchset)
Exemple #15
0
def test_context_residual(module):
    args = make_arg(context_residual=True)
    dummy_json = make_dummy_json(8, [1, 100], [1, 100], idim=20, odim=5)
    if module == "pytorch":
        import espnet.nets.pytorch_backend.e2e_asr as m
    else:
        raise NotImplementedError
    batchset = make_batchset(dummy_json, 2, 2**10, 2**10, shortest_first=True)
    model = m.E2E(20, 5, args)
    for batch in batchset:
        attn_loss = model(*convert_batch(batch, module, idim=20, odim=5))[0]
        attn_loss.backward()
    with torch.no_grad(), chainer.no_backprop_mode():
        in_data = np.random.randn(50, 20)
        model.recognize(in_data, args, args.char_list)
Exemple #16
0
def test_sortagrad_trainable(module):
    args = make_arg(sortagrad=1)
    dummy_json = make_dummy_json_mt(4, [10, 20], [10, 20], idim=6, odim=5)
    if module == "pytorch":
        import espnet.nets.pytorch_backend.e2e_mt as m
    else:
        import espnet.nets.chainer_backend.e2e_mt as m
    batchset = make_batchset(dummy_json, 2, 2 ** 10, 2 ** 10, shortest_first=True, mt=True)
    model = m.E2E(6, 5, args)
    for batch in batchset:
        attn_loss = model(*convert_batch(batch, module, idim=6, odim=5))
        attn_loss.backward()
    with torch.no_grad(), chainer.no_backprop_mode():
        in_data = np.random.randint(0, 5, (1, 100))
        model.translate(in_data, args, args.char_list)
Exemple #17
0
def test_sortagrad_trainable(module, num_encs):
    args = make_arg(num_encs=num_encs, sortagrad=1)
    dummy_json = make_dummy_json(6, [10, 20], [10, 20], idim=20, odim=5, num_inputs=num_encs)
    import espnet.nets.pytorch_backend.e2e_asr_mulenc as m
    batchset = make_batchset(dummy_json, 2, 2 ** 10, 2 ** 10, shortest_first=True)
    model = m.E2E([20 for _ in range(num_encs)], 5, args)
    num_utts = 0
    for batch in batchset:
        num_utts += len(batch)
        loss = model(*convert_batch(batch, module, idim=20, odim=5, num_inputs=num_encs))
        loss.backward()  # trainable
    assert num_utts == 6
    with torch.no_grad(), chainer.no_backprop_mode():
        in_data = [np.random.randn(50, 20) for _ in range(num_encs)]
        model.recognize(in_data, args, args.char_list)
Exemple #18
0
def test_gradient_noise_injection(module, num_encs):
    args = make_arg(num_encs=num_encs, grad_noise=True)
    args_org = make_arg(num_encs=num_encs)
    dummy_json = make_dummy_json(num_encs, [10, 20], [10, 20], idim=20, odim=5, num_inputs=num_encs)
    import espnet.nets.pytorch_backend.e2e_asr_mulenc as m
    batchset = make_batchset(dummy_json, 2, 2 ** 10, 2 ** 10, shortest_first=True)
    model = m.E2E([20 for _ in range(num_encs)], 5, args)
    model_org = m.E2E([20 for _ in range(num_encs)], 5, args_org)
    for batch in batchset:
        loss = model(*convert_batch(batch, module, idim=20, odim=5, num_inputs=num_encs))
        loss_org = model_org(*convert_batch(batch, module, idim=20, odim=5, num_inputs=num_encs))
        loss.backward()
        grad = [param.grad for param in model.parameters()][10]
        loss_org.backward()
        grad_org = [param.grad for param in model_org.parameters()][10]
        assert grad[0] != grad_org[0]
Exemple #19
0
def test_gradient_noise_injection(module):
    args = make_arg(grad_noise=True)
    args_org = make_arg()
    dummy_json = make_dummy_json(2, [3, 4], [3, 4], idim=10, odim=5)
    if module == "pytorch":
        import espnet.nets.pytorch_backend.e2e_asr as m
    else:
        import espnet.nets.chainer_backend.e2e_asr as m
    batchset = make_batchset(dummy_json, 2, 2**10, 2**10, shortest_first=True)
    model = m.E2E(10, 5, args)
    model_org = m.E2E(10, 5, args_org)
    for batch in batchset:
        loss = model(*convert_batch(batch, module, idim=10, odim=5))
        loss_org = model_org(*convert_batch(batch, module, idim=10, odim=5))
        loss.backward()
        grad = [param.grad for param in model.parameters()][10]
        loss_org.backward()
        grad_org = [param.grad for param in model_org.parameters()][10]
        assert grad[0] != grad_org[0]
Exemple #20
0
def test_calculate_plot_attention_ctc(module, num_encs, model_dict):
    args = make_arg(num_encs=num_encs, **model_dict)
    m = importlib.import_module(module)
    model = m.E2E([2 for _ in range(num_encs)], 2, args)

    # test attention plot
    dummy_json = make_dummy_json(num_encs, [2, 3], [2, 3],
                                 idim=2,
                                 odim=2,
                                 num_inputs=num_encs)
    batchset = make_batchset(dummy_json, 2, 2**10, 2**10, shortest_first=True)
    att_ws = model.calculate_all_attentions(*convert_batch(
        batchset[0], "pytorch", idim=2, odim=2, num_inputs=num_encs))
    from espnet.asr.asr_utils import PlotAttentionReport

    tmpdir = tempfile.mkdtemp()
    plot = PlotAttentionReport(model.calculate_all_attentions, batchset[0],
                               tmpdir, None, None, None)
    for i in range(num_encs):
        # att-encoder
        att_w = plot.trim_attention_weight("utt_%d" % 0, att_ws[i][0])
        plot._plot_and_save_attention(att_w, "{}/att{}.png".format(tmpdir, i))
    # han
    att_w = plot.trim_attention_weight("utt_%d" % 0, att_ws[num_encs][0])
    plot._plot_and_save_attention(att_w,
                                  "{}/han.png".format(tmpdir),
                                  han_mode=True)

    # test CTC plot
    ctc_probs = model.calculate_all_ctc_probs(*convert_batch(
        batchset[0], "pytorch", idim=2, odim=2, num_inputs=num_encs))
    from espnet.asr.asr_utils import PlotCTCReport

    tmpdir = tempfile.mkdtemp()
    plot = PlotCTCReport(model.calculate_all_ctc_probs, batchset[0], tmpdir,
                         None, None, None)
    if args.mtlalpha > 0:
        for i in range(num_encs):
            # ctc-encoder
            plot._plot_and_save_ctc(ctc_probs[i][0],
                                    "{}/ctc{}.png".format(tmpdir, i))
Exemple #21
0
def test_sortagrad_trainable(module):
    args = make_arg(sortagrad=1)
    idim = 10
    odim = 5
    dummy_json = make_dummy_json(2, [3, 5], [3, 5], idim=idim, odim=odim)
    if module == "pytorch":
        import espnet.nets.pytorch_backend.e2e_asr as m
    else:
        import espnet.nets.chainer_backend.e2e_asr as m
    batchset = make_batchset(dummy_json, 2, 2**10, 2**10, shortest_first=True)
    model = m.E2E(idim, odim, args)
    for batch in batchset:
        loss = model(*convert_batch(batch, module, idim=idim, odim=odim))
        if isinstance(loss, tuple):
            # chainer return several values as tuple
            loss[0].backward()  # trainable
        else:
            loss.backward()  # trainable
    with torch.no_grad(), chainer.no_backprop_mode():
        in_data = np.random.randn(10, idim)
        model.recognize(in_data, args, args.char_list)
Exemple #22
0
def load_asr_data(json_path, dic):
    with open(json_path, "rb") as f:
        train_feature = json.load(f)["utts"]

    converter = CustomConverter(subsampling_factor=1)
    train = make_batchset(train_feature, batch_size=len(train_feature))
    name = [train[0][i][0] for i in range(len(train_feature))]
    load_tr = LoadInputsAndTargets(
        mode="asr",
        load_output=True,
        preprocess_conf=None,
        preprocess_args={"train": True},  # Switch the mode of preprocessing
    )
    dataset = TransformDataset(train, lambda data: converter([load_tr(data)]))
    data1 = dataset[0][1]
    data2 = dataset[0][2]
    for i in tqdm(range(len(name))):
        ilen = data1[i]
        y = data2[i]
        dic[name[i]] = [ilen, y]

    return dic
Exemple #23
0
def load_asr_data(json_path, dic, TMHINT=None):
    with open(json_path, "rb") as f:
        train_feature = json.load(f)["utts"]

    converter = CustomConverter(subsampling_factor=1)
    train = make_batchset(train_feature, batch_size=len(train_feature))
    name = [train[0][i][0] for i in range(len(train_feature))]
    load_tr = LoadInputsAndTargets(
        mode="asr",
        load_output=True,
        preprocess_conf=None,
        preprocess_args={"train": True},  # Switch the mode of preprocessing
    )
    dataset = TransformDataset(train, lambda data: converter([load_tr(data)]))
    data1 = dataset[0][1]
    data2 = dataset[0][2]

    for i in tqdm(range(len(name))):
        ilen = data1[i]
        y = data2[i]

        if TMHINT:
            speaker = ["M1", "M2", "M3", "F1", "F2", "F3"]
            tmp = name[i].split("_")
            if tmp[0] in speaker:
                if int(tmp[2]) >= 13:
                    name_key = int(speaker.index(
                        tmp[0])) * 200 + (int(tmp[2]) - 13) * 10 + int(tmp[3])
                    dic[str(name_key)] = [ilen, y]
            else:
                dic["_".join(tmp[1:])] = [ilen, y]

        else:
            name_key = name[i]
            dic[name_key] = [ilen, y]
    return dic
Exemple #24
0
def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning("cuda is not available")

    # get input and output dimension info
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]["output"][1]["shape"][1])
    odim = int(valid_json[utts[0]]["output"][0]["shape"][1])
    logging.info("#input dims : " + str(idim))
    logging.info("#output dims: " + str(odim))

    # specify model architecture
    model_class = dynamic_import(args.model_module)
    model = model_class(idim, odim, args)
    assert isinstance(model, MTInterface)

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + "/model.json"
    with open(model_conf, "wb") as f:
        logging.info("writing a model config file to " + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)),
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode("utf_8"))
    for key in sorted(vars(args).keys()):
        logging.info("ARGS: " + key + ": " + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        if args.batch_size != 0:
            logging.warning(
                "batch size is automatically increased (%d -> %d)" %
                (args.batch_size, args.batch_size * args.ngpu))
            args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    if args.train_dtype in ("float16", "float32", "float64"):
        dtype = getattr(torch, args.train_dtype)
    else:
        dtype = torch.float32
    model = model.to(device=device, dtype=dtype)

    logging.warning(
        "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format(
            sum(p.numel() for p in model.parameters()),
            sum(p.numel() for p in model.parameters() if p.requires_grad),
            sum(p.numel() for p in model.parameters() if p.requires_grad) *
            100.0 / sum(p.numel() for p in model.parameters()),
        ))

    # Setup an optimizer
    if args.opt == "adadelta":
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         rho=0.95,
                                         eps=args.eps,
                                         weight_decay=args.weight_decay)
    elif args.opt == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
    elif args.opt == "noam":
        from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt

        optimizer = get_std_opt(
            model.parameters(),
            args.adim,
            args.transformer_warmup_steps,
            args.transformer_lr,
        )
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # setup apex.amp
    if args.train_dtype in ("O0", "O1", "O2", "O3"):
        try:
            from apex import amp
        except ImportError as e:
            logging.error(
                f"You need to install apex for --train-dtype {args.train_dtype}. "
                "See https://github.com/NVIDIA/apex#linux")
            raise e
        if args.opt == "noam":
            model, optimizer.optimizer = amp.initialize(
                model, optimizer.optimizer, opt_level=args.train_dtype)
        else:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.train_dtype)
        use_apex = True
    else:
        use_apex = False

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter()

    # read json data
    with open(args.train_json, "rb") as f:
        train_json = json.load(f)["utts"]
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # make minibatch list (variable length)
    train = make_batchset(
        train_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        shortest_first=use_sortagrad,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        mt=True,
        iaxis=1,
        oaxis=0,
    )
    valid = make_batchset(
        valid_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        mt=True,
        iaxis=1,
        oaxis=0,
    )

    load_tr = LoadInputsAndTargets(mode="mt", load_output=True)
    load_cv = LoadInputsAndTargets(mode="mt", load_output=True)
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    # default collate function converts numpy array to pytorch tensor
    # we used an empty collate function instead which returns list
    train_iter = ChainerDataLoader(
        dataset=TransformDataset(train,
                                 lambda data: converter([load_tr(data)])),
        batch_size=1,
        num_workers=args.n_iter_processes,
        shuffle=not use_sortagrad,
        collate_fn=lambda x: x[0],
    )
    valid_iter = ChainerDataLoader(
        dataset=TransformDataset(valid,
                                 lambda data: converter([load_cv(data)])),
        batch_size=1,
        shuffle=False,
        collate_fn=lambda x: x[0],
        num_workers=args.n_iter_processes,
    )

    # Set up a trainer
    updater = CustomUpdater(
        model,
        args.grad_clip,
        {"main": train_iter},
        optimizer,
        device,
        args.ngpu,
        False,
        args.accum_grad,
        use_apex=use_apex,
    )
    trainer = training.Trainer(updater, (args.epochs, "epoch"),
                               out=args.outdir)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     "epoch"),
        )

    # Resume from a snapshot
    if args.resume:
        logging.info("resumed from %s" % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    if args.save_interval_iters > 0:
        trainer.extend(
            CustomEvaluator(model, {"main": valid_iter}, reporter, device,
                            args.ngpu),
            trigger=(args.save_interval_iters, "iteration"),
        )
    else:
        trainer.extend(
            CustomEvaluator(model, {"main": valid_iter}, reporter, device,
                            args.ngpu))

    # Save attention weight each epoch
    if args.num_save_attention > 0:
        # NOTE: sort it by output lengths
        data = sorted(
            list(valid_json.items())[:args.num_save_attention],
            key=lambda x: int(x[1]["output"][0]["shape"][0]),
            reverse=True,
        )
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        att_reporter = plot_class(
            att_vis_fn,
            data,
            args.outdir + "/att_ws",
            converter=converter,
            transform=load_cv,
            device=device,
            ikey="output",
            iaxis=1,
        )
        trainer.extend(att_reporter, trigger=(1, "epoch"))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport(["main/loss", "validation/main/loss"],
                              "epoch",
                              file_name="loss.png"))
    trainer.extend(
        extensions.PlotReport(["main/acc", "validation/main/acc"],
                              "epoch",
                              file_name="acc.png"))
    trainer.extend(
        extensions.PlotReport(["main/ppl", "validation/main/ppl"],
                              "epoch",
                              file_name="ppl.png"))
    trainer.extend(
        extensions.PlotReport(["main/bleu", "validation/main/bleu"],
                              "epoch",
                              file_name="bleu.png"))

    # Save best models
    trainer.extend(
        snapshot_object(model, "model.loss.best"),
        trigger=training.triggers.MinValueTrigger("validation/main/loss"),
    )
    trainer.extend(
        snapshot_object(model, "model.acc.best"),
        trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
    )

    # save snapshot which contains model and optimizer states
    if args.save_interval_iters > 0:
        trainer.extend(
            torch_snapshot(filename="snapshot.iter.{.updater.iteration}"),
            trigger=(args.save_interval_iters, "iteration"),
        )
    else:
        trainer.extend(torch_snapshot(), trigger=(1, "epoch"))

    # epsilon decay in the optimizer
    if args.opt == "adadelta":
        if args.criterion == "acc":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.acc.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
            trainer.extend(
                adadelta_eps_decay(args.eps_decay),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
        elif args.criterion == "loss":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.loss.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )
            trainer.extend(
                adadelta_eps_decay(args.eps_decay),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )
    elif args.opt == "adam":
        if args.criterion == "acc":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.acc.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
            trainer.extend(
                adam_lr_decay(args.lr_decay),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
        elif args.criterion == "loss":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.loss.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )
            trainer.extend(
                adam_lr_decay(args.lr_decay),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters,
                                      "iteration")))
    report_keys = [
        "epoch",
        "iteration",
        "main/loss",
        "validation/main/loss",
        "main/acc",
        "validation/main/acc",
        "main/ppl",
        "validation/main/ppl",
        "elapsed_time",
    ]
    if args.opt == "adadelta":
        trainer.extend(
            extensions.observe_value(
                "eps",
                lambda trainer: trainer.updater.get_optimizer("main").
                param_groups[0]["eps"],
            ),
            trigger=(args.report_interval_iters, "iteration"),
        )
        report_keys.append("eps")
    elif args.opt in ["adam", "noam"]:
        trainer.extend(
            extensions.observe_value(
                "lr",
                lambda trainer: trainer.updater.get_optimizer("main").
                param_groups[0]["lr"],
            ),
            trigger=(args.report_interval_iters, "iteration"),
        )
        report_keys.append("lr")
    if args.report_bleu:
        report_keys.append("main/bleu")
        report_keys.append("validation/main/bleu")
    trainer.extend(
        extensions.PrintReport(report_keys),
        trigger=(args.report_interval_iters, "iteration"),
    )

    trainer.extend(
        extensions.ProgressBar(update_interval=args.report_interval_iters))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        from torch.utils.tensorboard import SummaryWriter

        trainer.extend(
            TensorboardLogger(SummaryWriter(args.tensorboard_dir),
                              att_reporter),
            trigger=(args.report_interval_iters, "iteration"),
        )
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Exemple #25
0
def load_multilingual_data(root_path, datasets, args, languages):
    def collate(minibatch):
        out = []
        for b in minibatch:
            fbanks = []
            tokens = []
            language = None
            for _, info in b:
                fbanks.append(
                    torch.tensor(
                        kaldiio.load_mat(info["input"][0]["feat"].replace(
                            data_config[dataset]["prefix"], root_path))))
                tokens.append(
                    torch.tensor([
                        int(s) for s in info["output"][0]["tokenid"].split()
                    ]))
                if language is not None:
                    assert language == info['category']
                else:
                    language = info['category']
            ilens = torch.tensor([x.shape[0] for x in fbanks])
            out.append((
                pad_sequence(fbanks, batch_first=True, padding_value=0),
                ilens,
                pad_sequence(tokens, batch_first=True, padding_value=-1),
                language,
            ))
        return out[0] if len(out) == 1 else out

    idim = None
    odim_dict = {}
    mtl_train_json, mtl_dev_json, mtl_test_json = {}, {}, {}
    for idx, dataset in enumerate(datasets):
        language = dataset
        if language in low_resource_languages:
            template_key = "template100"
        else:
            template_key = "template150"
        data_config[dataset] = data_config[template_key].copy()
        for key in ["train", "val", "test", "token"]:
            data_config[dataset][key] = data_config[template_key][key].replace(
                "template", dataset)

        train_json = os.path.join(root_path, data_config[dataset]["train"])
        dev_json = (os.path.join(root_path, data_config[dataset]["val"])
                    if data_config[dataset]["val"] else
                    f"{root_path}/tmp_dev_set_{dataset}.json")
        test_json = os.path.join(root_path, data_config[dataset]["test"])
        train_json, dev_json, test_json = load_json(train_json, dev_json,
                                                    test_json)
        for key in train_json.keys():
            train_json[key]['category'] = language
        for key in dev_json.keys():
            dev_json[key]['category'] = language
        for key in test_json.keys():
            test_json[key]['category'] = language
        #print(train_json)
        _, info = next(iter(train_json.items()))
        if idim is not None:
            assert idim == info["input"][0]["shape"][1]
        else:
            idim = info["input"][0]["shape"][1]
        odim_dict[language] = info["output"][0]["shape"][1]

        # Break if not in specified languages
        if dataset not in languages:
            continue

        mtl_train_json.update(train_json)
        mtl_dev_json.update(dev_json)
        mtl_test_json.update(test_json)
        #print(len(mtl_train_json), len(train_json))
    train_json, dev_json, test_json = mtl_train_json, mtl_dev_json, mtl_test_json
    use_sortagrad = False  # args.sortagrad == -1 or args.sortagrad > 0
    # trainset = make_batchset(train_json, batch_size, max_length_in=800, max_length_out=150)
    if args.ngpu > 1 and not args.dist_train:
        min_batch_size = args.ngpu
    else:
        min_batch_size = 1
    if args.meta_train:
        min_batch_size = 2 * min_batch_size
    trainset = make_batchset(
        train_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=min_batch_size,
        shortest_first=use_sortagrad,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        iaxis=0,
        oaxis=0,
    )
    # devset = make_batchset(dev_json, batch_size, max_length_in=800, max_length_out=150)
    devset = make_batchset(
        dev_json,
        args.batch_size if args.ngpu <= 1 else int(args.batch_size /
                                                   args.ngpu),
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=1,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        iaxis=0,
        oaxis=0,
    )
    testset = make_batchset(
        test_json,
        args.batch_size if args.ngpu <= 1 else int(args.batch_size /
                                                   args.ngpu),
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=1,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        iaxis=0,
        oaxis=0,
    )
    if args.dist_train and args.ngpu > 1:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            trainset)
    elif args.meta_train:
        train_sampler = BalancedBatchSampler(trainset)
    else:
        train_sampler = None
    train_loader = DataLoader(
        trainset,
        batch_size=1 if not args.meta_train else len(languages),
        collate_fn=collate,
        num_workers=args.n_iter_processes,
        shuffle=(train_sampler is None),
        pin_memory=True,
        sampler=train_sampler,
    )
    dev_loader = DataLoader(
        devset,
        batch_size=1,
        collate_fn=collate,
        shuffle=False,
        num_workers=args.n_iter_processes,
        pin_memory=True,
    )
    test_loader = DataLoader(
        testset,
        batch_size=1,
        collate_fn=collate,
        shuffle=False,
        num_workers=args.n_iter_processes,
        pin_memory=True,
    )
    return (train_loader, dev_loader, test_loader), (idim, odim_dict)
Exemple #26
0
def load_data(root_path, dataset, args):
    def collate(minibatch):
        fbanks = []
        tokens = []
        for _, info in minibatch[0]:
            fbanks.append(
                torch.tensor(
                    kaldiio.load_mat(info["input"][0]["feat"].replace(
                        data_config[dataset]["prefix"], root_path))))
            tokens.append(
                torch.tensor(
                    [int(s) for s in info["output"][0]["tokenid"].split()]))
        ilens = torch.tensor([x.shape[0] for x in fbanks])
        return (
            pad_sequence(fbanks, batch_first=True, padding_value=0),
            ilens,
            pad_sequence(tokens, batch_first=True, padding_value=-1),
        )

    language = dataset
    if language in low_resource_languages:
        template_key = "template100"
    else:
        template_key = "template150"
    data_config[dataset] = data_config[template_key].copy()
    for key in ["train", "val", "test", "token"]:
        data_config[dataset][key] = data_config[template_key][key].replace(
            "template", dataset)
    train_json = os.path.join(root_path, data_config[dataset]["train"])
    dev_json = (os.path.join(root_path, data_config[dataset]["val"])
                if data_config[dataset]["val"] else
                f"{root_path}/tmp_dev_set_{dataset}.json")
    test_json = os.path.join(root_path, data_config[dataset]["test"])
    train_json, dev_json, test_json = load_json(train_json, dev_json,
                                                test_json)
    _, info = next(iter(train_json.items()))
    idim = info["input"][0]["shape"][1]
    odim = info["output"][0]["shape"][1]

    use_sortagrad = False  # args.sortagrad == -1 or args.sortagrad > 0
    # trainset = make_batchset(train_json, batch_size, max_length_in=800, max_length_out=150)
    trainset = make_batchset(
        train_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=args.ngpu if
        (args.ngpu > 1 and not args.dist_train) else 1,
        shortest_first=use_sortagrad,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        iaxis=0,
        oaxis=0,
    )
    # devset = make_batchset(dev_json, batch_size, max_length_in=800, max_length_out=150)
    devset = make_batchset(
        dev_json,
        args.batch_size if args.ngpu <= 1 else int(args.batch_size /
                                                   args.ngpu),
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=1,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        iaxis=0,
        oaxis=0,
    )
    testset = make_batchset(
        test_json,
        args.batch_size if args.ngpu <= 1 else int(args.batch_size /
                                                   args.ngpu),
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=1,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        iaxis=0,
        oaxis=0,
    )
    if args.dist_train and args.ngpu > 1:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            trainset)
    else:
        train_sampler = None
    train_loader = DataLoader(
        trainset,
        batch_size=1,
        collate_fn=collate,
        num_workers=args.n_iter_processes,
        shuffle=(train_sampler is None),
        pin_memory=True,
        sampler=train_sampler,
    )
    dev_loader = DataLoader(
        devset,
        batch_size=1,
        collate_fn=collate,
        shuffle=False,
        num_workers=args.n_iter_processes,
        pin_memory=True,
    )
    test_loader = DataLoader(
        testset,
        batch_size=1,
        collate_fn=collate,
        shuffle=False,
        num_workers=args.n_iter_processes,
        pin_memory=True,
    )
    return (train_loader, dev_loader, test_loader), (idim, odim)
Exemple #27
0
def train(args):
    """Train E2E-TTS model."""
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning("cuda is not available")

    # get input and output dimension info
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]
    utts = list(valid_json.keys())

    # reverse input and output dimension
    idim = int(valid_json[utts[0]]["output"][0]["shape"][1])
    odim = int(valid_json[utts[0]]["input"][0]["shape"][1])
    logging.info("#input dims : " + str(idim))
    logging.info("#output dims: " + str(odim))

    # get extra input and output dimenstion
    if args.use_speaker_embedding:
        args.spk_embed_dim = int(valid_json[utts[0]]["input"][1]["shape"][0])
    else:
        args.spk_embed_dim = None
    if args.use_second_target:
        args.spc_dim = int(valid_json[utts[0]]["input"][1]["shape"][1])
    else:
        args.spc_dim = None

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + "/model.json"
    with open(model_conf, "wb") as f:
        logging.info("writing a model config file to" + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)),
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode("utf_8"))
    for key in sorted(vars(args).keys()):
        logging.info("ARGS: " + key + ": " + str(vars(args)[key]))

    # specify model architecture
    if args.enc_init is not None or args.dec_init is not None:
        model = load_trained_modules(idim, odim, args, TTSInterface)
    else:
        model_class = dynamic_import(args.model_module)
        model = model_class(idim, odim, args)
    assert isinstance(model, TTSInterface)
    logging.info(model)
    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
        if args.batch_size != 0:
            logging.warning(
                "batch size is automatically increased (%d -> %d)" %
                (args.batch_size, args.batch_size * args.ngpu))
            args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)

    # freeze modules, if specified
    if args.freeze_mods:
        for mod, param in model.state_dict().items():
            if any(mod.startswith(key) for key in args.freeze_mods):
                logging.info(f"{mod} is frozen not to be updated.")
                param.requires_grad = False

    # Setup an optimizer
    if args.opt == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     args.lr,
                                     eps=args.eps,
                                     weight_decay=args.weight_decay)
    elif args.opt == "noam":
        from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt

        optimizer = get_std_opt(model, args.adim,
                                args.transformer_warmup_steps,
                                args.transformer_lr)
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # read json data
    with open(args.train_json, "rb") as f:
        train_json = json.load(f)["utts"]
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    if use_sortagrad:
        args.batch_sort_key = "input"
    # make minibatch list (variable length)
    train_batchset = make_batchset(
        train_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        batch_sort_key=args.batch_sort_key,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        shortest_first=use_sortagrad,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        swap_io=True,
        iaxis=0,
        oaxis=0,
    )
    valid_batchset = make_batchset(
        valid_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        batch_sort_key=args.batch_sort_key,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        swap_io=True,
        iaxis=0,
        oaxis=0,
    )

    load_tr = LoadInputsAndTargets(
        mode="tts",
        use_speaker_embedding=args.use_speaker_embedding,
        use_second_target=args.use_second_target,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={"train": True},  # Switch the mode of preprocessing
        keep_all_data_on_mem=args.keep_all_data_on_mem,
    )

    load_cv = LoadInputsAndTargets(
        mode="tts",
        use_speaker_embedding=args.use_speaker_embedding,
        use_second_target=args.use_second_target,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={"train": False},  # Switch the mode of preprocessing
        keep_all_data_on_mem=args.keep_all_data_on_mem,
    )

    converter = CustomConverter()
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    train_iter = {
        "main":
        ChainerDataLoader(
            dataset=TransformDataset(train_batchset,
                                     lambda data: converter([load_tr(data)])),
            batch_size=1,
            num_workers=args.num_iter_processes,
            shuffle=not use_sortagrad,
            collate_fn=lambda x: x[0],
        )
    }
    valid_iter = {
        "main":
        ChainerDataLoader(
            dataset=TransformDataset(valid_batchset,
                                     lambda data: converter([load_cv(data)])),
            batch_size=1,
            shuffle=False,
            collate_fn=lambda x: x[0],
            num_workers=args.num_iter_processes,
        )
    }

    # Set up a trainer
    updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer,
                            device, args.accum_grad)
    trainer = training.Trainer(updater, (args.epochs, "epoch"),
                               out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        logging.info("resumed from %s" % args.resume)
        torch_resume(args.resume, trainer)

    # set intervals
    eval_interval = (args.eval_interval_epochs, "epoch")
    save_interval = (args.save_interval_epochs, "epoch")
    report_interval = (args.report_interval_iters, "iteration")

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(CustomEvaluator(model, valid_iter, reporter, device),
                   trigger=eval_interval)

    # Save snapshot for each epoch
    trainer.extend(torch_snapshot(), trigger=save_interval)

    # Save best models
    trainer.extend(
        snapshot_object(model, "model.loss.best"),
        trigger=training.triggers.MinValueTrigger("validation/main/loss",
                                                  trigger=eval_interval),
    )

    # Save attention figure for each epoch
    if args.num_save_attention > 0:
        data = sorted(
            list(valid_json.items())[:args.num_save_attention],
            key=lambda x: int(x[1]["output"][0]["shape"][0]),
            reverse=True,
        )
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
            reduction_factor = model.module.reduction_factor
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
            reduction_factor = model.reduction_factor
        if reduction_factor > 1:
            # fix the length to crop attention weight plot correctly
            data = copy.deepcopy(data)
            for idx in range(len(data)):
                ilen = data[idx][1]["input"][0]["shape"][0]
                data[idx][1]["input"][0]["shape"][0] = ilen // reduction_factor
        att_reporter = plot_class(
            att_vis_fn,
            data,
            args.outdir + "/att_ws",
            converter=converter,
            transform=load_cv,
            device=device,
            reverse=True,
        )
        trainer.extend(att_reporter, trigger=eval_interval)
    else:
        att_reporter = None

    # Make a plot for training and validation values
    if hasattr(model, "module"):
        base_plot_keys = model.module.base_plot_keys
    else:
        base_plot_keys = model.base_plot_keys
    plot_keys = []
    for key in base_plot_keys:
        plot_key = ["main/" + key, "validation/main/" + key]
        trainer.extend(
            extensions.PlotReport(plot_key, "epoch", file_name=key + ".png"),
            trigger=eval_interval,
        )
        plot_keys += plot_key
    trainer.extend(
        extensions.PlotReport(plot_keys, "epoch", file_name="all_loss.png"),
        trigger=eval_interval,
    )

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=report_interval))
    report_keys = ["epoch", "iteration", "elapsed_time"] + plot_keys
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=report_interval)
    trainer.extend(extensions.ProgressBar(), trigger=report_interval)

    set_early_stop(trainer, args)
    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        writer = SummaryWriter(args.tensorboard_dir)
        trainer.extend(TensorboardLogger(writer, att_reporter),
                       trigger=report_interval)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     "epoch"),
        )

    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Exemple #28
0
def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    # display chainer version
    logging.info('chainer version = ' + chainer.__version__)

    set_deterministic_chainer(args)

    # check cuda and cudnn availability
    if not chainer.cuda.available:
        logging.warning('cuda is not available')
    if not chainer.cuda.cudnn_enabled:
        logging.warning('cudnn is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['input'][0]['shape'][1])
    odim = int(valid_json[utts[0]]['output'][0]['shape'][1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    # specify model architecture
    logging.info('import model module: ' + args.model_module)
    model_class = dynamic_import(args.model_module)
    model = model_class(idim, odim, args, flag_return=False)
    assert isinstance(model, ASRInterface)

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)),
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # Set gpu
    ngpu = args.ngpu
    if ngpu == 1:
        gpu_id = 0
        # Make a specified GPU current
        chainer.cuda.get_device_from_id(gpu_id).use()
        model.to_gpu()  # Copy the model to the GPU
        logging.info('single gpu calculation.')
    elif ngpu > 1:
        gpu_id = 0
        devices = {'main': gpu_id}
        for gid in six.moves.xrange(1, ngpu):
            devices['sub_%d' % gid] = gid
        logging.info('multi gpu calculation (#gpus = %d).' % ngpu)
        logging.info('batch size is automatically increased (%d -> %d)' %
                     (args.batch_size, args.batch_size * args.ngpu))
    else:
        gpu_id = -1
        logging.info('cpu calculation')

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = chainer.optimizers.AdaDelta(eps=args.eps)
    elif args.opt == 'adam':
        optimizer = chainer.optimizers.Adam()
    elif args.opt == 'noam':
        optimizer = chainer.optimizers.Adam(alpha=0,
                                            beta1=0.9,
                                            beta2=0.98,
                                            eps=1e-9)
    else:
        raise NotImplementedError('args.opt={}'.format(args.opt))

    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.GradientClipping(args.grad_clip))

    # Setup Training Extensions
    if 'transformer' in args.model_module:
        from espnet.nets.chainer_backend.transformer.training import CustomConverter
        from espnet.nets.chainer_backend.transformer.training import CustomParallelUpdater
        from espnet.nets.chainer_backend.transformer.training import CustomUpdater
    else:
        from espnet.nets.chainer_backend.rnn.training import CustomConverter
        from espnet.nets.chainer_backend.rnn.training import CustomParallelUpdater
        from espnet.nets.chainer_backend.rnn.training import CustomUpdater

    # Setup a converter
    converter = CustomConverter(subsampling_factor=model.subsample[0])

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    # set up training iterator and updater
    load_tr = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': True}  # Switch the mode of preprocessing
    )
    load_cv = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': False}  # Switch the mode of preprocessing
    )

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    accum_grad = args.accum_grad
    if ngpu <= 1:
        # make minibatch list (variable length)
        train = make_batchset(train_json,
                              args.batch_size,
                              args.maxlen_in,
                              args.maxlen_out,
                              args.minibatches,
                              min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                              shortest_first=use_sortagrad,
                              count=args.batch_count,
                              batch_bins=args.batch_bins,
                              batch_frames_in=args.batch_frames_in,
                              batch_frames_out=args.batch_frames_out,
                              batch_frames_inout=args.batch_frames_inout,
                              iaxis=0,
                              oaxis=0)
        # hack to make batchsize argument as 1
        # actual batchsize is included in a list
        if args.n_iter_processes > 0:
            train_iters = [
                ToggleableShufflingMultiprocessIterator(
                    TransformDataset(train, load_tr),
                    batch_size=1,
                    n_processes=args.n_iter_processes,
                    n_prefetch=8,
                    maxtasksperchild=20,
                    shuffle=not use_sortagrad)
            ]
        else:
            train_iters = [
                ToggleableShufflingSerialIterator(TransformDataset(
                    train, load_tr),
                                                  batch_size=1,
                                                  shuffle=not use_sortagrad)
            ]

        # set up updater
        updater = CustomUpdater(train_iters[0],
                                optimizer,
                                converter=converter,
                                device=gpu_id,
                                accum_grad=accum_grad)
    else:
        if args.batch_count not in ("auto", "seq") and args.batch_size == 0:
            raise NotImplementedError(
                "--batch-count 'bin' and 'frame' are not implemented in chainer multi gpu"
            )
        # set up minibatches
        train_subsets = []
        for gid in six.moves.xrange(ngpu):
            # make subset
            train_json_subset = {
                k: v
                for i, (k, v) in enumerate(train_json.items())
                if i % ngpu == gid
            }
            # make minibatch list (variable length)
            train_subsets += [
                make_batchset(train_json_subset, args.batch_size,
                              args.maxlen_in, args.maxlen_out,
                              args.minibatches)
            ]

        # each subset must have same length for MultiprocessParallelUpdater
        maxlen = max([len(train_subset) for train_subset in train_subsets])
        for train_subset in train_subsets:
            if maxlen != len(train_subset):
                for i in six.moves.xrange(maxlen - len(train_subset)):
                    train_subset += [train_subset[i]]

        # hack to make batchsize argument as 1
        # actual batchsize is included in a list
        if args.n_iter_processes > 0:
            train_iters = [
                ToggleableShufflingMultiprocessIterator(
                    TransformDataset(train_subsets[gid], load_tr),
                    batch_size=1,
                    n_processes=args.n_iter_processes,
                    n_prefetch=8,
                    maxtasksperchild=20,
                    shuffle=not use_sortagrad)
                for gid in six.moves.xrange(ngpu)
            ]
        else:
            train_iters = [
                ToggleableShufflingSerialIterator(TransformDataset(
                    train_subsets[gid], load_tr),
                                                  batch_size=1,
                                                  shuffle=not use_sortagrad)
                for gid in six.moves.xrange(ngpu)
            ]

        # set up updater
        updater = CustomParallelUpdater(train_iters,
                                        optimizer,
                                        converter=converter,
                                        devices=devices)

    # Set up a trainer
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler(train_iters),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     'epoch'))
    if args.opt == 'noam':
        from espnet.nets.chainer_backend.transformer.training import VaswaniRule
        trainer.extend(VaswaniRule('alpha',
                                   d=args.adim,
                                   warmup_steps=args.transformer_warmup_steps,
                                   scale=args.transformer_lr),
                       trigger=(1, 'iteration'))
    # Resume from a snapshot
    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    # set up validation iterator
    valid = make_batchset(valid_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout,
                          iaxis=0,
                          oaxis=0)

    if args.n_iter_processes > 0:
        valid_iter = chainer.iterators.MultiprocessIterator(
            TransformDataset(valid, load_cv),
            batch_size=1,
            repeat=False,
            shuffle=False,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20)
    else:
        valid_iter = chainer.iterators.SerialIterator(TransformDataset(
            valid, load_cv),
                                                      batch_size=1,
                                                      repeat=False,
                                                      shuffle=False)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        BaseEvaluator(valid_iter, model, converter=converter, device=gpu_id))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        logging.info('Using custom PlotAttentionReport')
        att_reporter = plot_class(att_vis_fn,
                                  data,
                                  args.outdir + "/att_ws",
                                  converter=converter,
                                  transform=load_cv,
                                  device=gpu_id)
        trainer.extend(att_reporter, trigger=(1, 'epoch'))
    else:
        att_reporter = None

    # Take a snapshot for each specified epoch
    trainer.extend(
        extensions.snapshot(filename='snapshot.ep.{.updater.epoch}'),
        trigger=(1, 'epoch'))

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport([
            'main/loss', 'validation/main/loss', 'main/loss_ctc',
            'validation/main/loss_ctc', 'main/loss_att',
            'validation/main/loss_att'
        ],
                              'epoch',
                              file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/acc', 'validation/main/acc'],
                              'epoch',
                              file_name='acc.png'))

    # Save best models
    trainer.extend(
        extensions.snapshot_object(model, 'model.loss.best'),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    if mtl_mode != 'ctc':
        trainer.extend(
            extensions.snapshot_object(model, 'model.acc.best'),
            trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc' and mtl_mode != 'ctc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best'),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best'),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters,
                                      'iteration')))
    report_keys = [
        'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
        'validation/main/loss', 'validation/main/loss_ctc',
        'validation/main/loss_att', 'main/acc', 'validation/main/acc',
        'elapsed_time'
    ]
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').eps),
                       trigger=(args.report_interval_iters, 'iteration'))
        report_keys.append('eps')
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(args.report_interval_iters, 'iteration'))

    trainer.extend(
        extensions.ProgressBar(update_interval=args.report_interval_iters))

    set_early_stop(trainer, args)
    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        writer = SummaryWriter(args.tensorboard_dir)
        trainer.extend(TensorboardLogger(writer, att_reporter),
                       trigger=(args.report_interval_iters, 'iteration'))

    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Exemple #29
0
def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['input'][0]['shape'][-1])
    odim = int(valid_json[utts[0]]['output'][0]['shape'][-1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    if args.enc_init is not None or args.dec_init is not None:
        model = load_trained_modules(idim, odim, args)
    elif args.asr_init is not None:
        model, _ = load_trained_model(args.asr_init)
    else:
        model_class = dynamic_import(args.model_module)
        model = model_class(idim, odim, args)
    assert isinstance(model, ASRInterface)

    subsampling_factor = model.subsample[0]

    if args.rnnlm is not None:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer,
                             rnnlm_args.unit))
        torch.load(args.rnnlm, rnnlm)
        model.rnnlm = rnnlm

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)),
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        if args.batch_size != 0:
            logging.info('batch size is automatically increased (%d -> %d)' %
                         (args.batch_size, args.batch_size * args.ngpu))
            args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    if args.train_dtype in ("float16", "float32", "float64"):
        dtype = getattr(torch, args.train_dtype)
    else:
        dtype = torch.float32
    logging.info(device)
    logging.info(dtype)
    model = model.to(device=device, dtype=dtype)

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         rho=0.95,
                                         eps=args.eps,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     weight_decay=args.weight_decay)
    elif args.opt == 'noam':
        from espnet.nets.pytorch_backend.rnn.optimizer import get_std_opt
        optimizer = get_std_opt(model, args.adim,
                                args.transformer_warmup_steps,
                                args.transformer_lr)
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # setup apex.amp
    if args.train_dtype in ("O0", "O1", "O2", "O3"):
        try:
            from apex import amp
        except ImportError as e:
            logging.error(
                f"You need to install apex for --train-dtype {args.train_dtype}. "
                "See https://github.com/NVIDIA/apex#linux")
            raise e
        if args.opt == 'noam':
            model, optimizer.optimizer = amp.initialize(
                model, optimizer.optimizer, opt_level=args.train_dtype)
        else:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.train_dtype)
        use_apex = True
    else:
        use_apex = False

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter(subsampling_factor=subsampling_factor,
                                dtype=dtype)

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # make minibatch list (variable length)
    train = make_batchset(train_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          shortest_first=use_sortagrad,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout,
                          iaxis=0,
                          oaxis=0)
    valid = make_batchset(valid_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout,
                          iaxis=0,
                          oaxis=0)

    load_tr = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': True}  # Switch the mode of preprocessing
    )
    load_cv = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': False}  # Switch the mode of preprocessing
    )
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    if args.n_iter_processes > 0:
        train_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(train, load_tr),
            batch_size=1,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(valid, load_cv),
            batch_size=1,
            repeat=False,
            shuffle=False,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20)
    else:
        train_iter = ToggleableShufflingSerialIterator(
            TransformDataset(train, load_tr),
            batch_size=1,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingSerialIterator(TransformDataset(
            valid, load_cv),
                                                       batch_size=1,
                                                       repeat=False,
                                                       shuffle=False)

    # Set up a trainer
    updater = CustomUpdater(model,
                            args.grad_clip,
                            train_iter,
                            optimizer,
                            converter,
                            device,
                            args.ngpu,
                            args.grad_noise,
                            args.accum_grad,
                            use_apex=use_apex)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     'epoch'))

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        CustomEvaluator(model, valid_iter, reporter, converter, device,
                        args.ngpu))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        att_reporter = plot_class(att_vis_fn,
                                  data,
                                  args.outdir + "/att_ws",
                                  converter=converter,
                                  transform=load_cv,
                                  device=device)
        trainer.extend(att_reporter, trigger=(1, 'epoch'))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport([
            'main/loss', 'validation/main/loss', 'main/loss_ctc',
            'validation/main/loss_ctc', 'main/loss_att',
            'validation/main/loss_att'
        ],
                              'epoch',
                              file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/acc', 'validation/main/acc'],
                              'epoch',
                              file_name='acc.png'))
    trainer.extend(
        extensions.PlotReport(['main/cer_ctc', 'validation/main/cer_ctc'],
                              'epoch',
                              file_name='cer.png'))

    # Save best models
    trainer.extend(
        snapshot_object(model, 'model.loss.best'),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    if mtl_mode != 'ctc':
        trainer.extend(
            snapshot_object(model, 'model.acc.best'),
            trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # save snapshot which contains model and optimizer states
    trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc' and mtl_mode != 'ctc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters,
                                      'iteration')))
    report_keys = [
        'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
        'validation/main/loss', 'validation/main/loss_ctc',
        'validation/main/loss_att', 'main/acc', 'validation/main/acc',
        'main/cer_ctc', 'validation/main/cer_ctc', 'elapsed_time'
    ]
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').
            param_groups[0]["eps"]),
                       trigger=(args.report_interval_iters, 'iteration'))
        report_keys.append('eps')
    if args.report_cer:
        report_keys.append('validation/main/cer')
    if args.report_wer:
        report_keys.append('validation/main/wer')
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(args.report_interval_iters, 'iteration'))

    trainer.extend(
        extensions.ProgressBar(update_interval=args.report_interval_iters))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        trainer.extend(TensorboardLogger(SummaryWriter(args.tensorboard_dir),
                                         att_reporter),
                       trigger=(args.report_interval_iters, "iteration"))
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Exemple #30
0
def train(args):
    """Train E2E-TTS model."""
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())

    # reverse input and output dimension
    idim = int(valid_json[utts[0]]['output'][0]['shape'][1])
    odim = int(valid_json[utts[0]]['input'][0]['shape'][1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # get extra input and output dimenstion
    if args.use_speaker_embedding:
        args.spk_embed_dim = int(valid_json[utts[0]]['input'][1]['shape'][0])
    else:
        args.spk_embed_dim = None
    if args.use_second_target:
        args.spc_dim = int(valid_json[utts[0]]['input'][1]['shape'][1])
    else:
        args.spc_dim = None

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to' + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)),
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # specify model architecture
    model_class = dynamic_import(args.model_module)
    model = model_class(idim, odim, args)
    assert isinstance(model, TTSInterface)
    logging.info(model)
    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
        if args.batch_size != 0:
            logging.warning(
                'batch size is automatically increased (%d -> %d)' %
                (args.batch_size, args.batch_size * args.ngpu))
            args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)

    # Setup an optimizer
    if args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     args.lr,
                                     eps=args.eps,
                                     weight_decay=args.weight_decay)
    elif args.opt == 'noam':
        from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt
        optimizer = get_std_opt(model, args.adim,
                                args.transformer_warmup_steps,
                                args.transformer_lr)
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, 'target', reporter)
    setattr(optimizer, 'serialize', lambda s: reporter.serialize(s))

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    if use_sortagrad:
        args.batch_sort_key = "input"
    # make minibatch list (variable length)
    train_batchset = make_batchset(
        train_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        batch_sort_key=args.batch_sort_key,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        shortest_first=use_sortagrad,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        swap_io=True,
        iaxis=0,
        oaxis=0)
    valid_batchset = make_batchset(
        valid_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        batch_sort_key=args.batch_sort_key,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        swap_io=True,
        iaxis=0,
        oaxis=0)

    load_tr = LoadInputsAndTargets(
        mode='tts',
        use_speaker_embedding=args.use_speaker_embedding,
        use_second_target=args.use_second_target,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': True},  # Switch the mode of preprocessing
        keep_all_data_on_mem=args.keep_all_data_on_mem,
    )

    load_cv = LoadInputsAndTargets(
        mode='tts',
        use_speaker_embedding=args.use_speaker_embedding,
        use_second_target=args.use_second_target,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': False},  # Switch the mode of preprocessing
        keep_all_data_on_mem=args.keep_all_data_on_mem,
    )

    converter = CustomConverter()
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    train_iter = {
        'main':
        ChainerDataLoader(dataset=TransformDataset(
            train_batchset, lambda data: converter([load_tr(data)])),
                          batch_size=1,
                          num_workers=args.num_iter_processes,
                          shuffle=not use_sortagrad,
                          collate_fn=lambda x: x[0])
    }
    valid_iter = {
        'main':
        ChainerDataLoader(dataset=TransformDataset(
            valid_batchset, lambda data: converter([load_cv(data)])),
                          batch_size=1,
                          shuffle=False,
                          collate_fn=lambda x: x[0],
                          num_workers=args.num_iter_processes)
    }

    # Set up a trainer
    updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer,
                            device, args.accum_grad)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    # set intervals
    eval_interval = (args.eval_interval_epochs, 'epoch')
    save_interval = (args.save_interval_epochs, 'epoch')
    report_interval = (args.report_interval_iters, 'iteration')

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(CustomEvaluator(model, valid_iter, reporter, device),
                   trigger=eval_interval)

    # Save snapshot for each epoch
    trainer.extend(torch_snapshot(), trigger=save_interval)

    # Save best models
    trainer.extend(snapshot_object(model, 'model.loss.best'),
                   trigger=training.triggers.MinValueTrigger(
                       'validation/main/loss', trigger=eval_interval))

    # Save attention figure for each epoch
    if args.num_save_attention > 0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        att_reporter = plot_class(att_vis_fn,
                                  data,
                                  args.outdir + '/att_ws',
                                  converter=converter,
                                  transform=load_cv,
                                  device=device,
                                  reverse=True)
        trainer.extend(att_reporter, trigger=eval_interval)
    else:
        att_reporter = None

    # Make a plot for training and validation values
    if hasattr(model, "module"):
        base_plot_keys = model.module.base_plot_keys
    else:
        base_plot_keys = model.base_plot_keys
    plot_keys = []
    for key in base_plot_keys:
        plot_key = ['main/' + key, 'validation/main/' + key]
        trainer.extend(extensions.PlotReport(plot_key,
                                             'epoch',
                                             file_name=key + '.png'),
                       trigger=eval_interval)
        plot_keys += plot_key
    trainer.extend(extensions.PlotReport(plot_keys,
                                         'epoch',
                                         file_name='all_loss.png'),
                   trigger=eval_interval)

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=report_interval))
    report_keys = ['epoch', 'iteration', 'elapsed_time'] + plot_keys
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=report_interval)
    trainer.extend(extensions.ProgressBar(), trigger=report_interval)

    set_early_stop(trainer, args)
    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        writer = SummaryWriter(args.tensorboard_dir)
        trainer.extend(TensorboardLogger(writer, att_reporter),
                       trigger=report_interval)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     'epoch'))

    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)