def test_model_trainable_and_decodable(module, num_encs, model_dict): args = make_arg(num_encs=num_encs, **model_dict) batch = prepare_inputs("pytorch", num_encs) # test trainable m = importlib.import_module(module) model = m.E2E([40 for _ in range(num_encs)], 5, args) loss = model(*batch) loss.backward() # trainable # test attention plot dummy_json = make_dummy_json(num_encs, [10, 20], [10, 20], idim=40, odim=5, num_inputs=num_encs) batchset = make_batchset(dummy_json, 2, 2 ** 10, 2 ** 10, shortest_first=True) att_ws = model.calculate_all_attentions(*convert_batch( batchset[0], "pytorch", idim=40, odim=5, num_inputs=num_encs)) from espnet.asr.asr_utils import PlotAttentionReport tmpdir = tempfile.mkdtemp() plot = PlotAttentionReport(model.calculate_all_attentions, batchset[0], tmpdir, None, None, None) for i in range(num_encs): # att-encoder att_w = plot.get_attention_weight(0, att_ws[i][0]) plot._plot_and_save_attention(att_w, '{}/att{}.png'.format(tmpdir, i)) # han att_w = plot.get_attention_weight(0, att_ws[num_encs][0]) plot._plot_and_save_attention(att_w, '{}/han.png'.format(tmpdir), han_mode=True) # test decodable with torch.no_grad(), chainer.no_backprop_mode(): in_data = [np.random.randn(10, 40) for _ in range(num_encs)] model.recognize(in_data, args, args.char_list) # decodable if "pytorch" in module: batch_in_data = [[np.random.randn(10, 40), np.random.randn(5, 40)] for _ in range(num_encs)] model.recognize_batch(batch_in_data, args, args.char_list) # batch decodable
def train(args): """Train with the given args :param Namespace args: The program arguments """ set_deterministic_pytorch(args) # check cuda availability if not torch.cuda.is_available(): logging.warning('cuda is not available') # get input and output dimension info with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] utts = list(valid_json.keys()) # reverse input and output dimension idim = int(valid_json[utts[0]]['output'][0]['shape'][1]) odim = int(valid_json[utts[0]]['input'][0]['shape'][1]) if args.use_cbhg: args.spc_dim = int(valid_json[utts[0]]['input'][1]['shape'][1]) if args.use_speaker_embedding: args.spk_embed_dim = int(valid_json[utts[0]]['input'][1]['shape'][0]) else: args.spk_embed_dim = None logging.info('#input dims : ' + str(idim)) logging.info('#output dims: ' + str(odim)) # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to' + model_conf) f.write( json.dumps((idim, odim, vars(args)), indent=4, sort_keys=True).encode('utf_8')) for key in sorted(vars(args).keys()): logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) # specify model architecture tacotron2 = Tacotron2(idim, odim, args) logging.info(tacotron2) # check the use of multi-gpu if args.ngpu > 1: tacotron2 = torch.nn.DataParallel(tacotron2, device_ids=list(range(args.ngpu))) logging.info('batch size is automatically increased (%d -> %d)' % (args.batch_size, args.batch_size * args.ngpu)) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") tacotron2 = tacotron2.to(device) # define loss model = Tacotron2Loss(tacotron2, args.use_masking, args.bce_pos_weight) reporter = model.reporter # Setup an optimizer optimizer = torch.optim.Adam(model.parameters(), args.lr, eps=args.eps, weight_decay=args.weight_decay) # FIXME: TOO DIRTY HACK setattr(optimizer, 'target', reporter) setattr(optimizer, 'serialize', lambda s: reporter.serialize(s)) # Setup a converter converter = CustomConverter( return_targets=True, use_speaker_embedding=args.use_speaker_embedding, use_second_target=args.use_cbhg, preprocess_conf=args.preprocess_conf) # read json data with open(args.train_json, 'rb') as f: train_json = json.load(f)['utts'] with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 if use_sortagrad: args.batch_sort_key = "input" # make minibatch list (variable length) train_batchset = make_batchset( train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad) valid_batchset = make_batchset( valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad) # hack to make batchsize argument as 1 # actual bathsize is included in a list if args.n_iter_processes > 0: train_iter = ToggleableShufflingMultiprocessIterator( TransformDataset(train_batchset, converter.transform), batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20, shuffle=not use_sortagrad) valid_iter = ToggleableShufflingMultiprocessIterator( TransformDataset(valid_batchset, converter.transform), batch_size=1, repeat=False, shuffle=False, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20) else: train_iter = ToggleableShufflingSerialIterator( TransformDataset(train_batchset, converter.transform), batch_size=1, shuffle=not use_sortagrad) valid_iter = ToggleableShufflingSerialIterator(TransformDataset( valid_batchset, converter.transform), batch_size=1, repeat=False, shuffle=False) # Set up a trainer updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer, converter, device) trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.outdir) # Resume from a snapshot if args.resume: logging.info('resumed from %s' % args.resume) torch_resume(args.resume, trainer) # Evaluate the model with the test dataset for each epoch trainer.extend( CustomEvaluator(model, valid_iter, reporter, converter, device)) # Save snapshot for each epoch trainer.extend(torch_snapshot(), trigger=(1, 'epoch')) # Save best models trainer.extend( extensions.snapshot_object(tacotron2, 'model.loss.best', savefun=torch_save), trigger=training.triggers.MinValueTrigger('validation/main/loss')) # Save attention figure for each epoch if args.num_save_attention > 0: data = sorted(list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True) if hasattr(tacotron2, "module"): att_vis_fn = tacotron2.module.calculate_all_attentions else: att_vis_fn = tacotron2.calculate_all_attentions att_reporter = PlotAttentionReport( att_vis_fn, data, args.outdir + '/att_ws', converter=CustomConverter( return_targets=False, use_speaker_embedding=args.use_speaker_embedding, preprocess_conf=args.preprocess_conf), device=device, reverse=True) trainer.extend(att_reporter, trigger=(1, 'epoch')) else: att_reporter = None # Make a plot for training and validation values plot_keys = [ 'main/loss', 'validation/main/loss', 'main/l1_loss', 'validation/main/l1_loss', 'main/mse_loss', 'validation/main/mse_loss', 'main/bce_loss', 'validation/main/bce_loss' ] trainer.extend( extensions.PlotReport(['main/l1_loss', 'validation/main/l1_loss'], 'epoch', file_name='l1_loss.png')) trainer.extend( extensions.PlotReport(['main/mse_loss', 'validation/main/mse_loss'], 'epoch', file_name='mse_loss.png')) trainer.extend( extensions.PlotReport(['main/bce_loss', 'validation/main/bce_loss'], 'epoch', file_name='bce_loss.png')) if args.use_cbhg: plot_keys += [ 'main/cbhg_l1_loss', 'validation/main/cbhg_l1_loss', 'main/cbhg_mse_loss', 'validation/main/cbhg_mse_loss' ] trainer.extend( extensions.PlotReport( ['main/cbhg_l1_loss', 'validation/main/cbhg_l1_loss'], 'epoch', file_name='cbhg_l1_loss.png')) trainer.extend( extensions.PlotReport( ['main/cbhg_mse_loss', 'validation/main/cbhg_mse_loss'], 'epoch', file_name='cbhg_mse_loss.png')) trainer.extend( extensions.PlotReport(plot_keys, 'epoch', file_name='loss.png')) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL, 'iteration'))) report_keys = plot_keys[:] report_keys[0:0] = ['epoch', 'iteration', 'elapsed_time'] trainer.extend(extensions.PrintReport(report_keys), trigger=(REPORT_INTERVAL, 'iteration')) trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": writer = SummaryWriter(log_dir=args.tensorboard_dir) trainer.extend(TensorboardLogger(writer, att_reporter)) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, 'epoch')) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def train(args): """Train with the given args :param Namespace args: The program arguments """ set_deterministic_pytorch(args) # check cuda availability if not torch.cuda.is_available(): logging.warning('cuda is not available') # get input and output dimension info with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] utts = list(valid_json.keys()) idim = int(valid_json[utts[0]]['input'][0]['shape'][1]) odim = int(valid_json[utts[0]]['output'][0]['shape'][1]) logging.info('#input dims : ' + str(idim)) logging.info('#output dims: ' + str(odim)) # specify attention, CTC, hybrid mode if args.mtlalpha == 1.0: mtl_mode = 'ctc' logging.info('Pure CTC mode') elif args.mtlalpha == 0.0: mtl_mode = 'att' logging.info('Pure attention mode') else: mtl_mode = 'mtl' logging.info('Multitask learning mode') # specify model architecture model = E2E(idim, odim, args) subsampling_factor = model.subsample[0] if args.rnnlm is not None: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer, rnnlm_args.unit)) torch.load(args.rnnlm, rnnlm) model.rnnlm = rnnlm # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to ' + model_conf) f.write( json.dumps((idim, odim, vars(args)), indent=4, sort_keys=True).encode('utf_8')) for key in sorted(vars(args).keys()): logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))) logging.info('batch size is automatically increased (%d -> %d)' % (args.batch_size, args.batch_size * args.ngpu)) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") model = model.to(device) # Setup an optimizer if args.opt == 'adadelta': optimizer = torch.optim.Adadelta(model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay) elif args.opt == 'adam': optimizer = torch.optim.Adam(model.parameters(), weight_decay=args.weight_decay) # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) # Setup a converter converter = CustomConverter(subsampling_factor=subsampling_factor, preprocess_conf=args.preprocess_conf) # read json data with open(args.train_json, 'rb') as f: train_json = json.load(f)['utts'] with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 # make minibatch list (variable length) train = make_batchset(train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad) valid = make_batchset(valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1) # hack to make batchsize argument as 1 # actual bathsize is included in a list if args.n_iter_processes > 0: train_iter = ToggleableShufflingMultiprocessIterator( TransformDataset(train, converter.transform), batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20, shuffle=not use_sortagrad) valid_iter = ToggleableShufflingMultiprocessIterator( TransformDataset(valid, converter.transform), batch_size=1, repeat=False, shuffle=False, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20) else: train_iter = ToggleableShufflingSerialIterator( TransformDataset(train, converter.transform), batch_size=1, shuffle=not use_sortagrad) valid_iter = ToggleableShufflingSerialIterator(TransformDataset( valid, converter.transform), batch_size=1, repeat=False, shuffle=False) # Set up a trainer updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer, converter, device, args.ngpu) trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.outdir) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, 'epoch')) # Resume from a snapshot if args.resume: logging.info('resumed from %s' % args.resume) torch_resume(args.resume, trainer) # Evaluate the model with the test dataset for each epoch trainer.extend( CustomEvaluator(model, valid_iter, reporter, converter, device)) # Save attention weight each epoch if args.num_save_attention > 0 and args.mtlalpha != 1.0: data = sorted(list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions else: att_vis_fn = model.calculate_all_attentions att_reporter = PlotAttentionReport(att_vis_fn, data, args.outdir + "/att_ws", converter=converter, device=device) trainer.extend(att_reporter, trigger=(1, 'epoch')) else: att_reporter = None # Make a plot for training and validation values trainer.extend( extensions.PlotReport([ 'main/loss', 'validation/main/loss', 'main/loss_ctc', 'validation/main/loss_ctc', 'main/loss_att', 'validation/main/loss_att' ], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport(['main/acc', 'validation/main/acc'], 'epoch', file_name='acc.png')) trainer.extend( extensions.PlotReport(['main/cer_ctc', 'validation/main/cer_ctc'], 'epoch', file_name='cer.png')) # Save best models trainer.extend( extensions.snapshot_object(model, 'model.loss.best', savefun=torch_save), trigger=training.triggers.MinValueTrigger('validation/main/loss')) if mtl_mode is not 'ctc': trainer.extend( extensions.snapshot_object(model, 'model.acc.best', savefun=torch_save), trigger=training.triggers.MaxValueTrigger('validation/main/acc')) # save snapshot which contains model and optimizer states trainer.extend(torch_snapshot(), trigger=(1, 'epoch')) # epsilon decay in the optimizer if args.opt == 'adadelta': if args.criterion == 'acc' and mtl_mode is not 'ctc': trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best', load_fn=torch_load), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) elif args.criterion == 'loss': trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best', load_fn=torch_load), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL, 'iteration'))) report_keys = [ 'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att', 'validation/main/loss', 'validation/main/loss_ctc', 'validation/main/loss_att', 'main/acc', 'validation/main/acc', 'main/cer_ctc', 'validation/main/cer_ctc', 'elapsed_time' ] if args.opt == 'adadelta': trainer.extend(extensions.observe_value( 'eps', lambda trainer: trainer.updater.get_optimizer('main'). param_groups[0]["eps"]), trigger=(REPORT_INTERVAL, 'iteration')) report_keys.append('eps') if args.report_cer: report_keys.append('validation/main/cer') if args.report_wer: report_keys.append('validation/main/wer') trainer.extend(extensions.PrintReport(report_keys), trigger=(REPORT_INTERVAL, 'iteration')) trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": writer = SummaryWriter(log_dir=args.tensorboard_dir) trainer.extend(TensorboardLogger(writer, att_reporter)) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def test_calculate_plot_attention_ctc(module, num_encs, model_dict): args = make_arg(num_encs=num_encs, **model_dict) m = importlib.import_module(module) model = m.E2E([2 for _ in range(num_encs)], 2, args) # test attention plot dummy_json = make_dummy_json(num_encs, [2, 3], [2, 3], idim=2, odim=2, num_inputs=num_encs) batchset = make_batchset(dummy_json, 2, 2**10, 2**10, shortest_first=True) att_ws = model.calculate_all_attentions(*convert_batch( batchset[0], "pytorch", idim=2, odim=2, num_inputs=num_encs)) from espnet.asr.asr_utils import PlotAttentionReport tmpdir = tempfile.mkdtemp() plot = PlotAttentionReport(model.calculate_all_attentions, batchset[0], tmpdir, None, None, None) for i in range(num_encs): # att-encoder att_w = plot.trim_attention_weight("utt_%d" % 0, att_ws[i][0]) plot._plot_and_save_attention(att_w, "{}/att{}.png".format(tmpdir, i)) # han att_w = plot.trim_attention_weight("utt_%d" % 0, att_ws[num_encs][0]) plot._plot_and_save_attention(att_w, "{}/han.png".format(tmpdir), han_mode=True) # test CTC plot ctc_probs = model.calculate_all_ctc_probs(*convert_batch( batchset[0], "pytorch", idim=2, odim=2, num_inputs=num_encs)) from espnet.asr.asr_utils import PlotCTCReport tmpdir = tempfile.mkdtemp() plot = PlotCTCReport(model.calculate_all_ctc_probs, batchset[0], tmpdir, None, None, None) if args.mtlalpha > 0: for i in range(num_encs): # ctc-encoder plot._plot_and_save_ctc(ctc_probs[i][0], "{}/ctc{}.png".format(tmpdir, i))
def train(args): """Train with the given args :param Namespace args: The program arguments """ # display chainer version logging.info('chainer version = ' + chainer.__version__) set_deterministic_chainer(args) # check cuda and cudnn availability if not chainer.cuda.available: logging.warning('cuda is not available') if not chainer.cuda.cudnn_enabled: logging.warning('cudnn is not available') # get input and output dimension info with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] utts = list(valid_json.keys()) idim = int(valid_json[utts[0]]['input'][0]['shape'][1]) odim = int(valid_json[utts[0]]['output'][0]['shape'][1]) logging.info('#input dims : ' + str(idim)) logging.info('#output dims: ' + str(odim)) # check attention type if args.atype not in ['noatt', 'dot', 'location']: raise NotImplementedError( 'chainer supports only noatt, dot, and location attention.') # specify attention, CTC, hybrid mode if args.mtlalpha == 1.0: mtl_mode = 'ctc' logging.info('Pure CTC mode') elif args.mtlalpha == 0.0: mtl_mode = 'att' logging.info('Pure attention mode') else: mtl_mode = 'mtl' logging.info('Multitask learning mode') # specify model architecture model = E2E(idim, odim, args, flag_return=False) # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to ' + model_conf) f.write( json.dumps((idim, odim, vars(args)), indent=4, sort_keys=True).encode('utf_8')) for key in sorted(vars(args).keys()): logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) # Set gpu ngpu = args.ngpu if ngpu == 1: gpu_id = 0 # Make a specified GPU current chainer.cuda.get_device_from_id(gpu_id).use() model.to_gpu() # Copy the model to the GPU logging.info('single gpu calculation.') elif ngpu > 1: gpu_id = 0 devices = {'main': gpu_id} for gid in six.moves.xrange(1, ngpu): devices['sub_%d' % gid] = gid logging.info('multi gpu calculation (#gpus = %d).' % ngpu) logging.info('batch size is automatically increased (%d -> %d)' % (args.batch_size, args.batch_size * args.ngpu)) else: gpu_id = -1 logging.info('cpu calculation') # Setup an optimizer if args.opt == 'adadelta': optimizer = chainer.optimizers.AdaDelta(eps=args.eps) elif args.opt == 'adam': optimizer = chainer.optimizers.Adam() optimizer.setup(model) optimizer.add_hook(chainer.optimizer.GradientClipping(args.grad_clip)) # read json data with open(args.train_json, 'rb') as f: train_json = json.load(f)['utts'] with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] # set up training iterator and updater converter = CustomConverter(subsampling_factor=model.subsample[0], preprocess_conf=args.preprocess_conf) use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 if ngpu <= 1: # make minibatch list (variable length) train = make_batchset(train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, shortest_first=use_sortagrad) # hack to make batchsize argument as 1 # actual batchsize is included in a list if args.n_iter_processes > 0: train_iters = list( ToggleableShufflingMultiprocessIterator( TransformDataset(train, converter.transform), batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20, shuffle=not use_sortagrad)) else: train_iters = list( ToggleableShufflingSerialIterator(TransformDataset( train, converter.transform), batch_size=1, shuffle=not use_sortagrad)) # set up updater updater = CustomUpdater(train_iters[0], optimizer, converter=converter, device=gpu_id) else: # set up minibatches train_subsets = [] for gid in six.moves.xrange(ngpu): # make subset train_json_subset = { k: v for i, (k, v) in enumerate(train_json.items()) if i % ngpu == gid } # make minibatch list (variable length) train_subsets += [ make_batchset(train_json_subset, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches) ] # each subset must have same length for MultiprocessParallelUpdater maxlen = max([len(train_subset) for train_subset in train_subsets]) for train_subset in train_subsets: if maxlen != len(train_subset): for i in six.moves.xrange(maxlen - len(train_subset)): train_subset += [train_subset[i]] # hack to make batchsize argument as 1 # actual batchsize is included in a list if args.n_iter_processes > 0: train_iters = [ ToggleableShufflingMultiprocessIterator( TransformDataset(train_subsets[gid], converter.transform), batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20, shuffle=not use_sortagrad) for gid in six.moves.xrange(ngpu) ] else: train_iters = [ ToggleableShufflingSerialIterator(TransformDataset( train_subsets[gid], converter.transform), batch_size=1, shuffle=not use_sortagrad) for gid in six.moves.xrange(ngpu) ] # set up updater updater = CustomParallelUpdater(train_iters, optimizer, converter=converter, devices=devices) # Set up a trainer trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.outdir) if use_sortagrad: trainer.extend( ShufflingEnabler(train_iters), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, 'epoch')) # Resume from a snapshot if args.resume: chainer.serializers.load_npz(args.resume, trainer) # set up validation iterator valid = make_batchset(valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches) if args.n_iter_processes > 0: valid_iter = chainer.iterators.MultiprocessIterator( TransformDataset(valid, converter.transform), batch_size=1, repeat=False, shuffle=False, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20) else: valid_iter = chainer.iterators.SerialIterator(TransformDataset( valid, converter.transform), batch_size=1, repeat=False, shuffle=False) # Evaluate the model with the test dataset for each epoch trainer.extend( extensions.Evaluator(valid_iter, model, converter=converter, device=gpu_id)) # Save attention weight each epoch if args.num_save_attention > 0 and args.mtlalpha != 1.0: data = sorted(list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions else: att_vis_fn = model.calculate_all_attentions att_reporter = PlotAttentionReport(att_vis_fn, data, args.outdir + "/att_ws", converter=converter, device=gpu_id) trainer.extend(att_reporter, trigger=(1, 'epoch')) else: att_reporter = None # Take a snapshot for each specified epoch trainer.extend( extensions.snapshot(filename='snapshot.ep.{.updater.epoch}'), trigger=(1, 'epoch')) # Make a plot for training and validation values trainer.extend( extensions.PlotReport([ 'main/loss', 'validation/main/loss', 'main/loss_ctc', 'validation/main/loss_ctc', 'main/loss_att', 'validation/main/loss_att' ], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport(['main/acc', 'validation/main/acc'], 'epoch', file_name='acc.png')) # Save best models trainer.extend( extensions.snapshot_object(model, 'model.loss.best'), trigger=training.triggers.MinValueTrigger('validation/main/loss')) if mtl_mode is not 'ctc': trainer.extend( extensions.snapshot_object(model, 'model.acc.best'), trigger=training.triggers.MaxValueTrigger('validation/main/acc')) # epsilon decay in the optimizer if args.opt == 'adadelta': if args.criterion == 'acc' and mtl_mode is not 'ctc': trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best'), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) elif args.criterion == 'loss': trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best'), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL, 'iteration'))) report_keys = [ 'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att', 'validation/main/loss', 'validation/main/loss_ctc', 'validation/main/loss_att', 'main/acc', 'validation/main/acc', 'elapsed_time' ] if args.opt == 'adadelta': trainer.extend(extensions.observe_value( 'eps', lambda trainer: trainer.updater.get_optimizer('main').eps), trigger=(REPORT_INTERVAL, 'iteration')) report_keys.append('eps') trainer.extend(extensions.PrintReport(report_keys), trigger=(REPORT_INTERVAL, 'iteration')) trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": writer = SummaryWriter(log_dir=args.tensorboard_dir) trainer.extend(TensorboardLogger(writer, att_reporter)) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def train(args): '''Run training''' # seed setting torch.manual_seed(args.seed) # debug mode setting # 0 would be fastest, but 1 seems to be reasonable # by considering reproducability # revmoe type check if args.debugmode < 2: chainer.config.type_check = False logging.info('torch type check is disabled') # use determinisitic computation or not if args.debugmode < 1: torch.backends.cudnn.deterministic = False logging.info('torch cudnn deterministic is disabled') else: torch.backends.cudnn.deterministic = True # check cuda availability if not torch.cuda.is_available(): logging.warning('cuda is not available') # get input and output dimension info with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] utts = list(valid_json.keys()) idim = int(valid_json[utts[0]]['input'][0]['shape'][1]) odim = int(valid_json[utts[0]]['output'][0]['shape'][1]) logging.info('#input dims : ' + str(idim)) logging.info('#output dims: ' + str(odim)) odim_adv = None if args.adv: odim_adv = int(valid_json[utts[0]]['output'][1]['shape'][1]) logging.info('#output dims adversarial: ' + str(odim_adv)) # specify attention, CTC, hybrid mode if args.mtlalpha == 1.0: mtl_mode = 'ctc' logging.info('Pure CTC mode') elif args.mtlalpha == 0.0: mtl_mode = 'att' logging.info('Pure attention mode') else: mtl_mode = 'mtl' logging.info('Multitask learning mode') # specify model architecture e2e = E2E(idim, odim, args, odim_adv=odim_adv) model = Loss(e2e, args.mtlalpha) if args.rnnlm is not None: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer, rnnlm_args.unit)) torch_load(args.rnnlm, rnnlm) e2e.rnnlm = rnnlm # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to ' + model_conf) f.write( json.dumps((idim, odim, odim_adv, vars(args)), indent=4, sort_keys=True).encode('utf_8')) for key in sorted(vars(args).keys()): logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))) logging.info('batch size is automatically increased (%d -> %d)' % (args.batch_size, args.batch_size * args.ngpu)) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") model = model.to(device) # Setup an optimizer # First distinguish between learning rates if args.ngpu > 1: param_grp = [{ 'params': model.module.predictor.enc.parameters(), 'lr': args.asr_lr }, { 'params': model.module.predictor.dec.parameters(), 'lr': args.asr_lr }, { 'params': model.module.predictor.adv.parameters(), 'lr': args.adv_lr }] else: param_grp = [{ 'params': model.predictor.enc.parameters(), 'lr': args.asr_lr }, { 'params': model.predictor.dec.parameters(), 'lr': args.asr_lr }, { 'params': model.predictor.adv.parameters(), 'lr': args.adv_lr }] if args.opt == 'adadelta': optimizer = torch.optim.Adadelta(param_grp, rho=0.95, eps=args.eps) elif args.opt == 'adam': optimizer = torch.optim.Adam(param_grp) # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) # Setup a converter converter = CustomConverter(e2e.subsample[0]) # read json data with open(args.train_json, 'rb') as f: train_json = json.load(f)['utts'] with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] # make minibatch list (variable length) train = make_batchset(train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1) valid = make_batchset(valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1) # hack to make batchsze argument as 1 # actual bathsize is included in a list if args.n_iter_processes > 0: train_iter = chainer.iterators.MultiprocessIterator( TransformDataset(train, converter.transform), batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20) valid_iter = chainer.iterators.MultiprocessIterator( TransformDataset(valid, converter.transform), batch_size=1, repeat=False, shuffle=False, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20) else: train_iter = chainer.iterators.SerialIterator(TransformDataset( train, converter.transform), batch_size=1) valid_iter = chainer.iterators.SerialIterator(TransformDataset( valid, converter.transform), batch_size=1, repeat=False, shuffle=False) # Prepare adversarial training schedule dictionary adv_schedule = get_advsched(args.adv, args.epochs) # Set up a trainer updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer, converter, device, args.ngpu, adv_schedule=adv_schedule, max_grlalpha=args.grlalpha) trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.outdir) # Resume from a snapshot if args.resume: logging.info('resumed from %s' % args.resume) #torch_resume(args.resume, trainer, weight_sharing=args.weight_sharing) torch_resume(args.resume, trainer, weight_sharing=args.weight_sharing, reinit_adv=args.reinit_adv) # Evaluate the model with the test dataset for each epoch trainer.extend( CustomEvaluator(model, valid_iter, reporter, converter, device)) # Save attention weight each epoch if args.num_save_attention > 0 and args.mtlalpha != 1.0: data = sorted(list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True) if hasattr(model, "module"): att_vis_fn = model.module.predictor.calculate_all_attentions else: att_vis_fn = model.predictor.calculate_all_attentions trainer.extend(PlotAttentionReport(att_vis_fn, data, args.outdir + "/att_ws", converter=converter, device=device), trigger=(1, 'epoch')) # Make a plot for training and validation values trainer.extend( extensions.PlotReport([ 'main/loss', 'validation/main/loss', 'main/loss_ctc', 'validation/main/loss_ctc', 'main/loss_att', 'validation/main/loss_att', 'main/loss_adv', 'validation/main/loss_adv' ], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport([ 'main/acc', 'validation/main/acc', 'main/acc_adv', 'validation/main/acc_adv' ], 'epoch', file_name='acc.png')) # Save best models trainer.extend( extensions.snapshot_object(model, 'model.loss.best', savefun=torch_save), trigger=training.triggers.MinValueTrigger('validation/main/loss')) if mtl_mode is not 'ctc': trainer.extend( extensions.snapshot_object(model, 'model.acc.best', savefun=torch_save), trigger=training.triggers.MaxValueTrigger('validation/main/acc')) # save snapshot which contains model and optimizer states trainer.extend(torch_snapshot(), trigger=(1, 'epoch')) # epsilon decay in the optimizer if args.opt == 'adadelta': if args.criterion == 'acc' and mtl_mode is not 'ctc': trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best', load_fn=torch_load), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) elif args.criterion == 'loss': trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best', load_fn=torch_load), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL, 'iteration'))) report_keys = [ 'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att', 'validation/main/loss', 'validation/main/loss_ctc', 'validation/main/loss_att', 'main/acc', 'validation/main/acc', 'elapsed_time' ] if args.opt == 'adadelta': trainer.extend(extensions.observe_value( 'eps', lambda trainer: trainer.updater.get_optimizer('main'). param_groups[0]["eps"]), trigger=(REPORT_INTERVAL, 'iteration')) report_keys.append('eps') if args.report_cer: report_keys.append('validation/main/cer') if args.report_wer: report_keys.append('validation/main/wer') if args.adv: report_keys.extend([ 'main/loss_adv', 'main/acc_adv', 'validation/main/loss_adv', 'validation/main/acc_adv' ]) trainer.extend(extensions.PrintReport(report_keys), trigger=(REPORT_INTERVAL, 'iteration')) trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL)) # Run the training trainer.run()