def train(args): """Train E2E-TTS model.""" set_deterministic_pytorch(args) # check cuda availability if not torch.cuda.is_available(): logging.warning("cuda is not available") # get input and output dimension info with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] utts = list(valid_json.keys()) # reverse input and output dimension idim = int(valid_json[utts[0]]["output"][0]["shape"][1]) odim = int(valid_json[utts[0]]["input"][0]["shape"][1]) logging.info("#input dims : " + str(idim)) logging.info("#output dims: " + str(odim)) # get extra input and output dimenstion if args.use_speaker_embedding: args.spk_embed_dim = int(valid_json[utts[0]]["input"][1]["shape"][0]) else: args.spk_embed_dim = None if args.use_second_target: args.spc_dim = int(valid_json[utts[0]]["input"][1]["shape"][1]) else: args.spc_dim = None # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + "/model.json" with open(model_conf, "wb") as f: logging.info("writing a model config file to" + model_conf) f.write( json.dumps((idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True).encode("utf_8")) for key in sorted(vars(args).keys()): logging.info("ARGS: " + key + ": " + str(vars(args)[key])) # specify model architecture if args.enc_init is not None or args.dec_init is not None: model = load_trained_modules(idim, odim, args, TTSInterface) else: model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args) assert isinstance(model, TTSInterface) logging.info(model) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))) if args.batch_size != 0: logging.warning( "batch size is automatically increased (%d -> %d)" % (args.batch_size, args.batch_size * args.ngpu)) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") model = model.to(device) # freeze modules, if specified if args.freeze_mods: for mod, param in model.state_dict().items(): if any(mod.startswith(key) for key in args.freeze_mods): logging.info(f"{mod} is frozen not to be updated.") param.requires_grad = False # Setup an optimizer if args.opt == "adam": optimizer = torch.optim.Adam(model.parameters(), args.lr, eps=args.eps, weight_decay=args.weight_decay) elif args.opt == "noam": from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt optimizer = get_std_opt(model, args.adim, args.transformer_warmup_steps, args.transformer_lr) else: raise NotImplementedError("unknown optimizer: " + args.opt) # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) # read json data with open(args.train_json, "rb") as f: train_json = json.load(f)["utts"] with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 if use_sortagrad: args.batch_sort_key = "input" # make minibatch list (variable length) train_batchset = make_batchset( train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, batch_sort_key=args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, swap_io=True, iaxis=0, oaxis=0, ) valid_batchset = make_batchset( valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, batch_sort_key=args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, swap_io=True, iaxis=0, oaxis=0, ) load_tr = LoadInputsAndTargets( mode="tts", use_speaker_embedding=args.use_speaker_embedding, use_second_target=args.use_second_target, preprocess_conf=args.preprocess_conf, preprocess_args={"train": True}, # Switch the mode of preprocessing keep_all_data_on_mem=args.keep_all_data_on_mem, ) load_cv = LoadInputsAndTargets( mode="tts", use_speaker_embedding=args.use_speaker_embedding, use_second_target=args.use_second_target, preprocess_conf=args.preprocess_conf, preprocess_args={"train": False}, # Switch the mode of preprocessing keep_all_data_on_mem=args.keep_all_data_on_mem, ) converter = CustomConverter() # hack to make batchsize argument as 1 # actual bathsize is included in a list train_iter = { "main": ChainerDataLoader( dataset=TransformDataset(train_batchset, lambda data: converter([load_tr(data)])), batch_size=1, num_workers=args.num_iter_processes, shuffle=not use_sortagrad, collate_fn=lambda x: x[0], ) } valid_iter = { "main": ChainerDataLoader( dataset=TransformDataset(valid_batchset, lambda data: converter([load_cv(data)])), batch_size=1, shuffle=False, collate_fn=lambda x: x[0], num_workers=args.num_iter_processes, ) } # Set up a trainer updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer, device, args.accum_grad) trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) # Resume from a snapshot if args.resume: logging.info("resumed from %s" % args.resume) torch_resume(args.resume, trainer) # set intervals eval_interval = (args.eval_interval_epochs, "epoch") save_interval = (args.save_interval_epochs, "epoch") report_interval = (args.report_interval_iters, "iteration") # Evaluate the model with the test dataset for each epoch trainer.extend(CustomEvaluator(model, valid_iter, reporter, device), trigger=eval_interval) # Save snapshot for each epoch trainer.extend(torch_snapshot(), trigger=save_interval) # Save best models trainer.extend( snapshot_object(model, "model.loss.best"), trigger=training.triggers.MinValueTrigger("validation/main/loss", trigger=eval_interval), ) # Save attention figure for each epoch if args.num_save_attention > 0: data = sorted( list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]["output"][0]["shape"][0]), reverse=True, ) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class reduction_factor = model.module.reduction_factor else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class reduction_factor = model.reduction_factor if reduction_factor > 1: # fix the length to crop attention weight plot correctly data = copy.deepcopy(data) for idx in range(len(data)): ilen = data[idx][1]["input"][0]["shape"][0] data[idx][1]["input"][0]["shape"][0] = ilen // reduction_factor att_reporter = plot_class( att_vis_fn, data, args.outdir + "/att_ws", converter=converter, transform=load_cv, device=device, reverse=True, ) trainer.extend(att_reporter, trigger=eval_interval) else: att_reporter = None # Make a plot for training and validation values if hasattr(model, "module"): base_plot_keys = model.module.base_plot_keys else: base_plot_keys = model.base_plot_keys plot_keys = [] for key in base_plot_keys: plot_key = ["main/" + key, "validation/main/" + key] trainer.extend( extensions.PlotReport(plot_key, "epoch", file_name=key + ".png"), trigger=eval_interval, ) plot_keys += plot_key trainer.extend( extensions.PlotReport(plot_keys, "epoch", file_name="all_loss.png"), trigger=eval_interval, ) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport(trigger=report_interval)) report_keys = ["epoch", "iteration", "elapsed_time"] + plot_keys trainer.extend(extensions.PrintReport(report_keys), trigger=report_interval) trainer.extend(extensions.ProgressBar(), trigger=report_interval) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": writer = SummaryWriter(args.tensorboard_dir) trainer.extend(TensorboardLogger(writer, att_reporter), trigger=report_interval) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), ) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def train(args): """Train with the given args :param Namespace args: The program arguments """ set_deterministic_pytorch(args) # check cuda availability if not torch.cuda.is_available(): logging.warning('cuda is not available') # get input and output dimension info with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] utts = list(valid_json.keys()) idim = int(valid_json[utts[0]]['input'][0]['shape'][1]) odim = int(valid_json[utts[0]]['output'][0]['shape'][1]) logging.info('#input dims : ' + str(idim)) logging.info('#output dims: ' + str(odim)) # specify attention, CTC, hybrid mode if args.mtlalpha == 1.0: mtl_mode = 'ctc' logging.info('Pure CTC mode') elif args.mtlalpha == 0.0: mtl_mode = 'att' logging.info('Pure attention mode') else: mtl_mode = 'mtl' logging.info('Multitask learning mode') # specify model architecture model = E2E(idim, odim, args) subsampling_factor = model.subsample[0] if args.rnnlm is not None: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM( len(args.char_list), rnnlm_args.layer, rnnlm_args.unit)) torch.load(args.rnnlm, rnnlm) model.rnnlm = rnnlm # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to ' + model_conf) f.write(json.dumps((idim, odim, vars(args)), indent=4, sort_keys=True).encode('utf_8')) for key in sorted(vars(args).keys()): logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))) logging.info('batch size is automatically increased (%d -> %d)' % ( args.batch_size, args.batch_size * args.ngpu)) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") model = model.to(device) # Setup an optimizer if args.opt == 'adadelta': optimizer = torch.optim.Adadelta( model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay) elif args.opt == 'adam': optimizer = torch.optim.Adam(model.parameters(), weight_decay=args.weight_decay) # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) # Setup a converter converter = CustomConverter(subsampling_factor=subsampling_factor, preprocess_conf=args.preprocess_conf) # read json data with open(args.train_json, 'rb') as f: train_json = json.load(f)['utts'] with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] # make minibatch list (variable length) train = make_batchset(train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1) valid = make_batchset(valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1) # hack to make batchsize argument as 1 # actual bathsize is included in a list if args.n_iter_processes > 0: train_iter = chainer.iterators.MultiprocessIterator( TransformDataset(train, converter.transform), batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20) valid_iter = chainer.iterators.MultiprocessIterator( TransformDataset(valid, converter.transform), batch_size=1, repeat=False, shuffle=False, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20) else: train_iter = chainer.iterators.SerialIterator( TransformDataset(train, converter.transform), batch_size=1) valid_iter = chainer.iterators.SerialIterator( TransformDataset(valid, converter.transform), batch_size=1, repeat=False, shuffle=False) # Set up a trainer updater = CustomUpdater( model, args.grad_clip, train_iter, optimizer, converter, device, args.ngpu) trainer = training.Trainer( updater, (args.epochs, 'epoch'), out=args.outdir) # Resume from a snapshot if args.resume: logging.info('resumed from %s' % args.resume) torch_resume(args.resume, trainer) # Evaluate the model with the test dataset for each epoch trainer.extend(CustomEvaluator(model, valid_iter, reporter, converter, device)) # Save attention weight each epoch if args.num_save_attention > 0 and args.mtlalpha != 1.0: data = sorted(list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions else: att_vis_fn = model.calculate_all_attentions att_reporter = PlotAttentionReport( att_vis_fn, data, args.outdir + "/att_ws", converter=converter, device=device) trainer.extend(att_reporter, trigger=(1, 'epoch')) else: att_reporter = None # Make a plot for training and validation values trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss', 'main/loss_ctc', 'validation/main/loss_ctc', 'main/loss_att', 'validation/main/loss_att'], 'epoch', file_name='loss.png')) trainer.extend(extensions.PlotReport(['main/acc', 'validation/main/acc'], 'epoch', file_name='acc.png')) # Save best models trainer.extend(extensions.snapshot_object(model, 'model.loss.best', savefun=torch_save), trigger=training.triggers.MinValueTrigger('validation/main/loss')) if mtl_mode is not 'ctc': trainer.extend(extensions.snapshot_object(model, 'model.acc.best', savefun=torch_save), trigger=training.triggers.MaxValueTrigger('validation/main/acc')) # save snapshot which contains model and optimizer states trainer.extend(torch_snapshot(), trigger=(1, 'epoch')) # epsilon decay in the optimizer if args.opt == 'adadelta': if args.criterion == 'acc' and mtl_mode is not 'ctc': trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best', load_fn=torch_load), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) elif args.criterion == 'loss': trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best', load_fn=torch_load), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL, 'iteration'))) report_keys = ['epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att', 'validation/main/loss', 'validation/main/loss_ctc', 'validation/main/loss_att', 'main/acc', 'validation/main/acc', 'elapsed_time'] if args.opt == 'adadelta': trainer.extend(extensions.observe_value( 'eps', lambda trainer: trainer.updater.get_optimizer('main').param_groups[0]["eps"]), trigger=(REPORT_INTERVAL, 'iteration')) report_keys.append('eps') if args.report_cer: report_keys.append('validation/main/cer') if args.report_wer: report_keys.append('validation/main/wer') trainer.extend(extensions.PrintReport( report_keys), trigger=(REPORT_INTERVAL, 'iteration')) trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": writer = SummaryWriter(args.tensorboard_dir) trainer.extend(TensorboardLogger(writer, att_reporter)) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def train(args): """Train with the given args :param Namespace args: The program arguments """ # display torch version logging.info('torch version = ' + torch.__version__) set_deterministic_pytorch(args) # check cuda and cudnn availability if not torch.cuda.is_available(): logging.warning('cuda is not available') # get special label ids unk = args.char_list_dict['<unk>'] eos = args.char_list_dict['<eos>'] # read tokens as a sequence of sentences train = read_tokens(args.train_label, args.char_list_dict) val = read_tokens(args.valid_label, args.char_list_dict) # count tokens n_train_tokens, n_train_oovs = count_tokens(train, unk) n_val_tokens, n_val_oovs = count_tokens(val, unk) logging.info('#vocab = ' + str(args.n_vocab)) logging.info('#sentences in the training data = ' + str(len(train))) logging.info('#tokens in the training data = ' + str(n_train_tokens)) logging.info('oov rate in the training data = %.2f %%' % (n_train_oovs / n_train_tokens * 100)) logging.info('#sentences in the validation data = ' + str(len(val))) logging.info('#tokens in the validation data = ' + str(n_val_tokens)) logging.info('oov rate in the validation data = %.2f %%' % (n_val_oovs / n_val_tokens * 100)) use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 # Create the dataset iterators train_iter = ParallelSentenceIterator(train, args.batchsize, max_length=args.maxlen, sos=eos, eos=eos, shuffle=not use_sortagrad) val_iter = ParallelSentenceIterator(val, args.batchsize, max_length=args.maxlen, sos=eos, eos=eos, repeat=False) logging.info('#iterations per epoch = ' + str(len(train_iter.batch_indices))) logging.info('#total iterations = ' + str(args.epoch * len(train_iter.batch_indices))) # Prepare an RNNLM model rnn = RNNLM(args.n_vocab, args.layer, args.unit, args.type) model = ClassifierWithState(rnn) if args.ngpu > 1: logging.warning( "currently, multi-gpu is not supported. use single gpu.") if args.ngpu > 0: # Make the specified GPU current gpu_id = 0 model.cuda(gpu_id) else: gpu_id = -1 # Save model conf to json model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to ' + model_conf) f.write( json.dumps(vars(args), indent=4, sort_keys=True).encode('utf_8')) # Set up an optimizer if args.opt == 'sgd': optimizer = torch.optim.SGD(model.parameters(), lr=1.0) elif args.opt == 'adam': optimizer = torch.optim.Adam(model.parameters()) # FIXME: TOO DIRTY HACK reporter = model.reporter setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) updater = BPTTUpdater(train_iter, model, optimizer, gpu_id, gradclip=args.gradclip) trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.outdir) trainer.extend(LMEvaluator(val_iter, model, reporter, device=gpu_id)) trainer.extend( extensions.LogReport(postprocess=compute_perplexity, trigger=(REPORT_INTERVAL, 'iteration'))) trainer.extend(extensions.PrintReport( ['epoch', 'iteration', 'perplexity', 'val_perplexity', 'elapsed_time']), trigger=(REPORT_INTERVAL, 'iteration')) trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL)) # Save best models trainer.extend(torch_snapshot(filename='snapshot.ep.{.updater.epoch}')) trainer.extend( extensions.snapshot_object(model, 'rnnlm.model.{.updater.epoch}', savefun=torch_save)) # T.Hori: MinValueTrigger should be used, but it fails when resuming trainer.extend( MakeSymlinkToBestModel('validation/main/loss', 'rnnlm.model')) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epoch, 'epoch')) if args.resume: logging.info('resumed from %s' % args.resume) torch_resume(args.resume, trainer) set_early_stop(trainer, args, is_lm=True) if args.tensorboard_dir is not None and args.tensorboard_dir != "": writer = SummaryWriter(args.tensorboard_dir) trainer.extend(TensorboardLogger(writer)) trainer.run() check_early_stop(trainer, args.epoch) # compute perplexity for test set if args.test_label: logging.info('test the best model') torch_load(args.outdir + '/rnnlm.model.best', model) test = read_tokens(args.test_label, args.char_list_dict) n_test_tokens, n_test_oovs = count_tokens(test, unk) logging.info('#sentences in the test data = ' + str(len(test))) logging.info('#tokens in the test data = ' + str(n_test_tokens)) logging.info('oov rate in the test data = %.2f %%' % (n_test_oovs / n_test_tokens * 100)) test_iter = ParallelSentenceIterator(test, args.batchsize, max_length=args.maxlen, sos=eos, eos=eos, repeat=False) evaluator = LMEvaluator(test_iter, model, reporter, device=gpu_id) result = evaluator() logging.info('test perplexity: ' + str(np.exp(float(result['main/loss']))))
def train(args): """Train with the given args. Args: args (namespace): The program arguments. """ set_deterministic_pytorch(args) # check cuda availability if not torch.cuda.is_available(): logging.warning('cuda is not available') # get input and output dimension info with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] utts = list(valid_json.keys()) idim = int(valid_json[utts[0]]['input'][0]['shape'][-1]) odim = int(valid_json[utts[0]]['output'][0]['shape'][-1]) logging.info('#input dims : ' + str(idim)) logging.info('#output dims: ' + str(odim)) # specify attention, CTC, hybrid mode if args.mtlalpha == 1.0: mtl_mode = 'ctc' logging.info('Pure CTC mode') elif args.mtlalpha == 0.0: mtl_mode = 'att' logging.info('Pure attention mode') else: mtl_mode = 'mtl' logging.info('Multitask learning mode') if args.enc_init is not None or args.dec_init is not None: model = load_trained_modules(idim, odim, args) elif args.asr_init is not None: model, _ = load_trained_model(args.asr_init) else: model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args) assert isinstance(model, ASRInterface) subsampling_factor = model.subsample[0] if args.rnnlm is not None: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer, rnnlm_args.unit)) torch.load(args.rnnlm, rnnlm) model.rnnlm = rnnlm # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to ' + model_conf) f.write( json.dumps((idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8')) for key in sorted(vars(args).keys()): logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: if args.batch_size != 0: logging.info('batch size is automatically increased (%d -> %d)' % (args.batch_size, args.batch_size * args.ngpu)) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") if args.train_dtype in ("float16", "float32", "float64"): dtype = getattr(torch, args.train_dtype) else: dtype = torch.float32 logging.info(device) logging.info(dtype) model = model.to(device=device, dtype=dtype) # Setup an optimizer if args.opt == 'adadelta': optimizer = torch.optim.Adadelta(model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay) elif args.opt == 'adam': optimizer = torch.optim.Adam(model.parameters(), weight_decay=args.weight_decay) elif args.opt == 'noam': from espnet.nets.pytorch_backend.rnn.optimizer import get_std_opt optimizer = get_std_opt(model, args.adim, args.transformer_warmup_steps, args.transformer_lr) else: raise NotImplementedError("unknown optimizer: " + args.opt) # setup apex.amp if args.train_dtype in ("O0", "O1", "O2", "O3"): try: from apex import amp except ImportError as e: logging.error( f"You need to install apex for --train-dtype {args.train_dtype}. " "See https://github.com/NVIDIA/apex#linux") raise e if args.opt == 'noam': model, optimizer.optimizer = amp.initialize( model, optimizer.optimizer, opt_level=args.train_dtype) else: model, optimizer = amp.initialize(model, optimizer, opt_level=args.train_dtype) use_apex = True else: use_apex = False # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) # Setup a converter converter = CustomConverter(subsampling_factor=subsampling_factor, dtype=dtype) # read json data with open(args.train_json, 'rb') as f: train_json = json.load(f)['utts'] with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 # make minibatch list (variable length) train = make_batchset(train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0) valid = make_batchset(valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0) load_tr = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': True} # Switch the mode of preprocessing ) load_cv = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': False} # Switch the mode of preprocessing ) # hack to make batchsize argument as 1 # actual bathsize is included in a list if args.n_iter_processes > 0: train_iter = ToggleableShufflingMultiprocessIterator( TransformDataset(train, load_tr), batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20, shuffle=not use_sortagrad) valid_iter = ToggleableShufflingMultiprocessIterator( TransformDataset(valid, load_cv), batch_size=1, repeat=False, shuffle=False, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20) else: train_iter = ToggleableShufflingSerialIterator( TransformDataset(train, load_tr), batch_size=1, shuffle=not use_sortagrad) valid_iter = ToggleableShufflingSerialIterator(TransformDataset( valid, load_cv), batch_size=1, repeat=False, shuffle=False) # Set up a trainer updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer, converter, device, args.ngpu, args.grad_noise, args.accum_grad, use_apex=use_apex) trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.outdir) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, 'epoch')) # Resume from a snapshot if args.resume: logging.info('resumed from %s' % args.resume) torch_resume(args.resume, trainer) # Evaluate the model with the test dataset for each epoch trainer.extend( CustomEvaluator(model, valid_iter, reporter, converter, device, args.ngpu)) # Save attention weight each epoch if args.num_save_attention > 0 and args.mtlalpha != 1.0: data = sorted(list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class att_reporter = plot_class(att_vis_fn, data, args.outdir + "/att_ws", converter=converter, transform=load_cv, device=device) trainer.extend(att_reporter, trigger=(1, 'epoch')) else: att_reporter = None # Make a plot for training and validation values trainer.extend( extensions.PlotReport([ 'main/loss', 'validation/main/loss', 'main/loss_ctc', 'validation/main/loss_ctc', 'main/loss_att', 'validation/main/loss_att' ], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport(['main/acc', 'validation/main/acc'], 'epoch', file_name='acc.png')) trainer.extend( extensions.PlotReport(['main/cer_ctc', 'validation/main/cer_ctc'], 'epoch', file_name='cer.png')) # Save best models trainer.extend( snapshot_object(model, 'model.loss.best'), trigger=training.triggers.MinValueTrigger('validation/main/loss')) if mtl_mode != 'ctc': trainer.extend( snapshot_object(model, 'model.acc.best'), trigger=training.triggers.MaxValueTrigger('validation/main/acc')) # save snapshot which contains model and optimizer states trainer.extend(torch_snapshot(), trigger=(1, 'epoch')) # epsilon decay in the optimizer if args.opt == 'adadelta': if args.criterion == 'acc' and mtl_mode != 'ctc': trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best', load_fn=torch_load), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) elif args.criterion == 'loss': trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best', load_fn=torch_load), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) # Write a log of evaluation statistics for each epoch trainer.extend( extensions.LogReport(trigger=(args.report_interval_iters, 'iteration'))) report_keys = [ 'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att', 'validation/main/loss', 'validation/main/loss_ctc', 'validation/main/loss_att', 'main/acc', 'validation/main/acc', 'main/cer_ctc', 'validation/main/cer_ctc', 'elapsed_time' ] if args.opt == 'adadelta': trainer.extend(extensions.observe_value( 'eps', lambda trainer: trainer.updater.get_optimizer('main'). param_groups[0]["eps"]), trigger=(args.report_interval_iters, 'iteration')) report_keys.append('eps') if args.report_cer: report_keys.append('validation/main/cer') if args.report_wer: report_keys.append('validation/main/wer') trainer.extend(extensions.PrintReport(report_keys), trigger=(args.report_interval_iters, 'iteration')) trainer.extend( extensions.ProgressBar(update_interval=args.report_interval_iters)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": trainer.extend(TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter), trigger=(args.report_interval_iters, "iteration")) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def train(args): """Train with the given args. :param Namespace args: The program arguments :param type model_class: LMInterface class for training """ model_class = dynamic_import_lm(args.model_module, args.backend) assert issubclass(model_class, LMInterface), "model should implement LMInterface" # display torch version logging.info('torch version = ' + torch.__version__) set_deterministic_pytorch(args) # check cuda and cudnn availability if not torch.cuda.is_available(): logging.warning('cuda is not available') # get special label ids unk = args.char_list_dict['<unk>'] eos = args.char_list_dict['<eos>'] # read tokens as a sequence of sentences val, n_val_tokens, n_val_oovs = load_dataset(args.valid_label, args.char_list_dict, args.dump_hdf5_path) train, n_train_tokens, n_train_oovs = load_dataset(args.train_label, args.char_list_dict, args.dump_hdf5_path) logging.info('#vocab = ' + str(args.n_vocab)) logging.info('#sentences in the training data = ' + str(len(train))) logging.info('#tokens in the training data = ' + str(n_train_tokens)) logging.info('oov rate in the training data = %.2f %%' % (n_train_oovs / n_train_tokens * 100)) logging.info('#sentences in the validation data = ' + str(len(val))) logging.info('#tokens in the validation data = ' + str(n_val_tokens)) logging.info('oov rate in the validation data = %.2f %%' % (n_val_oovs / n_val_tokens * 100)) use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 # Create the dataset iterators batch_size = args.batchsize * max(args.ngpu, 1) if batch_size * args.accum_grad > args.batchsize: logging.info( f'batch size is automatically increased ({args.batchsize} -> {batch_size * args.accum_grad})' ) train_iter = ParallelSentenceIterator(train, batch_size, max_length=args.maxlen, sos=eos, eos=eos, shuffle=not use_sortagrad) val_iter = ParallelSentenceIterator(val, batch_size, max_length=args.maxlen, sos=eos, eos=eos, repeat=False) epoch_iters = int(len(train_iter.batch_indices) / args.accum_grad) logging.info('#iterations per epoch = %d' % epoch_iters) logging.info('#total iterations = ' + str(args.epoch * epoch_iters)) # Prepare an RNNLM model if args.train_dtype in ("float16", "float32", "float64"): dtype = getattr(torch, args.train_dtype) else: dtype = torch.float32 model = model_class(args.n_vocab, args).to(dtype=dtype) if args.ngpu > 0: model.to("cuda") gpu_id = list(range(args.ngpu)) else: gpu_id = [-1] # Save model conf to json model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to ' + model_conf) f.write( json.dumps(vars(args), indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8')) # Set up an optimizer opt_class = dynamic_import_optimizer(args.opt, args.backend) optimizer = opt_class.from_args(model.parameters(), args) if args.schedulers is None: schedulers = [] else: schedulers = [ dynamic_import_scheduler(v)(k, args) for k, v in args.schedulers ] # setup apex.amp if args.train_dtype in ("O0", "O1", "O2", "O3"): try: from apex import amp except ImportError as e: logging.error( f"You need to install apex for --train-dtype {args.train_dtype}. " "See https://github.com/NVIDIA/apex#linux") raise e model, optimizer = amp.initialize(model, optimizer, opt_level=args.train_dtype) use_apex = True else: use_apex = False # FIXME: TOO DIRTY HACK reporter = Reporter() setattr(model, "reporter", reporter) setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) print('----------------------', gpu_id[0]) updater = BPTTUpdater(train_iter, model, optimizer, schedulers, gpu_id, gradclip=args.gradclip, use_apex=use_apex, accum_grad=args.accum_grad) trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.outdir) trainer.extend(LMEvaluator(val_iter, model, reporter, device=gpu_id)) trainer.extend( extensions.LogReport(postprocess=compute_perplexity, trigger=(args.report_interval_iters, 'iteration'))) trainer.extend(extensions.PrintReport([ 'epoch', 'iteration', 'main/loss', 'perplexity', 'val_perplexity', 'elapsed_time' ]), trigger=(args.report_interval_iters, 'iteration')) trainer.extend( extensions.ProgressBar(update_interval=args.report_interval_iters)) # Save best models trainer.extend(torch_snapshot(filename='snapshot.ep.{.updater.epoch}')) trainer.extend(snapshot_object(model, 'rnnlm.model.{.updater.epoch}')) # T.Hori: MinValueTrigger should be used, but it fails when resuming trainer.extend( MakeSymlinkToBestModel('validation/main/loss', 'rnnlm.model')) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epoch, 'epoch')) if args.resume: logging.info('resumed from %s' % args.resume) torch_resume(args.resume, trainer) set_early_stop(trainer, args, is_lm=True) if args.tensorboard_dir is not None and args.tensorboard_dir != "": writer = SummaryWriter(args.tensorboard_dir) trainer.extend(TensorboardLogger(writer), trigger=(args.report_interval_iters, 'iteration')) trainer.run() check_early_stop(trainer, args.epoch) # compute perplexity for test set if args.test_label: logging.info('test the best model') torch_load(args.outdir + '/rnnlm.model.best', model) test = read_tokens(args.test_label, args.char_list_dict) n_test_tokens, n_test_oovs = count_tokens(test, unk) logging.info('#sentences in the test data = ' + str(len(test))) logging.info('#tokens in the test data = ' + str(n_test_tokens)) logging.info('oov rate in the test data = %.2f %%' % (n_test_oovs / n_test_tokens * 100)) test_iter = ParallelSentenceIterator(test, batch_size, max_length=args.maxlen, sos=eos, eos=eos, repeat=False) evaluator = LMEvaluator(test_iter, model, reporter, device=gpu_id) result = evaluator() compute_perplexity(result) logging.info(f"test perplexity: {result['perplexity']}")
def train(args): """Train with the given args :param Namespace args: The program arguments """ set_deterministic_pytorch(args) # check cuda availability if not torch.cuda.is_available(): logging.warning('cuda is not available') # get input and output dimension info with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] utts = list(valid_json.keys()) # reverse input and output dimension idim = int(valid_json[utts[0]]['output'][0]['shape'][1]) odim = int(valid_json[utts[0]]['input'][0]['shape'][1]) if args.use_cbhg: args.spc_dim = int(valid_json[utts[0]]['input'][1]['shape'][1]) if args.use_speaker_embedding: args.spk_embed_dim = int(valid_json[utts[0]]['input'][1]['shape'][0]) else: args.spk_embed_dim = None logging.info('#input dims : ' + str(idim)) logging.info('#output dims: ' + str(odim)) # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to' + model_conf) f.write( json.dumps((idim, odim, vars(args)), indent=4, sort_keys=True).encode('utf_8')) for key in sorted(vars(args).keys()): logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) # specify model architecture tacotron2 = Tacotron2(idim, odim, args) logging.info(tacotron2) # check the use of multi-gpu if args.ngpu > 1: tacotron2 = torch.nn.DataParallel(tacotron2, device_ids=list(range(args.ngpu))) logging.info('batch size is automatically increased (%d -> %d)' % (args.batch_size, args.batch_size * args.ngpu)) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") tacotron2 = tacotron2.to(device) # define loss model = Tacotron2Loss(tacotron2, args.use_masking, args.bce_pos_weight) reporter = model.reporter # Setup an optimizer optimizer = torch.optim.Adam(model.parameters(), args.lr, eps=args.eps, weight_decay=args.weight_decay) # FIXME: TOO DIRTY HACK setattr(optimizer, 'target', reporter) setattr(optimizer, 'serialize', lambda s: reporter.serialize(s)) # Setup a converter converter = CustomConverter( return_targets=True, use_speaker_embedding=args.use_speaker_embedding, use_second_target=args.use_cbhg, preprocess_conf=args.preprocess_conf) # read json data with open(args.train_json, 'rb') as f: train_json = json.load(f)['utts'] with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 if use_sortagrad: args.batch_sort_key = "input" # make minibatch list (variable length) train_batchset = make_batchset( train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad) valid_batchset = make_batchset( valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad) # hack to make batchsize argument as 1 # actual bathsize is included in a list if args.n_iter_processes > 0: train_iter = ToggleableShufflingMultiprocessIterator( TransformDataset(train_batchset, converter.transform), batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20, shuffle=not use_sortagrad) valid_iter = ToggleableShufflingMultiprocessIterator( TransformDataset(valid_batchset, converter.transform), batch_size=1, repeat=False, shuffle=False, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20) else: train_iter = ToggleableShufflingSerialIterator( TransformDataset(train_batchset, converter.transform), batch_size=1, shuffle=not use_sortagrad) valid_iter = ToggleableShufflingSerialIterator(TransformDataset( valid_batchset, converter.transform), batch_size=1, repeat=False, shuffle=False) # Set up a trainer updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer, converter, device) trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.outdir) # Resume from a snapshot if args.resume: logging.info('resumed from %s' % args.resume) torch_resume(args.resume, trainer) # Evaluate the model with the test dataset for each epoch trainer.extend( CustomEvaluator(model, valid_iter, reporter, converter, device)) # Save snapshot for each epoch trainer.extend(torch_snapshot(), trigger=(1, 'epoch')) # Save best models trainer.extend( extensions.snapshot_object(tacotron2, 'model.loss.best', savefun=torch_save), trigger=training.triggers.MinValueTrigger('validation/main/loss')) # Save attention figure for each epoch if args.num_save_attention > 0: data = sorted(list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True) if hasattr(tacotron2, "module"): att_vis_fn = tacotron2.module.calculate_all_attentions else: att_vis_fn = tacotron2.calculate_all_attentions att_reporter = PlotAttentionReport( att_vis_fn, data, args.outdir + '/att_ws', converter=CustomConverter( return_targets=False, use_speaker_embedding=args.use_speaker_embedding, preprocess_conf=args.preprocess_conf), device=device, reverse=True) trainer.extend(att_reporter, trigger=(1, 'epoch')) else: att_reporter = None # Make a plot for training and validation values plot_keys = [ 'main/loss', 'validation/main/loss', 'main/l1_loss', 'validation/main/l1_loss', 'main/mse_loss', 'validation/main/mse_loss', 'main/bce_loss', 'validation/main/bce_loss' ] trainer.extend( extensions.PlotReport(['main/l1_loss', 'validation/main/l1_loss'], 'epoch', file_name='l1_loss.png')) trainer.extend( extensions.PlotReport(['main/mse_loss', 'validation/main/mse_loss'], 'epoch', file_name='mse_loss.png')) trainer.extend( extensions.PlotReport(['main/bce_loss', 'validation/main/bce_loss'], 'epoch', file_name='bce_loss.png')) if args.use_cbhg: plot_keys += [ 'main/cbhg_l1_loss', 'validation/main/cbhg_l1_loss', 'main/cbhg_mse_loss', 'validation/main/cbhg_mse_loss' ] trainer.extend( extensions.PlotReport( ['main/cbhg_l1_loss', 'validation/main/cbhg_l1_loss'], 'epoch', file_name='cbhg_l1_loss.png')) trainer.extend( extensions.PlotReport( ['main/cbhg_mse_loss', 'validation/main/cbhg_mse_loss'], 'epoch', file_name='cbhg_mse_loss.png')) trainer.extend( extensions.PlotReport(plot_keys, 'epoch', file_name='loss.png')) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL, 'iteration'))) report_keys = plot_keys[:] report_keys[0:0] = ['epoch', 'iteration', 'elapsed_time'] trainer.extend(extensions.PrintReport(report_keys), trigger=(REPORT_INTERVAL, 'iteration')) trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": writer = SummaryWriter(log_dir=args.tensorboard_dir) trainer.extend(TensorboardLogger(writer, att_reporter)) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, 'epoch')) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def train(args): """Train with the given args. Args: args (namespace): The program arguments. """ # display chainer version logging.info('chainer version = ' + chainer.__version__) set_deterministic_chainer(args) # check cuda and cudnn availability if not chainer.cuda.available: logging.warning('cuda is not available') if not chainer.cuda.cudnn_enabled: logging.warning('cudnn is not available') # get input and output dimension info with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] utts = list(valid_json.keys()) idim = int(valid_json[utts[0]]['input'][0]['shape'][1]) odim = int(valid_json[utts[0]]['output'][0]['shape'][1]) logging.info('#input dims : ' + str(idim)) logging.info('#output dims: ' + str(odim)) # specify attention, CTC, hybrid mode if args.mtlalpha == 1.0: mtl_mode = 'ctc' logging.info('Pure CTC mode') elif args.mtlalpha == 0.0: mtl_mode = 'att' logging.info('Pure attention mode') else: mtl_mode = 'mtl' logging.info('Multitask learning mode') # specify model architecture logging.info('import model module: ' + args.model_module) model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args, flag_return=False) assert isinstance(model, ASRInterface) # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to ' + model_conf) f.write( json.dumps((idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8')) for key in sorted(vars(args).keys()): logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) # Set gpu ngpu = args.ngpu if ngpu == 1: gpu_id = 0 # Make a specified GPU current chainer.cuda.get_device_from_id(gpu_id).use() model.to_gpu() # Copy the model to the GPU logging.info('single gpu calculation.') elif ngpu > 1: gpu_id = 0 devices = {'main': gpu_id} for gid in six.moves.xrange(1, ngpu): devices['sub_%d' % gid] = gid logging.info('multi gpu calculation (#gpus = %d).' % ngpu) logging.info('batch size is automatically increased (%d -> %d)' % (args.batch_size, args.batch_size * args.ngpu)) else: gpu_id = -1 logging.info('cpu calculation') # Setup an optimizer if args.opt == 'adadelta': optimizer = chainer.optimizers.AdaDelta(eps=args.eps) elif args.opt == 'adam': optimizer = chainer.optimizers.Adam() elif args.opt == 'noam': optimizer = chainer.optimizers.Adam(alpha=0, beta1=0.9, beta2=0.98, eps=1e-9) else: raise NotImplementedError('args.opt={}'.format(args.opt)) optimizer.setup(model) optimizer.add_hook(chainer.optimizer.GradientClipping(args.grad_clip)) # Setup Training Extensions if 'transformer' in args.model_module: from espnet.nets.chainer_backend.transformer.training import CustomConverter from espnet.nets.chainer_backend.transformer.training import CustomParallelUpdater from espnet.nets.chainer_backend.transformer.training import CustomUpdater else: from espnet.nets.chainer_backend.rnn.training import CustomConverter from espnet.nets.chainer_backend.rnn.training import CustomParallelUpdater from espnet.nets.chainer_backend.rnn.training import CustomUpdater # Setup a converter converter = CustomConverter(subsampling_factor=model.subsample[0]) # read json data with open(args.train_json, 'rb') as f: train_json = json.load(f)['utts'] with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] # set up training iterator and updater load_tr = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': True} # Switch the mode of preprocessing ) load_cv = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': False} # Switch the mode of preprocessing ) use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 accum_grad = args.accum_grad if ngpu <= 1: # make minibatch list (variable length) train = make_batchset(train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0) # hack to make batchsize argument as 1 # actual batchsize is included in a list if args.n_iter_processes > 0: train_iters = [ ToggleableShufflingMultiprocessIterator( TransformDataset(train, load_tr), batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20, shuffle=not use_sortagrad) ] else: train_iters = [ ToggleableShufflingSerialIterator(TransformDataset( train, load_tr), batch_size=1, shuffle=not use_sortagrad) ] # set up updater updater = CustomUpdater(train_iters[0], optimizer, converter=converter, device=gpu_id, accum_grad=accum_grad) else: if args.batch_count not in ("auto", "seq") and args.batch_size == 0: raise NotImplementedError( "--batch-count 'bin' and 'frame' are not implemented in chainer multi gpu" ) # set up minibatches train_subsets = [] for gid in six.moves.xrange(ngpu): # make subset train_json_subset = { k: v for i, (k, v) in enumerate(train_json.items()) if i % ngpu == gid } # make minibatch list (variable length) train_subsets += [ make_batchset(train_json_subset, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches) ] # each subset must have same length for MultiprocessParallelUpdater maxlen = max([len(train_subset) for train_subset in train_subsets]) for train_subset in train_subsets: if maxlen != len(train_subset): for i in six.moves.xrange(maxlen - len(train_subset)): train_subset += [train_subset[i]] # hack to make batchsize argument as 1 # actual batchsize is included in a list if args.n_iter_processes > 0: train_iters = [ ToggleableShufflingMultiprocessIterator( TransformDataset(train_subsets[gid], load_tr), batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20, shuffle=not use_sortagrad) for gid in six.moves.xrange(ngpu) ] else: train_iters = [ ToggleableShufflingSerialIterator(TransformDataset( train_subsets[gid], load_tr), batch_size=1, shuffle=not use_sortagrad) for gid in six.moves.xrange(ngpu) ] # set up updater updater = CustomParallelUpdater(train_iters, optimizer, converter=converter, devices=devices) # Set up a trainer trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.outdir) if use_sortagrad: trainer.extend( ShufflingEnabler(train_iters), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, 'epoch')) if args.opt == 'noam': from espnet.nets.chainer_backend.transformer.training import VaswaniRule trainer.extend(VaswaniRule('alpha', d=args.adim, warmup_steps=args.transformer_warmup_steps, scale=args.transformer_lr), trigger=(1, 'iteration')) # Resume from a snapshot if args.resume: chainer.serializers.load_npz(args.resume, trainer) # set up validation iterator valid = make_batchset(valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0) if args.n_iter_processes > 0: valid_iter = chainer.iterators.MultiprocessIterator( TransformDataset(valid, load_cv), batch_size=1, repeat=False, shuffle=False, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20) else: valid_iter = chainer.iterators.SerialIterator(TransformDataset( valid, load_cv), batch_size=1, repeat=False, shuffle=False) # Evaluate the model with the test dataset for each epoch trainer.extend( BaseEvaluator(valid_iter, model, converter=converter, device=gpu_id)) # Save attention weight each epoch if args.num_save_attention > 0 and args.mtlalpha != 1.0: data = sorted(list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class logging.info('Using custom PlotAttentionReport') att_reporter = plot_class(att_vis_fn, data, args.outdir + "/att_ws", converter=converter, transform=load_cv, device=gpu_id) trainer.extend(att_reporter, trigger=(1, 'epoch')) else: att_reporter = None # Take a snapshot for each specified epoch trainer.extend( extensions.snapshot(filename='snapshot.ep.{.updater.epoch}'), trigger=(1, 'epoch')) # Make a plot for training and validation values trainer.extend( extensions.PlotReport([ 'main/loss', 'validation/main/loss', 'main/loss_ctc', 'validation/main/loss_ctc', 'main/loss_att', 'validation/main/loss_att' ], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport(['main/acc', 'validation/main/acc'], 'epoch', file_name='acc.png')) # Save best models trainer.extend( extensions.snapshot_object(model, 'model.loss.best'), trigger=training.triggers.MinValueTrigger('validation/main/loss')) if mtl_mode != 'ctc': trainer.extend( extensions.snapshot_object(model, 'model.acc.best'), trigger=training.triggers.MaxValueTrigger('validation/main/acc')) # epsilon decay in the optimizer if args.opt == 'adadelta': if args.criterion == 'acc' and mtl_mode != 'ctc': trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best'), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) elif args.criterion == 'loss': trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best'), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) # Write a log of evaluation statistics for each epoch trainer.extend( extensions.LogReport(trigger=(args.report_interval_iters, 'iteration'))) report_keys = [ 'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att', 'validation/main/loss', 'validation/main/loss_ctc', 'validation/main/loss_att', 'main/acc', 'validation/main/acc', 'elapsed_time' ] if args.opt == 'adadelta': trainer.extend(extensions.observe_value( 'eps', lambda trainer: trainer.updater.get_optimizer('main').eps), trigger=(args.report_interval_iters, 'iteration')) report_keys.append('eps') trainer.extend(extensions.PrintReport(report_keys), trigger=(args.report_interval_iters, 'iteration')) trainer.extend( extensions.ProgressBar(update_interval=args.report_interval_iters)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": writer = SummaryWriter(args.tensorboard_dir) trainer.extend(TensorboardLogger(writer, att_reporter), trigger=(args.report_interval_iters, 'iteration')) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def train(args): """Train with the given args. Args: args (namespace): The program arguments. """ set_deterministic_pytorch(args) # check cuda availability if not torch.cuda.is_available(): logging.warning('cuda is not available') # get paths to data lang_pairs = sorted(args.lang_pairs.split(',')) args.one_to_many = True if len(lang_pairs) > 1 else False tgt_langs = sorted([p.split('-')[-1] for p in lang_pairs]) src_lang = lang_pairs[0].split('-')[0] if args.one_to_many: train_jpaths = [ os.path.join(args.train_json, fname) for fname in sorted(os.listdir(args.train_json)) if fname.endswith('.json') ] valid_jpaths = [ os.path.join(args.valid_json, fname) for fname in sorted(os.listdir(args.valid_json)) if fname.endswith('.json') ] all_langs = list( sorted(set([l for p in lang_pairs for l in p.split('-')]))) args.langs_dict = {} offset = 2 # for <blank> and <unk> for i, lang in enumerate(all_langs): args.langs_dict[f'<2{lang}>'] = offset + i logging.info(f'| train_jpaths: {train_jpaths}') logging.info(f'| valid_jpaths: {valid_jpaths}') logging.info(f'| lang_pairs : {lang_pairs}') logging.info(f'| langs_dict : {args.langs_dict}') else: train_jpaths = [args.train_json] valid_jpaths = [args.valid_json] args.langs_dict = None # get input and output dimension info idim = 0 odim = 0 for i, jpath in enumerate(valid_jpaths): with open(jpath, 'rb') as f: valid_json = json.load(f)['utts'] utts = list(valid_json.keys()) idim_tmp = int(valid_json[utts[0]]['input'][0]['shape'][-1]) odim_tmp = int(valid_json[utts[0]]['output'][0]['shape'][-1]) logging.info('| pair {}: idim={}, odim={}'.format( lang_pairs[i], idim_tmp, odim_tmp)) if idim == 0: idim = idim_tmp else: assert idim == idim_tmp if odim < odim_tmp: odim = odim_tmp logging.info('#input dims : ' + str(idim)) logging.info('#output dims: ' + str(odim)) # Initialize with pre-trained ASR encoder and MT decoder if args.enc_init is not None or args.dec_init is not None: logging.info('Loading pretrained ASR encoder and/or MT decoder ...') model = load_trained_modules(idim, odim, args, interface=STInterface) logging.info(f'*** Model *** \n {model}') else: model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args) logging.info(f'*** Model *** \n {model}') assert isinstance(model, STInterface) logging.info( f'| Number of model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}' ) subsampling_factor = model.subsample[0] logging.info(f'subsampling_factor={subsampling_factor}') if args.rnnlm is not None: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM( len(args.char_list), rnnlm_args.layer, rnnlm_args.unit, getattr(rnnlm_args, "embed_unit", None), # for backward compatibility )) torch_load(args.rnnlm, rnnlm) model.rnnlm = rnnlm # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to ' + model_conf) f.write( json.dumps((idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8')) for key in sorted(vars(args).keys()): logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: if args.batch_size != 0: logging.warning( 'batch size is automatically increased (%d -> %d)' % (args.batch_size, args.batch_size * args.ngpu)) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") if args.train_dtype in ("float16", "float32", "float64"): dtype = getattr(torch, args.train_dtype) else: dtype = torch.float32 model = model.to(device=device, dtype=dtype) # Setup an optimizer if args.opt == 'adadelta': optimizer = torch.optim.Adadelta(model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay) elif args.opt == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.opt == 'noam': from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt optimizer = get_std_opt(model, args.adim, args.transformer_warmup_steps, args.transformer_lr) else: raise NotImplementedError("unknown optimizer: " + args.opt) # setup apex.amp if args.train_dtype in ("O0", "O1", "O2", "O3"): try: from apex import amp except ImportError as e: logging.error( f"You need to install apex for --train-dtype {args.train_dtype}. " "See https://github.com/NVIDIA/apex#linux") raise e if args.opt == 'noam': model, optimizer.optimizer = amp.initialize( model, optimizer.optimizer, opt_level=args.train_dtype) else: model, optimizer = amp.initialize(model, optimizer, opt_level=args.train_dtype) use_apex = True else: use_apex = False # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 logging.info(f'use_sortagrad: {use_sortagrad}') # read json data num_langs = len(tgt_langs) train_all_pairs = [None] * num_langs valid_all_pairs = [None] * num_langs # check_data = {} batch_size = args.batch_size // num_langs if num_langs > 1 else args.batch_size for i, jpath in enumerate(train_jpaths): with open(jpath, 'rb') as f: train_json = json.load(f)['utts'] train_all_pairs[i] = make_batchset( train_json, batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout) # check_data[lang_pairs[i]] = list(train_json.keys()) for i, jpath in enumerate(valid_jpaths): with open(jpath, 'rb') as f: valid_json = json.load(f)['utts'] valid_all_pairs[i] = make_batchset( valid_json, batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout) # check_data[lang_pairs[i]] = list(valid_json.keys()) # print(f'len(train_all_pairs) = {len(train_all_pairs)}') # print(f'len(valid_all_pairs) = {len(valid_all_pairs)}') # for i, batch_langs in enumerate(train_all_pairs): # print(f'batch for lang {lang_pairs[i]}') # for batch_lang in batch_langs: # print(f'len(batch_lang) = {len(batch_lang)}') # print('-'*5) if num_langs > 1: cycle_train = [cycle(x) for x in train_all_pairs] cycle_valid = [cycle(x) for x in valid_all_pairs] num_batches_train = max(len(i) for i in train_all_pairs) num_batches_valid = max(len(i) for i in valid_all_pairs) train = [None] * num_batches_train valid = [None] * num_batches_valid for i, s in enumerate(zip(*cycle_train)): x = [] for y in s: x.extend(y) train[i] = x if i >= num_batches_train - 1: break for i, s in enumerate(zip(*cycle_valid)): x = [] for y in s: x.extend(y) valid[i] = x if i >= num_batches_valid - 1: break else: train = train_all_pairs[0] valid = valid_all_pairs[0] # print(f'num_batches_train = {num_batches_train}') # print(f'num_batches_valid = {num_batches_valid}') # print(f'len(train) = {len(train)}') # print(f'len(valid) = {len(valid)}') # print('*** Checking results of make_batchset() ***') # for i, batch in enumerate(train): # # if i == 0: # # print(batch) # ids = [sample[0] for sample in batch] # langs = [sample[1]['lang'] for sample in batch] # pairs = ['en-'+l for l in langs] # for i in range(len(ids)): # r = ids[i] in list(check_data[pairs[i]]) # print(f'ids[i]={ids[i]} in {check_data[pairs[i]]}: {r}') # print('-') # if r: # check_data[pairs[i]].remove(ids[i]) # print(f'len(batch) = {len(batch)}') # print(f'langs in batch: {langs}') # print('-'*5) # # if i > 5: # # break # print('*** Samples that are not used yet ***') # for k, v in check_data.items(): # print(k, v) # print('-'*5) # print('-'*20) load_tr = LoadInputsAndTargets(mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': True}, langs_dict=args.langs_dict, src_lang=src_lang) load_cv = LoadInputsAndTargets(mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': False}, langs_dict=args.langs_dict, src_lang=src_lang) # print('LoadInputsAndTargets()') # features, targets = load_cv(train[0]) # print(f'*** features: {features} ***') # for f in features: # # print(f) # print(f'len(f) = {len(f)}') # print('---') # print(f'*** targets : {targets} ***') # y1, y2 = zip(*targets) # # print(f'y1 = {y1}') # # print(f'y2 = {y2}') # for s in zip(y1, y2): # print(len(s[0][1]), len(s[1][1])) # print('-'*20) # Setup a converter converter = CustomConverter(subsampling_factor=subsampling_factor, dtype=dtype, asr_task=args.asr_weight > 0) # hack to make batchsize argument as 1 # actual bathsize is included in a list # default collate function converts numpy array to pytorch tensor # we used an empty collate function instead which returns list n_iter_processes = args.n_iter_processes if n_iter_processes < 0: n_iter_processes = multiprocessing.cpu_count() elif n_iter_processes > 0: n_iter_processes = min(n_iter_processes, multiprocessing.cpu_count()) print(f'n_iter_processes = {n_iter_processes}') train_iter = { 'main': ChainerDataLoader(dataset=TransformDataset( train, lambda data: converter([load_tr(data)])), batch_size=1, num_workers=n_iter_processes, shuffle=not use_sortagrad, collate_fn=lambda x: x[0], pin_memory=False) } valid_iter = { 'main': ChainerDataLoader(dataset=TransformDataset( valid, lambda data: converter([load_cv(data)])), batch_size=1, shuffle=False, collate_fn=lambda x: x[0], num_workers=n_iter_processes, pin_memory=False) } # xs_pad, ilens, ys_pad, ys_pad_asr = converter([load_cv(valid[0])]) # print('*** xs_pad ***') # # print(xs_pad) # print(xs_pad.size()) # print('*** ilens ***') # print(ilens) # print('*** ys_pad ***') # # print(ys_pad) # print(ys_pad.size()) # print('*** ys_pad_asr ***') # print(ys_pad_asr) # print('-'*20) # print(train_iter['main']) # i=0 # for item in train_iter['main']: # print(item) # print('-'*5) # if i > 8: # break # i += 1 # Set up a trainer updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer, device, args.ngpu, args.grad_noise, args.accum_grad, use_apex=use_apex) # trainer = training.Trainer( # updater, (args.epochs, 'epoch'), out=args.outdir) time_limit_trigger = TimeLimitTrigger(args) trainer = training.Trainer(updater, time_limit_trigger, out=args.outdir) logging.info(f'updater: {updater}') logging.info(f'trainer: {trainer}') if use_sortagrad: logging.info(f'use_sortagrad ...') trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, 'epoch')) # Evaluate the model with the test dataset for each epoch if args.save_interval_iters > 0: trainer.extend(CustomEvaluator(model, valid_iter, reporter, device, args.ngpu), trigger=(args.save_interval_iters, 'iteration')) else: trainer.extend( CustomEvaluator(model, valid_iter, reporter, device, args.ngpu)) # Save attention weight each epoch if args.num_save_attention > 0: data = sorted(list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class att_reporter = plot_class(att_vis_fn, data, args.outdir + "/att_ws", converter=converter, transform=load_cv, device=device) trainer.extend(att_reporter, trigger=(1, 'epoch')) else: att_reporter = None # Make a plot for training and validation values trainer.extend( extensions.PlotReport([ 'main/loss', 'validation/main/loss', 'main/loss_asr', 'validation/main/loss_asr', 'main/loss_st', 'validation/main/loss_st' ], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport([ 'main/acc', 'validation/main/acc', 'main/acc_asr', 'validation/main/acc_asr' ], 'epoch', file_name='acc.png')) trainer.extend( extensions.PlotReport(['main/bleu', 'validation/main/bleu'], 'epoch', file_name='bleu.png')) # Save best models if args.report_interval_iters > 0: trainer.extend(snapshot_object(model, 'model.loss.best'), trigger=MinValueTrigger( 'validation/main/loss', trigger=(args.report_interval_iters, 'iteration'), best_value=None)) trainer.extend(snapshot_object(model, 'model.acc.best'), trigger=MaxValueTrigger( 'validation/main/acc', trigger=(args.report_interval_iters, 'iteration'), best_value=None)) else: trainer.extend(snapshot_object(model, 'model.loss.best'), trigger=MinValueTrigger('validation/main/loss', best_value=None)) trainer.extend(snapshot_object(model, 'model.acc.best'), trigger=MaxValueTrigger('validation/main/acc', best_value=None)) # save snapshot which contains model and optimizer states if args.save_interval_iters > 0: trainer.extend( torch_snapshot(filename='snapshot.iter.{.updater.iteration}'), trigger=(args.save_interval_iters, 'iteration')) else: trainer.extend(torch_snapshot(), trigger=(1, 'epoch')) # epsilon decay in the optimizer if args.opt == 'adadelta': if args.criterion == 'acc': trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best', load_fn=torch_load), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) elif args.criterion == 'loss': trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best', load_fn=torch_load), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) elif args.opt == 'adam': if args.criterion == 'acc': trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best', load_fn=torch_load), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) trainer.extend(adam_lr_decay(args.lr_decay), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) elif args.criterion == 'loss': trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best', load_fn=torch_load), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) trainer.extend(adam_lr_decay(args.lr_decay), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) # Write a log of evaluation statistics for each epoch trainer.extend( extensions.LogReport(trigger=(args.report_interval_iters, 'iteration'))) report_keys = [ 'epoch', 'iteration', 'main/loss', 'main/loss_st', 'main/loss_asr', 'validation/main/loss', 'validation/main/loss_st', 'validation/main/loss_asr', 'main/acc', 'validation/main/acc' ] if args.asr_weight > 0: report_keys.append('main/acc_asr') report_keys.append('validation/main/acc_asr') report_keys += ['elapsed_time'] if args.opt == 'adadelta': trainer.extend(extensions.observe_value( 'eps', lambda trainer: trainer.updater.get_optimizer('main'). param_groups[0]["eps"]), trigger=(args.report_interval_iters, 'iteration')) report_keys.append('eps') elif args.opt in ['adam', 'noam']: trainer.extend(extensions.observe_value( 'lr', lambda trainer: trainer.updater.get_optimizer('main'). param_groups[0]["lr"]), trigger=(args.report_interval_iters, 'iteration')) report_keys.append('lr') if args.asr_weight > 0: if args.mtlalpha > 0: report_keys.append('main/cer_ctc') report_keys.append('validation/main/cer_ctc') if args.mtlalpha < 1: if args.report_cer: report_keys.append('validation/main/cer') if args.report_wer: report_keys.append('validation/main/wer') if args.report_bleu: report_keys.append('validation/main/bleu') trainer.extend(extensions.PrintReport(report_keys), trigger=(args.report_interval_iters, 'iteration')) trainer.extend( extensions.ProgressBar(update_interval=args.report_interval_iters)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": trainer.extend(TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter), trigger=(args.report_interval_iters, "iteration")) # Resume from a snapshot if args.resume: logging.info('resumed from %s' % args.resume) torch_resume(args.resume, trainer) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def train(args): """Train with the given args. Args: args (namespace): The program arguments. """ set_deterministic_pytorch(args) if args.num_encs > 1: args = format_mulenc_args(args) # check cuda availability if not torch.cuda.is_available(): logging.warning("cuda is not available") # get input and output dimension info with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] utts = list(valid_json.keys()) idim_list = [ int(valid_json[utts[0]]["input"][i]["shape"][-1]) for i in range(args.num_encs) ] odim = int(valid_json[utts[0]]["output"][0]["shape"][-1]) for i in range(args.num_encs): logging.info("stream{}: input dims : {}".format(i + 1, idim_list[i])) logging.info("#output dims: " + str(odim)) # specify attention, CTC, hybrid mode if args.mtlalpha == 1.0: mtl_mode = "ctc" logging.info("Pure CTC mode") elif args.mtlalpha == 0.0: mtl_mode = "att" logging.info("Pure attention mode") else: mtl_mode = "mtl" logging.info("Multitask learning mode") if (args.enc_init is not None or args.dec_init is not None) and args.num_encs == 1: model = load_trained_modules(idim_list[0], odim, args) else: model_class = dynamic_import(args.model_module) model = model_class( idim_list[0] if args.num_encs == 1 else idim_list, odim, args ) assert isinstance(model, ASRInterface) logging.info( " Total parameter of the model = " + str(sum(p.numel() for p in model.parameters())) ) if args.rnnlm is not None: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer, rnnlm_args.unit) ) torch_load(args.rnnlm, rnnlm) model.rnnlm = rnnlm # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + "/model.json" with open(model_conf, "wb") as f: logging.info("writing a model config file to " + model_conf) f.write( json.dumps( (idim_list[0] if args.num_encs == 1 else idim_list, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True, ).encode("utf_8") ) for key in sorted(vars(args).keys()): logging.info("ARGS: " + key + ": " + str(vars(args)[key])) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: if args.batch_size != 0: logging.warning( "batch size is automatically increased (%d -> %d)" % (args.batch_size, args.batch_size * args.ngpu) ) args.batch_size *= args.ngpu if args.num_encs > 1: # TODO(ruizhili): implement data parallel for multi-encoder setup. raise NotImplementedError( "Data parallel is not supported for multi-encoder setup." ) # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") if args.train_dtype in ("float16", "float32", "float64"): dtype = getattr(torch, args.train_dtype) else: dtype = torch.float32 model = model.to(device=device, dtype=dtype) # Setup an optimizer if args.opt == "adadelta": optimizer = torch.optim.Adadelta( model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay ) elif args.opt == "adam": optimizer = torch.optim.Adam(model.parameters(), weight_decay=args.weight_decay) elif args.opt == "noam": from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt optimizer = get_std_opt( model.parameters(), args.adim, args.transformer_warmup_steps, args.transformer_lr ) elif args.opt == "rmsprop": optimizer = torch.optim.RMSprop(model.parameters(), lr=0.0008, alpha=0.95) else: raise NotImplementedError("unknown optimizer: " + args.opt) # setup apex.amp if args.train_dtype in ("O0", "O1", "O2", "O3"): try: from apex import amp except ImportError as e: logging.error( f"You need to install apex for --train-dtype {args.train_dtype}. " "See https://github.com/NVIDIA/apex#linux" ) raise e if args.opt == "noam": model, optimizer.optimizer = amp.initialize( model, optimizer.optimizer, opt_level=args.train_dtype ) else: model, optimizer = amp.initialize( model, optimizer, opt_level=args.train_dtype ) use_apex = True from espnet.nets.pytorch_backend.ctc import CTC amp.register_float_function(CTC, "loss_fn") amp.init() logging.warning("register ctc as float function") else: use_apex = False # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) # Setup a converter if args.num_encs == 1: converter = CustomConverter(subsampling_factor=model.subsample[0], dtype=dtype) else: converter = CustomConverterMulEnc( [i[0] for i in model.subsample_list], dtype=dtype ) # read json data with open(args.train_json, "rb") as f: train_json = json.load(f)["utts"] with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 # make minibatch list (variable length) train = make_batchset( train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0, ) valid = make_batchset( valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0, ) load_tr = LoadInputsAndTargets( mode="asr", load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={"train": True}, # Switch the mode of preprocessing ) load_cv = LoadInputsAndTargets( mode="asr", load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={"train": False}, # Switch the mode of preprocessing ) # hack to make batchsize argument as 1 # actual bathsize is included in a list # default collate function converts numpy array to pytorch tensor # we used an empty collate function instead which returns list train_iter = ChainerDataLoader( dataset=TransformDataset(train, lambda data: converter([load_tr(data)])), batch_size=1, num_workers=args.n_iter_processes, shuffle=not use_sortagrad, collate_fn=lambda x: x[0], ) valid_iter = ChainerDataLoader( dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])), batch_size=1, shuffle=False, collate_fn=lambda x: x[0], num_workers=args.n_iter_processes, ) # Set up a trainer updater = CustomUpdater( model, args.grad_clip, {"main": train_iter}, optimizer, device, args.ngpu, args.grad_noise, args.accum_grad, use_apex=use_apex, ) trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), ) # Resume from a snapshot if args.resume: logging.info("resumed from %s" % args.resume) torch_resume(args.resume, trainer) # Evaluate the model with the test dataset for each epoch if args.save_interval_iters > 0: trainer.extend( CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu), trigger=(args.save_interval_iters, "iteration"), ) else: trainer.extend( CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu) ) # Save attention weight each epoch if args.num_save_attention > 0 and args.mtlalpha != 1.0 and "transformer" in args.model_module: data = sorted( list(valid_json.items())[: args.num_save_attention], key=lambda x: int(x[1]["input"][0]["shape"][1]), reverse=True, ) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class att_reporter = plot_class( att_vis_fn, data, args.outdir + "/att_ws", converter=converter, transform=load_cv, device=device, ) trainer.extend(att_reporter, trigger=(1, "epoch")) else: att_reporter = None # Make a plot for training and validation values if args.num_encs > 1: report_keys_loss_ctc = [ "main/loss_ctc{}".format(i + 1) for i in range(model.num_encs) ] + ["validation/main/loss_ctc{}".format(i + 1) for i in range(model.num_encs)] report_keys_cer_ctc = [ "main/cer_ctc{}".format(i + 1) for i in range(model.num_encs) ] + ["validation/main/cer_ctc{}".format(i + 1) for i in range(model.num_encs)] trainer.extend( extensions.PlotReport( [ "main/loss", "validation/main/loss", "main/loss_ctc", "validation/main/loss_ctc", "main/loss_att", "validation/main/loss_att", ] + ([] if args.num_encs == 1 else report_keys_loss_ctc), "epoch", file_name="loss.png", ) ) trainer.extend( extensions.PlotReport( ["main/acc", "validation/main/acc"], "epoch", file_name="acc.png" ) ) trainer.extend( extensions.PlotReport( ["main/cer_ctc", "validation/main/cer_ctc"] + ([] if args.num_encs == 1 else report_keys_loss_ctc), "epoch", file_name="cer.png", ) ) # Save best models trainer.extend( snapshot_object(model, "model.loss.best"), trigger=training.triggers.MinValueTrigger("validation/main/loss"), ) if mtl_mode != "ctc": trainer.extend( snapshot_object(model, "model.acc.best"), trigger=training.triggers.MaxValueTrigger("validation/main/acc"), ) # save snapshot which contains model and optimizer states if args.save_interval_iters > 0: trainer.extend( torch_snapshot(filename="snapshot.iter.{.updater.iteration}"), trigger=(args.save_interval_iters, "iteration"), ) else: trainer.extend(torch_snapshot(), trigger=(1, "epoch")) # epsilon decay in the optimizer if args.opt == "adadelta": if args.criterion == "acc" and mtl_mode != "ctc": trainer.extend( restore_snapshot( model, args.outdir + "/model.acc.best", load_fn=torch_load ), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) trainer.extend( adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) elif args.criterion == "loss": trainer.extend( restore_snapshot( model, args.outdir + "/model.loss.best", load_fn=torch_load ), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) trainer.extend( adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) # lr decay in rmsprop if args.opt == "rmsprop": if args.criterion == "acc" and mtl_mode != "ctc": trainer.extend( restore_snapshot( model, args.outdir + "/model.acc.best", load_fn=torch_load ), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) trainer.extend( rmsprop_lr_decay(args.lr_decay), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) elif args.criterion == "loss": trainer.extend( restore_snapshot( model, args.outdir + "/model.loss.best", load_fn=torch_load ), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) trainer.extend( rmsprop_lr_decay(args.lr_decay), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) # Write a log of evaluation statistics for each epoch trainer.extend( extensions.LogReport(trigger=(args.report_interval_iters, "iteration")) ) report_keys = [ "epoch", "iteration", "main/loss", "main/loss_ctc", "main/loss_att", "validation/main/loss", "validation/main/loss_ctc", "validation/main/loss_att", "main/acc", "validation/main/acc", "main/cer_ctc", "validation/main/cer_ctc", "elapsed_time", ] + ([] if args.num_encs == 1 else report_keys_cer_ctc + report_keys_loss_ctc) if args.opt == "adadelta": trainer.extend( extensions.observe_value( "eps", lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][ "eps" ], ), trigger=(args.report_interval_iters, "iteration"), ) report_keys.append("eps") if args.opt == "rmsprop": trainer.extend( extensions.observe_value( "lr", lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][ "lr" ], ), trigger=(args.report_interval_iters, "iteration"), ) report_keys.append("lr") if args.report_cer: report_keys.append("validation/main/cer") if args.report_wer: report_keys.append("validation/main/wer") trainer.extend( extensions.PrintReport(report_keys), trigger=(args.report_interval_iters, "iteration"), ) trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": trainer.extend( TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter), trigger=(args.report_interval_iters, "iteration"), ) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def train(args): """Train with the given args :param Namespace args: The program arguments """ set_deterministic_pytorch(args) # check cuda availability if not torch.cuda.is_available(): logging.warning('cuda is not available') # get input and output src_path = os.path.join(args.base_dir, args.source) trg_path = os.path.join(args.base_dir, args.target) feeder = DataFeeder(src_path, trg_path, args) # Setup a converter converter = CustomConverter() train_batchset = feeder.make_train_batches() valid_batchset = feeder.make_test_batches() train_iter = SerialIterator(TransformDataset(train_batchset, converter.transform), batch_size=1, shuffle=True) valid_iter = SerialIterator(TransformDataset(valid_batchset, converter.transform), batch_size=1, repeat=False, shuffle=False) #pdb.set_trace() # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to' + model_conf) f.write( json.dumps(vars(args), indent=4, sort_keys=True).encode('utf_8')) for key in sorted(vars(args).keys()): logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) # specify model architecture translacotron = Translacotron(args) logging.info(translacotron) logging.info("Make Translacotron Model Done") # check the use of multi-gpu if args.ngpu > 1: translacotron = torch.nn.DataParallel(translacotron, device_ids=list(range( args.ngpu))) logging.info('batch size is automatically increased (%d -> %d)' % (args.batch_size, args.batch_size * args.ngpu)) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") translacotron = translacotron.to(device) # define loss model = TranslacotronLoss(translacotron, args.use_masking, args.bce_pos_weight) reporter = model.reporter # Setup an optimizer optimizer = torch.optim.Adam(model.parameters(), args.lr, eps=args.eps, weight_decay=args.weight_decay) # FIXME: TOO DIRTY HACK setattr(optimizer, 'target', reporter) setattr(optimizer, 'serialize', lambda s: reporter.serialize(s)) # Set up a trainer updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer, converter, device) trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.outdir) # Resume from a snapshot if args.resume: logging.info('resumed from %s' % args.resume) torch_resume(args.resume, trainer) # Evaluate the model with the test dataset for each epoch trainer.extend( CustomEvaluator(model, valid_iter, reporter, converter, device)) # Save snapshot for each epoch trainer.extend(torch_snapshot(), trigger=(1, 'epoch')) # Save best models trainer.extend( extensions.snapshot_object(translacotron, 'model.loss.best', savefun=torch_save), trigger=training.triggers.MinValueTrigger('validation/main/loss')) # Save attention figure for each epoch # Make a plot for training and validation values plot_keys = [ 'main/loss', 'validation/main/loss', 'main/l1_loss', 'validation/main/l1_loss', 'main/mse_loss', 'validation/main/mse_loss', 'main/bce_loss', 'validation/main/bce_loss' ] trainer.extend( extensions.PlotReport(['main/l1_loss', 'validation/main/l1_loss'], 'epoch', file_name='l1_loss.png')) trainer.extend( extensions.PlotReport(['main/mse_loss', 'validation/main/mse_loss'], 'epoch', file_name='mse_loss.png')) trainer.extend( extensions.PlotReport(['main/bce_loss', 'validation/main/bce_loss'], 'epoch', file_name='bce_loss.png')) if args.use_cbhg: plot_keys += [ 'main/cbhg_l1_loss', 'validation/main/cbhg_l1_loss', 'main/cbhg_mse_loss', 'validation/main/cbhg_mse_loss' ] trainer.extend( extensions.PlotReport( ['main/cbhg_l1_loss', 'validation/main/cbhg_l1_loss'], 'epoch', file_name='cbhg_l1_loss.png')) trainer.extend( extensions.PlotReport( ['main/cbhg_mse_loss', 'validation/main/cbhg_mse_loss'], 'epoch', file_name='cbhg_mse_loss.png')) trainer.extend( extensions.PlotReport(plot_keys, 'epoch', file_name='loss.png')) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL, 'iteration'))) report_keys = plot_keys[:] report_keys[0:0] = ['epoch', 'iteration', 'elapsed_time'] trainer.extend(extensions.PrintReport(report_keys), trigger=(REPORT_INTERVAL, 'iteration')) trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": writer = SummaryWriter(args.tensorboard_dir) trainer.extend(TensorboardLogger(writer, None)) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def train(args, teacher_args): """Train FCL-taco2 model.""" set_deterministic_pytorch(args) # args.use_fe_condition = True # # pre-occupy GPU # buff = torch.randn(int(1e9)).cuda() # del buff # check cuda availability if not torch.cuda.is_available(): logging.warning("cuda is not available") # get input and output dimension info with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] utts = list(valid_json.keys()) # reverse input and output dimension idim = int(valid_json[utts[0]]["output"][0]["shape"][1]) odim = int(valid_json[utts[0]]["input"][0]["shape"][1]) logging.info("#input dims: " + str(idim)) logging.info("#output dims: " + str(odim)) # get extra input and output dimenstion if args.use_speaker_embedding: args.spk_embed_dim = int(valid_json[utts[0]]["input"][1]["shape"][0]) else: args.spk_embed_dim = None if args.use_second_target: args.spc_dim = int(valid_json[utts[0]]["input"][1]["shape"][1]) else: args.spc_dim = None # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + "/model.json" with open(model_conf, "wb") as f: logging.info("writing a model config file to" + model_conf) f.write( json.dumps( (idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True ).encode("utf_8") ) for key in sorted(vars(args).keys()): logging.info("ARGS: " + key + ": " + str(vars(args)[key])) # specify model architecture if args.enc_init is not None or args.dec_init is not None: model = load_trained_modules(idim, odim, args, TTSInterface) else: model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args, args, teacher_args=teacher_args) #print('\n\nteacher_args:', teacher_args.embed_dim, '\n\n') teacher_model_class = dynamic_import(teacher_args.model_module) teacher_model = teacher_model_class(idim, odim, teacher_args, teacher_args) #teacher_model = teacher_model.to('cuda') if teacher_args.amp_checkpoint is None: raise ValueError('please provide the teacher-model-amp-checkpoint') else: logging.info("teacher-model resumed from %s" % teacher_args.amp_checkpoint) teacher_checkpoint = torch.load(teacher_args.amp_checkpoint) teacher_model.load_state_dict(teacher_checkpoint['model']) # print('tts_wds:', model.base_plot_keys) assert isinstance(model, TTSInterface) logging.info(model) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))) # model = torch.nn.DataParallel(model, device_ids=[4,5,6,7]) if args.batch_size != 0: logging.warning( "batch size is automatically increased (%d -> %d)" % (args.batch_size, args.batch_size * args.ngpu) ) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") model = model.to(device) teacher_model = teacher_model.to(device) for param in teacher_model.parameters(): # fix teacher model params param.requires_grad = False # freeze modules, if specified if args.freeze_mods: if hasattr(model, "module"): freeze_mods = ["module." + x for x in args.freeze_mods] else: freeze_mods = args.freeze_mods for mod, param in model.named_parameters(): if any(mod.startswith(key) for key in freeze_mods): logging.info(f"{mod} is frozen not to be updated.") param.requires_grad = False model_params = filter(lambda x: x.requires_grad, model.parameters()) else: model_params = model.parameters() # Setup an optimizer if args.opt == "adam": optimizer = torch.optim.Adam( model_params, args.lr, eps=args.eps, weight_decay=args.weight_decay ) elif args.opt == "noam": from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt optimizer = get_std_opt( model_params, args.adim, args.transformer_warmup_steps, args.transformer_lr ) elif args.opt == 'lamb': kw = dict(lr=0.1, betas=(0.9, 0.98), eps=1e-9, weight_decay=1e-6) from apex.optimizers import FusedAdam, FusedLAMB optimizer = FusedLAMB(model.parameters(), **kw) else: raise NotImplementedError("unknown optimizer: " + args.opt) if args.use_amp: opt_level = 'O1' model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) if args.amp_checkpoint is not None: logging.info("resumed from %s" % args.amp_checkpoint) checkpoint = torch.load(args.amp_checkpoint) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) amp.load_state_dict(checkpoint['amp']) # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) # read json data with open(args.train_json, "rb") as f: train_json = json.load(f)["utts"] with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] num_batches = len(train_json.keys()) // args.batch_size use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 if use_sortagrad: args.batch_sort_key = "input" print(f'\n\n batch_sort_key: {args.batch_sort_key} \n\n') # make minibatch list (variable length) train_batchset = make_batchset( train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, batch_sort_key=args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, swap_io=True, iaxis=0, oaxis=0, ) valid_batchset = make_batchset( valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, batch_sort_key=args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, swap_io=True, iaxis=0, oaxis=0, ) from io_utils_fcl import LoadInputsAndTargets load_tr = LoadInputsAndTargets( mode="tts", use_speaker_embedding=args.use_speaker_embedding, use_second_target=args.use_second_target, preprocess_conf=args.preprocess_conf, preprocess_args={"train": True}, # Switch the mode of preprocessing keep_all_data_on_mem=args.keep_all_data_on_mem, pad_eos=args.pad_eos, ) load_cv = LoadInputsAndTargets( mode="tts", use_speaker_embedding=args.use_speaker_embedding, use_second_target=args.use_second_target, preprocess_conf=args.preprocess_conf, preprocess_args={"train": False}, # Switch the mode of preprocessing keep_all_data_on_mem=args.keep_all_data_on_mem, pad_eos=args.pad_eos, ) converter = CustomConverter(reduction_factor=args.reduction_factor, use_fe_condition=args.use_fe_condition, append_position=args.append_position, ) # hack to make batchsize argument as 1 # actual bathsize is included in a list train_iter = { "main": ChainerDataLoader( dataset=TransformDataset( train_batchset, lambda data: converter([load_tr(data)]) ), batch_size=1, num_workers=args.num_iter_processes, shuffle=not use_sortagrad, collate_fn=lambda x: x[0], ) } valid_iter = { "main": ChainerDataLoader( dataset=TransformDataset( valid_batchset, lambda data: converter([load_cv(data)]) ), batch_size=1, shuffle=False, collate_fn=lambda x: x[0], num_workers=args.num_iter_processes, ) } # Set up a trainer updater = CustomUpdater( teacher_model, model, args.grad_clip, train_iter, optimizer, device, args.accum_grad, args.use_amp, num_batches, args.outdir ) trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) # Resume from a snapshot if args.resume: logging.info("resumed from %s" % args.resume) torch_resume(args.resume, trainer) # set intervals eval_interval = (args.eval_interval_epochs, "epoch") save_interval = (args.save_interval_epochs, "epoch") report_interval = (args.report_interval_iters, "iteration") # Evaluate the model with the test dataset for each epoch trainer.extend( CustomEvaluator(teacher_model, model, valid_iter, reporter, device), trigger=eval_interval ) # Save snapshot for each epoch trainer.extend(torch_snapshot(), trigger=save_interval) # Save best models trainer.extend( snapshot_object(model, "model.loss.best"), trigger=training.triggers.MinValueTrigger( "validation/main/loss", trigger=eval_interval ), ) # Make a plot for training and validation values if hasattr(model, "module"): base_plot_keys = model.module.base_plot_keys else: base_plot_keys = model.base_plot_keys plot_keys = [] for key in base_plot_keys: plot_key = ["main/" + key, "validation/main/" + key] trainer.extend( extensions.PlotReport(plot_key, "epoch", file_name=key + ".png"), trigger=eval_interval, ) plot_keys += plot_key trainer.extend( extensions.PlotReport(plot_keys, "epoch", file_name="all_loss.png"), trigger=eval_interval, ) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport(trigger=report_interval)) report_keys = ["epoch", "iteration", "elapsed_time"] + plot_keys trainer.extend(extensions.PrintReport(report_keys), trigger=report_interval) trainer.extend(extensions.ProgressBar(), trigger=report_interval) set_early_stop(trainer, args) # if args.tensorboard_dir is not None and args.tensorboard_dir != "": # writer = SummaryWriter(args.tensorboard_dir) # trainer.extend(TensorboardLogger(writer, att_reporter), trigger=report_interval) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), ) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def train(args): """Train with the given args. Args: args (namespace): The program arguments. """ set_deterministic_pytorch(args) if args.num_encs > 1: args = format_mulenc_args(args) # check cuda availability if not torch.cuda.is_available(): logging.warning("cuda is not available") # get input and output dimension info with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] utts = list(valid_json.keys()) idim_list = [ int(valid_json[utts[0]]["input"][i]["shape"][-1]) for i in range(args.num_encs) ] odim = int(valid_json[utts[0]]["output"][0]["shape"][-1]) for i in range(args.num_encs): logging.info("stream{}: input dims : {}".format(i + 1, idim_list[i])) logging.info("#output dims: " + str(odim)) # specify semi-supervised method assert 0.0 <= args.mixup_alpha <= 1.0, "mixup-alpha should be [0.0, 1.0]" if args.mixup_alpha == 0.0: semi_mode = "MT" logging.info("Pure Mean-Teacher mode") else: semi_mode = "ICT" logging.info("Interpolation Consistency Training mode") if (args.enc_init is not None or args.dec_init is not None) and args.num_encs == 1: model = load_trained_modules(idim_list[0], odim, args) else: model_class = dynamic_import(args.model_module) model = model_class(idim_list[0] if args.num_encs == 1 else idim_list, odim, args) assert isinstance(model, ASRInterface) if args.rnnlm is not None: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer, rnnlm_args.unit)) torch_load(args.rnnlm, rnnlm) model.rnnlm = rnnlm # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + "/model.json" with open(model_conf, "wb") as f: logging.info("writing a model config file to " + model_conf) f.write( json.dumps( (idim_list[0] if args.num_encs == 1 else idim_list, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True, ).encode("utf_8")) for key in sorted(vars(args).keys()): logging.info("ARGS: " + key + ": " + str(vars(args)[key])) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: if args.batch_size != 0: logging.warning( "batch size is automatically increased (%d -> %d)" % (args.batch_size, args.batch_size * args.ngpu)) args.batch_size *= args.ngpu if args.num_encs > 1: # TODO(ruizhili): implement data parallel for multi-encoder setup. raise NotImplementedError( "Data parallel is not supported for multi-encoder setup.") # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") if args.train_dtype in ("float16", "float32", "float64"): dtype = getattr(torch, args.train_dtype) else: dtype = torch.float32 model = model.to(device=device, dtype=dtype) # Setup an optimizer if args.opt == "adadelta": optimizer = torch.optim.Adadelta(model.enc.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay) elif args.opt == "adam": optimizer = torch.optim.Adam(model.enc.parameters(), weight_decay=args.weight_decay) elif args.opt == "noam": from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt optimizer = get_std_opt(model.enc, args.adim, args.transformer_warmup_steps, args.transformer_lr) elif args.opt == "rmsprop": optimizer = torch.optim.RMSprop(model.enc.parameters(), lr=0.0008, alpha=0.95) elif args.opt == "sgd": optimizer = torch.optim.SGD(model.enc.parameters(), lr=0.5, momentum=0.9, nesterov=True) else: raise NotImplementedError("unknown optimizer: " + args.opt) use_apex = False # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) # Setup a converter if args.num_encs == 1: converter = CustomConverter(subsampling_factor=model.subsample[0], dtype=dtype) else: converter = CustomConverterMulEnc([i[0] for i in model.subsample_list], dtype=dtype) # read json data assert 0.0 < args.utt_using_ratio < 1.0, "utt-using-ratio should not be 0 or 1" with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] if args.train_json is not None: with open(args.train_json, 'rb') as f: train_json = json.load(f)['utts'] train_labeled_json_path = args.train_json.replace( '.json', '_labeled_{}.json'.format(int(args.utt_using_ratio * 100))) train_unlabeled_json_path = args.train_json.replace( '.json', '_unlabeled_{}.json'.format(100 - int(args.utt_using_ratio * 100))) if os.path.exists(train_labeled_json_path): with open(train_labeled_json_path, 'rb') as f: train_labeled_json = json.load(f)['utts'] with open(train_unlabeled_json_path, 'rb') as f: train_unlabeled_json = json.load(f)['utts'] else: # split json for each task split_point = [int(len(train_json) * args.utt_using_ratio)] train_labeled_json = dict( list(train_json.items())[:split_point[0]]) train_unlabeled_json = dict( list(train_json.items())[split_point[0]:]) with codecs.open(train_labeled_json_path, 'w+', encoding='utf8') as f: json.dump({'utts': train_labeled_json}, f, indent=4, sort_keys=True, ensure_ascii=False, separators=(',', ': ')) with codecs.open(train_unlabeled_json_path, 'w+', encoding='utf8') as f: json.dump({'utts': train_unlabeled_json}, f, indent=4, sort_keys=True, ensure_ascii=False, separators=(',', ': ')) else: with open(args.label_train_json, 'rb') as f: train_labeled_json = json.load(f)['utts'] with open(args.unlabel_train_json, 'rb') as f: train_unlabeled_json = json.load(f)['utts'] valid_labeled_json_path = args.valid_json.replace( '.json', '_labeled_{}.json'.format(int(args.utt_using_ratio * 100))) valid_unlabeled_json_path = args.valid_json.replace( '.json', '_unlabeled_{}.json'.format(100 - int(args.utt_using_ratio * 100))) if os.path.exists(valid_labeled_json_path): with open(valid_labeled_json_path, 'rb') as f: valid_labeled_json = json.load(f)['utts'] with open(valid_unlabeled_json_path, 'rb') as f: valid_unlabeled_json = json.load(f)['utts'] else: # split json for each task split_point = [int(len(valid_json) * args.utt_using_ratio)] valid_labeled_json = dict(list(valid_json.items())[:split_point[0]]) valid_unlabeled_json = dict(list(valid_json.items())[split_point[0]:]) with codecs.open(valid_labeled_json_path, 'w+', encoding='utf8') as f: json.dump({'utts': valid_labeled_json}, f, indent=4, sort_keys=True, ensure_ascii=False, separators=(',', ': ')) with codecs.open(valid_unlabeled_json_path, 'w+', encoding='utf8') as f: json.dump({'utts': valid_unlabeled_json}, f, indent=4, sort_keys=True, ensure_ascii=False, separators=(',', ': ')) use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 # make minibatch list (variable length) train = make_batchset(train_labeled_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0) valid = make_batchset(valid_labeled_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0) ul_train = make_batchset(train_unlabeled_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0) ul_valid = make_batchset(valid_unlabeled_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0) load_tr = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': True} # Switch the mode of preprocessing ) load_cv = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': False} # Switch the mode of preprocessing ) load_ul_tr = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': True} # Switch the mode of preprocessing ) load_ul_cv = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': False} # Switch the mode of preprocessing ) # hack to make batchsize argument as 1 # actual bathsize is included in a list # default collate function converts numpy array to pytorch tensor # we used an empty collate function instead which returns list train_iter = ChainerDataLoader( dataset=TransformDataset(train, lambda data: converter([load_tr(data)])), batch_size=1, num_workers=args.n_iter_processes, shuffle=not use_sortagrad, collate_fn=lambda x: x[0], ) valid_iter = ChainerDataLoader( dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])), batch_size=1, shuffle=False, collate_fn=lambda x: x[0], num_workers=args.n_iter_processes, ) # Add splited data iteration for training ul_train_iter = ChainerDataLoader( dataset=TransformDataset(ul_train, lambda data: converter([load_ul_tr(data)])), batch_size=1, shuffle=True, collate_fn=lambda x: x[0], num_workers=args.n_iter_processes, ) ul_valid_iter = ChainerDataLoader( dataset=TransformDataset(ul_valid, lambda data: converter([load_ul_cv(data)])), batch_size=1, shuffle=False, collate_fn=lambda x: x[0], num_workers=args.n_iter_processes, ) # Set up ICT related arguments ICT_args = { "consistency_rampup_starts": args.consistency_rampup_starts, "consistency_rampup_ends": args.consistency_rampup_ends, "cosine_rampdown_starts": args.cosine_rampdown_starts, "cosine_rampdown_ends": args.cosine_rampdown_ends, "ema_pre_decay": args.ema_pre_decay, "ema_post_decay": args.ema_post_decay } # Set up a trainer updater = CustomUpdater( model, args.grad_clip, { "main": train_iter, "sub": ul_train_iter }, optimizer, device, args.ngpu, ICT_args, args.grad_noise, args.accum_grad, use_apex=use_apex, ) trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.outdir) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), ) # Resume from a snapshot if args.resume: logging.info("resumed from %s" % args.resume) torch_resume(args.resume, trainer) # Evaluate the model with the test dataset for each epoch # TODO: custom evaluator if args.save_interval_iters > 0: trainer.extend( CustomEvaluator(model, { "main": valid_iter, "sub": ul_valid_iter }, reporter, device, args.ngpu), trigger=(args.save_interval_iters, "iteration"), ) else: trainer.extend( CustomEvaluator(model, { "main": valid_iter, "sub": ul_valid_iter }, reporter, device, args.ngpu)) # Make a plot for training and validation values trainer.extend( extensions.PlotReport( [ "main/loss", "validation/main/loss", "main/loss_ce", "validation/main/loss_ce", "main/loss_mse", "validation/main/loss_mse", ], "epoch", file_name="loss.png", )) trainer.extend( extensions.PlotReport( ["main/teacher_acc", "validation/main/teacher_acc"] + (["main/student_acc", "validation/main/student_acc"] if args.show_student_model_acc else []), "epoch", file_name="acc.png")) # Save best models trainer.extend( snapshot_object(model, "model.loss.best"), trigger=training.triggers.MinValueTrigger("validation/main/loss"), ) trainer.extend( snapshot_object(model, "model.acc.best"), trigger=training.triggers.MaxValueTrigger( "validation/main/teacher_acc"), ) # save snapshot which contains model and optimizer states if args.save_interval_iters > 0: trainer.extend( torch_snapshot(filename="snapshot.iter.{.updater.iteration}"), trigger=(args.save_interval_iters, "iteration"), ) else: trainer.extend(torch_snapshot(), trigger=(1, "epoch")) # epsilon decay in the optimizer if args.opt == "adadelta": if args.criterion == "acc": trainer.extend( restore_snapshot(model, args.outdir + "/model.acc.best", load_fn=torch_load), trigger=CompareValueTrigger( "validation/main/student_acc", lambda best_value, current_value: best_value > current_value, ), ) trainer.extend( adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( "validation/main/student_acc", lambda best_value, current_value: best_value > current_value, ), ) elif args.criterion == "loss": trainer.extend( restore_snapshot(model, args.outdir + "/model.loss.best", load_fn=torch_load), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) trainer.extend( adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) # lr decay in rmsprop elif args.opt == "rmsprop" or "sgd": if args.criterion == "acc": trainer.extend( restore_snapshot(model, args.outdir + "/model.acc.best", load_fn=torch_load), trigger=CompareValueTrigger( "validation/main/teacher_acc", lambda best_value, current_value: best_value > current_value, ), ) trainer.extend( rmsprop_lr_decay(args.lr_decay), trigger=CompareValueTrigger( "validation/main/teacher_acc", lambda best_value, current_value: best_value > current_value, ), ) elif args.criterion == "loss": trainer.extend( restore_snapshot(model, args.outdir + "/model.loss.best", load_fn=torch_load), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) trainer.extend( rmsprop_lr_decay(args.lr_decay), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) # Write a log of evaluation statistics for each epoch trainer.extend( extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))) report_keys = [ "epoch", "iteration", "main/loss", "main/loss_ce", "main/loss_mse", "validation/main/loss", "validation/main/loss_ce", "validation/main/loss_mse", "main/teacher_acc", "validation/main/teacher_acc", "elapsed_time", ] + ["main/student_acc", "validation/main/student_acc" ] if args.show_student_model_acc else [] if args.opt == "adadelta": trainer.extend( extensions.observe_value( "eps", lambda trainer: trainer.updater.get_optimizer("main"). param_groups[0]["eps"], ), trigger=(args.report_interval_iters, "iteration"), ) report_keys.append("eps") if args.opt == "rmsprop" or "sgd": trainer.extend( extensions.observe_value( "lr", lambda trainer: trainer.updater.get_optimizer("main"). param_groups[0]["lr"], ), trigger=(args.report_interval_iters, "iteration"), ) report_keys.append("lr") trainer.extend( extensions.PrintReport(report_keys), trigger=(args.report_interval_iters, "iteration"), ) trainer.extend( extensions.ProgressBar(update_interval=args.report_interval_iters)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": trainer.extend( TensorboardLogger(SummaryWriter(args.tensorboard_dir), None), trigger=(args.report_interval_iters, "iteration"), ) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def train(args): """Train with the given args. Args: args (namespace): The program arguments. """ set_deterministic_pytorch(args) # check cuda availability if not torch.cuda.is_available(): logging.warning('cuda is not available') # get input and output dimension info with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] utts = list(valid_json.keys()) idim = int(valid_json[utts[0]]['input'][0]['shape'][-1]) odim = int(valid_json[utts[0]]['output'][0]['shape'][-1]) logging.info('#input dims : ' + str(idim)) logging.info('#output dims: ' + str(odim)) # specify attention, CTC, hybrid mode if args.mtlalpha == 1.0: mtl_mode = 'ctc' logging.info('Pure CTC mode') elif args.mtlalpha == 0.0: mtl_mode = 'att' logging.info('Pure attention mode') else: mtl_mode = 'mtl' logging.info('Multitask learning mode') asr_model, mt_model = None, None # Initialize encoder with pre-trained ASR encoder if args.asr_model: asr_model, _ = load_trained_model(args.asr_model) assert isinstance(asr_model, ASRInterface) # Initialize decoder with pre-trained MT decoder if args.mt_model: mt_model, _ = load_trained_model(args.mt_model) assert isinstance(mt_model, MTInterface) # specify model architecture model_class = dynamic_import(args.model_module) # TODO(hirofumi0810) better to simplify the E2E model interface by only allowing idim, odim, and args # the pre-trained ASR and MT model arguments should be removed here and we should implement an additional method # to attach these models if asr_model is None and mt_model is None: model = model_class(idim, odim, args) elif mt_model is None: model = asr_model else: model = model_class(idim, odim, args, asr_model=asr_model, mt_model=mt_model) assert isinstance(model, ASRInterface) subsampling_factor = model.subsample[0] # delete pre-trained models if args.asr_model: del asr_model if args.mt_model: del mt_model if args.slu_model and args.slu_loss: model.add_slu(args.slu_model, args.slu_loss, args.slu_tune_weights, args.slu_pooling) if args.rnnlm is not None: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer, rnnlm_args.unit)) torch.load(args.rnnlm, rnnlm) model.rnnlm = rnnlm # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to ' + model_conf) f.write( json.dumps((idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8')) for key in sorted(vars(args).keys()): logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))) if args.batch_size != 0: logging.info('batch size is automatically increased (%d -> %d)' % (args.batch_size, args.batch_size * args.ngpu)) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") model = model.to(device) scheduler = None # Setup an optimizer if args.opt == 'adadelta': optimizer = torch.optim.Adadelta(model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay) elif args.opt == 'adam': optimizer = torch.optim.Adam(model.parameters(), weight_decay=args.weight_decay) elif args.opt == 'noam': from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt optimizer = get_std_opt(model, args.adim, args.transformer_warmup_steps, args.transformer_lr) elif args.opt == 'adamw': from transformers import AdamW, WarmupLinearSchedule # Prepare optimizer and schedule (linear warmup and decay) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }, { 'params': [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] optimizer = AdamW(optimizer_grouped_parameters, lr=5e-5, eps=1e-8) else: raise NotImplementedError("unknown optimizer: " + args.opt) # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) # Setup a converter converter = CustomConverter(subsampling_factor=subsampling_factor) # read json data with open(args.train_json, 'rb') as f: train_json = json.load(f)['utts'] with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 # make minibatch list (variable length) train = make_batchset(train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout) valid = make_batchset(valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout) load_tr = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': True} # Switch the mode of preprocessing ) load_cv = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': False} # Switch the mode of preprocessing ) # hack to make batchsize argument as 1 # actual bathsize is included in a list if args.n_iter_processes > 0: train_iter = ToggleableShufflingMultiprocessIterator( TransformDataset(train, load_tr), batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20, shuffle=not use_sortagrad) valid_iter = ToggleableShufflingMultiprocessIterator( TransformDataset(valid, load_cv), batch_size=1, repeat=False, shuffle=False, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20) else: train_iter = ToggleableShufflingSerialIterator( TransformDataset(train, load_tr), batch_size=1, shuffle=not use_sortagrad) valid_iter = ToggleableShufflingSerialIterator(TransformDataset( valid, load_cv), batch_size=1, repeat=False, shuffle=False) # Set up a trainer updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer, converter, device, args.ngpu, args.grad_noise, args.accum_grad) trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.outdir) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, 'epoch')) if scheduler: trainer.extend(scheduler.step(), name='transformer_warmup') # Resume from a snapshot if args.resume: logging.info('resumed from %s' % args.resume) torch_resume(args.resume, trainer) # Evaluate the model with the test dataset for each epoch trainer.extend( CustomEvaluator(model, valid_iter, reporter, converter, device)) # Save attention weight each epoch if args.num_save_attention > 0 and args.mtlalpha != 1.0: data = sorted(list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class att_reporter = plot_class(att_vis_fn, data, args.outdir + "/att_ws", converter=converter, transform=load_cv, device=device) trainer.extend(att_reporter, trigger=(1, 'epoch')) else: att_reporter = None # Make a plot for training and validation values trainer.extend( extensions.PlotReport([ 'main/loss', 'validation/main/loss', 'main/loss_ctc', 'validation/main/loss_ctc', 'main/loss_att', 'validation/main/loss_att' ], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport(['main/acc', 'validation/main/acc'], 'epoch', file_name='acc.png')) trainer.extend( extensions.PlotReport(['main/cer_ctc', 'validation/main/cer_ctc'], 'epoch', file_name='cer.png')) # Save best models trainer.extend( snapshot_object(model, 'model.loss.best'), trigger=training.triggers.MinValueTrigger('validation/main/loss')) if mtl_mode != 'ctc': trainer.extend( snapshot_object(model, 'model.acc.best'), trigger=training.triggers.MaxValueTrigger('validation/main/acc')) # save snapshot which contains model and optimizer states trainer.extend(torch_snapshot()) # epsilon decay in the optimizer if args.opt == 'adadelta': if args.criterion == 'acc' and mtl_mode != 'ctc': trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best', load_fn=torch_load), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) elif args.criterion == 'loss': trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best', load_fn=torch_load), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) # Write a log of evaluation statistics for each epoch trainer.extend( extensions.LogReport(trigger=(args.report_interval_iters, 'iteration'))) report_keys = [ 'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att', 'validation/main/loss', 'validation/main/loss_ctc', 'validation/main/loss_att', 'main/acc', 'validation/main/acc', 'main/cer_ctc', 'validation/main/cer_ctc', 'elapsed_time' ] if args.opt == 'adadelta': trainer.extend(extensions.observe_value( 'eps', lambda trainer: trainer.updater.get_optimizer('main'). param_groups[0]["eps"]), trigger=(args.report_interval_iters, 'iteration')) report_keys.append('eps') if args.report_cer: report_keys.append('validation/main/cer') if args.report_wer: report_keys.append('validation/main/wer') trainer.extend(extensions.PrintReport(report_keys), trigger=(args.report_interval_iters, 'iteration')) trainer.extend( extensions.ProgressBar(update_interval=args.report_interval_iters)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": trainer.extend(TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter), trigger=(args.report_interval_iters, "iteration")) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def train(args): """Train with the given args. Args: args (namespace): The program arguments. """ set_deterministic_pytorch(args) # check cuda availability if not torch.cuda.is_available(): logging.warning("cuda is not available") # get input and output dimension info with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] utts = list(valid_json.keys()) idim = int(valid_json[utts[0]]["output"][1]["shape"][1]) odim = int(valid_json[utts[0]]["output"][0]["shape"][1]) logging.info("#input dims : " + str(idim)) logging.info("#output dims: " + str(odim)) # specify model architecture model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args) assert isinstance(model, MTInterface) # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + "/model.json" with open(model_conf, "wb") as f: logging.info("writing a model config file to " + model_conf) f.write( json.dumps((idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True).encode("utf_8")) for key in sorted(vars(args).keys()): logging.info("ARGS: " + key + ": " + str(vars(args)[key])) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: if args.batch_size != 0: logging.warning( "batch size is automatically increased (%d -> %d)" % (args.batch_size, args.batch_size * args.ngpu)) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") if args.train_dtype in ("float16", "float32", "float64"): dtype = getattr(torch, args.train_dtype) else: dtype = torch.float32 model = model.to(device=device, dtype=dtype) logging.warning( "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), sum(p.numel() for p in model.parameters() if p.requires_grad) * 100.0 / sum(p.numel() for p in model.parameters()), )) # Setup an optimizer if args.opt == "adadelta": optimizer = torch.optim.Adadelta(model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay) elif args.opt == "adam": optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.opt == "noam": from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt optimizer = get_std_opt( model.parameters(), args.adim, args.transformer_warmup_steps, args.transformer_lr, ) else: raise NotImplementedError("unknown optimizer: " + args.opt) # setup apex.amp if args.train_dtype in ("O0", "O1", "O2", "O3"): try: from apex import amp except ImportError as e: logging.error( f"You need to install apex for --train-dtype {args.train_dtype}. " "See https://github.com/NVIDIA/apex#linux") raise e if args.opt == "noam": model, optimizer.optimizer = amp.initialize( model, optimizer.optimizer, opt_level=args.train_dtype) else: model, optimizer = amp.initialize(model, optimizer, opt_level=args.train_dtype) use_apex = True else: use_apex = False # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) # Setup a converter converter = CustomConverter() # read json data with open(args.train_json, "rb") as f: train_json = json.load(f)["utts"] with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 # make minibatch list (variable length) train = make_batchset( train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, mt=True, iaxis=1, oaxis=0, ) valid = make_batchset( valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, mt=True, iaxis=1, oaxis=0, ) load_tr = LoadInputsAndTargets(mode="mt", load_output=True) load_cv = LoadInputsAndTargets(mode="mt", load_output=True) # hack to make batchsize argument as 1 # actual bathsize is included in a list # default collate function converts numpy array to pytorch tensor # we used an empty collate function instead which returns list train_iter = ChainerDataLoader( dataset=TransformDataset(train, lambda data: converter([load_tr(data)])), batch_size=1, num_workers=args.n_iter_processes, shuffle=not use_sortagrad, collate_fn=lambda x: x[0], ) valid_iter = ChainerDataLoader( dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])), batch_size=1, shuffle=False, collate_fn=lambda x: x[0], num_workers=args.n_iter_processes, ) # Set up a trainer updater = CustomUpdater( model, args.grad_clip, {"main": train_iter}, optimizer, device, args.ngpu, False, args.accum_grad, use_apex=use_apex, ) trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), ) # Resume from a snapshot if args.resume: logging.info("resumed from %s" % args.resume) torch_resume(args.resume, trainer) # Evaluate the model with the test dataset for each epoch if args.save_interval_iters > 0: trainer.extend( CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu), trigger=(args.save_interval_iters, "iteration"), ) else: trainer.extend( CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu)) # Save attention weight each epoch if args.num_save_attention > 0: # NOTE: sort it by output lengths data = sorted( list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]["output"][0]["shape"][0]), reverse=True, ) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class att_reporter = plot_class( att_vis_fn, data, args.outdir + "/att_ws", converter=converter, transform=load_cv, device=device, ikey="output", iaxis=1, ) trainer.extend(att_reporter, trigger=(1, "epoch")) else: att_reporter = None # Make a plot for training and validation values trainer.extend( extensions.PlotReport(["main/loss", "validation/main/loss"], "epoch", file_name="loss.png")) trainer.extend( extensions.PlotReport(["main/acc", "validation/main/acc"], "epoch", file_name="acc.png")) trainer.extend( extensions.PlotReport(["main/ppl", "validation/main/ppl"], "epoch", file_name="ppl.png")) trainer.extend( extensions.PlotReport(["main/bleu", "validation/main/bleu"], "epoch", file_name="bleu.png")) # Save best models trainer.extend( snapshot_object(model, "model.loss.best"), trigger=training.triggers.MinValueTrigger("validation/main/loss"), ) trainer.extend( snapshot_object(model, "model.acc.best"), trigger=training.triggers.MaxValueTrigger("validation/main/acc"), ) # save snapshot which contains model and optimizer states if args.save_interval_iters > 0: trainer.extend( torch_snapshot(filename="snapshot.iter.{.updater.iteration}"), trigger=(args.save_interval_iters, "iteration"), ) else: trainer.extend(torch_snapshot(), trigger=(1, "epoch")) # epsilon decay in the optimizer if args.opt == "adadelta": if args.criterion == "acc": trainer.extend( restore_snapshot(model, args.outdir + "/model.acc.best", load_fn=torch_load), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) trainer.extend( adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) elif args.criterion == "loss": trainer.extend( restore_snapshot(model, args.outdir + "/model.loss.best", load_fn=torch_load), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) trainer.extend( adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) elif args.opt == "adam": if args.criterion == "acc": trainer.extend( restore_snapshot(model, args.outdir + "/model.acc.best", load_fn=torch_load), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) trainer.extend( adam_lr_decay(args.lr_decay), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) elif args.criterion == "loss": trainer.extend( restore_snapshot(model, args.outdir + "/model.loss.best", load_fn=torch_load), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) trainer.extend( adam_lr_decay(args.lr_decay), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) # Write a log of evaluation statistics for each epoch trainer.extend( extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))) report_keys = [ "epoch", "iteration", "main/loss", "validation/main/loss", "main/acc", "validation/main/acc", "main/ppl", "validation/main/ppl", "elapsed_time", ] if args.opt == "adadelta": trainer.extend( extensions.observe_value( "eps", lambda trainer: trainer.updater.get_optimizer("main"). param_groups[0]["eps"], ), trigger=(args.report_interval_iters, "iteration"), ) report_keys.append("eps") elif args.opt in ["adam", "noam"]: trainer.extend( extensions.observe_value( "lr", lambda trainer: trainer.updater.get_optimizer("main"). param_groups[0]["lr"], ), trigger=(args.report_interval_iters, "iteration"), ) report_keys.append("lr") if args.report_bleu: report_keys.append("main/bleu") report_keys.append("validation/main/bleu") trainer.extend( extensions.PrintReport(report_keys), trigger=(args.report_interval_iters, "iteration"), ) trainer.extend( extensions.ProgressBar(update_interval=args.report_interval_iters)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": from torch.utils.tensorboard import SummaryWriter trainer.extend( TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter), trigger=(args.report_interval_iters, "iteration"), ) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def train(args): """Train E2E-TTS model.""" set_deterministic_pytorch(args) # check cuda availability if not torch.cuda.is_available(): logging.warning('cuda is not available') # get input and output dimension info with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] utts = list(valid_json.keys()) # reverse input and output dimension idim = int(valid_json[utts[0]]['output'][0]['shape'][1]) odim = int(valid_json[utts[0]]['input'][0]['shape'][1]) logging.info('#input dims : ' + str(idim)) logging.info('#output dims: ' + str(odim)) # get extra input and output dimenstion if args.use_speaker_embedding: args.spk_embed_dim = int(valid_json[utts[0]]['input'][1]['shape'][0]) else: args.spk_embed_dim = None if args.use_second_target: args.spc_dim = int(valid_json[utts[0]]['input'][1]['shape'][1]) else: args.spc_dim = None # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to' + model_conf) f.write( json.dumps((idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8')) for key in sorted(vars(args).keys()): logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) # specify model architecture model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args) assert isinstance(model, TTSInterface) logging.info(model) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))) if args.batch_size != 0: logging.warning( 'batch size is automatically increased (%d -> %d)' % (args.batch_size, args.batch_size * args.ngpu)) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") model = model.to(device) # Setup an optimizer if args.opt == 'adam': optimizer = torch.optim.Adam(model.parameters(), args.lr, eps=args.eps, weight_decay=args.weight_decay) elif args.opt == 'noam': from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt optimizer = get_std_opt(model, args.adim, args.transformer_warmup_steps, args.transformer_lr) else: raise NotImplementedError("unknown optimizer: " + args.opt) # FIXME: TOO DIRTY HACK setattr(optimizer, 'target', reporter) setattr(optimizer, 'serialize', lambda s: reporter.serialize(s)) # read json data with open(args.train_json, 'rb') as f: train_json = json.load(f)['utts'] with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 if use_sortagrad: args.batch_sort_key = "input" # make minibatch list (variable length) train_batchset = make_batchset( train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, batch_sort_key=args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, swap_io=True, iaxis=0, oaxis=0) valid_batchset = make_batchset( valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, batch_sort_key=args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, swap_io=True, iaxis=0, oaxis=0) load_tr = LoadInputsAndTargets( mode='tts', use_speaker_embedding=args.use_speaker_embedding, use_second_target=args.use_second_target, preprocess_conf=args.preprocess_conf, preprocess_args={'train': True}, # Switch the mode of preprocessing keep_all_data_on_mem=args.keep_all_data_on_mem, ) load_cv = LoadInputsAndTargets( mode='tts', use_speaker_embedding=args.use_speaker_embedding, use_second_target=args.use_second_target, preprocess_conf=args.preprocess_conf, preprocess_args={'train': False}, # Switch the mode of preprocessing keep_all_data_on_mem=args.keep_all_data_on_mem, ) converter = CustomConverter() # hack to make batchsize argument as 1 # actual bathsize is included in a list train_iter = { 'main': ChainerDataLoader(dataset=TransformDataset( train_batchset, lambda data: converter([load_tr(data)])), batch_size=1, num_workers=args.num_iter_processes, shuffle=not use_sortagrad, collate_fn=lambda x: x[0]) } valid_iter = { 'main': ChainerDataLoader(dataset=TransformDataset( valid_batchset, lambda data: converter([load_cv(data)])), batch_size=1, shuffle=False, collate_fn=lambda x: x[0], num_workers=args.num_iter_processes) } # Set up a trainer updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer, device, args.accum_grad) trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.outdir) # Resume from a snapshot if args.resume: logging.info('resumed from %s' % args.resume) torch_resume(args.resume, trainer) # set intervals eval_interval = (args.eval_interval_epochs, 'epoch') save_interval = (args.save_interval_epochs, 'epoch') report_interval = (args.report_interval_iters, 'iteration') # Evaluate the model with the test dataset for each epoch trainer.extend(CustomEvaluator(model, valid_iter, reporter, device), trigger=eval_interval) # Save snapshot for each epoch trainer.extend(torch_snapshot(), trigger=save_interval) # Save best models trainer.extend(snapshot_object(model, 'model.loss.best'), trigger=training.triggers.MinValueTrigger( 'validation/main/loss', trigger=eval_interval)) # Save attention figure for each epoch if args.num_save_attention > 0: data = sorted(list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class att_reporter = plot_class(att_vis_fn, data, args.outdir + '/att_ws', converter=converter, transform=load_cv, device=device, reverse=True) trainer.extend(att_reporter, trigger=eval_interval) else: att_reporter = None # Make a plot for training and validation values if hasattr(model, "module"): base_plot_keys = model.module.base_plot_keys else: base_plot_keys = model.base_plot_keys plot_keys = [] for key in base_plot_keys: plot_key = ['main/' + key, 'validation/main/' + key] trainer.extend(extensions.PlotReport(plot_key, 'epoch', file_name=key + '.png'), trigger=eval_interval) plot_keys += plot_key trainer.extend(extensions.PlotReport(plot_keys, 'epoch', file_name='all_loss.png'), trigger=eval_interval) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport(trigger=report_interval)) report_keys = ['epoch', 'iteration', 'elapsed_time'] + plot_keys trainer.extend(extensions.PrintReport(report_keys), trigger=report_interval) trainer.extend(extensions.ProgressBar(), trigger=report_interval) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": writer = SummaryWriter(args.tensorboard_dir) trainer.extend(TensorboardLogger(writer, att_reporter), trigger=report_interval) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, 'epoch')) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def train(args): """Train with the given args :param Namespace args: The program arguments """ # TODO(karita): support this if args.model_module != "default": raise NotImplementedError( "chainer backend does not support --model-module") # display chainer version logging.info('chainer version = ' + chainer.__version__) set_deterministic_chainer(args) # check cuda and cudnn availability if not chainer.cuda.available: logging.warning('cuda is not available') if not chainer.cuda.cudnn_enabled: logging.warning('cudnn is not available') # get special label ids unk = args.char_list_dict['<unk>'] eos = args.char_list_dict['<eos>'] # read tokens as a sequence of sentences train = read_tokens(args.train_label, args.char_list_dict) val = read_tokens(args.valid_label, args.char_list_dict) # count tokens n_train_tokens, n_train_oovs = count_tokens(train, unk) n_val_tokens, n_val_oovs = count_tokens(val, unk) logging.info('#vocab = ' + str(args.n_vocab)) logging.info('#sentences in the training data = ' + str(len(train))) logging.info('#tokens in the training data = ' + str(n_train_tokens)) logging.info('oov rate in the training data = %.2f %%' % (n_train_oovs / n_train_tokens * 100)) logging.info('#sentences in the validation data = ' + str(len(val))) logging.info('#tokens in the validation data = ' + str(n_val_tokens)) logging.info('oov rate in the validation data = %.2f %%' % (n_val_oovs / n_val_tokens * 100)) use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 # Create the dataset iterators train_iter = ParallelSentenceIterator(train, args.batchsize, max_length=args.maxlen, sos=eos, eos=eos, shuffle=not use_sortagrad) val_iter = ParallelSentenceIterator(val, args.batchsize, max_length=args.maxlen, sos=eos, eos=eos, repeat=False) epoch_iters = int(len(train_iter.batch_indices) / args.accum_grad) logging.info('#iterations per epoch = %d' % epoch_iters) logging.info('#total iterations = ' + str(args.epoch * epoch_iters)) # Prepare an RNNLM model rnn = RNNLM(args.n_vocab, args.layer, args.unit, args.type) model = ClassifierWithState(rnn) if args.ngpu > 1: logging.warning( "currently, multi-gpu is not supported. use single gpu.") if args.ngpu > 0: # Make the specified GPU current gpu_id = 0 chainer.cuda.get_device_from_id(gpu_id).use() model.to_gpu() else: gpu_id = -1 # Save model conf to json model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to ' + model_conf) f.write( json.dumps(vars(args), indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8')) # Set up an optimizer opt_class = dynamic_import_optimizer(args.opt, args.backend) optimizer = opt_class.from_args(model, args) if args.schedulers is None: schedulers = [] else: schedulers = [ dynamic_import_scheduler(v)(k, args) for k, v in args.schedulers ] optimizer.setup(model) optimizer.add_hook(chainer.optimizer.GradientClipping(args.gradclip)) updater = BPTTUpdater(train_iter, optimizer, schedulers, gpu_id, args.accum_grad) trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.outdir) trainer.extend(LMEvaluator(val_iter, model, device=gpu_id)) trainer.extend( extensions.LogReport(postprocess=compute_perplexity, trigger=(args.report_interval_iters, 'iteration'))) trainer.extend(extensions.PrintReport( ['epoch', 'iteration', 'perplexity', 'val_perplexity', 'elapsed_time']), trigger=(args.report_interval_iters, 'iteration')) trainer.extend( extensions.ProgressBar(update_interval=args.report_interval_iters)) trainer.extend( extensions.snapshot(filename='snapshot.ep.{.updater.epoch}')) trainer.extend( extensions.snapshot_object(model, 'rnnlm.model.{.updater.epoch}')) # MEMO(Hori): wants to use MinValueTrigger, but it seems to fail in resuming trainer.extend( MakeSymlinkToBestModel('validation/main/loss', 'rnnlm.model')) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epoch, 'epoch')) if args.resume: logging.info('resumed from %s' % args.resume) chainer.serializers.load_npz(args.resume, trainer) set_early_stop(trainer, args, is_lm=True) if args.tensorboard_dir is not None and args.tensorboard_dir != "": writer = SummaryWriter(args.tensorboard_dir) trainer.extend(TensorboardLogger(writer), trigger=(args.report_interval_iters, 'iteration')) trainer.run() check_early_stop(trainer, args.epoch) # compute perplexity for test set if args.test_label: logging.info('test the best model') chainer.serializers.load_npz(args.outdir + '/rnnlm.model.best', model) test = read_tokens(args.test_label, args.char_list_dict) n_test_tokens, n_test_oovs = count_tokens(test, unk) logging.info('#sentences in the test data = ' + str(len(test))) logging.info('#tokens in the test data = ' + str(n_test_tokens)) logging.info('oov rate in the test data = %.2f %%' % (n_test_oovs / n_test_tokens * 100)) test_iter = ParallelSentenceIterator(test, args.batchsize, max_length=args.maxlen, sos=eos, eos=eos, repeat=False) evaluator = LMEvaluator(test_iter, model, device=gpu_id) with chainer.using_config('train', False): result = evaluator() logging.info('test perplexity: ' + str(np.exp(float(result['main/loss']))))
def train(args): """Train with the given args. :param Namespace args: The program arguments :param type model_class: LMInterface class for training """ model_class = dynamic_import_lm(args.model_module, args.backend) assert issubclass(model_class, LMInterface), "model should implement LMInterface" # display torch version logging.info('torch version = ' + torch.__version__) set_deterministic_pytorch(args) # check cuda and cudnn availability if not torch.cuda.is_available(): logging.warning('cuda is not available') # get special label ids unk = args.char_list_dict['<unk>'] eos = args.char_list_dict['<eos>'] # read tokens as a sequence of sentences val, n_val_tokens, n_val_oovs = load_dataset(args.valid_label, args.char_list_dict, args.dump_hdf5_path) train, n_train_tokens, n_train_oovs = load_dataset(args.train_label, args.char_list_dict, args.dump_hdf5_path) logging.info('#vocab = ' + str(args.n_vocab)) logging.info('#sentences in the training data = ' + str(len(train))) logging.info('#tokens in the training data = ' + str(n_train_tokens)) logging.info('oov rate in the training data = %.2f %%' % (n_train_oovs / n_train_tokens * 100)) logging.info('#sentences in the validation data = ' + str(len(val))) logging.info('#tokens in the validation data = ' + str(n_val_tokens)) logging.info('oov rate in the validation data = %.2f %%' % (n_val_oovs / n_val_tokens * 100)) use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 # Create the dataset iterators batch_size = args.batchsize * max(args.ngpu, 1) if batch_size > args.batchsize: logging.info(f'batch size is automatically increased ({args.batchsize} -> {batch_size})') train_iter = ParallelSentenceIterator(train, batch_size, max_length=args.maxlen, sos=eos, eos=eos, shuffle=not use_sortagrad) val_iter = ParallelSentenceIterator(val, batch_size, max_length=args.maxlen, sos=eos, eos=eos, repeat=False) logging.info('#iterations per epoch = ' + str(len(train_iter.batch_indices))) logging.info('#total iterations = ' + str(args.epoch * len(train_iter.batch_indices))) # Prepare an RNNLM model model = model_class(args.n_vocab, args) reporter = Reporter() if args.ngpu > 0: model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))).cuda() gpu_id = 0 else: gpu_id = -1 setattr(model, "reporter", reporter) # Save model conf to json model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to ' + model_conf) f.write(json.dumps(vars(args), indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8')) # Set up an optimizer if args.opt == 'sgd': optimizer = torch.optim.SGD(model.parameters(), lr=1.0) elif args.opt == 'adam': optimizer = torch.optim.Adam(model.parameters()) # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) updater = BPTTUpdater(train_iter, model, optimizer, gpu_id, gradclip=args.gradclip) trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.outdir) trainer.extend(LMEvaluator(val_iter, model, reporter, device=gpu_id)) trainer.extend(extensions.LogReport(postprocess=compute_perplexity, trigger=(args.report_interval_iters, 'iteration'))) trainer.extend(extensions.PrintReport( ['epoch', 'iteration', 'main/loss', 'perplexity', 'val_perplexity', 'elapsed_time'] ), trigger=(args.report_interval_iters, 'iteration')) trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters)) # Save best models trainer.extend(torch_snapshot(filename='snapshot.ep.{.updater.epoch}')) trainer.extend(snapshot_object(model, 'rnnlm.model.{.updater.epoch}')) # T.Hori: MinValueTrigger should be used, but it fails when resuming trainer.extend(MakeSymlinkToBestModel('validation/main/loss', 'rnnlm.model')) if use_sortagrad: trainer.extend(ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epoch, 'epoch')) if args.resume: logging.info('resumed from %s' % args.resume) torch_resume(args.resume, trainer) set_early_stop(trainer, args, is_lm=True) if args.tensorboard_dir is not None and args.tensorboard_dir != "": writer = SummaryWriter(args.tensorboard_dir) trainer.extend(TensorboardLogger(writer), trigger=(args.report_interval_iters, 'iteration')) trainer.run() check_early_stop(trainer, args.epoch) # compute perplexity for test set if args.test_label: logging.info('test the best model') torch_load(args.outdir + '/rnnlm.model.best', model) test = read_tokens(args.test_label, args.char_list_dict) n_test_tokens, n_test_oovs = count_tokens(test, unk) logging.info('#sentences in the test data = ' + str(len(test))) logging.info('#tokens in the test data = ' + str(n_test_tokens)) logging.info('oov rate in the test data = %.2f %%' % (n_test_oovs / n_test_tokens * 100)) test_iter = ParallelSentenceIterator(test, batch_size, max_length=args.maxlen, sos=eos, eos=eos, repeat=False) evaluator = LMEvaluator(test_iter, model, reporter, device=gpu_id) result = evaluator() compute_perplexity(result) logging.info(f"test perplexity: {result['perplexity']}")