示例#1
0
def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['input'][0]['shape'][-1])
    odim = int(valid_json[utts[0]]['output'][0]['shape'][-1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    if args.enc_init is not None or args.dec_init is not None:
        model = load_trained_modules(idim, odim, args)
    elif args.asr_init is not None:
        model, _ = load_trained_model(args.asr_init)
    else:
        model_class = dynamic_import(args.model_module)
        model = model_class(idim, odim, args)
    assert isinstance(model, ASRInterface)

    subsampling_factor = model.subsample[0]

    if args.rnnlm is not None:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer,
                             rnnlm_args.unit))
        torch.load(args.rnnlm, rnnlm)
        model.rnnlm = rnnlm

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)),
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        if args.batch_size != 0:
            logging.info('batch size is automatically increased (%d -> %d)' %
                         (args.batch_size, args.batch_size * args.ngpu))
            args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    if args.train_dtype in ("float16", "float32", "float64"):
        dtype = getattr(torch, args.train_dtype)
    else:
        dtype = torch.float32
    logging.info(device)
    logging.info(dtype)
    model = model.to(device=device, dtype=dtype)

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         rho=0.95,
                                         eps=args.eps,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     weight_decay=args.weight_decay)
    elif args.opt == 'noam':
        from espnet.nets.pytorch_backend.rnn.optimizer import get_std_opt
        optimizer = get_std_opt(model, args.adim,
                                args.transformer_warmup_steps,
                                args.transformer_lr)
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # setup apex.amp
    if args.train_dtype in ("O0", "O1", "O2", "O3"):
        try:
            from apex import amp
        except ImportError as e:
            logging.error(
                f"You need to install apex for --train-dtype {args.train_dtype}. "
                "See https://github.com/NVIDIA/apex#linux")
            raise e
        if args.opt == 'noam':
            model, optimizer.optimizer = amp.initialize(
                model, optimizer.optimizer, opt_level=args.train_dtype)
        else:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.train_dtype)
        use_apex = True
    else:
        use_apex = False

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter(subsampling_factor=subsampling_factor,
                                dtype=dtype)

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # make minibatch list (variable length)
    train = make_batchset(train_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          shortest_first=use_sortagrad,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout,
                          iaxis=0,
                          oaxis=0)
    valid = make_batchset(valid_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout,
                          iaxis=0,
                          oaxis=0)

    load_tr = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': True}  # Switch the mode of preprocessing
    )
    load_cv = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': False}  # Switch the mode of preprocessing
    )
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    if args.n_iter_processes > 0:
        train_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(train, load_tr),
            batch_size=1,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(valid, load_cv),
            batch_size=1,
            repeat=False,
            shuffle=False,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20)
    else:
        train_iter = ToggleableShufflingSerialIterator(
            TransformDataset(train, load_tr),
            batch_size=1,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingSerialIterator(TransformDataset(
            valid, load_cv),
                                                       batch_size=1,
                                                       repeat=False,
                                                       shuffle=False)

    # Set up a trainer
    updater = CustomUpdater(model,
                            args.grad_clip,
                            train_iter,
                            optimizer,
                            converter,
                            device,
                            args.ngpu,
                            args.grad_noise,
                            args.accum_grad,
                            use_apex=use_apex)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     'epoch'))

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        CustomEvaluator(model, valid_iter, reporter, converter, device,
                        args.ngpu))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        att_reporter = plot_class(att_vis_fn,
                                  data,
                                  args.outdir + "/att_ws",
                                  converter=converter,
                                  transform=load_cv,
                                  device=device)
        trainer.extend(att_reporter, trigger=(1, 'epoch'))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport([
            'main/loss', 'validation/main/loss', 'main/loss_ctc',
            'validation/main/loss_ctc', 'main/loss_att',
            'validation/main/loss_att'
        ],
                              'epoch',
                              file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/acc', 'validation/main/acc'],
                              'epoch',
                              file_name='acc.png'))
    trainer.extend(
        extensions.PlotReport(['main/cer_ctc', 'validation/main/cer_ctc'],
                              'epoch',
                              file_name='cer.png'))

    # Save best models
    trainer.extend(
        snapshot_object(model, 'model.loss.best'),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    if mtl_mode != 'ctc':
        trainer.extend(
            snapshot_object(model, 'model.acc.best'),
            trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # save snapshot which contains model and optimizer states
    trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc' and mtl_mode != 'ctc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters,
                                      'iteration')))
    report_keys = [
        'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
        'validation/main/loss', 'validation/main/loss_ctc',
        'validation/main/loss_att', 'main/acc', 'validation/main/acc',
        'main/cer_ctc', 'validation/main/cer_ctc', 'elapsed_time'
    ]
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').
            param_groups[0]["eps"]),
                       trigger=(args.report_interval_iters, 'iteration'))
        report_keys.append('eps')
    if args.report_cer:
        report_keys.append('validation/main/cer')
    if args.report_wer:
        report_keys.append('validation/main/wer')
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(args.report_interval_iters, 'iteration'))

    trainer.extend(
        extensions.ProgressBar(update_interval=args.report_interval_iters))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        trainer.extend(TensorboardLogger(SummaryWriter(args.tensorboard_dir),
                                         att_reporter),
                       trigger=(args.report_interval_iters, "iteration"))
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
示例#2
0
def train(args):
    """Train with the given args

    :param Namespace args: The program arguments
    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())

    # reverse input and output dimension
    idim = int(valid_json[utts[0]]['output'][0]['shape'][1])
    odim = int(valid_json[utts[0]]['input'][0]['shape'][1])
    if args.use_cbhg:
        args.spc_dim = int(valid_json[utts[0]]['input'][1]['shape'][1])
    if args.use_speaker_embedding:
        args.spk_embed_dim = int(valid_json[utts[0]]['input'][1]['shape'][0])
    else:
        args.spk_embed_dim = None
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to' + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)), indent=4,
                       sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # specify model architecture
    tacotron2 = Tacotron2(idim, odim, args)
    logging.info(tacotron2)

    # check the use of multi-gpu
    if args.ngpu > 1:
        tacotron2 = torch.nn.DataParallel(tacotron2,
                                          device_ids=list(range(args.ngpu)))
        logging.info('batch size is automatically increased (%d -> %d)' %
                     (args.batch_size, args.batch_size * args.ngpu))
        args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    tacotron2 = tacotron2.to(device)

    # define loss
    model = Tacotron2Loss(tacotron2, args.use_masking, args.bce_pos_weight)
    reporter = model.reporter

    # Setup an optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr,
                                 eps=args.eps,
                                 weight_decay=args.weight_decay)

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, 'target', reporter)
    setattr(optimizer, 'serialize', lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter(
        return_targets=True,
        use_speaker_embedding=args.use_speaker_embedding,
        use_second_target=args.use_cbhg,
        preprocess_conf=args.preprocess_conf)

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    if use_sortagrad:
        args.batch_sort_key = "input"
    # make minibatch list (variable length)
    train_batchset = make_batchset(
        train_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        args.batch_sort_key,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        shortest_first=use_sortagrad)
    valid_batchset = make_batchset(
        valid_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        args.batch_sort_key,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        shortest_first=use_sortagrad)
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    if args.n_iter_processes > 0:
        train_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(train_batchset, converter.transform),
            batch_size=1,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(valid_batchset, converter.transform),
            batch_size=1,
            repeat=False,
            shuffle=False,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20)
    else:
        train_iter = ToggleableShufflingSerialIterator(
            TransformDataset(train_batchset, converter.transform),
            batch_size=1,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingSerialIterator(TransformDataset(
            valid_batchset, converter.transform),
                                                       batch_size=1,
                                                       repeat=False,
                                                       shuffle=False)

    # Set up a trainer
    updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer,
                            converter, device)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        CustomEvaluator(model, valid_iter, reporter, converter, device))

    # Save snapshot for each epoch
    trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))

    # Save best models
    trainer.extend(
        extensions.snapshot_object(tacotron2,
                                   'model.loss.best',
                                   savefun=torch_save),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))

    # Save attention figure for each epoch
    if args.num_save_attention > 0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        if hasattr(tacotron2, "module"):
            att_vis_fn = tacotron2.module.calculate_all_attentions
        else:
            att_vis_fn = tacotron2.calculate_all_attentions
        att_reporter = PlotAttentionReport(
            att_vis_fn,
            data,
            args.outdir + '/att_ws',
            converter=CustomConverter(
                return_targets=False,
                use_speaker_embedding=args.use_speaker_embedding,
                preprocess_conf=args.preprocess_conf),
            device=device,
            reverse=True)
        trainer.extend(att_reporter, trigger=(1, 'epoch'))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    plot_keys = [
        'main/loss', 'validation/main/loss', 'main/l1_loss',
        'validation/main/l1_loss', 'main/mse_loss', 'validation/main/mse_loss',
        'main/bce_loss', 'validation/main/bce_loss'
    ]
    trainer.extend(
        extensions.PlotReport(['main/l1_loss', 'validation/main/l1_loss'],
                              'epoch',
                              file_name='l1_loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/mse_loss', 'validation/main/mse_loss'],
                              'epoch',
                              file_name='mse_loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/bce_loss', 'validation/main/bce_loss'],
                              'epoch',
                              file_name='bce_loss.png'))
    if args.use_cbhg:
        plot_keys += [
            'main/cbhg_l1_loss', 'validation/main/cbhg_l1_loss',
            'main/cbhg_mse_loss', 'validation/main/cbhg_mse_loss'
        ]
        trainer.extend(
            extensions.PlotReport(
                ['main/cbhg_l1_loss', 'validation/main/cbhg_l1_loss'],
                'epoch',
                file_name='cbhg_l1_loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/cbhg_mse_loss', 'validation/main/cbhg_mse_loss'],
                'epoch',
                file_name='cbhg_mse_loss.png'))
    trainer.extend(
        extensions.PlotReport(plot_keys, 'epoch', file_name='loss.png'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL,
                                                 'iteration')))
    report_keys = plot_keys[:]
    report_keys[0:0] = ['epoch', 'iteration', 'elapsed_time']
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(REPORT_INTERVAL, 'iteration'))
    trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL))

    set_early_stop(trainer, args)
    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        writer = SummaryWriter(log_dir=args.tensorboard_dir)
        trainer.extend(TensorboardLogger(writer, att_reporter))

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     'epoch'))

    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
示例#3
0
文件: tts.py 项目: jingyonghou/espnet
def train(args):
    """Train E2E-TTS model."""
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())

    # reverse input and output dimension
    idim = int(valid_json[utts[0]]['output'][0]['shape'][1])
    odim = int(valid_json[utts[0]]['input'][0]['shape'][1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # get extra input and output dimenstion
    if args.use_speaker_embedding:
        args.spk_embed_dim = int(valid_json[utts[0]]['input'][1]['shape'][0])
    else:
        args.spk_embed_dim = None
    if args.use_second_target:
        args.spc_dim = int(valid_json[utts[0]]['input'][1]['shape'][1])
    else:
        args.spc_dim = None

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to' + model_conf)
        f.write(json.dumps((idim, odim, vars(args)),
                           indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # specify model architecture
    model_class = dynamic_import(args.model_module)
    model = model_class(idim, odim, args)
    assert isinstance(model, TTSInterface)
    logging.info(model)
    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
        if args.batch_size != 0:
            logging.info('batch size is automatically increased (%d -> %d)' % (
                args.batch_size, args.batch_size * args.ngpu))
            args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)

    # Setup an optimizer
    if args.opt == 'adam':
        optimizer = torch.optim.Adam(
            model.parameters(), args.lr, eps=args.eps,
            weight_decay=args.weight_decay)
    elif args.opt == 'noam':
        from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt
        optimizer = get_std_opt(model, args.adim, args.transformer_warmup_steps, args.transformer_lr)
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, 'target', reporter)
    setattr(optimizer, 'serialize', lambda s: reporter.serialize(s))

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    if use_sortagrad:
        args.batch_sort_key = "input"
    # make minibatch list (variable length)
    train_batchset = make_batchset(train_json, args.batch_size,
                                   args.maxlen_in, args.maxlen_out, args.minibatches,
                                   batch_sort_key=args.batch_sort_key,
                                   min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                                   shortest_first=use_sortagrad,
                                   count=args.batch_count,
                                   batch_bins=args.batch_bins,
                                   batch_frames_in=args.batch_frames_in,
                                   batch_frames_out=args.batch_frames_out,
                                   batch_frames_inout=args.batch_frames_inout,
                                   swap_io=True)
    valid_batchset = make_batchset(valid_json, args.batch_size,
                                   args.maxlen_in, args.maxlen_out, args.minibatches,
                                   batch_sort_key=args.batch_sort_key,
                                   min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                                   count=args.batch_count,
                                   batch_bins=args.batch_bins,
                                   batch_frames_in=args.batch_frames_in,
                                   batch_frames_out=args.batch_frames_out,
                                   batch_frames_inout=args.batch_frames_inout,
                                   swap_io=True)

    load_tr = LoadInputsAndTargets(
        mode='tts',
        use_speaker_embedding=args.use_speaker_embedding,
        use_second_target=args.use_second_target,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': True},  # Switch the mode of preprocessing
        keep_all_data_on_mem=args.keep_all_data_on_mem,
    )

    load_cv = LoadInputsAndTargets(
        mode='tts',
        use_speaker_embedding=args.use_speaker_embedding,
        use_second_target=args.use_second_target,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': False},  # Switch the mode of preprocessing
        keep_all_data_on_mem=args.keep_all_data_on_mem,
    )

    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    if args.num_iter_processes > 0:
        train_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(train_batchset, load_tr),
            batch_size=1, n_processes=args.num_iter_processes, n_prefetch=8, maxtasksperchild=20,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(valid_batchset, load_cv),
            batch_size=1, repeat=False, shuffle=False,
            n_processes=args.num_iter_processes, n_prefetch=8, maxtasksperchild=20)
    else:
        train_iter = ToggleableShufflingSerialIterator(
            TransformDataset(train_batchset, load_tr),
            batch_size=1, shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingSerialIterator(
            TransformDataset(valid_batchset, load_cv),
            batch_size=1, repeat=False, shuffle=False)

    # Set up a trainer
    converter = CustomConverter()
    updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer, converter, device, args.accum_grad)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(CustomEvaluator(model, valid_iter, reporter, converter, device))

    # set intervals
    save_interval = (args.save_interval_epochs, 'epoch')
    report_interval = (args.report_interval_iters, 'iteration')

    # Save snapshot for each epoch
    trainer.extend(torch_snapshot(), trigger=save_interval)

    # Save best models
    trainer.extend(snapshot_object(model, 'model.loss.best'),
                   trigger=training.triggers.MinValueTrigger('validation/main/loss', trigger=save_interval))

    # Save attention figure for each epoch
    if args.num_save_attention > 0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        att_reporter = plot_class(
            att_vis_fn, data, args.outdir + '/att_ws',
            converter=converter,
            transform=load_cv,
            device=device, reverse=True)
        trainer.extend(att_reporter, trigger=save_interval)
    else:
        att_reporter = None

    # Make a plot for training and validation values
    if hasattr(model, "module"):
        base_plot_keys = model.module.base_plot_keys
    else:
        base_plot_keys = model.base_plot_keys
    plot_keys = []
    for key in base_plot_keys:
        plot_key = ['main/' + key, 'validation/main/' + key]
        trainer.extend(extensions.PlotReport(plot_key, 'epoch', file_name=key + '.png'))
        plot_keys += plot_key
    trainer.extend(extensions.PlotReport(plot_keys, 'epoch', file_name='all_loss.png'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=report_interval))
    report_keys = ['epoch', 'iteration', 'elapsed_time'] + plot_keys
    trainer.extend(extensions.PrintReport(report_keys), trigger=report_interval)
    trainer.extend(extensions.ProgressBar(), trigger=report_interval)

    set_early_stop(trainer, args)
    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        writer = SummaryWriter(args.tensorboard_dir)
        trainer.extend(TensorboardLogger(writer, att_reporter), trigger=report_interval)

    if use_sortagrad:
        trainer.extend(ShufflingEnabler([train_iter]),
                       trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, 'epoch'))

    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
示例#4
0
文件: asr.py 项目: FrontierDK/espnet
def train(args):
    """Train with the given args

    :param Namespace args: The program arguments
    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['input'][0]['shape'][1])
    odim = int(valid_json[utts[0]]['output'][0]['shape'][1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    # specify model architecture
    model = E2E(idim, odim, args)
    subsampling_factor = model.subsample[0]

    if args.rnnlm is not None:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer,
                             rnnlm_args.unit))
        torch.load(args.rnnlm, rnnlm)
        model.rnnlm = rnnlm

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)), indent=4,
                       sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
        logging.info('batch size is automatically increased (%d -> %d)' %
                     (args.batch_size, args.batch_size * args.ngpu))
        args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         rho=0.95,
                                         eps=args.eps,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     weight_decay=args.weight_decay)

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter(subsampling_factor=subsampling_factor,
                                preprocess_conf=args.preprocess_conf)

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # make minibatch list (variable length)
    train = make_batchset(train_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          shortest_first=use_sortagrad)
    valid = make_batchset(valid_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1)
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    if args.n_iter_processes > 0:
        train_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(train, converter.transform),
            batch_size=1,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(valid, converter.transform),
            batch_size=1,
            repeat=False,
            shuffle=False,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20)
    else:
        train_iter = ToggleableShufflingSerialIterator(
            TransformDataset(train, converter.transform),
            batch_size=1,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingSerialIterator(TransformDataset(
            valid, converter.transform),
                                                       batch_size=1,
                                                       repeat=False,
                                                       shuffle=False)

    # Set up a trainer
    updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer,
                            converter, device, args.ngpu)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     'epoch'))

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        CustomEvaluator(model, valid_iter, reporter, converter, device))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
        else:
            att_vis_fn = model.calculate_all_attentions
        att_reporter = PlotAttentionReport(att_vis_fn,
                                           data,
                                           args.outdir + "/att_ws",
                                           converter=converter,
                                           device=device)
        trainer.extend(att_reporter, trigger=(1, 'epoch'))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport([
            'main/loss', 'validation/main/loss', 'main/loss_ctc',
            'validation/main/loss_ctc', 'main/loss_att',
            'validation/main/loss_att'
        ],
                              'epoch',
                              file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/acc', 'validation/main/acc'],
                              'epoch',
                              file_name='acc.png'))
    trainer.extend(
        extensions.PlotReport(['main/cer_ctc', 'validation/main/cer_ctc'],
                              'epoch',
                              file_name='cer.png'))

    # Save best models
    trainer.extend(
        extensions.snapshot_object(model,
                                   'model.loss.best',
                                   savefun=torch_save),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    if mtl_mode is not 'ctc':
        trainer.extend(
            extensions.snapshot_object(model,
                                       'model.acc.best',
                                       savefun=torch_save),
            trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # save snapshot which contains model and optimizer states
    trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc' and mtl_mode is not 'ctc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL,
                                                 'iteration')))
    report_keys = [
        'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
        'validation/main/loss', 'validation/main/loss_ctc',
        'validation/main/loss_att', 'main/acc', 'validation/main/acc',
        'main/cer_ctc', 'validation/main/cer_ctc', 'elapsed_time'
    ]
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').
            param_groups[0]["eps"]),
                       trigger=(REPORT_INTERVAL, 'iteration'))
        report_keys.append('eps')
    if args.report_cer:
        report_keys.append('validation/main/cer')
    if args.report_wer:
        report_keys.append('validation/main/wer')
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(REPORT_INTERVAL, 'iteration'))

    trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        writer = SummaryWriter(log_dir=args.tensorboard_dir)
        trainer.extend(TensorboardLogger(writer, att_reporter))
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
示例#5
0
def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    # display chainer version
    logging.info('chainer version = ' + chainer.__version__)

    set_deterministic_chainer(args)

    # check cuda and cudnn availability
    if not chainer.cuda.available:
        logging.warning('cuda is not available')
    if not chainer.cuda.cudnn_enabled:
        logging.warning('cudnn is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['input'][0]['shape'][1])
    odim = int(valid_json[utts[0]]['output'][0]['shape'][1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    # specify model architecture
    logging.info('import model module: ' + args.model_module)
    model_class = dynamic_import(args.model_module)
    model = model_class(idim, odim, args, flag_return=False)
    assert isinstance(model, ASRInterface)

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)),
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # Set gpu
    ngpu = args.ngpu
    if ngpu == 1:
        gpu_id = 0
        # Make a specified GPU current
        chainer.cuda.get_device_from_id(gpu_id).use()
        model.to_gpu()  # Copy the model to the GPU
        logging.info('single gpu calculation.')
    elif ngpu > 1:
        gpu_id = 0
        devices = {'main': gpu_id}
        for gid in six.moves.xrange(1, ngpu):
            devices['sub_%d' % gid] = gid
        logging.info('multi gpu calculation (#gpus = %d).' % ngpu)
        logging.info('batch size is automatically increased (%d -> %d)' %
                     (args.batch_size, args.batch_size * args.ngpu))
    else:
        gpu_id = -1
        logging.info('cpu calculation')

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = chainer.optimizers.AdaDelta(eps=args.eps)
    elif args.opt == 'adam':
        optimizer = chainer.optimizers.Adam()
    elif args.opt == 'noam':
        optimizer = chainer.optimizers.Adam(alpha=0,
                                            beta1=0.9,
                                            beta2=0.98,
                                            eps=1e-9)
    else:
        raise NotImplementedError('args.opt={}'.format(args.opt))

    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.GradientClipping(args.grad_clip))

    # Setup Training Extensions
    if 'transformer' in args.model_module:
        from espnet.nets.chainer_backend.transformer.training import CustomConverter
        from espnet.nets.chainer_backend.transformer.training import CustomParallelUpdater
        from espnet.nets.chainer_backend.transformer.training import CustomUpdater
    else:
        from espnet.nets.chainer_backend.rnn.training import CustomConverter
        from espnet.nets.chainer_backend.rnn.training import CustomParallelUpdater
        from espnet.nets.chainer_backend.rnn.training import CustomUpdater

    # Setup a converter
    converter = CustomConverter(subsampling_factor=model.subsample[0])

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    # set up training iterator and updater
    load_tr = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': True}  # Switch the mode of preprocessing
    )
    load_cv = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': False}  # Switch the mode of preprocessing
    )

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    accum_grad = args.accum_grad
    if ngpu <= 1:
        # make minibatch list (variable length)
        train = make_batchset(train_json,
                              args.batch_size,
                              args.maxlen_in,
                              args.maxlen_out,
                              args.minibatches,
                              min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                              shortest_first=use_sortagrad,
                              count=args.batch_count,
                              batch_bins=args.batch_bins,
                              batch_frames_in=args.batch_frames_in,
                              batch_frames_out=args.batch_frames_out,
                              batch_frames_inout=args.batch_frames_inout,
                              iaxis=0,
                              oaxis=0)
        # hack to make batchsize argument as 1
        # actual batchsize is included in a list
        if args.n_iter_processes > 0:
            train_iters = [
                ToggleableShufflingMultiprocessIterator(
                    TransformDataset(train, load_tr),
                    batch_size=1,
                    n_processes=args.n_iter_processes,
                    n_prefetch=8,
                    maxtasksperchild=20,
                    shuffle=not use_sortagrad)
            ]
        else:
            train_iters = [
                ToggleableShufflingSerialIterator(TransformDataset(
                    train, load_tr),
                                                  batch_size=1,
                                                  shuffle=not use_sortagrad)
            ]

        # set up updater
        updater = CustomUpdater(train_iters[0],
                                optimizer,
                                converter=converter,
                                device=gpu_id,
                                accum_grad=accum_grad)
    else:
        if args.batch_count not in ("auto", "seq") and args.batch_size == 0:
            raise NotImplementedError(
                "--batch-count 'bin' and 'frame' are not implemented in chainer multi gpu"
            )
        # set up minibatches
        train_subsets = []
        for gid in six.moves.xrange(ngpu):
            # make subset
            train_json_subset = {
                k: v
                for i, (k, v) in enumerate(train_json.items())
                if i % ngpu == gid
            }
            # make minibatch list (variable length)
            train_subsets += [
                make_batchset(train_json_subset, args.batch_size,
                              args.maxlen_in, args.maxlen_out,
                              args.minibatches)
            ]

        # each subset must have same length for MultiprocessParallelUpdater
        maxlen = max([len(train_subset) for train_subset in train_subsets])
        for train_subset in train_subsets:
            if maxlen != len(train_subset):
                for i in six.moves.xrange(maxlen - len(train_subset)):
                    train_subset += [train_subset[i]]

        # hack to make batchsize argument as 1
        # actual batchsize is included in a list
        if args.n_iter_processes > 0:
            train_iters = [
                ToggleableShufflingMultiprocessIterator(
                    TransformDataset(train_subsets[gid], load_tr),
                    batch_size=1,
                    n_processes=args.n_iter_processes,
                    n_prefetch=8,
                    maxtasksperchild=20,
                    shuffle=not use_sortagrad)
                for gid in six.moves.xrange(ngpu)
            ]
        else:
            train_iters = [
                ToggleableShufflingSerialIterator(TransformDataset(
                    train_subsets[gid], load_tr),
                                                  batch_size=1,
                                                  shuffle=not use_sortagrad)
                for gid in six.moves.xrange(ngpu)
            ]

        # set up updater
        updater = CustomParallelUpdater(train_iters,
                                        optimizer,
                                        converter=converter,
                                        devices=devices)

    # Set up a trainer
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler(train_iters),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     'epoch'))
    if args.opt == 'noam':
        from espnet.nets.chainer_backend.transformer.training import VaswaniRule
        trainer.extend(VaswaniRule('alpha',
                                   d=args.adim,
                                   warmup_steps=args.transformer_warmup_steps,
                                   scale=args.transformer_lr),
                       trigger=(1, 'iteration'))
    # Resume from a snapshot
    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    # set up validation iterator
    valid = make_batchset(valid_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout,
                          iaxis=0,
                          oaxis=0)

    if args.n_iter_processes > 0:
        valid_iter = chainer.iterators.MultiprocessIterator(
            TransformDataset(valid, load_cv),
            batch_size=1,
            repeat=False,
            shuffle=False,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20)
    else:
        valid_iter = chainer.iterators.SerialIterator(TransformDataset(
            valid, load_cv),
                                                      batch_size=1,
                                                      repeat=False,
                                                      shuffle=False)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        BaseEvaluator(valid_iter, model, converter=converter, device=gpu_id))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        logging.info('Using custom PlotAttentionReport')
        att_reporter = plot_class(att_vis_fn,
                                  data,
                                  args.outdir + "/att_ws",
                                  converter=converter,
                                  transform=load_cv,
                                  device=gpu_id)
        trainer.extend(att_reporter, trigger=(1, 'epoch'))
    else:
        att_reporter = None

    # Take a snapshot for each specified epoch
    trainer.extend(
        extensions.snapshot(filename='snapshot.ep.{.updater.epoch}'),
        trigger=(1, 'epoch'))

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport([
            'main/loss', 'validation/main/loss', 'main/loss_ctc',
            'validation/main/loss_ctc', 'main/loss_att',
            'validation/main/loss_att'
        ],
                              'epoch',
                              file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/acc', 'validation/main/acc'],
                              'epoch',
                              file_name='acc.png'))

    # Save best models
    trainer.extend(
        extensions.snapshot_object(model, 'model.loss.best'),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    if mtl_mode != 'ctc':
        trainer.extend(
            extensions.snapshot_object(model, 'model.acc.best'),
            trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc' and mtl_mode != 'ctc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best'),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best'),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters,
                                      'iteration')))
    report_keys = [
        'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
        'validation/main/loss', 'validation/main/loss_ctc',
        'validation/main/loss_att', 'main/acc', 'validation/main/acc',
        'elapsed_time'
    ]
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').eps),
                       trigger=(args.report_interval_iters, 'iteration'))
        report_keys.append('eps')
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(args.report_interval_iters, 'iteration'))

    trainer.extend(
        extensions.ProgressBar(update_interval=args.report_interval_iters))

    set_early_stop(trainer, args)
    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        writer = SummaryWriter(args.tensorboard_dir)
        trainer.extend(TensorboardLogger(writer, att_reporter),
                       trigger=(args.report_interval_iters, 'iteration'))

    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
示例#6
0
def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['input'][0]['shape'][-1])
    odim = int(valid_json[utts[0]]['output'][0]['shape'][-1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    asr_model, mt_model = None, None
    # Initialize encoder with pre-trained ASR encoder
    if args.asr_model:
        asr_model, _ = load_trained_model(args.asr_model)
        assert isinstance(asr_model, ASRInterface)

    # Initialize decoder with pre-trained MT decoder
    if args.mt_model:
        mt_model, _ = load_trained_model(args.mt_model)
        assert isinstance(mt_model, MTInterface)

    # specify model architecture
    model_class = dynamic_import(args.model_module)
    # TODO(hirofumi0810) better to simplify the E2E model interface by only allowing idim, odim, and args
    # the pre-trained ASR and MT model arguments should be removed here and we should implement an additional method
    # to attach these models
    if asr_model is None and mt_model is None:
        model = model_class(idim, odim, args)
    elif mt_model is None:
        model = asr_model
    else:
        model = model_class(idim,
                            odim,
                            args,
                            asr_model=asr_model,
                            mt_model=mt_model)
    assert isinstance(model, ASRInterface)
    subsampling_factor = model.subsample[0]

    # delete pre-trained models
    if args.asr_model:
        del asr_model
    if args.mt_model:
        del mt_model

    if args.slu_model and args.slu_loss:
        model.add_slu(args.slu_model, args.slu_loss, args.slu_tune_weights,
                      args.slu_pooling)

    if args.rnnlm is not None:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer,
                             rnnlm_args.unit))
        torch.load(args.rnnlm, rnnlm)
        model.rnnlm = rnnlm

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)),
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
        if args.batch_size != 0:
            logging.info('batch size is automatically increased (%d -> %d)' %
                         (args.batch_size, args.batch_size * args.ngpu))
            args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)

    scheduler = None

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         rho=0.95,
                                         eps=args.eps,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     weight_decay=args.weight_decay)
    elif args.opt == 'noam':
        from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt
        optimizer = get_std_opt(model, args.adim,
                                args.transformer_warmup_steps,
                                args.transformer_lr)
    elif args.opt == 'adamw':
        from transformers import AdamW, WarmupLinearSchedule
        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }, {
            'params': [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]
        optimizer = AdamW(optimizer_grouped_parameters, lr=5e-5, eps=1e-8)
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter(subsampling_factor=subsampling_factor)

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # make minibatch list (variable length)
    train = make_batchset(train_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          shortest_first=use_sortagrad,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout)
    valid = make_batchset(valid_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout)

    load_tr = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': True}  # Switch the mode of preprocessing
    )
    load_cv = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': False}  # Switch the mode of preprocessing
    )
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    if args.n_iter_processes > 0:
        train_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(train, load_tr),
            batch_size=1,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(valid, load_cv),
            batch_size=1,
            repeat=False,
            shuffle=False,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20)
    else:
        train_iter = ToggleableShufflingSerialIterator(
            TransformDataset(train, load_tr),
            batch_size=1,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingSerialIterator(TransformDataset(
            valid, load_cv),
                                                       batch_size=1,
                                                       repeat=False,
                                                       shuffle=False)

    # Set up a trainer
    updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer,
                            converter, device, args.ngpu, args.grad_noise,
                            args.accum_grad)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     'epoch'))
    if scheduler:
        trainer.extend(scheduler.step(), name='transformer_warmup')

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        CustomEvaluator(model, valid_iter, reporter, converter, device))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        att_reporter = plot_class(att_vis_fn,
                                  data,
                                  args.outdir + "/att_ws",
                                  converter=converter,
                                  transform=load_cv,
                                  device=device)
        trainer.extend(att_reporter, trigger=(1, 'epoch'))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport([
            'main/loss', 'validation/main/loss', 'main/loss_ctc',
            'validation/main/loss_ctc', 'main/loss_att',
            'validation/main/loss_att'
        ],
                              'epoch',
                              file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/acc', 'validation/main/acc'],
                              'epoch',
                              file_name='acc.png'))
    trainer.extend(
        extensions.PlotReport(['main/cer_ctc', 'validation/main/cer_ctc'],
                              'epoch',
                              file_name='cer.png'))

    # Save best models
    trainer.extend(
        snapshot_object(model, 'model.loss.best'),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    if mtl_mode != 'ctc':
        trainer.extend(
            snapshot_object(model, 'model.acc.best'),
            trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # save snapshot which contains model and optimizer states
    trainer.extend(torch_snapshot())

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc' and mtl_mode != 'ctc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters,
                                      'iteration')))
    report_keys = [
        'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
        'validation/main/loss', 'validation/main/loss_ctc',
        'validation/main/loss_att', 'main/acc', 'validation/main/acc',
        'main/cer_ctc', 'validation/main/cer_ctc', 'elapsed_time'
    ]
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').
            param_groups[0]["eps"]),
                       trigger=(args.report_interval_iters, 'iteration'))
        report_keys.append('eps')
    if args.report_cer:
        report_keys.append('validation/main/cer')
    if args.report_wer:
        report_keys.append('validation/main/wer')
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(args.report_interval_iters, 'iteration'))

    trainer.extend(
        extensions.ProgressBar(update_interval=args.report_interval_iters))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        trainer.extend(TensorboardLogger(SummaryWriter(args.tensorboard_dir),
                                         att_reporter),
                       trigger=(args.report_interval_iters, "iteration"))
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)