def test_tensorboard_evaluator(): # setup model model = chainer.links.Classifier(chainer.links.Linear(3, 2)) optimizer = chainer.optimizers.Adam() optimizer.setup(model) # setup data data_size = 6 xs = numpy.random.randn(data_size, 3).astype(numpy.float32) ys = (numpy.random.randn(data_size) > 1).astype(numpy.int32) data = chainer.datasets.TupleDataset(xs, ys) batch_size = 2 epoch = 10 # test runnable without tensorboard logger trainer = chainer.training.Trainer( chainer.training.StandardUpdater( chainer.iterators.SerialIterator(data, batch_size), optimizer), (epoch, "epoch"), ) trainer.extend( BaseEvaluator( chainer.iterators.SerialIterator(data, batch_size, repeat=False), model)) trainer.run() # test runnable with tensorboard logger for log_interval in [1, 3]: trainer = chainer.training.Trainer( chainer.training.StandardUpdater( chainer.iterators.SerialIterator(data, batch_size), optimizer), (epoch, "epoch"), ) trainer.extend( BaseEvaluator( chainer.iterators.SerialIterator(data, batch_size, repeat=False), model)) writer = DummyWriter() trainer.extend(TensorboardLogger(writer), trigger=(log_interval, "iteration")) trainer.run() # test the number of log entries assert TensorboardLogger.default_name in trainer._extensions assert (len(writer.data["main/loss"]) == trainer.updater.iteration // log_interval) assert len(writer.data["validation/main/loss"]) == epoch
def train(args): """Train with the given args. Args: args (namespace): The program arguments. """ # display chainer version logging.info('chainer version = ' + chainer.__version__) set_deterministic_chainer(args) # check cuda and cudnn availability if not chainer.cuda.available: logging.warning('cuda is not available') if not chainer.cuda.cudnn_enabled: logging.warning('cudnn is not available') # get input and output dimension info with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] utts = list(valid_json.keys()) idim = int(valid_json[utts[0]]['input'][0]['shape'][1]) odim = int(valid_json[utts[0]]['output'][0]['shape'][1]) logging.info('#input dims : ' + str(idim)) logging.info('#output dims: ' + str(odim)) # specify attention, CTC, hybrid mode if args.mtlalpha == 1.0: mtl_mode = 'ctc' logging.info('Pure CTC mode') elif args.mtlalpha == 0.0: mtl_mode = 'att' logging.info('Pure attention mode') else: mtl_mode = 'mtl' logging.info('Multitask learning mode') # specify model architecture logging.info('import model module: ' + args.model_module) model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args, flag_return=False) assert isinstance(model, ASRInterface) # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to ' + model_conf) f.write( json.dumps((idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8')) for key in sorted(vars(args).keys()): logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) # Set gpu ngpu = args.ngpu if ngpu == 1: gpu_id = 0 # Make a specified GPU current chainer.cuda.get_device_from_id(gpu_id).use() model.to_gpu() # Copy the model to the GPU logging.info('single gpu calculation.') elif ngpu > 1: gpu_id = 0 devices = {'main': gpu_id} for gid in six.moves.xrange(1, ngpu): devices['sub_%d' % gid] = gid logging.info('multi gpu calculation (#gpus = %d).' % ngpu) logging.info('batch size is automatically increased (%d -> %d)' % (args.batch_size, args.batch_size * args.ngpu)) else: gpu_id = -1 logging.info('cpu calculation') # Setup an optimizer if args.opt == 'adadelta': optimizer = chainer.optimizers.AdaDelta(eps=args.eps) elif args.opt == 'adam': optimizer = chainer.optimizers.Adam() elif args.opt == 'noam': optimizer = chainer.optimizers.Adam(alpha=0, beta1=0.9, beta2=0.98, eps=1e-9) else: raise NotImplementedError('args.opt={}'.format(args.opt)) optimizer.setup(model) optimizer.add_hook(chainer.optimizer.GradientClipping(args.grad_clip)) # Setup Training Extensions if 'transformer' in args.model_module: from espnet.nets.chainer_backend.transformer.training import CustomConverter from espnet.nets.chainer_backend.transformer.training import CustomParallelUpdater from espnet.nets.chainer_backend.transformer.training import CustomUpdater else: from espnet.nets.chainer_backend.rnn.training import CustomConverter from espnet.nets.chainer_backend.rnn.training import CustomParallelUpdater from espnet.nets.chainer_backend.rnn.training import CustomUpdater # Setup a converter converter = CustomConverter(subsampling_factor=model.subsample[0]) # read json data with open(args.train_json, 'rb') as f: train_json = json.load(f)['utts'] with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] # set up training iterator and updater load_tr = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': True} # Switch the mode of preprocessing ) load_cv = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': False} # Switch the mode of preprocessing ) use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 accum_grad = args.accum_grad if ngpu <= 1: # make minibatch list (variable length) train = make_batchset(train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0) # hack to make batchsize argument as 1 # actual batchsize is included in a list if args.n_iter_processes > 0: train_iters = [ ToggleableShufflingMultiprocessIterator( TransformDataset(train, load_tr), batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20, shuffle=not use_sortagrad) ] else: train_iters = [ ToggleableShufflingSerialIterator(TransformDataset( train, load_tr), batch_size=1, shuffle=not use_sortagrad) ] # set up updater updater = CustomUpdater(train_iters[0], optimizer, converter=converter, device=gpu_id, accum_grad=accum_grad) else: if args.batch_count not in ("auto", "seq") and args.batch_size == 0: raise NotImplementedError( "--batch-count 'bin' and 'frame' are not implemented in chainer multi gpu" ) # set up minibatches train_subsets = [] for gid in six.moves.xrange(ngpu): # make subset train_json_subset = { k: v for i, (k, v) in enumerate(train_json.items()) if i % ngpu == gid } # make minibatch list (variable length) train_subsets += [ make_batchset(train_json_subset, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches) ] # each subset must have same length for MultiprocessParallelUpdater maxlen = max([len(train_subset) for train_subset in train_subsets]) for train_subset in train_subsets: if maxlen != len(train_subset): for i in six.moves.xrange(maxlen - len(train_subset)): train_subset += [train_subset[i]] # hack to make batchsize argument as 1 # actual batchsize is included in a list if args.n_iter_processes > 0: train_iters = [ ToggleableShufflingMultiprocessIterator( TransformDataset(train_subsets[gid], load_tr), batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20, shuffle=not use_sortagrad) for gid in six.moves.xrange(ngpu) ] else: train_iters = [ ToggleableShufflingSerialIterator(TransformDataset( train_subsets[gid], load_tr), batch_size=1, shuffle=not use_sortagrad) for gid in six.moves.xrange(ngpu) ] # set up updater updater = CustomParallelUpdater(train_iters, optimizer, converter=converter, devices=devices) # Set up a trainer trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.outdir) if use_sortagrad: trainer.extend( ShufflingEnabler(train_iters), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, 'epoch')) if args.opt == 'noam': from espnet.nets.chainer_backend.transformer.training import VaswaniRule trainer.extend(VaswaniRule('alpha', d=args.adim, warmup_steps=args.transformer_warmup_steps, scale=args.transformer_lr), trigger=(1, 'iteration')) # Resume from a snapshot if args.resume: chainer.serializers.load_npz(args.resume, trainer) # set up validation iterator valid = make_batchset(valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0) if args.n_iter_processes > 0: valid_iter = chainer.iterators.MultiprocessIterator( TransformDataset(valid, load_cv), batch_size=1, repeat=False, shuffle=False, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20) else: valid_iter = chainer.iterators.SerialIterator(TransformDataset( valid, load_cv), batch_size=1, repeat=False, shuffle=False) # Evaluate the model with the test dataset for each epoch trainer.extend( BaseEvaluator(valid_iter, model, converter=converter, device=gpu_id)) # Save attention weight each epoch if args.num_save_attention > 0 and args.mtlalpha != 1.0: data = sorted(list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class logging.info('Using custom PlotAttentionReport') att_reporter = plot_class(att_vis_fn, data, args.outdir + "/att_ws", converter=converter, transform=load_cv, device=gpu_id) trainer.extend(att_reporter, trigger=(1, 'epoch')) else: att_reporter = None # Take a snapshot for each specified epoch trainer.extend( extensions.snapshot(filename='snapshot.ep.{.updater.epoch}'), trigger=(1, 'epoch')) # Make a plot for training and validation values trainer.extend( extensions.PlotReport([ 'main/loss', 'validation/main/loss', 'main/loss_ctc', 'validation/main/loss_ctc', 'main/loss_att', 'validation/main/loss_att' ], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport(['main/acc', 'validation/main/acc'], 'epoch', file_name='acc.png')) # Save best models trainer.extend( extensions.snapshot_object(model, 'model.loss.best'), trigger=training.triggers.MinValueTrigger('validation/main/loss')) if mtl_mode != 'ctc': trainer.extend( extensions.snapshot_object(model, 'model.acc.best'), trigger=training.triggers.MaxValueTrigger('validation/main/acc')) # epsilon decay in the optimizer if args.opt == 'adadelta': if args.criterion == 'acc' and mtl_mode != 'ctc': trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best'), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) elif args.criterion == 'loss': trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best'), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) # Write a log of evaluation statistics for each epoch trainer.extend( extensions.LogReport(trigger=(args.report_interval_iters, 'iteration'))) report_keys = [ 'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att', 'validation/main/loss', 'validation/main/loss_ctc', 'validation/main/loss_att', 'main/acc', 'validation/main/acc', 'elapsed_time' ] if args.opt == 'adadelta': trainer.extend(extensions.observe_value( 'eps', lambda trainer: trainer.updater.get_optimizer('main').eps), trigger=(args.report_interval_iters, 'iteration')) report_keys.append('eps') trainer.extend(extensions.PrintReport(report_keys), trigger=(args.report_interval_iters, 'iteration')) trainer.extend( extensions.ProgressBar(update_interval=args.report_interval_iters)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": writer = SummaryWriter(args.tensorboard_dir) trainer.extend(TensorboardLogger(writer, att_reporter), trigger=(args.report_interval_iters, 'iteration')) # Run the training trainer.run() check_early_stop(trainer, args.epochs)