def test_sortagrad(swap_io): dummy_json = make_dummy_json(128, [1, 700], [1, 700]) if swap_io: batchset = make_batchset( dummy_json, 16, 2**10, 2**10, batch_sort_key="input", shortest_first=True, swap_io=True, ) key = "output" else: batchset = make_batchset(dummy_json, 16, 2**10, 2**10, shortest_first=True) key = "input" prev_start_ilen = batchset[0][0][1][key][0]["shape"][0] for batch in batchset: cur_start_ilen = batch[0][1][key][0]["shape"][0] assert cur_start_ilen >= prev_start_ilen prev_ilen = cur_start_ilen for sample in batch: cur_ilen = sample[1][key][0]["shape"][0] assert cur_ilen <= prev_ilen prev_ilen = cur_ilen prev_start_ilen = cur_start_ilen
def get_iter(args): # read json data with open(args.train_json, 'rb') as f: train_json = json.load(f)['utts'] with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 # make minibatch list (variable length) train = make_batchset(train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0, perturb_sampling=args.perturb_sampling) valid = make_batchset(valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0) load_tr = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': True}, train=True # Switch the mode of preprocessing ) load_cv = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': False}, train=False # Switch the mode of preprocessing ) # hack to make batchsize argument as 1 # actual bathsize is included in a list # default collate function converts numpy array to pytorch tensor # we used an empty collate function instead which returns list if args.train_dtype in ("float16", "float32", "float64"): dtype = getattr(torch, args.train_dtype) else: dtype = torch.float32 converter = CustomConverter(subsampling_factor=1, dtype=dtype) train_iter = {'main': ChainerDataLoader( dataset=TransformDataset(train, lambda data: converter([load_tr(data)])), batch_size=1, num_workers=args.n_iter_processes, shuffle=not use_sortagrad, collate_fn=lambda x: x[0])} valid_iter = {'main': ChainerDataLoader( dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])), batch_size=1, shuffle=False, collate_fn=lambda x: x[0], num_workers=args.n_iter_processes)} return train_iter, valid_iter
def test_model_trainable_and_decodable(module, num_encs, model_dict): args = make_arg(num_encs=num_encs, **model_dict) batch = prepare_inputs("pytorch", num_encs) # test trainable m = importlib.import_module(module) model = m.E2E([40 for _ in range(num_encs)], 5, args) loss = model(*batch) loss.backward() # trainable # test attention plot dummy_json = make_dummy_json(num_encs, [10, 20], [10, 20], idim=40, odim=5, num_inputs=num_encs) batchset = make_batchset(dummy_json, 2, 2 ** 10, 2 ** 10, shortest_first=True) att_ws = model.calculate_all_attentions(*convert_batch( batchset[0], "pytorch", idim=40, odim=5, num_inputs=num_encs)) from espnet.asr.asr_utils import PlotAttentionReport tmpdir = tempfile.mkdtemp() plot = PlotAttentionReport(model.calculate_all_attentions, batchset[0], tmpdir, None, None, None) for i in range(num_encs): # att-encoder att_w = plot.get_attention_weight(0, att_ws[i][0]) plot._plot_and_save_attention(att_w, '{}/att{}.png'.format(tmpdir, i)) # han att_w = plot.get_attention_weight(0, att_ws[num_encs][0]) plot._plot_and_save_attention(att_w, '{}/han.png'.format(tmpdir), han_mode=True) # test decodable with torch.no_grad(), chainer.no_backprop_mode(): in_data = [np.random.randn(10, 40) for _ in range(num_encs)] model.recognize(in_data, args, args.char_list) # decodable if "pytorch" in module: batch_in_data = [[np.random.randn(10, 40), np.random.randn(5, 40)] for _ in range(num_encs)] model.recognize_batch(batch_in_data, args, args.char_list) # batch decodable
def test_sortagrad_trainable(module): args = make_arg(sortagrad=1) dummy_json = make_dummy_json_mt(4, [10, 20], [10, 20], idim=6, odim=5) if module == "pytorch": import espnet.nets.pytorch_backend.e2e_mt as m else: import espnet.nets.chainer_backend.e2e_mt as m batchset = make_batchset(dummy_json, 2, 2**10, 2**10, shortest_first=True, mt=True, iaxis=1, oaxis=0) model = m.E2E(6, 5, args) for batch in batchset: loss = model(*convert_batch(batch, module, idim=6, odim=5)) if isinstance(loss, tuple): # chainer return several values as tuple loss[0].backward() # trainable else: loss.backward() # trainable with torch.no_grad(), chainer.no_backprop_mode(): in_data = np.random.randint(0, 5, (1, 100)) model.translate(in_data, args, args.char_list)
def test_sortagrad_trainable_with_batch_bins(module): args = make_arg(sortagrad=1) idim = 6 odim = 5 dummy_json = make_dummy_json_mt(4, [10, 20], [10, 20], idim=idim, odim=odim) if module == "pytorch": import espnet.nets.pytorch_backend.e2e_mt as m else: import espnet.nets.chainer_backend.e2e_mt as m batch_elems = 2000 batchset = make_batchset(dummy_json, batch_bins=batch_elems, shortest_first=True, mt=True, iaxis=1, oaxis=0) for batch in batchset: n = 0 for uttid, info in batch: ilen = int(info['output'][1]['shape'][0]) olen = int(info['output'][0]['shape'][0]) n += ilen * idim + olen * odim assert olen < batch_elems model = m.E2E(6, 5, args) for batch in batchset: attn_loss = model(*convert_batch(batch, module, idim=6, odim=5)) attn_loss.backward() with torch.no_grad(), chainer.no_backprop_mode(): in_data = np.random.randint(0, 5, (1, 100)) model.translate(in_data, args, args.char_list)
def __call__(self, trainer): """Calls the enabler on the given iterator :param trainer: The iterator """ args = self.args use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 # make new batch set for perturb_sampling mode # the following are imported from espnet.asr.pytorch_backend.asr:train train = make_batchset(self.train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0, perturb_sampling=args.perturb_sampling, rank=args.rank, world_size=args.world_size) dataset = TransformDataset( train, lambda data: self.converter([self.load_tr(data)])) self.train_iter['main'].perturb_sampling_shuffle(dataset) logging.warning("Doing Perturb-Sampling shuffling")
def test_sortagrad_trainable_with_batch_frames(module, num_encs): args = make_arg(num_encs=num_encs, sortagrad=1) idim = 2 odim = 2 dummy_json = make_dummy_json(4, [2, 3], [2, 3], idim=idim, odim=odim, num_inputs=num_encs) import espnet.nets.pytorch_backend.e2e_asr_mulenc as m batch_frames_in = 50 batch_frames_out = 50 batchset = make_batchset( dummy_json, batch_frames_in=batch_frames_in, batch_frames_out=batch_frames_out, shortest_first=True, ) for batch in batchset: i = 0 o = 0 for uttid, info in batch: i += int(info["input"][0]["shape"][0]) # based on the first input o += int(info["output"][0]["shape"][0]) assert i <= batch_frames_in assert o <= batch_frames_out model = m.E2E([2 for _ in range(num_encs)], 2, args) for batch in batchset: loss = model( *convert_batch(batch, module, idim=2, odim=2, num_inputs=num_encs)) loss.backward() # trainable with torch.no_grad(): in_data = [np.random.randn(100, 2) for _ in range(num_encs)] model.recognize(in_data, args, args.char_list)
def test_sortagrad_trainable_with_batch_bins(module, num_encs): args = make_arg(num_encs=num_encs, sortagrad=1) idim = 20 odim = 5 dummy_json = make_dummy_json(4, [10, 20], [10, 20], idim=idim, odim=odim, num_inputs=num_encs) import espnet.nets.pytorch_backend.e2e_asr_mulenc as m batch_elems = 2000 batchset = make_batchset(dummy_json, batch_bins=batch_elems, shortest_first=True) for batch in batchset: n = 0 for uttid, info in batch: ilen = int( info['input'][0]['shape'][0]) # based on the first input olen = int(info['output'][0]['shape'][0]) n += ilen * idim + olen * odim assert olen < batch_elems model = m.E2E([20 for _ in range(num_encs)], 5, args) for batch in batchset: loss = model(*convert_batch( batch, module, idim=20, odim=5, num_inputs=num_encs)) loss.backward() # trainable with torch.no_grad(), chainer.no_backprop_mode(): in_data = [np.random.randn(100, 20) for _ in range(num_encs)] model.recognize(in_data, args, args.char_list)
def test_sortagrad_trainable_with_batch_bins(module): args = make_arg(sortagrad=1) idim = 20 odim = 5 dummy_json = make_dummy_json_st(4, [10, 20], [10, 20], [10, 20], idim=idim, odim=odim) if module == "pytorch": import espnet.nets.pytorch_backend.e2e_st as m else: raise NotImplementedError batch_elems = 2000 batchset = make_batchset(dummy_json, batch_bins=batch_elems, shortest_first=True) for batch in batchset: n = 0 for uttid, info in batch: ilen = int(info['input'][0]['shape'][0]) olen = int(info['output'][0]['shape'][0]) n += ilen * idim + olen * odim assert olen < batch_elems model = m.E2E(20, 5, args) for batch in batchset: loss = model(*convert_batch(batch, module, idim=20, odim=5)) if isinstance(loss, tuple): # chainer return several values as tuple loss[0].backward() # trainable else: loss.backward() # trainable with torch.no_grad(), chainer.no_backprop_mode(): in_data = np.random.randn(100, 20) model.translate(in_data, args, args.char_list)
def test_sortagrad_trainable_with_batch_bins(module): args = make_arg(sortagrad=1) idim = 20 odim = 5 dummy_json = make_dummy_json(8, [100, 200], [100, 200], idim=idim, odim=odim) if module == "pytorch": import espnet.nets.pytorch_backend.e2e_asr as m else: import espnet.nets.chainer_backend.e2e_asr as m batch_elems = 20000 batchset = make_batchset(dummy_json, batch_bins=batch_elems, shortest_first=True) for batch in batchset: n = 0 for uttid, info in batch: ilen = int(info['input'][0]['shape'][0]) olen = int(info['output'][0]['shape'][0]) n += ilen * idim + olen * odim assert olen < batch_elems model = m.E2E(20, 5, args) for batch in batchset: attn_loss = model(*convert_batch(batch, module, idim=20, odim=5))[0] attn_loss.backward() with torch.no_grad(), chainer.no_backprop_mode(): in_data = np.random.randn(100, 20) model.recognize(in_data, args, args.char_list)
def test_sortagrad_trainable_with_batch_bins(module): args = make_arg(sortagrad=1) idim = 10 odim = 5 dummy_json = make_dummy_json(2, [3, 5], [3, 5], idim=idim, odim=odim) if module == "pytorch": import espnet.nets.pytorch_backend.e2e_asr as m else: import espnet.nets.chainer_backend.e2e_asr as m batch_elems = 2000 batchset = make_batchset(dummy_json, batch_bins=batch_elems, shortest_first=True) for batch in batchset: n = 0 for uttid, info in batch: ilen = int(info["input"][0]["shape"][0]) olen = int(info["output"][0]["shape"][0]) n += ilen * idim + olen * odim assert olen < batch_elems model = m.E2E(idim, odim, args) for batch in batchset: loss = model(*convert_batch(batch, module, idim=idim, odim=odim)) if isinstance(loss, tuple): # chainer return several values as tuple loss[0].backward() # trainable else: loss.backward() # trainable with torch.no_grad(), chainer.no_backprop_mode(): in_data = np.random.randn(10, idim) model.recognize(in_data, args, args.char_list)
def test_sortagrad_trainable_with_batch_frames(module): args = make_arg(sortagrad=1) idim = 20 odim = 5 dummy_json = make_dummy_json(4, [10, 20], [10, 20], idim=idim, odim=odim) if module == "pytorch": import espnet.nets.pytorch_backend.e2e_asr as m else: import espnet.nets.chainer_backend.e2e_asr as m batch_frames_in = 50 batch_frames_out = 50 batchset = make_batchset(dummy_json, batch_frames_in=batch_frames_in, batch_frames_out=batch_frames_out, shortest_first=True) for batch in batchset: i = 0 o = 0 for uttid, info in batch: i += int(info['input'][0]['shape'][0]) o += int(info['output'][0]['shape'][0]) assert i <= batch_frames_in assert o <= batch_frames_out model = m.E2E(20, 5, args) for batch in batchset: loss = model(*convert_batch(batch, module, idim=20, odim=5)) if isinstance(loss, tuple): # chainer return several values as tuple loss[0].backward() # trainable else: loss.backward() # trainable with torch.no_grad(), chainer.no_backprop_mode(): in_data = np.random.randn(100, 20) model.recognize(in_data, args, args.char_list)
def test_sortagrad_trainable_with_batch_frames(module): args = make_arg(sortagrad=1) idim = 6 odim = 5 dummy_json = make_dummy_json_mt(8, [100, 200], [100, 200], idim=idim, odim=odim) if module == "pytorch": import espnet.nets.pytorch_backend.e2e_mt as m else: import espnet.nets.chainer_backend.e2e_mt as m batch_frames_in = 200 batch_frames_out = 200 batchset = make_batchset(dummy_json, batch_frames_in=batch_frames_in, batch_frames_out=batch_frames_out, shortest_first=True, mt=True) for batch in batchset: i = 0 o = 0 for uttid, info in batch: i += int(info['output'][1]['shape'][0]) o += int(info['output'][0]['shape'][0]) assert i <= batch_frames_in assert o <= batch_frames_out model = m.E2E(6, 5, args) for batch in batchset: attn_loss = model(*convert_batch(batch, module, idim=6, odim=5)) attn_loss.backward() with torch.no_grad(), chainer.no_backprop_mode(): in_data = np.random.randint(0, 5, (1, 100)) model.translate(in_data, args, args.char_list)
def test_make_batchset(swap_io): dummy_json = make_dummy_json(128, [128, 512], [16, 128]) # check w/o adaptive batch size batchset = make_batchset(dummy_json, 24, 2 ** 10, 2 ** 10, min_batch_size=1, swap_io=swap_io) assert sum([len(batch) >= 1 for batch in batchset]) == len(batchset) print([len(batch) for batch in batchset]) batchset = make_batchset(dummy_json, 24, 2 ** 10, 2 ** 10, min_batch_size=10, swap_io=swap_io) assert sum([len(batch) >= 10 for batch in batchset]) == len(batchset) print([len(batch) for batch in batchset]) # check w/ adaptive batch size batchset = make_batchset(dummy_json, 24, 256, 64, min_batch_size=10, swap_io=swap_io) assert sum([len(batch) >= 10 for batch in batchset]) == len(batchset) print([len(batch) for batch in batchset]) batchset = make_batchset(dummy_json, 24, 256, 64, min_batch_size=10, swap_io=swap_io) assert sum([len(batch) >= 10 for batch in batchset]) == len(batchset)
def test_context_residual(module): args = make_arg(context_residual=True) dummy_json = make_dummy_json(8, [1, 100], [1, 100], idim=20, odim=5) if module == "pytorch": import espnet.nets.pytorch_backend.e2e_asr as m else: raise NotImplementedError batchset = make_batchset(dummy_json, 2, 2**10, 2**10, shortest_first=True) model = m.E2E(20, 5, args) for batch in batchset: attn_loss = model(*convert_batch(batch, module, idim=20, odim=5))[0] attn_loss.backward() with torch.no_grad(), chainer.no_backprop_mode(): in_data = np.random.randn(50, 20) model.recognize(in_data, args, args.char_list)
def test_sortagrad_trainable(module): args = make_arg(sortagrad=1) dummy_json = make_dummy_json_mt(4, [10, 20], [10, 20], idim=6, odim=5) if module == "pytorch": import espnet.nets.pytorch_backend.e2e_mt as m else: import espnet.nets.chainer_backend.e2e_mt as m batchset = make_batchset(dummy_json, 2, 2 ** 10, 2 ** 10, shortest_first=True, mt=True) model = m.E2E(6, 5, args) for batch in batchset: attn_loss = model(*convert_batch(batch, module, idim=6, odim=5)) attn_loss.backward() with torch.no_grad(), chainer.no_backprop_mode(): in_data = np.random.randint(0, 5, (1, 100)) model.translate(in_data, args, args.char_list)
def test_sortagrad_trainable(module, num_encs): args = make_arg(num_encs=num_encs, sortagrad=1) dummy_json = make_dummy_json(6, [10, 20], [10, 20], idim=20, odim=5, num_inputs=num_encs) import espnet.nets.pytorch_backend.e2e_asr_mulenc as m batchset = make_batchset(dummy_json, 2, 2 ** 10, 2 ** 10, shortest_first=True) model = m.E2E([20 for _ in range(num_encs)], 5, args) num_utts = 0 for batch in batchset: num_utts += len(batch) loss = model(*convert_batch(batch, module, idim=20, odim=5, num_inputs=num_encs)) loss.backward() # trainable assert num_utts == 6 with torch.no_grad(), chainer.no_backprop_mode(): in_data = [np.random.randn(50, 20) for _ in range(num_encs)] model.recognize(in_data, args, args.char_list)
def test_gradient_noise_injection(module, num_encs): args = make_arg(num_encs=num_encs, grad_noise=True) args_org = make_arg(num_encs=num_encs) dummy_json = make_dummy_json(num_encs, [10, 20], [10, 20], idim=20, odim=5, num_inputs=num_encs) import espnet.nets.pytorch_backend.e2e_asr_mulenc as m batchset = make_batchset(dummy_json, 2, 2 ** 10, 2 ** 10, shortest_first=True) model = m.E2E([20 for _ in range(num_encs)], 5, args) model_org = m.E2E([20 for _ in range(num_encs)], 5, args_org) for batch in batchset: loss = model(*convert_batch(batch, module, idim=20, odim=5, num_inputs=num_encs)) loss_org = model_org(*convert_batch(batch, module, idim=20, odim=5, num_inputs=num_encs)) loss.backward() grad = [param.grad for param in model.parameters()][10] loss_org.backward() grad_org = [param.grad for param in model_org.parameters()][10] assert grad[0] != grad_org[0]
def test_gradient_noise_injection(module): args = make_arg(grad_noise=True) args_org = make_arg() dummy_json = make_dummy_json(2, [3, 4], [3, 4], idim=10, odim=5) if module == "pytorch": import espnet.nets.pytorch_backend.e2e_asr as m else: import espnet.nets.chainer_backend.e2e_asr as m batchset = make_batchset(dummy_json, 2, 2**10, 2**10, shortest_first=True) model = m.E2E(10, 5, args) model_org = m.E2E(10, 5, args_org) for batch in batchset: loss = model(*convert_batch(batch, module, idim=10, odim=5)) loss_org = model_org(*convert_batch(batch, module, idim=10, odim=5)) loss.backward() grad = [param.grad for param in model.parameters()][10] loss_org.backward() grad_org = [param.grad for param in model_org.parameters()][10] assert grad[0] != grad_org[0]
def test_calculate_plot_attention_ctc(module, num_encs, model_dict): args = make_arg(num_encs=num_encs, **model_dict) m = importlib.import_module(module) model = m.E2E([2 for _ in range(num_encs)], 2, args) # test attention plot dummy_json = make_dummy_json(num_encs, [2, 3], [2, 3], idim=2, odim=2, num_inputs=num_encs) batchset = make_batchset(dummy_json, 2, 2**10, 2**10, shortest_first=True) att_ws = model.calculate_all_attentions(*convert_batch( batchset[0], "pytorch", idim=2, odim=2, num_inputs=num_encs)) from espnet.asr.asr_utils import PlotAttentionReport tmpdir = tempfile.mkdtemp() plot = PlotAttentionReport(model.calculate_all_attentions, batchset[0], tmpdir, None, None, None) for i in range(num_encs): # att-encoder att_w = plot.trim_attention_weight("utt_%d" % 0, att_ws[i][0]) plot._plot_and_save_attention(att_w, "{}/att{}.png".format(tmpdir, i)) # han att_w = plot.trim_attention_weight("utt_%d" % 0, att_ws[num_encs][0]) plot._plot_and_save_attention(att_w, "{}/han.png".format(tmpdir), han_mode=True) # test CTC plot ctc_probs = model.calculate_all_ctc_probs(*convert_batch( batchset[0], "pytorch", idim=2, odim=2, num_inputs=num_encs)) from espnet.asr.asr_utils import PlotCTCReport tmpdir = tempfile.mkdtemp() plot = PlotCTCReport(model.calculate_all_ctc_probs, batchset[0], tmpdir, None, None, None) if args.mtlalpha > 0: for i in range(num_encs): # ctc-encoder plot._plot_and_save_ctc(ctc_probs[i][0], "{}/ctc{}.png".format(tmpdir, i))
def test_sortagrad_trainable(module): args = make_arg(sortagrad=1) idim = 10 odim = 5 dummy_json = make_dummy_json(2, [3, 5], [3, 5], idim=idim, odim=odim) if module == "pytorch": import espnet.nets.pytorch_backend.e2e_asr as m else: import espnet.nets.chainer_backend.e2e_asr as m batchset = make_batchset(dummy_json, 2, 2**10, 2**10, shortest_first=True) model = m.E2E(idim, odim, args) for batch in batchset: loss = model(*convert_batch(batch, module, idim=idim, odim=odim)) if isinstance(loss, tuple): # chainer return several values as tuple loss[0].backward() # trainable else: loss.backward() # trainable with torch.no_grad(), chainer.no_backprop_mode(): in_data = np.random.randn(10, idim) model.recognize(in_data, args, args.char_list)
def load_asr_data(json_path, dic): with open(json_path, "rb") as f: train_feature = json.load(f)["utts"] converter = CustomConverter(subsampling_factor=1) train = make_batchset(train_feature, batch_size=len(train_feature)) name = [train[0][i][0] for i in range(len(train_feature))] load_tr = LoadInputsAndTargets( mode="asr", load_output=True, preprocess_conf=None, preprocess_args={"train": True}, # Switch the mode of preprocessing ) dataset = TransformDataset(train, lambda data: converter([load_tr(data)])) data1 = dataset[0][1] data2 = dataset[0][2] for i in tqdm(range(len(name))): ilen = data1[i] y = data2[i] dic[name[i]] = [ilen, y] return dic
def load_asr_data(json_path, dic, TMHINT=None): with open(json_path, "rb") as f: train_feature = json.load(f)["utts"] converter = CustomConverter(subsampling_factor=1) train = make_batchset(train_feature, batch_size=len(train_feature)) name = [train[0][i][0] for i in range(len(train_feature))] load_tr = LoadInputsAndTargets( mode="asr", load_output=True, preprocess_conf=None, preprocess_args={"train": True}, # Switch the mode of preprocessing ) dataset = TransformDataset(train, lambda data: converter([load_tr(data)])) data1 = dataset[0][1] data2 = dataset[0][2] for i in tqdm(range(len(name))): ilen = data1[i] y = data2[i] if TMHINT: speaker = ["M1", "M2", "M3", "F1", "F2", "F3"] tmp = name[i].split("_") if tmp[0] in speaker: if int(tmp[2]) >= 13: name_key = int(speaker.index( tmp[0])) * 200 + (int(tmp[2]) - 13) * 10 + int(tmp[3]) dic[str(name_key)] = [ilen, y] else: dic["_".join(tmp[1:])] = [ilen, y] else: name_key = name[i] dic[name_key] = [ilen, y] return dic
def train(args): """Train with the given args. Args: args (namespace): The program arguments. """ set_deterministic_pytorch(args) # check cuda availability if not torch.cuda.is_available(): logging.warning("cuda is not available") # get input and output dimension info with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] utts = list(valid_json.keys()) idim = int(valid_json[utts[0]]["output"][1]["shape"][1]) odim = int(valid_json[utts[0]]["output"][0]["shape"][1]) logging.info("#input dims : " + str(idim)) logging.info("#output dims: " + str(odim)) # specify model architecture model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args) assert isinstance(model, MTInterface) # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + "/model.json" with open(model_conf, "wb") as f: logging.info("writing a model config file to " + model_conf) f.write( json.dumps((idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True).encode("utf_8")) for key in sorted(vars(args).keys()): logging.info("ARGS: " + key + ": " + str(vars(args)[key])) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: if args.batch_size != 0: logging.warning( "batch size is automatically increased (%d -> %d)" % (args.batch_size, args.batch_size * args.ngpu)) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") if args.train_dtype in ("float16", "float32", "float64"): dtype = getattr(torch, args.train_dtype) else: dtype = torch.float32 model = model.to(device=device, dtype=dtype) logging.warning( "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), sum(p.numel() for p in model.parameters() if p.requires_grad) * 100.0 / sum(p.numel() for p in model.parameters()), )) # Setup an optimizer if args.opt == "adadelta": optimizer = torch.optim.Adadelta(model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay) elif args.opt == "adam": optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.opt == "noam": from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt optimizer = get_std_opt( model.parameters(), args.adim, args.transformer_warmup_steps, args.transformer_lr, ) else: raise NotImplementedError("unknown optimizer: " + args.opt) # setup apex.amp if args.train_dtype in ("O0", "O1", "O2", "O3"): try: from apex import amp except ImportError as e: logging.error( f"You need to install apex for --train-dtype {args.train_dtype}. " "See https://github.com/NVIDIA/apex#linux") raise e if args.opt == "noam": model, optimizer.optimizer = amp.initialize( model, optimizer.optimizer, opt_level=args.train_dtype) else: model, optimizer = amp.initialize(model, optimizer, opt_level=args.train_dtype) use_apex = True else: use_apex = False # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) # Setup a converter converter = CustomConverter() # read json data with open(args.train_json, "rb") as f: train_json = json.load(f)["utts"] with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 # make minibatch list (variable length) train = make_batchset( train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, mt=True, iaxis=1, oaxis=0, ) valid = make_batchset( valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, mt=True, iaxis=1, oaxis=0, ) load_tr = LoadInputsAndTargets(mode="mt", load_output=True) load_cv = LoadInputsAndTargets(mode="mt", load_output=True) # hack to make batchsize argument as 1 # actual bathsize is included in a list # default collate function converts numpy array to pytorch tensor # we used an empty collate function instead which returns list train_iter = ChainerDataLoader( dataset=TransformDataset(train, lambda data: converter([load_tr(data)])), batch_size=1, num_workers=args.n_iter_processes, shuffle=not use_sortagrad, collate_fn=lambda x: x[0], ) valid_iter = ChainerDataLoader( dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])), batch_size=1, shuffle=False, collate_fn=lambda x: x[0], num_workers=args.n_iter_processes, ) # Set up a trainer updater = CustomUpdater( model, args.grad_clip, {"main": train_iter}, optimizer, device, args.ngpu, False, args.accum_grad, use_apex=use_apex, ) trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), ) # Resume from a snapshot if args.resume: logging.info("resumed from %s" % args.resume) torch_resume(args.resume, trainer) # Evaluate the model with the test dataset for each epoch if args.save_interval_iters > 0: trainer.extend( CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu), trigger=(args.save_interval_iters, "iteration"), ) else: trainer.extend( CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu)) # Save attention weight each epoch if args.num_save_attention > 0: # NOTE: sort it by output lengths data = sorted( list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]["output"][0]["shape"][0]), reverse=True, ) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class att_reporter = plot_class( att_vis_fn, data, args.outdir + "/att_ws", converter=converter, transform=load_cv, device=device, ikey="output", iaxis=1, ) trainer.extend(att_reporter, trigger=(1, "epoch")) else: att_reporter = None # Make a plot for training and validation values trainer.extend( extensions.PlotReport(["main/loss", "validation/main/loss"], "epoch", file_name="loss.png")) trainer.extend( extensions.PlotReport(["main/acc", "validation/main/acc"], "epoch", file_name="acc.png")) trainer.extend( extensions.PlotReport(["main/ppl", "validation/main/ppl"], "epoch", file_name="ppl.png")) trainer.extend( extensions.PlotReport(["main/bleu", "validation/main/bleu"], "epoch", file_name="bleu.png")) # Save best models trainer.extend( snapshot_object(model, "model.loss.best"), trigger=training.triggers.MinValueTrigger("validation/main/loss"), ) trainer.extend( snapshot_object(model, "model.acc.best"), trigger=training.triggers.MaxValueTrigger("validation/main/acc"), ) # save snapshot which contains model and optimizer states if args.save_interval_iters > 0: trainer.extend( torch_snapshot(filename="snapshot.iter.{.updater.iteration}"), trigger=(args.save_interval_iters, "iteration"), ) else: trainer.extend(torch_snapshot(), trigger=(1, "epoch")) # epsilon decay in the optimizer if args.opt == "adadelta": if args.criterion == "acc": trainer.extend( restore_snapshot(model, args.outdir + "/model.acc.best", load_fn=torch_load), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) trainer.extend( adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) elif args.criterion == "loss": trainer.extend( restore_snapshot(model, args.outdir + "/model.loss.best", load_fn=torch_load), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) trainer.extend( adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) elif args.opt == "adam": if args.criterion == "acc": trainer.extend( restore_snapshot(model, args.outdir + "/model.acc.best", load_fn=torch_load), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) trainer.extend( adam_lr_decay(args.lr_decay), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) elif args.criterion == "loss": trainer.extend( restore_snapshot(model, args.outdir + "/model.loss.best", load_fn=torch_load), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) trainer.extend( adam_lr_decay(args.lr_decay), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) # Write a log of evaluation statistics for each epoch trainer.extend( extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))) report_keys = [ "epoch", "iteration", "main/loss", "validation/main/loss", "main/acc", "validation/main/acc", "main/ppl", "validation/main/ppl", "elapsed_time", ] if args.opt == "adadelta": trainer.extend( extensions.observe_value( "eps", lambda trainer: trainer.updater.get_optimizer("main"). param_groups[0]["eps"], ), trigger=(args.report_interval_iters, "iteration"), ) report_keys.append("eps") elif args.opt in ["adam", "noam"]: trainer.extend( extensions.observe_value( "lr", lambda trainer: trainer.updater.get_optimizer("main"). param_groups[0]["lr"], ), trigger=(args.report_interval_iters, "iteration"), ) report_keys.append("lr") if args.report_bleu: report_keys.append("main/bleu") report_keys.append("validation/main/bleu") trainer.extend( extensions.PrintReport(report_keys), trigger=(args.report_interval_iters, "iteration"), ) trainer.extend( extensions.ProgressBar(update_interval=args.report_interval_iters)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": from torch.utils.tensorboard import SummaryWriter trainer.extend( TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter), trigger=(args.report_interval_iters, "iteration"), ) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def load_multilingual_data(root_path, datasets, args, languages): def collate(minibatch): out = [] for b in minibatch: fbanks = [] tokens = [] language = None for _, info in b: fbanks.append( torch.tensor( kaldiio.load_mat(info["input"][0]["feat"].replace( data_config[dataset]["prefix"], root_path)))) tokens.append( torch.tensor([ int(s) for s in info["output"][0]["tokenid"].split() ])) if language is not None: assert language == info['category'] else: language = info['category'] ilens = torch.tensor([x.shape[0] for x in fbanks]) out.append(( pad_sequence(fbanks, batch_first=True, padding_value=0), ilens, pad_sequence(tokens, batch_first=True, padding_value=-1), language, )) return out[0] if len(out) == 1 else out idim = None odim_dict = {} mtl_train_json, mtl_dev_json, mtl_test_json = {}, {}, {} for idx, dataset in enumerate(datasets): language = dataset if language in low_resource_languages: template_key = "template100" else: template_key = "template150" data_config[dataset] = data_config[template_key].copy() for key in ["train", "val", "test", "token"]: data_config[dataset][key] = data_config[template_key][key].replace( "template", dataset) train_json = os.path.join(root_path, data_config[dataset]["train"]) dev_json = (os.path.join(root_path, data_config[dataset]["val"]) if data_config[dataset]["val"] else f"{root_path}/tmp_dev_set_{dataset}.json") test_json = os.path.join(root_path, data_config[dataset]["test"]) train_json, dev_json, test_json = load_json(train_json, dev_json, test_json) for key in train_json.keys(): train_json[key]['category'] = language for key in dev_json.keys(): dev_json[key]['category'] = language for key in test_json.keys(): test_json[key]['category'] = language #print(train_json) _, info = next(iter(train_json.items())) if idim is not None: assert idim == info["input"][0]["shape"][1] else: idim = info["input"][0]["shape"][1] odim_dict[language] = info["output"][0]["shape"][1] # Break if not in specified languages if dataset not in languages: continue mtl_train_json.update(train_json) mtl_dev_json.update(dev_json) mtl_test_json.update(test_json) #print(len(mtl_train_json), len(train_json)) train_json, dev_json, test_json = mtl_train_json, mtl_dev_json, mtl_test_json use_sortagrad = False # args.sortagrad == -1 or args.sortagrad > 0 # trainset = make_batchset(train_json, batch_size, max_length_in=800, max_length_out=150) if args.ngpu > 1 and not args.dist_train: min_batch_size = args.ngpu else: min_batch_size = 1 if args.meta_train: min_batch_size = 2 * min_batch_size trainset = make_batchset( train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=min_batch_size, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0, ) # devset = make_batchset(dev_json, batch_size, max_length_in=800, max_length_out=150) devset = make_batchset( dev_json, args.batch_size if args.ngpu <= 1 else int(args.batch_size / args.ngpu), args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0, ) testset = make_batchset( test_json, args.batch_size if args.ngpu <= 1 else int(args.batch_size / args.ngpu), args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0, ) if args.dist_train and args.ngpu > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( trainset) elif args.meta_train: train_sampler = BalancedBatchSampler(trainset) else: train_sampler = None train_loader = DataLoader( trainset, batch_size=1 if not args.meta_train else len(languages), collate_fn=collate, num_workers=args.n_iter_processes, shuffle=(train_sampler is None), pin_memory=True, sampler=train_sampler, ) dev_loader = DataLoader( devset, batch_size=1, collate_fn=collate, shuffle=False, num_workers=args.n_iter_processes, pin_memory=True, ) test_loader = DataLoader( testset, batch_size=1, collate_fn=collate, shuffle=False, num_workers=args.n_iter_processes, pin_memory=True, ) return (train_loader, dev_loader, test_loader), (idim, odim_dict)
def load_data(root_path, dataset, args): def collate(minibatch): fbanks = [] tokens = [] for _, info in minibatch[0]: fbanks.append( torch.tensor( kaldiio.load_mat(info["input"][0]["feat"].replace( data_config[dataset]["prefix"], root_path)))) tokens.append( torch.tensor( [int(s) for s in info["output"][0]["tokenid"].split()])) ilens = torch.tensor([x.shape[0] for x in fbanks]) return ( pad_sequence(fbanks, batch_first=True, padding_value=0), ilens, pad_sequence(tokens, batch_first=True, padding_value=-1), ) language = dataset if language in low_resource_languages: template_key = "template100" else: template_key = "template150" data_config[dataset] = data_config[template_key].copy() for key in ["train", "val", "test", "token"]: data_config[dataset][key] = data_config[template_key][key].replace( "template", dataset) train_json = os.path.join(root_path, data_config[dataset]["train"]) dev_json = (os.path.join(root_path, data_config[dataset]["val"]) if data_config[dataset]["val"] else f"{root_path}/tmp_dev_set_{dataset}.json") test_json = os.path.join(root_path, data_config[dataset]["test"]) train_json, dev_json, test_json = load_json(train_json, dev_json, test_json) _, info = next(iter(train_json.items())) idim = info["input"][0]["shape"][1] odim = info["output"][0]["shape"][1] use_sortagrad = False # args.sortagrad == -1 or args.sortagrad > 0 # trainset = make_batchset(train_json, batch_size, max_length_in=800, max_length_out=150) trainset = make_batchset( train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if (args.ngpu > 1 and not args.dist_train) else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0, ) # devset = make_batchset(dev_json, batch_size, max_length_in=800, max_length_out=150) devset = make_batchset( dev_json, args.batch_size if args.ngpu <= 1 else int(args.batch_size / args.ngpu), args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0, ) testset = make_batchset( test_json, args.batch_size if args.ngpu <= 1 else int(args.batch_size / args.ngpu), args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0, ) if args.dist_train and args.ngpu > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( trainset) else: train_sampler = None train_loader = DataLoader( trainset, batch_size=1, collate_fn=collate, num_workers=args.n_iter_processes, shuffle=(train_sampler is None), pin_memory=True, sampler=train_sampler, ) dev_loader = DataLoader( devset, batch_size=1, collate_fn=collate, shuffle=False, num_workers=args.n_iter_processes, pin_memory=True, ) test_loader = DataLoader( testset, batch_size=1, collate_fn=collate, shuffle=False, num_workers=args.n_iter_processes, pin_memory=True, ) return (train_loader, dev_loader, test_loader), (idim, odim)
def train(args): """Train E2E-TTS model.""" set_deterministic_pytorch(args) # check cuda availability if not torch.cuda.is_available(): logging.warning("cuda is not available") # get input and output dimension info with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] utts = list(valid_json.keys()) # reverse input and output dimension idim = int(valid_json[utts[0]]["output"][0]["shape"][1]) odim = int(valid_json[utts[0]]["input"][0]["shape"][1]) logging.info("#input dims : " + str(idim)) logging.info("#output dims: " + str(odim)) # get extra input and output dimenstion if args.use_speaker_embedding: args.spk_embed_dim = int(valid_json[utts[0]]["input"][1]["shape"][0]) else: args.spk_embed_dim = None if args.use_second_target: args.spc_dim = int(valid_json[utts[0]]["input"][1]["shape"][1]) else: args.spc_dim = None # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + "/model.json" with open(model_conf, "wb") as f: logging.info("writing a model config file to" + model_conf) f.write( json.dumps((idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True).encode("utf_8")) for key in sorted(vars(args).keys()): logging.info("ARGS: " + key + ": " + str(vars(args)[key])) # specify model architecture if args.enc_init is not None or args.dec_init is not None: model = load_trained_modules(idim, odim, args, TTSInterface) else: model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args) assert isinstance(model, TTSInterface) logging.info(model) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))) if args.batch_size != 0: logging.warning( "batch size is automatically increased (%d -> %d)" % (args.batch_size, args.batch_size * args.ngpu)) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") model = model.to(device) # freeze modules, if specified if args.freeze_mods: for mod, param in model.state_dict().items(): if any(mod.startswith(key) for key in args.freeze_mods): logging.info(f"{mod} is frozen not to be updated.") param.requires_grad = False # Setup an optimizer if args.opt == "adam": optimizer = torch.optim.Adam(model.parameters(), args.lr, eps=args.eps, weight_decay=args.weight_decay) elif args.opt == "noam": from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt optimizer = get_std_opt(model, args.adim, args.transformer_warmup_steps, args.transformer_lr) else: raise NotImplementedError("unknown optimizer: " + args.opt) # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) # read json data with open(args.train_json, "rb") as f: train_json = json.load(f)["utts"] with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 if use_sortagrad: args.batch_sort_key = "input" # make minibatch list (variable length) train_batchset = make_batchset( train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, batch_sort_key=args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, swap_io=True, iaxis=0, oaxis=0, ) valid_batchset = make_batchset( valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, batch_sort_key=args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, swap_io=True, iaxis=0, oaxis=0, ) load_tr = LoadInputsAndTargets( mode="tts", use_speaker_embedding=args.use_speaker_embedding, use_second_target=args.use_second_target, preprocess_conf=args.preprocess_conf, preprocess_args={"train": True}, # Switch the mode of preprocessing keep_all_data_on_mem=args.keep_all_data_on_mem, ) load_cv = LoadInputsAndTargets( mode="tts", use_speaker_embedding=args.use_speaker_embedding, use_second_target=args.use_second_target, preprocess_conf=args.preprocess_conf, preprocess_args={"train": False}, # Switch the mode of preprocessing keep_all_data_on_mem=args.keep_all_data_on_mem, ) converter = CustomConverter() # hack to make batchsize argument as 1 # actual bathsize is included in a list train_iter = { "main": ChainerDataLoader( dataset=TransformDataset(train_batchset, lambda data: converter([load_tr(data)])), batch_size=1, num_workers=args.num_iter_processes, shuffle=not use_sortagrad, collate_fn=lambda x: x[0], ) } valid_iter = { "main": ChainerDataLoader( dataset=TransformDataset(valid_batchset, lambda data: converter([load_cv(data)])), batch_size=1, shuffle=False, collate_fn=lambda x: x[0], num_workers=args.num_iter_processes, ) } # Set up a trainer updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer, device, args.accum_grad) trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) # Resume from a snapshot if args.resume: logging.info("resumed from %s" % args.resume) torch_resume(args.resume, trainer) # set intervals eval_interval = (args.eval_interval_epochs, "epoch") save_interval = (args.save_interval_epochs, "epoch") report_interval = (args.report_interval_iters, "iteration") # Evaluate the model with the test dataset for each epoch trainer.extend(CustomEvaluator(model, valid_iter, reporter, device), trigger=eval_interval) # Save snapshot for each epoch trainer.extend(torch_snapshot(), trigger=save_interval) # Save best models trainer.extend( snapshot_object(model, "model.loss.best"), trigger=training.triggers.MinValueTrigger("validation/main/loss", trigger=eval_interval), ) # Save attention figure for each epoch if args.num_save_attention > 0: data = sorted( list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]["output"][0]["shape"][0]), reverse=True, ) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class reduction_factor = model.module.reduction_factor else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class reduction_factor = model.reduction_factor if reduction_factor > 1: # fix the length to crop attention weight plot correctly data = copy.deepcopy(data) for idx in range(len(data)): ilen = data[idx][1]["input"][0]["shape"][0] data[idx][1]["input"][0]["shape"][0] = ilen // reduction_factor att_reporter = plot_class( att_vis_fn, data, args.outdir + "/att_ws", converter=converter, transform=load_cv, device=device, reverse=True, ) trainer.extend(att_reporter, trigger=eval_interval) else: att_reporter = None # Make a plot for training and validation values if hasattr(model, "module"): base_plot_keys = model.module.base_plot_keys else: base_plot_keys = model.base_plot_keys plot_keys = [] for key in base_plot_keys: plot_key = ["main/" + key, "validation/main/" + key] trainer.extend( extensions.PlotReport(plot_key, "epoch", file_name=key + ".png"), trigger=eval_interval, ) plot_keys += plot_key trainer.extend( extensions.PlotReport(plot_keys, "epoch", file_name="all_loss.png"), trigger=eval_interval, ) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport(trigger=report_interval)) report_keys = ["epoch", "iteration", "elapsed_time"] + plot_keys trainer.extend(extensions.PrintReport(report_keys), trigger=report_interval) trainer.extend(extensions.ProgressBar(), trigger=report_interval) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": writer = SummaryWriter(args.tensorboard_dir) trainer.extend(TensorboardLogger(writer, att_reporter), trigger=report_interval) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), ) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def train(args): """Train with the given args. Args: args (namespace): The program arguments. """ # display chainer version logging.info('chainer version = ' + chainer.__version__) set_deterministic_chainer(args) # check cuda and cudnn availability if not chainer.cuda.available: logging.warning('cuda is not available') if not chainer.cuda.cudnn_enabled: logging.warning('cudnn is not available') # get input and output dimension info with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] utts = list(valid_json.keys()) idim = int(valid_json[utts[0]]['input'][0]['shape'][1]) odim = int(valid_json[utts[0]]['output'][0]['shape'][1]) logging.info('#input dims : ' + str(idim)) logging.info('#output dims: ' + str(odim)) # specify attention, CTC, hybrid mode if args.mtlalpha == 1.0: mtl_mode = 'ctc' logging.info('Pure CTC mode') elif args.mtlalpha == 0.0: mtl_mode = 'att' logging.info('Pure attention mode') else: mtl_mode = 'mtl' logging.info('Multitask learning mode') # specify model architecture logging.info('import model module: ' + args.model_module) model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args, flag_return=False) assert isinstance(model, ASRInterface) # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to ' + model_conf) f.write( json.dumps((idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8')) for key in sorted(vars(args).keys()): logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) # Set gpu ngpu = args.ngpu if ngpu == 1: gpu_id = 0 # Make a specified GPU current chainer.cuda.get_device_from_id(gpu_id).use() model.to_gpu() # Copy the model to the GPU logging.info('single gpu calculation.') elif ngpu > 1: gpu_id = 0 devices = {'main': gpu_id} for gid in six.moves.xrange(1, ngpu): devices['sub_%d' % gid] = gid logging.info('multi gpu calculation (#gpus = %d).' % ngpu) logging.info('batch size is automatically increased (%d -> %d)' % (args.batch_size, args.batch_size * args.ngpu)) else: gpu_id = -1 logging.info('cpu calculation') # Setup an optimizer if args.opt == 'adadelta': optimizer = chainer.optimizers.AdaDelta(eps=args.eps) elif args.opt == 'adam': optimizer = chainer.optimizers.Adam() elif args.opt == 'noam': optimizer = chainer.optimizers.Adam(alpha=0, beta1=0.9, beta2=0.98, eps=1e-9) else: raise NotImplementedError('args.opt={}'.format(args.opt)) optimizer.setup(model) optimizer.add_hook(chainer.optimizer.GradientClipping(args.grad_clip)) # Setup Training Extensions if 'transformer' in args.model_module: from espnet.nets.chainer_backend.transformer.training import CustomConverter from espnet.nets.chainer_backend.transformer.training import CustomParallelUpdater from espnet.nets.chainer_backend.transformer.training import CustomUpdater else: from espnet.nets.chainer_backend.rnn.training import CustomConverter from espnet.nets.chainer_backend.rnn.training import CustomParallelUpdater from espnet.nets.chainer_backend.rnn.training import CustomUpdater # Setup a converter converter = CustomConverter(subsampling_factor=model.subsample[0]) # read json data with open(args.train_json, 'rb') as f: train_json = json.load(f)['utts'] with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] # set up training iterator and updater load_tr = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': True} # Switch the mode of preprocessing ) load_cv = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': False} # Switch the mode of preprocessing ) use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 accum_grad = args.accum_grad if ngpu <= 1: # make minibatch list (variable length) train = make_batchset(train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0) # hack to make batchsize argument as 1 # actual batchsize is included in a list if args.n_iter_processes > 0: train_iters = [ ToggleableShufflingMultiprocessIterator( TransformDataset(train, load_tr), batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20, shuffle=not use_sortagrad) ] else: train_iters = [ ToggleableShufflingSerialIterator(TransformDataset( train, load_tr), batch_size=1, shuffle=not use_sortagrad) ] # set up updater updater = CustomUpdater(train_iters[0], optimizer, converter=converter, device=gpu_id, accum_grad=accum_grad) else: if args.batch_count not in ("auto", "seq") and args.batch_size == 0: raise NotImplementedError( "--batch-count 'bin' and 'frame' are not implemented in chainer multi gpu" ) # set up minibatches train_subsets = [] for gid in six.moves.xrange(ngpu): # make subset train_json_subset = { k: v for i, (k, v) in enumerate(train_json.items()) if i % ngpu == gid } # make minibatch list (variable length) train_subsets += [ make_batchset(train_json_subset, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches) ] # each subset must have same length for MultiprocessParallelUpdater maxlen = max([len(train_subset) for train_subset in train_subsets]) for train_subset in train_subsets: if maxlen != len(train_subset): for i in six.moves.xrange(maxlen - len(train_subset)): train_subset += [train_subset[i]] # hack to make batchsize argument as 1 # actual batchsize is included in a list if args.n_iter_processes > 0: train_iters = [ ToggleableShufflingMultiprocessIterator( TransformDataset(train_subsets[gid], load_tr), batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20, shuffle=not use_sortagrad) for gid in six.moves.xrange(ngpu) ] else: train_iters = [ ToggleableShufflingSerialIterator(TransformDataset( train_subsets[gid], load_tr), batch_size=1, shuffle=not use_sortagrad) for gid in six.moves.xrange(ngpu) ] # set up updater updater = CustomParallelUpdater(train_iters, optimizer, converter=converter, devices=devices) # Set up a trainer trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.outdir) if use_sortagrad: trainer.extend( ShufflingEnabler(train_iters), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, 'epoch')) if args.opt == 'noam': from espnet.nets.chainer_backend.transformer.training import VaswaniRule trainer.extend(VaswaniRule('alpha', d=args.adim, warmup_steps=args.transformer_warmup_steps, scale=args.transformer_lr), trigger=(1, 'iteration')) # Resume from a snapshot if args.resume: chainer.serializers.load_npz(args.resume, trainer) # set up validation iterator valid = make_batchset(valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0) if args.n_iter_processes > 0: valid_iter = chainer.iterators.MultiprocessIterator( TransformDataset(valid, load_cv), batch_size=1, repeat=False, shuffle=False, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20) else: valid_iter = chainer.iterators.SerialIterator(TransformDataset( valid, load_cv), batch_size=1, repeat=False, shuffle=False) # Evaluate the model with the test dataset for each epoch trainer.extend( BaseEvaluator(valid_iter, model, converter=converter, device=gpu_id)) # Save attention weight each epoch if args.num_save_attention > 0 and args.mtlalpha != 1.0: data = sorted(list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class logging.info('Using custom PlotAttentionReport') att_reporter = plot_class(att_vis_fn, data, args.outdir + "/att_ws", converter=converter, transform=load_cv, device=gpu_id) trainer.extend(att_reporter, trigger=(1, 'epoch')) else: att_reporter = None # Take a snapshot for each specified epoch trainer.extend( extensions.snapshot(filename='snapshot.ep.{.updater.epoch}'), trigger=(1, 'epoch')) # Make a plot for training and validation values trainer.extend( extensions.PlotReport([ 'main/loss', 'validation/main/loss', 'main/loss_ctc', 'validation/main/loss_ctc', 'main/loss_att', 'validation/main/loss_att' ], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport(['main/acc', 'validation/main/acc'], 'epoch', file_name='acc.png')) # Save best models trainer.extend( extensions.snapshot_object(model, 'model.loss.best'), trigger=training.triggers.MinValueTrigger('validation/main/loss')) if mtl_mode != 'ctc': trainer.extend( extensions.snapshot_object(model, 'model.acc.best'), trigger=training.triggers.MaxValueTrigger('validation/main/acc')) # epsilon decay in the optimizer if args.opt == 'adadelta': if args.criterion == 'acc' and mtl_mode != 'ctc': trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best'), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) elif args.criterion == 'loss': trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best'), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) # Write a log of evaluation statistics for each epoch trainer.extend( extensions.LogReport(trigger=(args.report_interval_iters, 'iteration'))) report_keys = [ 'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att', 'validation/main/loss', 'validation/main/loss_ctc', 'validation/main/loss_att', 'main/acc', 'validation/main/acc', 'elapsed_time' ] if args.opt == 'adadelta': trainer.extend(extensions.observe_value( 'eps', lambda trainer: trainer.updater.get_optimizer('main').eps), trigger=(args.report_interval_iters, 'iteration')) report_keys.append('eps') trainer.extend(extensions.PrintReport(report_keys), trigger=(args.report_interval_iters, 'iteration')) trainer.extend( extensions.ProgressBar(update_interval=args.report_interval_iters)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": writer = SummaryWriter(args.tensorboard_dir) trainer.extend(TensorboardLogger(writer, att_reporter), trigger=(args.report_interval_iters, 'iteration')) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def train(args): """Train with the given args. Args: args (namespace): The program arguments. """ set_deterministic_pytorch(args) # check cuda availability if not torch.cuda.is_available(): logging.warning('cuda is not available') # get input and output dimension info with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] utts = list(valid_json.keys()) idim = int(valid_json[utts[0]]['input'][0]['shape'][-1]) odim = int(valid_json[utts[0]]['output'][0]['shape'][-1]) logging.info('#input dims : ' + str(idim)) logging.info('#output dims: ' + str(odim)) # specify attention, CTC, hybrid mode if args.mtlalpha == 1.0: mtl_mode = 'ctc' logging.info('Pure CTC mode') elif args.mtlalpha == 0.0: mtl_mode = 'att' logging.info('Pure attention mode') else: mtl_mode = 'mtl' logging.info('Multitask learning mode') if args.enc_init is not None or args.dec_init is not None: model = load_trained_modules(idim, odim, args) elif args.asr_init is not None: model, _ = load_trained_model(args.asr_init) else: model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args) assert isinstance(model, ASRInterface) subsampling_factor = model.subsample[0] if args.rnnlm is not None: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer, rnnlm_args.unit)) torch.load(args.rnnlm, rnnlm) model.rnnlm = rnnlm # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to ' + model_conf) f.write( json.dumps((idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8')) for key in sorted(vars(args).keys()): logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: if args.batch_size != 0: logging.info('batch size is automatically increased (%d -> %d)' % (args.batch_size, args.batch_size * args.ngpu)) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") if args.train_dtype in ("float16", "float32", "float64"): dtype = getattr(torch, args.train_dtype) else: dtype = torch.float32 logging.info(device) logging.info(dtype) model = model.to(device=device, dtype=dtype) # Setup an optimizer if args.opt == 'adadelta': optimizer = torch.optim.Adadelta(model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay) elif args.opt == 'adam': optimizer = torch.optim.Adam(model.parameters(), weight_decay=args.weight_decay) elif args.opt == 'noam': from espnet.nets.pytorch_backend.rnn.optimizer import get_std_opt optimizer = get_std_opt(model, args.adim, args.transformer_warmup_steps, args.transformer_lr) else: raise NotImplementedError("unknown optimizer: " + args.opt) # setup apex.amp if args.train_dtype in ("O0", "O1", "O2", "O3"): try: from apex import amp except ImportError as e: logging.error( f"You need to install apex for --train-dtype {args.train_dtype}. " "See https://github.com/NVIDIA/apex#linux") raise e if args.opt == 'noam': model, optimizer.optimizer = amp.initialize( model, optimizer.optimizer, opt_level=args.train_dtype) else: model, optimizer = amp.initialize(model, optimizer, opt_level=args.train_dtype) use_apex = True else: use_apex = False # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) # Setup a converter converter = CustomConverter(subsampling_factor=subsampling_factor, dtype=dtype) # read json data with open(args.train_json, 'rb') as f: train_json = json.load(f)['utts'] with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 # make minibatch list (variable length) train = make_batchset(train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0) valid = make_batchset(valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0) load_tr = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': True} # Switch the mode of preprocessing ) load_cv = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': False} # Switch the mode of preprocessing ) # hack to make batchsize argument as 1 # actual bathsize is included in a list if args.n_iter_processes > 0: train_iter = ToggleableShufflingMultiprocessIterator( TransformDataset(train, load_tr), batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20, shuffle=not use_sortagrad) valid_iter = ToggleableShufflingMultiprocessIterator( TransformDataset(valid, load_cv), batch_size=1, repeat=False, shuffle=False, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20) else: train_iter = ToggleableShufflingSerialIterator( TransformDataset(train, load_tr), batch_size=1, shuffle=not use_sortagrad) valid_iter = ToggleableShufflingSerialIterator(TransformDataset( valid, load_cv), batch_size=1, repeat=False, shuffle=False) # Set up a trainer updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer, converter, device, args.ngpu, args.grad_noise, args.accum_grad, use_apex=use_apex) trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.outdir) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, 'epoch')) # Resume from a snapshot if args.resume: logging.info('resumed from %s' % args.resume) torch_resume(args.resume, trainer) # Evaluate the model with the test dataset for each epoch trainer.extend( CustomEvaluator(model, valid_iter, reporter, converter, device, args.ngpu)) # Save attention weight each epoch if args.num_save_attention > 0 and args.mtlalpha != 1.0: data = sorted(list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class att_reporter = plot_class(att_vis_fn, data, args.outdir + "/att_ws", converter=converter, transform=load_cv, device=device) trainer.extend(att_reporter, trigger=(1, 'epoch')) else: att_reporter = None # Make a plot for training and validation values trainer.extend( extensions.PlotReport([ 'main/loss', 'validation/main/loss', 'main/loss_ctc', 'validation/main/loss_ctc', 'main/loss_att', 'validation/main/loss_att' ], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport(['main/acc', 'validation/main/acc'], 'epoch', file_name='acc.png')) trainer.extend( extensions.PlotReport(['main/cer_ctc', 'validation/main/cer_ctc'], 'epoch', file_name='cer.png')) # Save best models trainer.extend( snapshot_object(model, 'model.loss.best'), trigger=training.triggers.MinValueTrigger('validation/main/loss')) if mtl_mode != 'ctc': trainer.extend( snapshot_object(model, 'model.acc.best'), trigger=training.triggers.MaxValueTrigger('validation/main/acc')) # save snapshot which contains model and optimizer states trainer.extend(torch_snapshot(), trigger=(1, 'epoch')) # epsilon decay in the optimizer if args.opt == 'adadelta': if args.criterion == 'acc' and mtl_mode != 'ctc': trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best', load_fn=torch_load), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) elif args.criterion == 'loss': trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best', load_fn=torch_load), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) # Write a log of evaluation statistics for each epoch trainer.extend( extensions.LogReport(trigger=(args.report_interval_iters, 'iteration'))) report_keys = [ 'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att', 'validation/main/loss', 'validation/main/loss_ctc', 'validation/main/loss_att', 'main/acc', 'validation/main/acc', 'main/cer_ctc', 'validation/main/cer_ctc', 'elapsed_time' ] if args.opt == 'adadelta': trainer.extend(extensions.observe_value( 'eps', lambda trainer: trainer.updater.get_optimizer('main'). param_groups[0]["eps"]), trigger=(args.report_interval_iters, 'iteration')) report_keys.append('eps') if args.report_cer: report_keys.append('validation/main/cer') if args.report_wer: report_keys.append('validation/main/wer') trainer.extend(extensions.PrintReport(report_keys), trigger=(args.report_interval_iters, 'iteration')) trainer.extend( extensions.ProgressBar(update_interval=args.report_interval_iters)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": trainer.extend(TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter), trigger=(args.report_interval_iters, "iteration")) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def train(args): """Train E2E-TTS model.""" set_deterministic_pytorch(args) # check cuda availability if not torch.cuda.is_available(): logging.warning('cuda is not available') # get input and output dimension info with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] utts = list(valid_json.keys()) # reverse input and output dimension idim = int(valid_json[utts[0]]['output'][0]['shape'][1]) odim = int(valid_json[utts[0]]['input'][0]['shape'][1]) logging.info('#input dims : ' + str(idim)) logging.info('#output dims: ' + str(odim)) # get extra input and output dimenstion if args.use_speaker_embedding: args.spk_embed_dim = int(valid_json[utts[0]]['input'][1]['shape'][0]) else: args.spk_embed_dim = None if args.use_second_target: args.spc_dim = int(valid_json[utts[0]]['input'][1]['shape'][1]) else: args.spc_dim = None # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to' + model_conf) f.write( json.dumps((idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8')) for key in sorted(vars(args).keys()): logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) # specify model architecture model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args) assert isinstance(model, TTSInterface) logging.info(model) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))) if args.batch_size != 0: logging.warning( 'batch size is automatically increased (%d -> %d)' % (args.batch_size, args.batch_size * args.ngpu)) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") model = model.to(device) # Setup an optimizer if args.opt == 'adam': optimizer = torch.optim.Adam(model.parameters(), args.lr, eps=args.eps, weight_decay=args.weight_decay) elif args.opt == 'noam': from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt optimizer = get_std_opt(model, args.adim, args.transformer_warmup_steps, args.transformer_lr) else: raise NotImplementedError("unknown optimizer: " + args.opt) # FIXME: TOO DIRTY HACK setattr(optimizer, 'target', reporter) setattr(optimizer, 'serialize', lambda s: reporter.serialize(s)) # read json data with open(args.train_json, 'rb') as f: train_json = json.load(f)['utts'] with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 if use_sortagrad: args.batch_sort_key = "input" # make minibatch list (variable length) train_batchset = make_batchset( train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, batch_sort_key=args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, swap_io=True, iaxis=0, oaxis=0) valid_batchset = make_batchset( valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, batch_sort_key=args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, swap_io=True, iaxis=0, oaxis=0) load_tr = LoadInputsAndTargets( mode='tts', use_speaker_embedding=args.use_speaker_embedding, use_second_target=args.use_second_target, preprocess_conf=args.preprocess_conf, preprocess_args={'train': True}, # Switch the mode of preprocessing keep_all_data_on_mem=args.keep_all_data_on_mem, ) load_cv = LoadInputsAndTargets( mode='tts', use_speaker_embedding=args.use_speaker_embedding, use_second_target=args.use_second_target, preprocess_conf=args.preprocess_conf, preprocess_args={'train': False}, # Switch the mode of preprocessing keep_all_data_on_mem=args.keep_all_data_on_mem, ) converter = CustomConverter() # hack to make batchsize argument as 1 # actual bathsize is included in a list train_iter = { 'main': ChainerDataLoader(dataset=TransformDataset( train_batchset, lambda data: converter([load_tr(data)])), batch_size=1, num_workers=args.num_iter_processes, shuffle=not use_sortagrad, collate_fn=lambda x: x[0]) } valid_iter = { 'main': ChainerDataLoader(dataset=TransformDataset( valid_batchset, lambda data: converter([load_cv(data)])), batch_size=1, shuffle=False, collate_fn=lambda x: x[0], num_workers=args.num_iter_processes) } # Set up a trainer updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer, device, args.accum_grad) trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.outdir) # Resume from a snapshot if args.resume: logging.info('resumed from %s' % args.resume) torch_resume(args.resume, trainer) # set intervals eval_interval = (args.eval_interval_epochs, 'epoch') save_interval = (args.save_interval_epochs, 'epoch') report_interval = (args.report_interval_iters, 'iteration') # Evaluate the model with the test dataset for each epoch trainer.extend(CustomEvaluator(model, valid_iter, reporter, device), trigger=eval_interval) # Save snapshot for each epoch trainer.extend(torch_snapshot(), trigger=save_interval) # Save best models trainer.extend(snapshot_object(model, 'model.loss.best'), trigger=training.triggers.MinValueTrigger( 'validation/main/loss', trigger=eval_interval)) # Save attention figure for each epoch if args.num_save_attention > 0: data = sorted(list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class att_reporter = plot_class(att_vis_fn, data, args.outdir + '/att_ws', converter=converter, transform=load_cv, device=device, reverse=True) trainer.extend(att_reporter, trigger=eval_interval) else: att_reporter = None # Make a plot for training and validation values if hasattr(model, "module"): base_plot_keys = model.module.base_plot_keys else: base_plot_keys = model.base_plot_keys plot_keys = [] for key in base_plot_keys: plot_key = ['main/' + key, 'validation/main/' + key] trainer.extend(extensions.PlotReport(plot_key, 'epoch', file_name=key + '.png'), trigger=eval_interval) plot_keys += plot_key trainer.extend(extensions.PlotReport(plot_keys, 'epoch', file_name='all_loss.png'), trigger=eval_interval) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport(trigger=report_interval)) report_keys = ['epoch', 'iteration', 'elapsed_time'] + plot_keys trainer.extend(extensions.PrintReport(report_keys), trigger=report_interval) trainer.extend(extensions.ProgressBar(), trigger=report_interval) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": writer = SummaryWriter(args.tensorboard_dir) trainer.extend(TensorboardLogger(writer, att_reporter), trigger=report_interval) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, 'epoch')) # Run the training trainer.run() check_early_stop(trainer, args.epochs)