Beispiel #1
0
def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['input'][0]['shape'][-1])
    odim = int(valid_json[utts[0]]['output'][0]['shape'][-1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    if args.enc_init is not None or args.dec_init is not None:
        model = load_trained_modules(idim, odim, args)
    elif args.asr_init is not None:
        model, _ = load_trained_model(args.asr_init)
    else:
        model_class = dynamic_import(args.model_module)
        model = model_class(idim, odim, args)
    assert isinstance(model, ASRInterface)

    subsampling_factor = model.subsample[0]

    if args.rnnlm is not None:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer,
                             rnnlm_args.unit))
        torch.load(args.rnnlm, rnnlm)
        model.rnnlm = rnnlm

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)),
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        if args.batch_size != 0:
            logging.info('batch size is automatically increased (%d -> %d)' %
                         (args.batch_size, args.batch_size * args.ngpu))
            args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    if args.train_dtype in ("float16", "float32", "float64"):
        dtype = getattr(torch, args.train_dtype)
    else:
        dtype = torch.float32
    logging.info(device)
    logging.info(dtype)
    model = model.to(device=device, dtype=dtype)

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         rho=0.95,
                                         eps=args.eps,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     weight_decay=args.weight_decay)
    elif args.opt == 'noam':
        from espnet.nets.pytorch_backend.rnn.optimizer import get_std_opt
        optimizer = get_std_opt(model, args.adim,
                                args.transformer_warmup_steps,
                                args.transformer_lr)
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # setup apex.amp
    if args.train_dtype in ("O0", "O1", "O2", "O3"):
        try:
            from apex import amp
        except ImportError as e:
            logging.error(
                f"You need to install apex for --train-dtype {args.train_dtype}. "
                "See https://github.com/NVIDIA/apex#linux")
            raise e
        if args.opt == 'noam':
            model, optimizer.optimizer = amp.initialize(
                model, optimizer.optimizer, opt_level=args.train_dtype)
        else:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.train_dtype)
        use_apex = True
    else:
        use_apex = False

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter(subsampling_factor=subsampling_factor,
                                dtype=dtype)

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # make minibatch list (variable length)
    train = make_batchset(train_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          shortest_first=use_sortagrad,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout,
                          iaxis=0,
                          oaxis=0)
    valid = make_batchset(valid_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout,
                          iaxis=0,
                          oaxis=0)

    load_tr = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': True}  # Switch the mode of preprocessing
    )
    load_cv = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': False}  # Switch the mode of preprocessing
    )
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    if args.n_iter_processes > 0:
        train_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(train, load_tr),
            batch_size=1,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(valid, load_cv),
            batch_size=1,
            repeat=False,
            shuffle=False,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20)
    else:
        train_iter = ToggleableShufflingSerialIterator(
            TransformDataset(train, load_tr),
            batch_size=1,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingSerialIterator(TransformDataset(
            valid, load_cv),
                                                       batch_size=1,
                                                       repeat=False,
                                                       shuffle=False)

    # Set up a trainer
    updater = CustomUpdater(model,
                            args.grad_clip,
                            train_iter,
                            optimizer,
                            converter,
                            device,
                            args.ngpu,
                            args.grad_noise,
                            args.accum_grad,
                            use_apex=use_apex)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     'epoch'))

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        CustomEvaluator(model, valid_iter, reporter, converter, device,
                        args.ngpu))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        att_reporter = plot_class(att_vis_fn,
                                  data,
                                  args.outdir + "/att_ws",
                                  converter=converter,
                                  transform=load_cv,
                                  device=device)
        trainer.extend(att_reporter, trigger=(1, 'epoch'))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport([
            'main/loss', 'validation/main/loss', 'main/loss_ctc',
            'validation/main/loss_ctc', 'main/loss_att',
            'validation/main/loss_att'
        ],
                              'epoch',
                              file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/acc', 'validation/main/acc'],
                              'epoch',
                              file_name='acc.png'))
    trainer.extend(
        extensions.PlotReport(['main/cer_ctc', 'validation/main/cer_ctc'],
                              'epoch',
                              file_name='cer.png'))

    # Save best models
    trainer.extend(
        snapshot_object(model, 'model.loss.best'),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    if mtl_mode != 'ctc':
        trainer.extend(
            snapshot_object(model, 'model.acc.best'),
            trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # save snapshot which contains model and optimizer states
    trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc' and mtl_mode != 'ctc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters,
                                      'iteration')))
    report_keys = [
        'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
        'validation/main/loss', 'validation/main/loss_ctc',
        'validation/main/loss_att', 'main/acc', 'validation/main/acc',
        'main/cer_ctc', 'validation/main/cer_ctc', 'elapsed_time'
    ]
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').
            param_groups[0]["eps"]),
                       trigger=(args.report_interval_iters, 'iteration'))
        report_keys.append('eps')
    if args.report_cer:
        report_keys.append('validation/main/cer')
    if args.report_wer:
        report_keys.append('validation/main/wer')
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(args.report_interval_iters, 'iteration'))

    trainer.extend(
        extensions.ProgressBar(update_interval=args.report_interval_iters))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        trainer.extend(TensorboardLogger(SummaryWriter(args.tensorboard_dir),
                                         att_reporter),
                       trigger=(args.report_interval_iters, "iteration"))
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Beispiel #2
0
def train(args):
    """Train E2E-TTS model."""
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning("cuda is not available")

    # get input and output dimension info
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]
    utts = list(valid_json.keys())

    # reverse input and output dimension
    idim = int(valid_json[utts[0]]["output"][0]["shape"][1])
    odim = int(valid_json[utts[0]]["input"][0]["shape"][1])
    logging.info("#input dims : " + str(idim))
    logging.info("#output dims: " + str(odim))

    # get extra input and output dimenstion
    if args.use_speaker_embedding:
        args.spk_embed_dim = int(valid_json[utts[0]]["input"][1]["shape"][0])
    else:
        args.spk_embed_dim = None
    if args.use_second_target:
        args.spc_dim = int(valid_json[utts[0]]["input"][1]["shape"][1])
    else:
        args.spc_dim = None

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + "/model.json"
    with open(model_conf, "wb") as f:
        logging.info("writing a model config file to" + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)),
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode("utf_8"))
    for key in sorted(vars(args).keys()):
        logging.info("ARGS: " + key + ": " + str(vars(args)[key]))

    # specify model architecture
    if args.enc_init is not None or args.dec_init is not None:
        model = load_trained_modules(idim, odim, args, TTSInterface)
    else:
        model_class = dynamic_import(args.model_module)
        model = model_class(idim, odim, args)
    assert isinstance(model, TTSInterface)
    logging.info(model)
    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
        if args.batch_size != 0:
            logging.warning(
                "batch size is automatically increased (%d -> %d)" %
                (args.batch_size, args.batch_size * args.ngpu))
            args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)

    # freeze modules, if specified
    if args.freeze_mods:
        for mod, param in model.state_dict().items():
            if any(mod.startswith(key) for key in args.freeze_mods):
                logging.info(f"{mod} is frozen not to be updated.")
                param.requires_grad = False

    # Setup an optimizer
    if args.opt == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     args.lr,
                                     eps=args.eps,
                                     weight_decay=args.weight_decay)
    elif args.opt == "noam":
        from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt

        optimizer = get_std_opt(model, args.adim,
                                args.transformer_warmup_steps,
                                args.transformer_lr)
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # read json data
    with open(args.train_json, "rb") as f:
        train_json = json.load(f)["utts"]
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    if use_sortagrad:
        args.batch_sort_key = "input"
    # make minibatch list (variable length)
    train_batchset = make_batchset(
        train_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        batch_sort_key=args.batch_sort_key,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        shortest_first=use_sortagrad,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        swap_io=True,
        iaxis=0,
        oaxis=0,
    )
    valid_batchset = make_batchset(
        valid_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        batch_sort_key=args.batch_sort_key,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        swap_io=True,
        iaxis=0,
        oaxis=0,
    )

    load_tr = LoadInputsAndTargets(
        mode="tts",
        use_speaker_embedding=args.use_speaker_embedding,
        use_second_target=args.use_second_target,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={"train": True},  # Switch the mode of preprocessing
        keep_all_data_on_mem=args.keep_all_data_on_mem,
    )

    load_cv = LoadInputsAndTargets(
        mode="tts",
        use_speaker_embedding=args.use_speaker_embedding,
        use_second_target=args.use_second_target,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={"train": False},  # Switch the mode of preprocessing
        keep_all_data_on_mem=args.keep_all_data_on_mem,
    )

    converter = CustomConverter()
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    train_iter = {
        "main":
        ChainerDataLoader(
            dataset=TransformDataset(train_batchset,
                                     lambda data: converter([load_tr(data)])),
            batch_size=1,
            num_workers=args.num_iter_processes,
            shuffle=not use_sortagrad,
            collate_fn=lambda x: x[0],
        )
    }
    valid_iter = {
        "main":
        ChainerDataLoader(
            dataset=TransformDataset(valid_batchset,
                                     lambda data: converter([load_cv(data)])),
            batch_size=1,
            shuffle=False,
            collate_fn=lambda x: x[0],
            num_workers=args.num_iter_processes,
        )
    }

    # Set up a trainer
    updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer,
                            device, args.accum_grad)
    trainer = training.Trainer(updater, (args.epochs, "epoch"),
                               out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        logging.info("resumed from %s" % args.resume)
        torch_resume(args.resume, trainer)

    # set intervals
    eval_interval = (args.eval_interval_epochs, "epoch")
    save_interval = (args.save_interval_epochs, "epoch")
    report_interval = (args.report_interval_iters, "iteration")

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(CustomEvaluator(model, valid_iter, reporter, device),
                   trigger=eval_interval)

    # Save snapshot for each epoch
    trainer.extend(torch_snapshot(), trigger=save_interval)

    # Save best models
    trainer.extend(
        snapshot_object(model, "model.loss.best"),
        trigger=training.triggers.MinValueTrigger("validation/main/loss",
                                                  trigger=eval_interval),
    )

    # Save attention figure for each epoch
    if args.num_save_attention > 0:
        data = sorted(
            list(valid_json.items())[:args.num_save_attention],
            key=lambda x: int(x[1]["output"][0]["shape"][0]),
            reverse=True,
        )
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
            reduction_factor = model.module.reduction_factor
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
            reduction_factor = model.reduction_factor
        if reduction_factor > 1:
            # fix the length to crop attention weight plot correctly
            data = copy.deepcopy(data)
            for idx in range(len(data)):
                ilen = data[idx][1]["input"][0]["shape"][0]
                data[idx][1]["input"][0]["shape"][0] = ilen // reduction_factor
        att_reporter = plot_class(
            att_vis_fn,
            data,
            args.outdir + "/att_ws",
            converter=converter,
            transform=load_cv,
            device=device,
            reverse=True,
        )
        trainer.extend(att_reporter, trigger=eval_interval)
    else:
        att_reporter = None

    # Make a plot for training and validation values
    if hasattr(model, "module"):
        base_plot_keys = model.module.base_plot_keys
    else:
        base_plot_keys = model.base_plot_keys
    plot_keys = []
    for key in base_plot_keys:
        plot_key = ["main/" + key, "validation/main/" + key]
        trainer.extend(
            extensions.PlotReport(plot_key, "epoch", file_name=key + ".png"),
            trigger=eval_interval,
        )
        plot_keys += plot_key
    trainer.extend(
        extensions.PlotReport(plot_keys, "epoch", file_name="all_loss.png"),
        trigger=eval_interval,
    )

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=report_interval))
    report_keys = ["epoch", "iteration", "elapsed_time"] + plot_keys
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=report_interval)
    trainer.extend(extensions.ProgressBar(), trigger=report_interval)

    set_early_stop(trainer, args)
    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        writer = SummaryWriter(args.tensorboard_dir)
        trainer.extend(TensorboardLogger(writer, att_reporter),
                       trigger=report_interval)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     "epoch"),
        )

    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Beispiel #3
0
def train(args):
    """Train with the given args

    :param Namespace args: The program arguments
    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())

    # reverse input and output dimension
    idim = int(valid_json[utts[0]]['output'][0]['shape'][1])
    odim = int(valid_json[utts[0]]['input'][0]['shape'][1])
    if args.use_cbhg:
        args.spc_dim = int(valid_json[utts[0]]['input'][1]['shape'][1])
    if args.use_speaker_embedding:
        args.spk_embed_dim = int(valid_json[utts[0]]['input'][1]['shape'][0])
    else:
        args.spk_embed_dim = None
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to' + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)), indent=4,
                       sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # specify model architecture
    tacotron2 = Tacotron2(idim, odim, args)
    logging.info(tacotron2)

    # check the use of multi-gpu
    if args.ngpu > 1:
        tacotron2 = torch.nn.DataParallel(tacotron2,
                                          device_ids=list(range(args.ngpu)))
        logging.info('batch size is automatically increased (%d -> %d)' %
                     (args.batch_size, args.batch_size * args.ngpu))
        args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    tacotron2 = tacotron2.to(device)

    # define loss
    model = Tacotron2Loss(tacotron2, args.use_masking, args.bce_pos_weight)
    reporter = model.reporter

    # Setup an optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr,
                                 eps=args.eps,
                                 weight_decay=args.weight_decay)

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, 'target', reporter)
    setattr(optimizer, 'serialize', lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter(
        return_targets=True,
        use_speaker_embedding=args.use_speaker_embedding,
        use_second_target=args.use_cbhg,
        preprocess_conf=args.preprocess_conf)

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    if use_sortagrad:
        args.batch_sort_key = "input"
    # make minibatch list (variable length)
    train_batchset = make_batchset(
        train_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        args.batch_sort_key,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        shortest_first=use_sortagrad)
    valid_batchset = make_batchset(
        valid_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        args.batch_sort_key,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        shortest_first=use_sortagrad)
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    if args.n_iter_processes > 0:
        train_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(train_batchset, converter.transform),
            batch_size=1,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(valid_batchset, converter.transform),
            batch_size=1,
            repeat=False,
            shuffle=False,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20)
    else:
        train_iter = ToggleableShufflingSerialIterator(
            TransformDataset(train_batchset, converter.transform),
            batch_size=1,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingSerialIterator(TransformDataset(
            valid_batchset, converter.transform),
                                                       batch_size=1,
                                                       repeat=False,
                                                       shuffle=False)

    # Set up a trainer
    updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer,
                            converter, device)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        CustomEvaluator(model, valid_iter, reporter, converter, device))

    # Save snapshot for each epoch
    trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))

    # Save best models
    trainer.extend(
        extensions.snapshot_object(tacotron2,
                                   'model.loss.best',
                                   savefun=torch_save),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))

    # Save attention figure for each epoch
    if args.num_save_attention > 0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        if hasattr(tacotron2, "module"):
            att_vis_fn = tacotron2.module.calculate_all_attentions
        else:
            att_vis_fn = tacotron2.calculate_all_attentions
        att_reporter = PlotAttentionReport(
            att_vis_fn,
            data,
            args.outdir + '/att_ws',
            converter=CustomConverter(
                return_targets=False,
                use_speaker_embedding=args.use_speaker_embedding,
                preprocess_conf=args.preprocess_conf),
            device=device,
            reverse=True)
        trainer.extend(att_reporter, trigger=(1, 'epoch'))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    plot_keys = [
        'main/loss', 'validation/main/loss', 'main/l1_loss',
        'validation/main/l1_loss', 'main/mse_loss', 'validation/main/mse_loss',
        'main/bce_loss', 'validation/main/bce_loss'
    ]
    trainer.extend(
        extensions.PlotReport(['main/l1_loss', 'validation/main/l1_loss'],
                              'epoch',
                              file_name='l1_loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/mse_loss', 'validation/main/mse_loss'],
                              'epoch',
                              file_name='mse_loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/bce_loss', 'validation/main/bce_loss'],
                              'epoch',
                              file_name='bce_loss.png'))
    if args.use_cbhg:
        plot_keys += [
            'main/cbhg_l1_loss', 'validation/main/cbhg_l1_loss',
            'main/cbhg_mse_loss', 'validation/main/cbhg_mse_loss'
        ]
        trainer.extend(
            extensions.PlotReport(
                ['main/cbhg_l1_loss', 'validation/main/cbhg_l1_loss'],
                'epoch',
                file_name='cbhg_l1_loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/cbhg_mse_loss', 'validation/main/cbhg_mse_loss'],
                'epoch',
                file_name='cbhg_mse_loss.png'))
    trainer.extend(
        extensions.PlotReport(plot_keys, 'epoch', file_name='loss.png'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL,
                                                 'iteration')))
    report_keys = plot_keys[:]
    report_keys[0:0] = ['epoch', 'iteration', 'elapsed_time']
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(REPORT_INTERVAL, 'iteration'))
    trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL))

    set_early_stop(trainer, args)
    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        writer = SummaryWriter(log_dir=args.tensorboard_dir)
        trainer.extend(TensorboardLogger(writer, att_reporter))

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     'epoch'))

    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Beispiel #4
0
def train(args):
    """Train with the given args

    :param Namespace args: The program arguments
    """
    # display torch version
    logging.info('torch version = ' + torch.__version__)

    set_deterministic_pytorch(args)

    # check cuda and cudnn availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get special label ids
    unk = args.char_list_dict['<unk>']
    eos = args.char_list_dict['<eos>']
    # read tokens as a sequence of sentences
    train = read_tokens(args.train_label, args.char_list_dict)
    val = read_tokens(args.valid_label, args.char_list_dict)
    # count tokens
    n_train_tokens, n_train_oovs = count_tokens(train, unk)
    n_val_tokens, n_val_oovs = count_tokens(val, unk)
    logging.info('#vocab = ' + str(args.n_vocab))
    logging.info('#sentences in the training data = ' + str(len(train)))
    logging.info('#tokens in the training data = ' + str(n_train_tokens))
    logging.info('oov rate in the training data = %.2f %%' %
                 (n_train_oovs / n_train_tokens * 100))
    logging.info('#sentences in the validation data = ' + str(len(val)))
    logging.info('#tokens in the validation data = ' + str(n_val_tokens))
    logging.info('oov rate in the validation data = %.2f %%' %
                 (n_val_oovs / n_val_tokens * 100))

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # Create the dataset iterators
    train_iter = ParallelSentenceIterator(train,
                                          args.batchsize,
                                          max_length=args.maxlen,
                                          sos=eos,
                                          eos=eos,
                                          shuffle=not use_sortagrad)
    val_iter = ParallelSentenceIterator(val,
                                        args.batchsize,
                                        max_length=args.maxlen,
                                        sos=eos,
                                        eos=eos,
                                        repeat=False)
    logging.info('#iterations per epoch = ' +
                 str(len(train_iter.batch_indices)))
    logging.info('#total iterations = ' +
                 str(args.epoch * len(train_iter.batch_indices)))
    # Prepare an RNNLM model
    rnn = RNNLM(args.n_vocab, args.layer, args.unit, args.type)
    model = ClassifierWithState(rnn)
    if args.ngpu > 1:
        logging.warning(
            "currently, multi-gpu is not supported. use single gpu.")
    if args.ngpu > 0:
        # Make the specified GPU current
        gpu_id = 0
        model.cuda(gpu_id)
    else:
        gpu_id = -1

    # Save model conf to json
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(
            json.dumps(vars(args), indent=4, sort_keys=True).encode('utf_8'))

    # Set up an optimizer
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters())

    # FIXME: TOO DIRTY HACK
    reporter = model.reporter
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    updater = BPTTUpdater(train_iter,
                          model,
                          optimizer,
                          gpu_id,
                          gradclip=args.gradclip)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.outdir)
    trainer.extend(LMEvaluator(val_iter, model, reporter, device=gpu_id))
    trainer.extend(
        extensions.LogReport(postprocess=compute_perplexity,
                             trigger=(REPORT_INTERVAL, 'iteration')))
    trainer.extend(extensions.PrintReport(
        ['epoch', 'iteration', 'perplexity', 'val_perplexity',
         'elapsed_time']),
                   trigger=(REPORT_INTERVAL, 'iteration'))
    trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL))
    # Save best models
    trainer.extend(torch_snapshot(filename='snapshot.ep.{.updater.epoch}'))
    trainer.extend(
        extensions.snapshot_object(model,
                                   'rnnlm.model.{.updater.epoch}',
                                   savefun=torch_save))
    # T.Hori: MinValueTrigger should be used, but it fails when resuming
    trainer.extend(
        MakeSymlinkToBestModel('validation/main/loss', 'rnnlm.model'))

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epoch,
                     'epoch'))
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    set_early_stop(trainer, args, is_lm=True)
    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        writer = SummaryWriter(args.tensorboard_dir)
        trainer.extend(TensorboardLogger(writer))

    trainer.run()
    check_early_stop(trainer, args.epoch)

    # compute perplexity for test set
    if args.test_label:
        logging.info('test the best model')
        torch_load(args.outdir + '/rnnlm.model.best', model)
        test = read_tokens(args.test_label, args.char_list_dict)
        n_test_tokens, n_test_oovs = count_tokens(test, unk)
        logging.info('#sentences in the test data = ' + str(len(test)))
        logging.info('#tokens in the test data = ' + str(n_test_tokens))
        logging.info('oov rate in the test data = %.2f %%' %
                     (n_test_oovs / n_test_tokens * 100))
        test_iter = ParallelSentenceIterator(test,
                                             args.batchsize,
                                             max_length=args.maxlen,
                                             sos=eos,
                                             eos=eos,
                                             repeat=False)
        evaluator = LMEvaluator(test_iter, model, reporter, device=gpu_id)
        result = evaluator()
        logging.info('test perplexity: ' +
                     str(np.exp(float(result['main/loss']))))
Beispiel #5
0
def train(args):
    """Train E2E-TTS model."""
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())

    # reverse input and output dimension
    idim = int(valid_json[utts[0]]['output'][0]['shape'][1])
    odim = int(valid_json[utts[0]]['input'][0]['shape'][1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # get extra input and output dimenstion
    if args.use_speaker_embedding:
        args.spk_embed_dim = int(valid_json[utts[0]]['input'][1]['shape'][0])
    else:
        args.spk_embed_dim = None
    if args.use_second_target:
        args.spc_dim = int(valid_json[utts[0]]['input'][1]['shape'][1])
    else:
        args.spc_dim = None

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to' + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)),
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # specify model architecture
    model_class = dynamic_import(args.model_module)
    model = model_class(idim, odim, args)
    assert isinstance(model, TTSInterface)
    logging.info(model)
    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
        if args.batch_size != 0:
            logging.warning(
                'batch size is automatically increased (%d -> %d)' %
                (args.batch_size, args.batch_size * args.ngpu))
            args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)

    # Setup an optimizer
    if args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     args.lr,
                                     eps=args.eps,
                                     weight_decay=args.weight_decay)
    elif args.opt == 'noam':
        from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt
        optimizer = get_std_opt(model, args.adim,
                                args.transformer_warmup_steps,
                                args.transformer_lr)
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, 'target', reporter)
    setattr(optimizer, 'serialize', lambda s: reporter.serialize(s))

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    if use_sortagrad:
        args.batch_sort_key = "input"
    # make minibatch list (variable length)
    train_batchset = make_batchset(
        train_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        batch_sort_key=args.batch_sort_key,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        shortest_first=use_sortagrad,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        swap_io=True,
        iaxis=0,
        oaxis=0)
    valid_batchset = make_batchset(
        valid_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        batch_sort_key=args.batch_sort_key,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        swap_io=True,
        iaxis=0,
        oaxis=0)

    load_tr = LoadInputsAndTargets(
        mode='tts',
        use_speaker_embedding=args.use_speaker_embedding,
        use_second_target=args.use_second_target,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': True},  # Switch the mode of preprocessing
        keep_all_data_on_mem=args.keep_all_data_on_mem,
    )

    load_cv = LoadInputsAndTargets(
        mode='tts',
        use_speaker_embedding=args.use_speaker_embedding,
        use_second_target=args.use_second_target,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': False},  # Switch the mode of preprocessing
        keep_all_data_on_mem=args.keep_all_data_on_mem,
    )

    converter = CustomConverter()
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    train_iter = {
        'main':
        ChainerDataLoader(dataset=TransformDataset(
            train_batchset, lambda data: converter([load_tr(data)])),
                          batch_size=1,
                          num_workers=args.num_iter_processes,
                          shuffle=not use_sortagrad,
                          collate_fn=lambda x: x[0])
    }
    valid_iter = {
        'main':
        ChainerDataLoader(dataset=TransformDataset(
            valid_batchset, lambda data: converter([load_cv(data)])),
                          batch_size=1,
                          shuffle=False,
                          collate_fn=lambda x: x[0],
                          num_workers=args.num_iter_processes)
    }

    # Set up a trainer
    updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer,
                            device, args.accum_grad)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    # set intervals
    eval_interval = (args.eval_interval_epochs, 'epoch')
    save_interval = (args.save_interval_epochs, 'epoch')
    report_interval = (args.report_interval_iters, 'iteration')

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(CustomEvaluator(model, valid_iter, reporter, device),
                   trigger=eval_interval)

    # Save snapshot for each epoch
    trainer.extend(torch_snapshot(), trigger=save_interval)

    # Save best models
    trainer.extend(snapshot_object(model, 'model.loss.best'),
                   trigger=training.triggers.MinValueTrigger(
                       'validation/main/loss', trigger=eval_interval))

    # Save attention figure for each epoch
    if args.num_save_attention > 0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        att_reporter = plot_class(att_vis_fn,
                                  data,
                                  args.outdir + '/att_ws',
                                  converter=converter,
                                  transform=load_cv,
                                  device=device,
                                  reverse=True)
        trainer.extend(att_reporter, trigger=eval_interval)
    else:
        att_reporter = None

    # Make a plot for training and validation values
    if hasattr(model, "module"):
        base_plot_keys = model.module.base_plot_keys
    else:
        base_plot_keys = model.base_plot_keys
    plot_keys = []
    for key in base_plot_keys:
        plot_key = ['main/' + key, 'validation/main/' + key]
        trainer.extend(extensions.PlotReport(plot_key,
                                             'epoch',
                                             file_name=key + '.png'),
                       trigger=eval_interval)
        plot_keys += plot_key
    trainer.extend(extensions.PlotReport(plot_keys,
                                         'epoch',
                                         file_name='all_loss.png'),
                   trigger=eval_interval)

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=report_interval))
    report_keys = ['epoch', 'iteration', 'elapsed_time'] + plot_keys
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=report_interval)
    trainer.extend(extensions.ProgressBar(), trigger=report_interval)

    set_early_stop(trainer, args)
    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        writer = SummaryWriter(args.tensorboard_dir)
        trainer.extend(TensorboardLogger(writer, att_reporter),
                       trigger=report_interval)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     'epoch'))

    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Beispiel #6
0
def train(args):
    """Train with the given args.

    :param Namespace args: The program arguments
    :param type model_class: LMInterface class for training
    """
    model_class = dynamic_import_lm(args.model_module, args.backend)
    assert issubclass(model_class,
                      LMInterface), "model should implement LMInterface"
    # display torch version
    logging.info('torch version = ' + torch.__version__)

    set_deterministic_pytorch(args)

    # check cuda and cudnn availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get special label ids
    unk = args.char_list_dict['<unk>']
    eos = args.char_list_dict['<eos>']
    # read tokens as a sequence of sentences
    val, n_val_tokens, n_val_oovs = load_dataset(args.valid_label,
                                                 args.char_list_dict,
                                                 args.dump_hdf5_path)
    train, n_train_tokens, n_train_oovs = load_dataset(args.train_label,
                                                       args.char_list_dict,
                                                       args.dump_hdf5_path)
    logging.info('#vocab = ' + str(args.n_vocab))
    logging.info('#sentences in the training data = ' + str(len(train)))
    logging.info('#tokens in the training data = ' + str(n_train_tokens))
    logging.info('oov rate in the training data = %.2f %%' %
                 (n_train_oovs / n_train_tokens * 100))
    logging.info('#sentences in the validation data = ' + str(len(val)))
    logging.info('#tokens in the validation data = ' + str(n_val_tokens))
    logging.info('oov rate in the validation data = %.2f %%' %
                 (n_val_oovs / n_val_tokens * 100))

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # Create the dataset iterators
    batch_size = args.batchsize * max(args.ngpu, 1)
    if batch_size * args.accum_grad > args.batchsize:
        logging.info(
            f'batch size is automatically increased ({args.batchsize} -> {batch_size * args.accum_grad})'
        )
    train_iter = ParallelSentenceIterator(train,
                                          batch_size,
                                          max_length=args.maxlen,
                                          sos=eos,
                                          eos=eos,
                                          shuffle=not use_sortagrad)
    val_iter = ParallelSentenceIterator(val,
                                        batch_size,
                                        max_length=args.maxlen,
                                        sos=eos,
                                        eos=eos,
                                        repeat=False)
    epoch_iters = int(len(train_iter.batch_indices) / args.accum_grad)
    logging.info('#iterations per epoch = %d' % epoch_iters)
    logging.info('#total iterations = ' + str(args.epoch * epoch_iters))
    # Prepare an RNNLM model
    if args.train_dtype in ("float16", "float32", "float64"):
        dtype = getattr(torch, args.train_dtype)
    else:
        dtype = torch.float32
    model = model_class(args.n_vocab, args).to(dtype=dtype)
    if args.ngpu > 0:
        model.to("cuda")
        gpu_id = list(range(args.ngpu))
    else:
        gpu_id = [-1]

    # Save model conf to json
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(
            json.dumps(vars(args),
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode('utf_8'))

    # Set up an optimizer
    opt_class = dynamic_import_optimizer(args.opt, args.backend)
    optimizer = opt_class.from_args(model.parameters(), args)
    if args.schedulers is None:
        schedulers = []
    else:
        schedulers = [
            dynamic_import_scheduler(v)(k, args) for k, v in args.schedulers
        ]

    # setup apex.amp
    if args.train_dtype in ("O0", "O1", "O2", "O3"):
        try:
            from apex import amp
        except ImportError as e:
            logging.error(
                f"You need to install apex for --train-dtype {args.train_dtype}. "
                "See https://github.com/NVIDIA/apex#linux")
            raise e
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.train_dtype)
        use_apex = True
    else:
        use_apex = False

    # FIXME: TOO DIRTY HACK
    reporter = Reporter()
    setattr(model, "reporter", reporter)
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))
    print('----------------------', gpu_id[0])
    updater = BPTTUpdater(train_iter,
                          model,
                          optimizer,
                          schedulers,
                          gpu_id,
                          gradclip=args.gradclip,
                          use_apex=use_apex,
                          accum_grad=args.accum_grad)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.outdir)
    trainer.extend(LMEvaluator(val_iter, model, reporter, device=gpu_id))
    trainer.extend(
        extensions.LogReport(postprocess=compute_perplexity,
                             trigger=(args.report_interval_iters,
                                      'iteration')))
    trainer.extend(extensions.PrintReport([
        'epoch', 'iteration', 'main/loss', 'perplexity', 'val_perplexity',
        'elapsed_time'
    ]),
                   trigger=(args.report_interval_iters, 'iteration'))
    trainer.extend(
        extensions.ProgressBar(update_interval=args.report_interval_iters))
    # Save best models
    trainer.extend(torch_snapshot(filename='snapshot.ep.{.updater.epoch}'))
    trainer.extend(snapshot_object(model, 'rnnlm.model.{.updater.epoch}'))
    # T.Hori: MinValueTrigger should be used, but it fails when resuming
    trainer.extend(
        MakeSymlinkToBestModel('validation/main/loss', 'rnnlm.model'))

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epoch,
                     'epoch'))
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    set_early_stop(trainer, args, is_lm=True)
    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        writer = SummaryWriter(args.tensorboard_dir)
        trainer.extend(TensorboardLogger(writer),
                       trigger=(args.report_interval_iters, 'iteration'))

    trainer.run()
    check_early_stop(trainer, args.epoch)

    # compute perplexity for test set
    if args.test_label:
        logging.info('test the best model')
        torch_load(args.outdir + '/rnnlm.model.best', model)
        test = read_tokens(args.test_label, args.char_list_dict)
        n_test_tokens, n_test_oovs = count_tokens(test, unk)
        logging.info('#sentences in the test data = ' + str(len(test)))
        logging.info('#tokens in the test data = ' + str(n_test_tokens))
        logging.info('oov rate in the test data = %.2f %%' %
                     (n_test_oovs / n_test_tokens * 100))
        test_iter = ParallelSentenceIterator(test,
                                             batch_size,
                                             max_length=args.maxlen,
                                             sos=eos,
                                             eos=eos,
                                             repeat=False)
        evaluator = LMEvaluator(test_iter, model, reporter, device=gpu_id)
        result = evaluator()
        compute_perplexity(result)
        logging.info(f"test perplexity: {result['perplexity']}")
Beispiel #7
0
def train(args, teacher_args):
    """Train FCL-taco2 model."""
    set_deterministic_pytorch(args)
    # args.use_fe_condition = True
    # # pre-occupy GPU
    # buff = torch.randn(int(1e9)).cuda()
    # del buff
    
    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning("cuda is not available")

    # get input and output dimension info
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]
    utts = list(valid_json.keys())

    # reverse input and output dimension
    idim = int(valid_json[utts[0]]["output"][0]["shape"][1])
    odim = int(valid_json[utts[0]]["input"][0]["shape"][1])
    logging.info("#input dims: " + str(idim))
    logging.info("#output dims: " + str(odim))

    # get extra input and output dimenstion
    if args.use_speaker_embedding:
        args.spk_embed_dim = int(valid_json[utts[0]]["input"][1]["shape"][0])
    else:
        args.spk_embed_dim = None
    if args.use_second_target:
        args.spc_dim = int(valid_json[utts[0]]["input"][1]["shape"][1])
    else:
        args.spc_dim = None

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + "/model.json"
    with open(model_conf, "wb") as f:
        logging.info("writing a model config file to" + model_conf)
        f.write(
            json.dumps(
                (idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True
            ).encode("utf_8")
        )
    for key in sorted(vars(args).keys()):
        logging.info("ARGS: " + key + ": " + str(vars(args)[key]))

    # specify model architecture
    if args.enc_init is not None or args.dec_init is not None:
        model = load_trained_modules(idim, odim, args, TTSInterface)
    else:
        model_class = dynamic_import(args.model_module)
        model = model_class(idim, odim, args, args, teacher_args=teacher_args)
    
    #print('\n\nteacher_args:', teacher_args.embed_dim, '\n\n')
    teacher_model_class = dynamic_import(teacher_args.model_module)
    teacher_model = teacher_model_class(idim, odim, teacher_args, teacher_args)
    #teacher_model = teacher_model.to('cuda')
    if teacher_args.amp_checkpoint is None:
        raise ValueError('please provide the teacher-model-amp-checkpoint')
    else:
        logging.info("teacher-model resumed from %s" % teacher_args.amp_checkpoint)
        teacher_checkpoint = torch.load(teacher_args.amp_checkpoint)
        teacher_model.load_state_dict(teacher_checkpoint['model'])

    # print('tts_wds:', model.base_plot_keys)
    assert isinstance(model, TTSInterface)
    logging.info(model)
    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
        # model = torch.nn.DataParallel(model, device_ids=[4,5,6,7])
        if args.batch_size != 0:
            logging.warning(
                "batch size is automatically increased (%d -> %d)"
                % (args.batch_size, args.batch_size * args.ngpu)
            )
            args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)
    teacher_model = teacher_model.to(device)
    for param in teacher_model.parameters(): # fix teacher model params
        param.requires_grad = False

    # freeze modules, if specified
    if args.freeze_mods:
        if hasattr(model, "module"):
            freeze_mods = ["module." + x for x in args.freeze_mods]
        else:
            freeze_mods = args.freeze_mods

        for mod, param in model.named_parameters():
            if any(mod.startswith(key) for key in freeze_mods):
                logging.info(f"{mod} is frozen not to be updated.")
                param.requires_grad = False

        model_params = filter(lambda x: x.requires_grad, model.parameters())
    else:
        model_params = model.parameters()

    # Setup an optimizer
    if args.opt == "adam":
        optimizer = torch.optim.Adam(
            model_params, args.lr, eps=args.eps, weight_decay=args.weight_decay
        )
    elif args.opt == "noam":
        from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt

        optimizer = get_std_opt(
            model_params, args.adim, args.transformer_warmup_steps, args.transformer_lr
        )
    elif args.opt == 'lamb':
        kw = dict(lr=0.1, betas=(0.9, 0.98), eps=1e-9,
              weight_decay=1e-6)
        from apex.optimizers import FusedAdam, FusedLAMB
        optimizer = FusedLAMB(model.parameters(), **kw)
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)
    
    if args.use_amp:
        opt_level = 'O1'
        model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
    
    if args.amp_checkpoint is not None:
        logging.info("resumed from %s" % args.amp_checkpoint)
        checkpoint = torch.load(args.amp_checkpoint)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        amp.load_state_dict(checkpoint['amp'])
        
    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # read json data
    with open(args.train_json, "rb") as f:
        train_json = json.load(f)["utts"]
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]
    
    num_batches = len(train_json.keys()) // args.batch_size
    
    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    if use_sortagrad:
        args.batch_sort_key = "input"
        
    print(f'\n\n batch_sort_key: {args.batch_sort_key} \n\n')
    
    # make minibatch list (variable length)
    train_batchset = make_batchset(
        train_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        batch_sort_key=args.batch_sort_key,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        shortest_first=use_sortagrad,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        swap_io=True,
        iaxis=0,
        oaxis=0,
    )
    valid_batchset = make_batchset(
        valid_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        batch_sort_key=args.batch_sort_key,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        swap_io=True,
        iaxis=0,
        oaxis=0,
    )
    
    
    from io_utils_fcl import LoadInputsAndTargets
    
    load_tr = LoadInputsAndTargets(
        mode="tts",
        use_speaker_embedding=args.use_speaker_embedding,
        use_second_target=args.use_second_target,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={"train": True},  # Switch the mode of preprocessing
        keep_all_data_on_mem=args.keep_all_data_on_mem,
        pad_eos=args.pad_eos,
    )

    load_cv = LoadInputsAndTargets(
        mode="tts",
        use_speaker_embedding=args.use_speaker_embedding,
        use_second_target=args.use_second_target,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={"train": False},  # Switch the mode of preprocessing
        keep_all_data_on_mem=args.keep_all_data_on_mem,
        pad_eos=args.pad_eos,
    )

    converter = CustomConverter(reduction_factor=args.reduction_factor,
                                use_fe_condition=args.use_fe_condition,
                                append_position=args.append_position,
                                )
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    train_iter = {
        "main": ChainerDataLoader(
            dataset=TransformDataset(
                train_batchset, lambda data: converter([load_tr(data)])
            ),
            batch_size=1,
            num_workers=args.num_iter_processes,
            shuffle=not use_sortagrad,
            collate_fn=lambda x: x[0],
        )
    }
    valid_iter = {
        "main": ChainerDataLoader(
            dataset=TransformDataset(
                valid_batchset, lambda data: converter([load_cv(data)])
            ),
            batch_size=1,
            shuffle=False,
            collate_fn=lambda x: x[0],
            num_workers=args.num_iter_processes,
        )
    }

    # Set up a trainer
    updater = CustomUpdater(
        teacher_model, model, args.grad_clip, train_iter, optimizer, device, args.accum_grad, args.use_amp, num_batches, args.outdir
    )
    trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        logging.info("resumed from %s" % args.resume)
        torch_resume(args.resume, trainer)

    # set intervals
    eval_interval = (args.eval_interval_epochs, "epoch")
    save_interval = (args.save_interval_epochs, "epoch")
    report_interval = (args.report_interval_iters, "iteration")

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        CustomEvaluator(teacher_model, model, valid_iter, reporter, device), trigger=eval_interval
    )

    # Save snapshot for each epoch
    trainer.extend(torch_snapshot(), trigger=save_interval)

    # Save best models
    trainer.extend(
        snapshot_object(model, "model.loss.best"),
        trigger=training.triggers.MinValueTrigger(
            "validation/main/loss", trigger=eval_interval
        ),
    )


    # Make a plot for training and validation values
    if hasattr(model, "module"):
        base_plot_keys = model.module.base_plot_keys
    else:
        base_plot_keys = model.base_plot_keys
    plot_keys = []
    for key in base_plot_keys:
        plot_key = ["main/" + key, "validation/main/" + key]
        trainer.extend(
            extensions.PlotReport(plot_key, "epoch", file_name=key + ".png"),
            trigger=eval_interval,
        )
        plot_keys += plot_key
    trainer.extend(
        extensions.PlotReport(plot_keys, "epoch", file_name="all_loss.png"),
        trigger=eval_interval,
    )

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=report_interval))
    report_keys = ["epoch", "iteration", "elapsed_time"] + plot_keys
    trainer.extend(extensions.PrintReport(report_keys), trigger=report_interval)
    trainer.extend(extensions.ProgressBar(), trigger=report_interval)

    set_early_stop(trainer, args)
    # if args.tensorboard_dir is not None and args.tensorboard_dir != "":
    #     writer = SummaryWriter(args.tensorboard_dir)
    #     trainer.extend(TensorboardLogger(writer, att_reporter), trigger=report_interval)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"),
        )

    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Beispiel #8
0
def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get paths to data
    lang_pairs = sorted(args.lang_pairs.split(','))
    args.one_to_many = True if len(lang_pairs) > 1 else False
    tgt_langs = sorted([p.split('-')[-1] for p in lang_pairs])
    src_lang = lang_pairs[0].split('-')[0]
    if args.one_to_many:
        train_jpaths = [
            os.path.join(args.train_json, fname)
            for fname in sorted(os.listdir(args.train_json))
            if fname.endswith('.json')
        ]
        valid_jpaths = [
            os.path.join(args.valid_json, fname)
            for fname in sorted(os.listdir(args.valid_json))
            if fname.endswith('.json')
        ]

        all_langs = list(
            sorted(set([l for p in lang_pairs for l in p.split('-')])))
        args.langs_dict = {}
        offset = 2  # for <blank> and <unk>
        for i, lang in enumerate(all_langs):
            args.langs_dict[f'<2{lang}>'] = offset + i

        logging.info(f'| train_jpaths: {train_jpaths}')
        logging.info(f'| valid_jpaths: {valid_jpaths}')
        logging.info(f'| lang_pairs  : {lang_pairs}')
        logging.info(f'| langs_dict : {args.langs_dict}')
    else:
        train_jpaths = [args.train_json]
        valid_jpaths = [args.valid_json]
        args.langs_dict = None

    # get input and output dimension info
    idim = 0
    odim = 0
    for i, jpath in enumerate(valid_jpaths):
        with open(jpath, 'rb') as f:
            valid_json = json.load(f)['utts']
        utts = list(valid_json.keys())
        idim_tmp = int(valid_json[utts[0]]['input'][0]['shape'][-1])
        odim_tmp = int(valid_json[utts[0]]['output'][0]['shape'][-1])
        logging.info('| pair {}: idim={}, odim={}'.format(
            lang_pairs[i], idim_tmp, odim_tmp))
        if idim == 0:
            idim = idim_tmp
        else:
            assert idim == idim_tmp
        if odim < odim_tmp:
            odim = odim_tmp
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # Initialize with pre-trained ASR encoder and MT decoder
    if args.enc_init is not None or args.dec_init is not None:
        logging.info('Loading pretrained ASR encoder and/or MT decoder ...')
        model = load_trained_modules(idim, odim, args, interface=STInterface)
        logging.info(f'*** Model *** \n {model}')
    else:
        model_class = dynamic_import(args.model_module)
        model = model_class(idim, odim, args)
        logging.info(f'*** Model *** \n {model}')
    assert isinstance(model, STInterface)
    logging.info(
        f'| Number of model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}'
    )

    subsampling_factor = model.subsample[0]
    logging.info(f'subsampling_factor={subsampling_factor}')

    if args.rnnlm is not None:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(
                len(args.char_list),
                rnnlm_args.layer,
                rnnlm_args.unit,
                getattr(rnnlm_args, "embed_unit",
                        None),  # for backward compatibility
            ))
        torch_load(args.rnnlm, rnnlm)
        model.rnnlm = rnnlm

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)),
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        if args.batch_size != 0:
            logging.warning(
                'batch size is automatically increased (%d -> %d)' %
                (args.batch_size, args.batch_size * args.ngpu))
            args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    if args.train_dtype in ("float16", "float32", "float64"):
        dtype = getattr(torch, args.train_dtype)
    else:
        dtype = torch.float32
    model = model.to(device=device, dtype=dtype)

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         rho=0.95,
                                         eps=args.eps,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
    elif args.opt == 'noam':
        from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt
        optimizer = get_std_opt(model, args.adim,
                                args.transformer_warmup_steps,
                                args.transformer_lr)
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # setup apex.amp
    if args.train_dtype in ("O0", "O1", "O2", "O3"):
        try:
            from apex import amp
        except ImportError as e:
            logging.error(
                f"You need to install apex for --train-dtype {args.train_dtype}. "
                "See https://github.com/NVIDIA/apex#linux")
            raise e
        if args.opt == 'noam':
            model, optimizer.optimizer = amp.initialize(
                model, optimizer.optimizer, opt_level=args.train_dtype)
        else:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.train_dtype)
        use_apex = True
    else:
        use_apex = False

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    logging.info(f'use_sortagrad: {use_sortagrad}')
    # read json data
    num_langs = len(tgt_langs)
    train_all_pairs = [None] * num_langs
    valid_all_pairs = [None] * num_langs
    # check_data = {}
    batch_size = args.batch_size // num_langs if num_langs > 1 else args.batch_size
    for i, jpath in enumerate(train_jpaths):
        with open(jpath, 'rb') as f:
            train_json = json.load(f)['utts']
            train_all_pairs[i] = make_batchset(
                train_json,
                batch_size,
                args.maxlen_in,
                args.maxlen_out,
                args.minibatches,
                min_batch_size=1,
                shortest_first=use_sortagrad,
                count=args.batch_count,
                batch_bins=args.batch_bins,
                batch_frames_in=args.batch_frames_in,
                batch_frames_out=args.batch_frames_out,
                batch_frames_inout=args.batch_frames_inout)
        # check_data[lang_pairs[i]] = list(train_json.keys())

    for i, jpath in enumerate(valid_jpaths):
        with open(jpath, 'rb') as f:
            valid_json = json.load(f)['utts']
            valid_all_pairs[i] = make_batchset(
                valid_json,
                batch_size,
                args.maxlen_in,
                args.maxlen_out,
                args.minibatches,
                min_batch_size=1,
                count=args.batch_count,
                batch_bins=args.batch_bins,
                batch_frames_in=args.batch_frames_in,
                batch_frames_out=args.batch_frames_out,
                batch_frames_inout=args.batch_frames_inout)
        # check_data[lang_pairs[i]] = list(valid_json.keys())

    # print(f'len(train_all_pairs) = {len(train_all_pairs)}')
    # print(f'len(valid_all_pairs) = {len(valid_all_pairs)}')
    # for i, batch_langs in enumerate(train_all_pairs):
    #     print(f'batch for lang {lang_pairs[i]}')
    #     for batch_lang in batch_langs:
    #         print(f'len(batch_lang) = {len(batch_lang)}')
    #     print('-'*5)

    if num_langs > 1:
        cycle_train = [cycle(x) for x in train_all_pairs]
        cycle_valid = [cycle(x) for x in valid_all_pairs]

        num_batches_train = max(len(i) for i in train_all_pairs)
        num_batches_valid = max(len(i) for i in valid_all_pairs)
        train = [None] * num_batches_train
        valid = [None] * num_batches_valid

        for i, s in enumerate(zip(*cycle_train)):
            x = []
            for y in s:
                x.extend(y)
            train[i] = x
            if i >= num_batches_train - 1:
                break
        for i, s in enumerate(zip(*cycle_valid)):
            x = []
            for y in s:
                x.extend(y)
            valid[i] = x
            if i >= num_batches_valid - 1:
                break
    else:
        train = train_all_pairs[0]
        valid = valid_all_pairs[0]

    # print(f'num_batches_train = {num_batches_train}')
    # print(f'num_batches_valid = {num_batches_valid}')
    # print(f'len(train) = {len(train)}')
    # print(f'len(valid) = {len(valid)}')

    # print('*** Checking results of make_batchset() ***')
    # for i, batch in enumerate(train):
    #     # if i == 0:
    #     #     print(batch)
    #     ids = [sample[0] for sample in batch]
    #     langs = [sample[1]['lang'] for sample in batch]
    #     pairs = ['en-'+l for l in langs]
    #     for i in range(len(ids)):
    #         r = ids[i] in list(check_data[pairs[i]])
    #         print(f'ids[i]={ids[i]} in {check_data[pairs[i]]}: {r}')
    #         print('-')
    #         if r:
    #             check_data[pairs[i]].remove(ids[i])

    #     print(f'len(batch) = {len(batch)}')
    #     print(f'langs in batch: {langs}')
    #     print('-'*5)
    #     # if i > 5:
    #     #     break

    # print('*** Samples that are not used yet ***')
    # for k, v in check_data.items():
    #     print(k, v)
    #     print('-'*5)
    # print('-'*20)

    load_tr = LoadInputsAndTargets(mode='asr',
                                   load_output=True,
                                   preprocess_conf=args.preprocess_conf,
                                   preprocess_args={'train': True},
                                   langs_dict=args.langs_dict,
                                   src_lang=src_lang)
    load_cv = LoadInputsAndTargets(mode='asr',
                                   load_output=True,
                                   preprocess_conf=args.preprocess_conf,
                                   preprocess_args={'train': False},
                                   langs_dict=args.langs_dict,
                                   src_lang=src_lang)
    # print('LoadInputsAndTargets()')
    # features, targets = load_cv(train[0])
    # print(f'*** features: {features} ***')
    # for f in features:
    #     # print(f)
    #     print(f'len(f) = {len(f)}')
    #     print('---')
    # print(f'*** targets : {targets} ***')
    # y1, y2 = zip(*targets)
    # # print(f'y1 = {y1}')
    # # print(f'y2 = {y2}')
    # for s in zip(y1, y2):
    #     print(len(s[0][1]), len(s[1][1]))
    # print('-'*20)

    # Setup a converter
    converter = CustomConverter(subsampling_factor=subsampling_factor,
                                dtype=dtype,
                                asr_task=args.asr_weight > 0)

    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    # default collate function converts numpy array to pytorch tensor
    # we used an empty collate function instead which returns list
    n_iter_processes = args.n_iter_processes
    if n_iter_processes < 0:
        n_iter_processes = multiprocessing.cpu_count()
    elif n_iter_processes > 0:
        n_iter_processes = min(n_iter_processes, multiprocessing.cpu_count())
    print(f'n_iter_processes = {n_iter_processes}')

    train_iter = {
        'main':
        ChainerDataLoader(dataset=TransformDataset(
            train, lambda data: converter([load_tr(data)])),
                          batch_size=1,
                          num_workers=n_iter_processes,
                          shuffle=not use_sortagrad,
                          collate_fn=lambda x: x[0],
                          pin_memory=False)
    }
    valid_iter = {
        'main':
        ChainerDataLoader(dataset=TransformDataset(
            valid, lambda data: converter([load_cv(data)])),
                          batch_size=1,
                          shuffle=False,
                          collate_fn=lambda x: x[0],
                          num_workers=n_iter_processes,
                          pin_memory=False)
    }

    # xs_pad, ilens, ys_pad, ys_pad_asr = converter([load_cv(valid[0])])
    # print('*** xs_pad ***')
    # # print(xs_pad)
    # print(xs_pad.size())
    # print('*** ilens ***')
    # print(ilens)
    # print('*** ys_pad ***')
    # # print(ys_pad)
    # print(ys_pad.size())
    # print('*** ys_pad_asr ***')
    # print(ys_pad_asr)
    # print('-'*20)

    # print(train_iter['main'])
    # i=0
    # for item in train_iter['main']:
    #     print(item)
    #     print('-'*5)
    #     if i > 8:
    #         break
    #     i += 1

    # Set up a trainer
    updater = CustomUpdater(model,
                            args.grad_clip,
                            train_iter,
                            optimizer,
                            device,
                            args.ngpu,
                            args.grad_noise,
                            args.accum_grad,
                            use_apex=use_apex)
    # trainer = training.Trainer(
    #     updater, (args.epochs, 'epoch'), out=args.outdir)
    time_limit_trigger = TimeLimitTrigger(args)
    trainer = training.Trainer(updater, time_limit_trigger, out=args.outdir)
    logging.info(f'updater: {updater}')
    logging.info(f'trainer: {trainer}')

    if use_sortagrad:
        logging.info(f'use_sortagrad ...')
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     'epoch'))

    # Evaluate the model with the test dataset for each epoch
    if args.save_interval_iters > 0:
        trainer.extend(CustomEvaluator(model, valid_iter, reporter, device,
                                       args.ngpu),
                       trigger=(args.save_interval_iters, 'iteration'))
    else:
        trainer.extend(
            CustomEvaluator(model, valid_iter, reporter, device, args.ngpu))

    # Save attention weight each epoch
    if args.num_save_attention > 0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        att_reporter = plot_class(att_vis_fn,
                                  data,
                                  args.outdir + "/att_ws",
                                  converter=converter,
                                  transform=load_cv,
                                  device=device)
        trainer.extend(att_reporter, trigger=(1, 'epoch'))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport([
            'main/loss', 'validation/main/loss', 'main/loss_asr',
            'validation/main/loss_asr', 'main/loss_st',
            'validation/main/loss_st'
        ],
                              'epoch',
                              file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport([
            'main/acc', 'validation/main/acc', 'main/acc_asr',
            'validation/main/acc_asr'
        ],
                              'epoch',
                              file_name='acc.png'))
    trainer.extend(
        extensions.PlotReport(['main/bleu', 'validation/main/bleu'],
                              'epoch',
                              file_name='bleu.png'))

    # Save best models
    if args.report_interval_iters > 0:
        trainer.extend(snapshot_object(model, 'model.loss.best'),
                       trigger=MinValueTrigger(
                           'validation/main/loss',
                           trigger=(args.report_interval_iters, 'iteration'),
                           best_value=None))
        trainer.extend(snapshot_object(model, 'model.acc.best'),
                       trigger=MaxValueTrigger(
                           'validation/main/acc',
                           trigger=(args.report_interval_iters, 'iteration'),
                           best_value=None))
    else:
        trainer.extend(snapshot_object(model, 'model.loss.best'),
                       trigger=MinValueTrigger('validation/main/loss',
                                               best_value=None))
        trainer.extend(snapshot_object(model, 'model.acc.best'),
                       trigger=MaxValueTrigger('validation/main/acc',
                                               best_value=None))

    # save snapshot which contains model and optimizer states
    if args.save_interval_iters > 0:
        trainer.extend(
            torch_snapshot(filename='snapshot.iter.{.updater.iteration}'),
            trigger=(args.save_interval_iters, 'iteration'))
    else:
        trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
    elif args.opt == 'adam':
        if args.criterion == 'acc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adam_lr_decay(args.lr_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adam_lr_decay(args.lr_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters,
                                      'iteration')))
    report_keys = [
        'epoch', 'iteration', 'main/loss', 'main/loss_st', 'main/loss_asr',
        'validation/main/loss', 'validation/main/loss_st',
        'validation/main/loss_asr', 'main/acc', 'validation/main/acc'
    ]
    if args.asr_weight > 0:
        report_keys.append('main/acc_asr')
        report_keys.append('validation/main/acc_asr')
    report_keys += ['elapsed_time']
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').
            param_groups[0]["eps"]),
                       trigger=(args.report_interval_iters, 'iteration'))
        report_keys.append('eps')
    elif args.opt in ['adam', 'noam']:
        trainer.extend(extensions.observe_value(
            'lr', lambda trainer: trainer.updater.get_optimizer('main').
            param_groups[0]["lr"]),
                       trigger=(args.report_interval_iters, 'iteration'))
        report_keys.append('lr')
    if args.asr_weight > 0:
        if args.mtlalpha > 0:
            report_keys.append('main/cer_ctc')
            report_keys.append('validation/main/cer_ctc')
        if args.mtlalpha < 1:
            if args.report_cer:
                report_keys.append('validation/main/cer')
            if args.report_wer:
                report_keys.append('validation/main/wer')
    if args.report_bleu:
        report_keys.append('validation/main/bleu')
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(args.report_interval_iters, 'iteration'))

    trainer.extend(
        extensions.ProgressBar(update_interval=args.report_interval_iters))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        trainer.extend(TensorboardLogger(SummaryWriter(args.tensorboard_dir),
                                         att_reporter),
                       trigger=(args.report_interval_iters, "iteration"))

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Beispiel #9
0
def train(args):
    '''Run training'''
    # seed setting
    torch.manual_seed(args.seed)

    # debug mode setting
    # 0 would be fastest, but 1 seems to be reasonable
    # by considering reproducability
    # revmoe type check
    if args.debugmode < 2:
        chainer.config.type_check = False
        logging.info('torch type check is disabled')
    # use determinisitic computation or not
    if args.debugmode < 1:
        torch.backends.cudnn.deterministic = False
        logging.info('torch cudnn deterministic is disabled')
    else:
        torch.backends.cudnn.deterministic = True

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['input'][0]['shape'][1])
    odim = int(valid_json[utts[0]]['output'][0]['shape'][1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))
    odim_adv = None
    if args.adv:
        odim_adv = int(valid_json[utts[0]]['output'][1]['shape'][1])
        logging.info('#output dims adversarial: ' + str(odim_adv))

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    # specify model architecture
    e2e = E2E(idim, odim, args, odim_adv=odim_adv)
    model = Loss(e2e, args.mtlalpha)

    if args.rnnlm is not None:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer,
                             rnnlm_args.unit))
        torch_load(args.rnnlm, rnnlm)
        e2e.rnnlm = rnnlm

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(
            json.dumps((idim, odim, odim_adv, vars(args)),
                       indent=4,
                       sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
        logging.info('batch size is automatically increased (%d -> %d)' %
                     (args.batch_size, args.batch_size * args.ngpu))
        args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)

    # Setup an optimizer
    # First distinguish between learning rates
    if args.ngpu > 1:
        param_grp = [{
            'params': model.module.predictor.enc.parameters(),
            'lr': args.asr_lr
        }, {
            'params': model.module.predictor.dec.parameters(),
            'lr': args.asr_lr
        }, {
            'params': model.module.predictor.adv.parameters(),
            'lr': args.adv_lr
        }]
    else:
        param_grp = [{
            'params': model.predictor.enc.parameters(),
            'lr': args.asr_lr
        }, {
            'params': model.predictor.dec.parameters(),
            'lr': args.asr_lr
        }, {
            'params': model.predictor.adv.parameters(),
            'lr': args.adv_lr
        }]
    if args.opt == 'adadelta':
        optimizer = torch.optim.Adadelta(param_grp, rho=0.95, eps=args.eps)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(param_grp)

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter(e2e.subsample[0])

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    # make minibatch list (variable length)
    train = make_batchset(train_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1)
    valid = make_batchset(valid_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1)
    # hack to make batchsze argument as 1
    # actual bathsize is included in a list
    if args.n_iter_processes > 0:
        train_iter = chainer.iterators.MultiprocessIterator(
            TransformDataset(train, converter.transform),
            batch_size=1,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20)
        valid_iter = chainer.iterators.MultiprocessIterator(
            TransformDataset(valid, converter.transform),
            batch_size=1,
            repeat=False,
            shuffle=False,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20)
    else:
        train_iter = chainer.iterators.SerialIterator(TransformDataset(
            train, converter.transform),
                                                      batch_size=1)
        valid_iter = chainer.iterators.SerialIterator(TransformDataset(
            valid, converter.transform),
                                                      batch_size=1,
                                                      repeat=False,
                                                      shuffle=False)

    # Prepare adversarial training schedule dictionary
    adv_schedule = get_advsched(args.adv, args.epochs)

    # Set up a trainer
    updater = CustomUpdater(model,
                            args.grad_clip,
                            train_iter,
                            optimizer,
                            converter,
                            device,
                            args.ngpu,
                            adv_schedule=adv_schedule,
                            max_grlalpha=args.grlalpha)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        #torch_resume(args.resume, trainer, weight_sharing=args.weight_sharing)
        torch_resume(args.resume,
                     trainer,
                     weight_sharing=args.weight_sharing,
                     reinit_adv=args.reinit_adv)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        CustomEvaluator(model, valid_iter, reporter, converter, device))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.predictor.calculate_all_attentions
        else:
            att_vis_fn = model.predictor.calculate_all_attentions
        trainer.extend(PlotAttentionReport(att_vis_fn,
                                           data,
                                           args.outdir + "/att_ws",
                                           converter=converter,
                                           device=device),
                       trigger=(1, 'epoch'))

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport([
            'main/loss', 'validation/main/loss', 'main/loss_ctc',
            'validation/main/loss_ctc', 'main/loss_att',
            'validation/main/loss_att', 'main/loss_adv',
            'validation/main/loss_adv'
        ],
                              'epoch',
                              file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport([
            'main/acc', 'validation/main/acc', 'main/acc_adv',
            'validation/main/acc_adv'
        ],
                              'epoch',
                              file_name='acc.png'))

    # Save best models
    trainer.extend(
        extensions.snapshot_object(model,
                                   'model.loss.best',
                                   savefun=torch_save),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    if mtl_mode is not 'ctc':
        trainer.extend(
            extensions.snapshot_object(model,
                                       'model.acc.best',
                                       savefun=torch_save),
            trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # save snapshot which contains model and optimizer states
    trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc' and mtl_mode is not 'ctc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL,
                                                 'iteration')))
    report_keys = [
        'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
        'validation/main/loss', 'validation/main/loss_ctc',
        'validation/main/loss_att', 'main/acc', 'validation/main/acc',
        'elapsed_time'
    ]
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').
            param_groups[0]["eps"]),
                       trigger=(REPORT_INTERVAL, 'iteration'))
        report_keys.append('eps')
    if args.report_cer:
        report_keys.append('validation/main/cer')
    if args.report_wer:
        report_keys.append('validation/main/wer')
    if args.adv:
        report_keys.extend([
            'main/loss_adv', 'main/acc_adv', 'validation/main/loss_adv',
            'validation/main/acc_adv'
        ])
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(REPORT_INTERVAL, 'iteration'))

    trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL))

    # Run the training
    trainer.run()
Beispiel #10
0
def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    set_deterministic_pytorch(args)
    if args.num_encs > 1:
        args = format_mulenc_args(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning("cuda is not available")

    # get input and output dimension info
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]
    utts = list(valid_json.keys())
    idim_list = [
        int(valid_json[utts[0]]["input"][i]["shape"][-1])
        for i in range(args.num_encs)
    ]
    odim = int(valid_json[utts[0]]["output"][0]["shape"][-1])
    for i in range(args.num_encs):
        logging.info("stream{}: input dims : {}".format(i + 1, idim_list[i]))
    logging.info("#output dims: " + str(odim))

    # specify attention, CTC, hybrid mode
    if "transducer" in args.model_module:
        assert args.mtlalpha == 1.0
        mtl_mode = "transducer"
        logging.info("Pure transducer mode")
    if args.mtlalpha == 1.0:
        mtl_mode = "ctc"
        logging.info("Pure CTC mode")
    elif args.mtlalpha == 0.0:
        mtl_mode = "att"
        logging.info("Pure attention mode")
    else:
        mtl_mode = "mtl"
        logging.info("Multitask learning mode")

    if (args.enc_init is not None
            or args.dec_init is not None) and args.num_encs == 1:
        model = load_trained_modules(idim_list[0], odim, args)
    else:
        model_class = dynamic_import(args.model_module)
        model = model_class(idim_list[0] if args.num_encs == 1 else idim_list,
                            odim, args)
    assert isinstance(model, ASRInterface)

    if args.rnnlm is not None:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer,
                             rnnlm_args.unit))
        torch_load(args.rnnlm, rnnlm)
        model.rnnlm = rnnlm

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + "/model.json"
    with open(model_conf, "wb") as f:
        logging.info("writing a model config file to " + model_conf)
        f.write(
            json.dumps(
                (idim_list[0] if args.num_encs == 1 else idim_list, odim,
                 vars(args)),
                indent=4,
                ensure_ascii=False,
                sort_keys=True,
            ).encode("utf_8"))
    for key in sorted(vars(args).keys()):
        logging.info("ARGS: " + key + ": " + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        if args.batch_size != 0:
            logging.warning(
                "batch size is automatically increased (%d -> %d)" %
                (args.batch_size, args.batch_size * args.ngpu))
            args.batch_size *= args.ngpu
        if args.num_encs > 1:
            # TODO(ruizhili): implement data parallel for multi-encoder setup.
            raise NotImplementedError(
                "Data parallel is not supported for multi-encoder setup.")

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    if args.train_dtype in ("float16", "float32", "float64"):
        dtype = getattr(torch, args.train_dtype)
    else:
        dtype = torch.float32
    model = model.to(device=device, dtype=dtype)

    # Setup an optimizer
    if args.opt == "adadelta":
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         rho=0.95,
                                         eps=args.eps,
                                         weight_decay=args.weight_decay)
    elif args.opt == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     weight_decay=args.weight_decay)
    elif args.opt == "noam":
        from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt

        optimizer = get_std_opt(model, args.adim,
                                args.transformer_warmup_steps,
                                args.transformer_lr)
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # setup apex.amp
    if args.train_dtype in ("O0", "O1", "O2", "O3"):
        try:
            from apex import amp
        except ImportError as e:
            logging.error(
                f"You need to install apex for --train-dtype {args.train_dtype}. "
                "See https://github.com/NVIDIA/apex#linux")
            raise e
        if args.opt == "noam":
            model, optimizer.optimizer = amp.initialize(
                model, optimizer.optimizer, opt_level=args.train_dtype)
        else:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.train_dtype)
        use_apex = True

        from espnet.nets.pytorch_backend.ctc import CTC

        amp.register_float_function(CTC, "loss_fn")
        amp.init()
        logging.warning("register ctc as float function")
    else:
        use_apex = False

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    if args.num_encs == 1:
        converter = CustomConverter(subsampling_factor=model.subsample[0],
                                    dtype=dtype)
    else:
        converter = CustomConverterMulEnc([i[0] for i in model.subsample_list],
                                          dtype=dtype)

    # read json data
    with open(args.train_json, "rb") as f:
        train_json = json.load(f)["utts"]
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # make minibatch list (variable length)
    train = make_batchset(
        train_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        shortest_first=use_sortagrad,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        iaxis=0,
        oaxis=0,
    )
    valid = make_batchset(
        valid_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        iaxis=0,
        oaxis=0,
    )

    load_tr = LoadInputsAndTargets(
        mode="asr",
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={"train": True},  # Switch the mode of preprocessing
    )
    load_cv = LoadInputsAndTargets(
        mode="asr",
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={"train": False},  # Switch the mode of preprocessing
    )
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    # default collate function converts numpy array to pytorch tensor
    # we used an empty collate function instead which returns list
    train_iter = ChainerDataLoader(
        dataset=TransformDataset(train,
                                 lambda data: converter([load_tr(data)])),
        batch_size=1,
        num_workers=args.n_iter_processes,
        shuffle=not use_sortagrad,
        collate_fn=lambda x: x[0],
    )
    valid_iter = ChainerDataLoader(
        dataset=TransformDataset(valid,
                                 lambda data: converter([load_cv(data)])),
        batch_size=1,
        shuffle=False,
        collate_fn=lambda x: x[0],
        num_workers=args.n_iter_processes,
    )

    # Set up a trainer
    updater = CustomUpdater(
        model,
        args.grad_clip,
        {"main": train_iter},
        optimizer,
        device,
        args.ngpu,
        args.grad_noise,
        args.accum_grad,
        use_apex=use_apex,
    )
    trainer = training.Trainer(updater, (args.epochs, "epoch"),
                               out=args.outdir)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     "epoch"),
        )

    # Resume from a snapshot
    if args.resume:
        logging.info("resumed from %s" % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    if args.save_interval_iters > 0:
        trainer.extend(
            CustomEvaluator(model, {"main": valid_iter}, reporter, device,
                            args.ngpu),
            trigger=(args.save_interval_iters, "iteration"),
        )
    else:
        trainer.extend(
            CustomEvaluator(model, {"main": valid_iter}, reporter, device,
                            args.ngpu))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and (mtl_mode == "transducer" and getattr(
            args, "rnnt_mode", False) == "rnnt"):
        data = sorted(
            list(valid_json.items())[:args.num_save_attention],
            key=lambda x: int(x[1]["input"][0]["shape"][1]),
            reverse=True,
        )
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        att_reporter = plot_class(
            att_vis_fn,
            data,
            args.outdir + "/att_ws",
            converter=converter,
            transform=load_cv,
            device=device,
        )
        trainer.extend(att_reporter, trigger=(1, "epoch"))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    if args.num_encs > 1:
        report_keys_loss_ctc = [
            "main/loss_ctc{}".format(i + 1) for i in range(model.num_encs)
        ] + [
            "validation/main/loss_ctc{}".format(i + 1)
            for i in range(model.num_encs)
        ]
        report_keys_cer_ctc = [
            "main/cer_ctc{}".format(i + 1) for i in range(model.num_encs)
        ] + [
            "validation/main/cer_ctc{}".format(i + 1)
            for i in range(model.num_encs)
        ]
    trainer.extend(
        extensions.PlotReport(
            [
                "main/loss",
                "validation/main/loss",
                "main/loss_ctc",
                "validation/main/loss_ctc",
                "main/loss_att",
                "validation/main/loss_att",
            ] + ([] if args.num_encs == 1 else report_keys_loss_ctc),
            "epoch",
            file_name="loss.png",
        ))
    trainer.extend(
        extensions.PlotReport(["main/acc", "validation/main/acc"],
                              "epoch",
                              file_name="acc.png"))
    trainer.extend(
        extensions.PlotReport(
            ["main/cer_ctc", "validation/main/cer_ctc"] +
            ([] if args.num_encs == 1 else report_keys_loss_ctc),
            "epoch",
            file_name="cer.png",
        ))

    # Save best models
    trainer.extend(
        snapshot_object(model, "model.loss.best"),
        trigger=training.triggers.MinValueTrigger("validation/main/loss"),
    )
    if mtl_mode not in ["ctc", "transducer"]:
        trainer.extend(
            snapshot_object(model, "model.acc.best"),
            trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
        )

    # save snapshot which contains model and optimizer states
    if args.save_interval_iters > 0:
        trainer.extend(
            torch_snapshot(filename="snapshot.iter.{.updater.iteration}"),
            trigger=(args.save_interval_iters, "iteration"),
        )
    else:
        trainer.extend(torch_snapshot(), trigger=(1, "epoch"))

    # epsilon decay in the optimizer
    if args.opt == "adadelta":
        if args.criterion == "acc" and mtl_mode != "ctc":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.acc.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
            trainer.extend(
                adadelta_eps_decay(args.eps_decay),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
        elif args.criterion == "loss":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.loss.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )
            trainer.extend(
                adadelta_eps_decay(args.eps_decay),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters,
                                      "iteration")))
    report_keys = [
        "epoch",
        "iteration",
        "main/loss",
        "main/loss_ctc",
        "main/loss_att",
        "validation/main/loss",
        "validation/main/loss_ctc",
        "validation/main/loss_att",
        "main/acc",
        "validation/main/acc",
        "main/cer_ctc",
        "validation/main/cer_ctc",
        "elapsed_time",
    ] + ([] if args.num_encs == 1 else report_keys_cer_ctc +
         report_keys_loss_ctc)
    if args.opt == "adadelta":
        trainer.extend(
            extensions.observe_value(
                "eps",
                lambda trainer: trainer.updater.get_optimizer("main").
                param_groups[0]["eps"],
            ),
            trigger=(args.report_interval_iters, "iteration"),
        )
        report_keys.append("eps")
    if args.report_cer:
        report_keys.append("validation/main/cer")
    if args.report_wer:
        report_keys.append("validation/main/wer")
    trainer.extend(
        extensions.PrintReport(report_keys),
        trigger=(args.report_interval_iters, "iteration"),
    )

    trainer.extend(
        extensions.ProgressBar(update_interval=args.report_interval_iters))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        trainer.extend(
            TensorboardLogger(SummaryWriter(args.tensorboard_dir),
                              att_reporter),
            trigger=(args.report_interval_iters, "iteration"),
        )
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Beispiel #11
0
def train(args):
    '''Run training'''
    # seed setting
    torch.manual_seed(args.seed)

    # debug mode setting
    # 0 would be fastest, but 1 seems to be reasonable
    # by considering reproducability
    # revmoe type check
    if args.debugmode < 2:
        chainer.config.type_check = False
        logging.info('torch type check is disabled')
    # use determinisitic computation or not
    if args.debugmode < 1:
        torch.backends.cudnn.deterministic = False
        logging.info('torch cudnn deterministic is disabled')
    else:
        torch.backends.cudnn.deterministic = True

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['input'][0]['shape'][1])
    odim = int(valid_json[utts[0]]['output'][0]['shape'][1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))
    odim_adv = int(valid_json[utts[0]]['output'][1]['shape'][1])
    logging.info('#output dims adversarial: ' + str(odim_adv))

    # specify model architecture
    e2e = E2E(idim, odim_adv, args)
    model = Loss(e2e)

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(
            json.dumps((idim, odim, odim_adv, vars(args)),
                       indent=4,
                       sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # Log total number of parameters
    pytorch_total_params = sum(p.numel() for p in e2e.parameters())
    logging.info("Total parameters in e2e: " + str(pytorch_total_params))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
        logging.info('batch size is automatically increased (%d -> %d)' %
                     (args.batch_size, args.batch_size * args.ngpu))
        args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         rho=0.95,
                                         eps=args.eps)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters())

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter()

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    # make minibatch list (variable length)
    train = make_batchset(train_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1)
    valid = make_batchset(valid_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1)
    # hack to make batchsze argument as 1
    # actual bathsize is included in a list
    if args.n_iter_processes > 0:
        train_iter = chainer.iterators.MultiprocessIterator(
            TransformDataset(train, converter.transform),
            batch_size=1,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20)
        valid_iter = chainer.iterators.MultiprocessIterator(
            TransformDataset(valid, converter.transform),
            batch_size=1,
            repeat=False,
            shuffle=False,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20)
    else:
        train_iter = chainer.iterators.SerialIterator(TransformDataset(
            train, converter.transform),
                                                      batch_size=1)
        valid_iter = chainer.iterators.SerialIterator(TransformDataset(
            valid, converter.transform),
                                                      batch_size=1,
                                                      repeat=False,
                                                      shuffle=False)

    # Set up a trainer
    updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer,
                            converter, device, args.ngpu)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer, weight_sharing=args.weight_sharing)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        CustomEvaluator(model, valid_iter, reporter, converter, device))

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport(['main/loss_adv', 'validation/main/loss_adv'],
                              'epoch',
                              file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/acc_adv', 'validation/main/acc_adv'],
                              'epoch',
                              file_name='acc.png'))

    # Save best models
    trainer.extend(
        extensions.snapshot_object(model,
                                   'model.loss.best',
                                   savefun=torch_save),
        trigger=training.triggers.MinValueTrigger('validation/main/loss_adv'))
    trainer.extend(
        extensions.snapshot_object(model, 'model.acc.best',
                                   savefun=torch_save),
        trigger=training.triggers.MaxValueTrigger('validation/main/acc_adv'))

    # save snapshot which contains model and optimizer states
    trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/acc_adv', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc_adv', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/loss_adv', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss_adv', lambda best_value,
                               current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL,
                                                 'iteration')))
    report_keys = ['epoch', 'iteration', 'elapsed_time']
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').
            param_groups[0]["eps"]),
                       trigger=(REPORT_INTERVAL, 'iteration'))
        report_keys.append('eps')
    report_keys.extend([
        'main/loss_adv', 'main/acc_adv', 'validation/main/loss_adv',
        'validation/main/acc_adv'
    ])
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(REPORT_INTERVAL, 'iteration'))

    trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL))

    # Run the training
    trainer.run()
Beispiel #12
0
def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['input'][0]['shape'][-1])
    odim = int(valid_json[utts[0]]['output'][0]['shape'][-1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    asr_model, mt_model = None, None
    # Initialize encoder with pre-trained ASR encoder
    if args.asr_model:
        asr_model, _ = load_trained_model(args.asr_model)
        assert isinstance(asr_model, ASRInterface)

    # Initialize decoder with pre-trained MT decoder
    if args.mt_model:
        mt_model, _ = load_trained_model(args.mt_model)
        assert isinstance(mt_model, MTInterface)

    # specify model architecture
    model_class = dynamic_import(args.model_module)
    # TODO(hirofumi0810) better to simplify the E2E model interface by only allowing idim, odim, and args
    # the pre-trained ASR and MT model arguments should be removed here and we should implement an additional method
    # to attach these models
    if asr_model is None and mt_model is None:
        model = model_class(idim, odim, args)
    elif mt_model is None:
        model = asr_model
    else:
        model = model_class(idim,
                            odim,
                            args,
                            asr_model=asr_model,
                            mt_model=mt_model)
    assert isinstance(model, ASRInterface)
    subsampling_factor = model.subsample[0]

    # delete pre-trained models
    if args.asr_model:
        del asr_model
    if args.mt_model:
        del mt_model

    if args.slu_model and args.slu_loss:
        model.add_slu(args.slu_model, args.slu_loss, args.slu_tune_weights,
                      args.slu_pooling)

    if args.rnnlm is not None:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer,
                             rnnlm_args.unit))
        torch.load(args.rnnlm, rnnlm)
        model.rnnlm = rnnlm

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)),
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
        if args.batch_size != 0:
            logging.info('batch size is automatically increased (%d -> %d)' %
                         (args.batch_size, args.batch_size * args.ngpu))
            args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)

    scheduler = None

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         rho=0.95,
                                         eps=args.eps,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     weight_decay=args.weight_decay)
    elif args.opt == 'noam':
        from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt
        optimizer = get_std_opt(model, args.adim,
                                args.transformer_warmup_steps,
                                args.transformer_lr)
    elif args.opt == 'adamw':
        from transformers import AdamW, WarmupLinearSchedule
        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }, {
            'params': [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]
        optimizer = AdamW(optimizer_grouped_parameters, lr=5e-5, eps=1e-8)
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter(subsampling_factor=subsampling_factor)

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # make minibatch list (variable length)
    train = make_batchset(train_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          shortest_first=use_sortagrad,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout)
    valid = make_batchset(valid_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout)

    load_tr = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': True}  # Switch the mode of preprocessing
    )
    load_cv = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': False}  # Switch the mode of preprocessing
    )
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    if args.n_iter_processes > 0:
        train_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(train, load_tr),
            batch_size=1,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(valid, load_cv),
            batch_size=1,
            repeat=False,
            shuffle=False,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20)
    else:
        train_iter = ToggleableShufflingSerialIterator(
            TransformDataset(train, load_tr),
            batch_size=1,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingSerialIterator(TransformDataset(
            valid, load_cv),
                                                       batch_size=1,
                                                       repeat=False,
                                                       shuffle=False)

    # Set up a trainer
    updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer,
                            converter, device, args.ngpu, args.grad_noise,
                            args.accum_grad)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     'epoch'))
    if scheduler:
        trainer.extend(scheduler.step(), name='transformer_warmup')

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        CustomEvaluator(model, valid_iter, reporter, converter, device))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        att_reporter = plot_class(att_vis_fn,
                                  data,
                                  args.outdir + "/att_ws",
                                  converter=converter,
                                  transform=load_cv,
                                  device=device)
        trainer.extend(att_reporter, trigger=(1, 'epoch'))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport([
            'main/loss', 'validation/main/loss', 'main/loss_ctc',
            'validation/main/loss_ctc', 'main/loss_att',
            'validation/main/loss_att'
        ],
                              'epoch',
                              file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/acc', 'validation/main/acc'],
                              'epoch',
                              file_name='acc.png'))
    trainer.extend(
        extensions.PlotReport(['main/cer_ctc', 'validation/main/cer_ctc'],
                              'epoch',
                              file_name='cer.png'))

    # Save best models
    trainer.extend(
        snapshot_object(model, 'model.loss.best'),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    if mtl_mode != 'ctc':
        trainer.extend(
            snapshot_object(model, 'model.acc.best'),
            trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # save snapshot which contains model and optimizer states
    trainer.extend(torch_snapshot())

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc' and mtl_mode != 'ctc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters,
                                      'iteration')))
    report_keys = [
        'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
        'validation/main/loss', 'validation/main/loss_ctc',
        'validation/main/loss_att', 'main/acc', 'validation/main/acc',
        'main/cer_ctc', 'validation/main/cer_ctc', 'elapsed_time'
    ]
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').
            param_groups[0]["eps"]),
                       trigger=(args.report_interval_iters, 'iteration'))
        report_keys.append('eps')
    if args.report_cer:
        report_keys.append('validation/main/cer')
    if args.report_wer:
        report_keys.append('validation/main/wer')
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(args.report_interval_iters, 'iteration'))

    trainer.extend(
        extensions.ProgressBar(update_interval=args.report_interval_iters))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        trainer.extend(TensorboardLogger(SummaryWriter(args.tensorboard_dir),
                                         att_reporter),
                       trigger=(args.report_interval_iters, "iteration"))
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Beispiel #13
0
def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning("cuda is not available")

    # get input and output dimension info
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]["input"][0]["shape"][-1])
    odim = int(valid_json[utts[0]]["output"][0]["shape"][-1])
    logging.info("#output dims: " + str(odim))

    # dynamic model
    model_class = dynamic_import(args.model_module)
    model = model_class(idim, idim, args)

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + "/model.json"
    with open(model_conf, "wb") as f:
        logging.info("writing a model config file to " + model_conf)
        f.write(
            json.dumps(
                (idim, odim, vars(args)),
                indent=4,
                ensure_ascii=False,
                sort_keys=True,
            ).encode("utf_8"))
    for key in sorted(vars(args).keys()):
        logging.info("ARGS: " + key + ": " + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        if args.batch_size != 0:
            logging.warning(
                "batch size is automatically increased (%d -> %d)" %
                (args.batch_size, args.batch_size * args.ngpu))
            args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    if args.train_dtype in ("float16", "float32", "float64"):
        dtype = getattr(torch, args.train_dtype)
    else:
        dtype = torch.float32
    model = model.to(device=device, dtype=dtype)

    # Setup an optimizer
    if args.opt == "adadelta":
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         rho=0.95,
                                         eps=args.eps,
                                         weight_decay=args.weight_decay)
    elif args.opt == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     weight_decay=args.weight_decay,
                                     lr=args.lr)
    elif args.opt == "noam":
        from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt

        optimizer = get_std_opt(model, args.hdim, args.warmup_steps, args.lr)
    elif args.opt == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # setup apex.amp
    if args.train_dtype in ("O0", "O1", "O2", "O3"):
        try:
            from apex import amp
        except ImportError as e:
            logging.error(
                f"You need to install apex for --train-dtype {args.train_dtype}. "
                "See https://github.com/NVIDIA/apex#linux")
            raise e
        if args.opt == "noam":
            model, optimizer.optimizer = amp.initialize(
                model, optimizer.optimizer, opt_level=args.train_dtype)
        else:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.train_dtype)
        use_apex = True
    else:
        use_apex = False

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter(tnum=args.tnum,
                                subsampling_factor=model.subsample[0],
                                dtype=dtype)

    # read json data
    with open(args.train_json, "rb") as f:
        train_json = json.load(f)["utts"]
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # make minibatch list (variable length)
    train = make_batchset(
        train_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        shortest_first=use_sortagrad,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        iaxis=0,
        oaxis=0,
    )
    valid = make_batchset(
        valid_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        iaxis=0,
        oaxis=0,
    )

    load_tr = LoadInputsAndTargets(
        mode="asr",
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={"train": True},  # Switch the mode of preprocessing
    )
    load_cv = LoadInputsAndTargets(
        mode="asr",
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={"train": False},  # Switch the mode of preprocessing
    )
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    # default collate function converts numpy array to pytorch tensor
    # we used an empty collate function instead which returns list
    train_iter = ChainerDataLoader(
        dataset=TransformDataset(train,
                                 lambda data: converter([load_tr(data)])),
        batch_size=1,
        num_workers=args.n_iter_processes,
        shuffle=not use_sortagrad,
        collate_fn=lambda x: x[0],
    )
    valid_iter = ChainerDataLoader(
        dataset=TransformDataset(valid,
                                 lambda data: converter([load_cv(data)])),
        batch_size=1,
        shuffle=False,
        collate_fn=lambda x: x[0],
        num_workers=args.n_iter_processes,
    )

    # Set up a trainer
    updater = CustomUpdater(
        model,
        args.grad_clip,
        {"main": train_iter},
        optimizer,
        device,
        args.ngpu,
        args.grad_noise,
        args.accum_grad,
        use_apex=use_apex,
    )
    trainer = training.Trainer(updater, (args.epochs, "epoch"),
                               out=args.outdir)
    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     "epoch"),
        )

    # Resume from a snapshot
    if args.resume:
        logging.info("resumed from %s" % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    if args.save_interval_iters > 0:
        trainer.extend(
            CustomEvaluator(model, {"main": valid_iter}, reporter, device,
                            args.ngpu),
            trigger=(args.save_interval_iters, "iteration"),
        )
    else:
        trainer.extend(
            CustomEvaluator(model, {"main": valid_iter}, reporter, device,
                            args.ngpu))

    data = sorted(
        list(valid_json.items())[:args.num_save_attention],
        key=lambda x: int(x[1]["input"][0]["shape"][1]),
        reverse=True,
    )
    vis_fn = model.calculate_images
    plot_class = model.images_plot_class
    img_reporter = plot_class(
        vis_fn,
        data,
        args.outdir + "/images",
        converter=converter,
        transform=load_cv,
        device=device,
    )
    trainer.extend(img_reporter,
                   trigger=(args.report_interval_iters, "iteration"))
    # trainer.extend(img_reporter, trigger=(1, "epoch"))

    trainer.extend(
        extensions.PlotReport(
            ["main/loss", "validation/main/loss"],
            "epoch",
            file_name="loss.png",
        ))

    # Save best models
    trainer.extend(
        snapshot_object(model, "model.loss.best"),
        trigger=training.triggers.MinValueTrigger("validation/main/loss"),
    )

    # save snapshot which contains model and optimizer states
    if args.save_interval_iters > 0:
        trainer.extend(
            torch_snapshot(filename="snapshot.iter.{.updater.iteration}"),
            trigger=(args.save_interval_iters, "iteration"),
        )
    else:
        trainer.extend(torch_snapshot(), trigger=(1, "epoch"))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters,
                                      "iteration")))
    report_keys = [
        "epoch",
        "iteration",
        "main/loss",
        "validation/main/loss",
        "main/loss_1",
        "validation/main/loss_1"
        "main/loss_2",
        "validation/main/loss_2",
        "main/loss_3",
        "validation/main/loss_3",
        "main/loss_4",
        "validation/main/loss_4",
        "elapsed_time",
    ]
    if args.opt == "adadelta":
        trainer.extend(
            extensions.observe_value(
                "eps",
                lambda trainer: trainer.updater.get_optimizer("main").
                param_groups[0]["eps"],
            ),
            trigger=(args.report_interval_iters, "iteration"),
        )
        report_keys.append("eps")
    trainer.extend(
        extensions.PrintReport(report_keys),
        trigger=(args.report_interval_iters, "iteration"),
    )
    trainer.extend(
        extensions.ProgressBar(update_interval=args.report_interval_iters))

    # Run the training
    trainer.run()
Beispiel #14
0
def train(args):
    """Train with the given args

    :param Namespace args: The program arguments
    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['input'][0]['shape'][1])
    odim = int(valid_json[utts[0]]['output'][0]['shape'][1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    # specify model architecture
    model = E2E(idim, odim, args)
    subsampling_factor = model.subsample[0]

    if args.rnnlm is not None:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(
                len(args.char_list), rnnlm_args.layer, rnnlm_args.unit))
        torch.load(args.rnnlm, rnnlm)
        model.rnnlm = rnnlm

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(json.dumps((idim, odim, vars(args)), indent=4, sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
        logging.info('batch size is automatically increased (%d -> %d)' % (
            args.batch_size, args.batch_size * args.ngpu))
        args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = torch.optim.Adadelta(
            model.parameters(), rho=0.95, eps=args.eps,
            weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     weight_decay=args.weight_decay)

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter(subsampling_factor=subsampling_factor,
                                preprocess_conf=args.preprocess_conf)

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    # make minibatch list (variable length)
    train = make_batchset(train_json, args.batch_size,
                          args.maxlen_in, args.maxlen_out, args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1)
    valid = make_batchset(valid_json, args.batch_size,
                          args.maxlen_in, args.maxlen_out, args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1)
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    if args.n_iter_processes > 0:
        train_iter = chainer.iterators.MultiprocessIterator(
            TransformDataset(train, converter.transform),
            batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20)
        valid_iter = chainer.iterators.MultiprocessIterator(
            TransformDataset(valid, converter.transform),
            batch_size=1, repeat=False, shuffle=False,
            n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20)
    else:
        train_iter = chainer.iterators.SerialIterator(
            TransformDataset(train, converter.transform),
            batch_size=1)
        valid_iter = chainer.iterators.SerialIterator(
            TransformDataset(valid, converter.transform),
            batch_size=1, repeat=False, shuffle=False)

    # Set up a trainer
    updater = CustomUpdater(
        model, args.grad_clip, train_iter, optimizer, converter, device, args.ngpu)
    trainer = training.Trainer(
        updater, (args.epochs, 'epoch'), out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(CustomEvaluator(model, valid_iter, reporter, converter, device))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
        else:
            att_vis_fn = model.calculate_all_attentions
        att_reporter = PlotAttentionReport(
            att_vis_fn, data, args.outdir + "/att_ws",
            converter=converter, device=device)
        trainer.extend(att_reporter, trigger=(1, 'epoch'))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss',
                                          'main/loss_ctc', 'validation/main/loss_ctc',
                                          'main/loss_att', 'validation/main/loss_att'],
                                         'epoch', file_name='loss.png'))
    trainer.extend(extensions.PlotReport(['main/acc', 'validation/main/acc'],
                                         'epoch', file_name='acc.png'))

    # Save best models
    trainer.extend(extensions.snapshot_object(model, 'model.loss.best', savefun=torch_save),
                   trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    if mtl_mode is not 'ctc':
        trainer.extend(extensions.snapshot_object(model, 'model.acc.best', savefun=torch_save),
                       trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # save snapshot which contains model and optimizer states
    trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc' and mtl_mode is not 'ctc':
            trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best', load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/acc',
                               lambda best_value, current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc',
                               lambda best_value, current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best', load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/loss',
                               lambda best_value, current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss',
                               lambda best_value, current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL, 'iteration')))
    report_keys = ['epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
                   'validation/main/loss', 'validation/main/loss_ctc', 'validation/main/loss_att',
                   'main/acc', 'validation/main/acc', 'elapsed_time']
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').param_groups[0]["eps"]),
            trigger=(REPORT_INTERVAL, 'iteration'))
        report_keys.append('eps')
    if args.report_cer:
        report_keys.append('validation/main/cer')
    if args.report_wer:
        report_keys.append('validation/main/wer')
    trainer.extend(extensions.PrintReport(
        report_keys), trigger=(REPORT_INTERVAL, 'iteration'))

    trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        writer = SummaryWriter(args.tensorboard_dir)
        trainer.extend(TensorboardLogger(writer, att_reporter))
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Beispiel #15
0
def train(args):
    """Train with the given args

    :param Namespace args: The program arguments
    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output
    src_path = os.path.join(args.base_dir, args.source)
    trg_path = os.path.join(args.base_dir, args.target)
    feeder = DataFeeder(src_path, trg_path, args)

    # Setup a converter
    converter = CustomConverter()

    train_batchset = feeder.make_train_batches()
    valid_batchset = feeder.make_test_batches()

    train_iter = SerialIterator(TransformDataset(train_batchset,
                                                 converter.transform),
                                batch_size=1,
                                shuffle=True)
    valid_iter = SerialIterator(TransformDataset(valid_batchset,
                                                 converter.transform),
                                batch_size=1,
                                repeat=False,
                                shuffle=False)

    #pdb.set_trace()
    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to' + model_conf)
        f.write(
            json.dumps(vars(args), indent=4, sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # specify model architecture
    translacotron = Translacotron(args)
    logging.info(translacotron)
    logging.info("Make Translacotron Model Done")
    # check the use of multi-gpu
    if args.ngpu > 1:
        translacotron = torch.nn.DataParallel(translacotron,
                                              device_ids=list(range(
                                                  args.ngpu)))
        logging.info('batch size is automatically increased (%d -> %d)' %
                     (args.batch_size, args.batch_size * args.ngpu))
        args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    translacotron = translacotron.to(device)

    # define loss
    model = TranslacotronLoss(translacotron, args.use_masking,
                              args.bce_pos_weight)
    reporter = model.reporter

    # Setup an optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr,
                                 eps=args.eps,
                                 weight_decay=args.weight_decay)

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, 'target', reporter)
    setattr(optimizer, 'serialize', lambda s: reporter.serialize(s))

    # Set up a trainer
    updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer,
                            converter, device)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        CustomEvaluator(model, valid_iter, reporter, converter, device))

    # Save snapshot for each epoch
    trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))

    # Save best models
    trainer.extend(
        extensions.snapshot_object(translacotron,
                                   'model.loss.best',
                                   savefun=torch_save),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))

    # Save attention figure for each epoch

    # Make a plot for training and validation values
    plot_keys = [
        'main/loss', 'validation/main/loss', 'main/l1_loss',
        'validation/main/l1_loss', 'main/mse_loss', 'validation/main/mse_loss',
        'main/bce_loss', 'validation/main/bce_loss'
    ]
    trainer.extend(
        extensions.PlotReport(['main/l1_loss', 'validation/main/l1_loss'],
                              'epoch',
                              file_name='l1_loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/mse_loss', 'validation/main/mse_loss'],
                              'epoch',
                              file_name='mse_loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/bce_loss', 'validation/main/bce_loss'],
                              'epoch',
                              file_name='bce_loss.png'))
    if args.use_cbhg:
        plot_keys += [
            'main/cbhg_l1_loss', 'validation/main/cbhg_l1_loss',
            'main/cbhg_mse_loss', 'validation/main/cbhg_mse_loss'
        ]
        trainer.extend(
            extensions.PlotReport(
                ['main/cbhg_l1_loss', 'validation/main/cbhg_l1_loss'],
                'epoch',
                file_name='cbhg_l1_loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/cbhg_mse_loss', 'validation/main/cbhg_mse_loss'],
                'epoch',
                file_name='cbhg_mse_loss.png'))
    trainer.extend(
        extensions.PlotReport(plot_keys, 'epoch', file_name='loss.png'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL,
                                                 'iteration')))
    report_keys = plot_keys[:]
    report_keys[0:0] = ['epoch', 'iteration', 'elapsed_time']
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(REPORT_INTERVAL, 'iteration'))
    trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL))

    set_early_stop(trainer, args)
    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        writer = SummaryWriter(args.tensorboard_dir)
        trainer.extend(TensorboardLogger(writer, None))

    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Beispiel #16
0
def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning("cuda is not available")

    # get input and output dimension info
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]["output"][1]["shape"][1])
    odim = int(valid_json[utts[0]]["output"][0]["shape"][1])
    logging.info("#input dims : " + str(idim))
    logging.info("#output dims: " + str(odim))

    # specify model architecture
    model_class = dynamic_import(args.model_module)
    model = model_class(idim, odim, args)
    assert isinstance(model, MTInterface)

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + "/model.json"
    with open(model_conf, "wb") as f:
        logging.info("writing a model config file to " + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)),
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode("utf_8"))
    for key in sorted(vars(args).keys()):
        logging.info("ARGS: " + key + ": " + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        if args.batch_size != 0:
            logging.warning(
                "batch size is automatically increased (%d -> %d)" %
                (args.batch_size, args.batch_size * args.ngpu))
            args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    if args.train_dtype in ("float16", "float32", "float64"):
        dtype = getattr(torch, args.train_dtype)
    else:
        dtype = torch.float32
    model = model.to(device=device, dtype=dtype)

    logging.warning(
        "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format(
            sum(p.numel() for p in model.parameters()),
            sum(p.numel() for p in model.parameters() if p.requires_grad),
            sum(p.numel() for p in model.parameters() if p.requires_grad) *
            100.0 / sum(p.numel() for p in model.parameters()),
        ))

    # Setup an optimizer
    if args.opt == "adadelta":
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         rho=0.95,
                                         eps=args.eps,
                                         weight_decay=args.weight_decay)
    elif args.opt == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
    elif args.opt == "noam":
        from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt

        optimizer = get_std_opt(
            model.parameters(),
            args.adim,
            args.transformer_warmup_steps,
            args.transformer_lr,
        )
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # setup apex.amp
    if args.train_dtype in ("O0", "O1", "O2", "O3"):
        try:
            from apex import amp
        except ImportError as e:
            logging.error(
                f"You need to install apex for --train-dtype {args.train_dtype}. "
                "See https://github.com/NVIDIA/apex#linux")
            raise e
        if args.opt == "noam":
            model, optimizer.optimizer = amp.initialize(
                model, optimizer.optimizer, opt_level=args.train_dtype)
        else:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.train_dtype)
        use_apex = True
    else:
        use_apex = False

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter()

    # read json data
    with open(args.train_json, "rb") as f:
        train_json = json.load(f)["utts"]
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # make minibatch list (variable length)
    train = make_batchset(
        train_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        shortest_first=use_sortagrad,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        mt=True,
        iaxis=1,
        oaxis=0,
    )
    valid = make_batchset(
        valid_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        mt=True,
        iaxis=1,
        oaxis=0,
    )

    load_tr = LoadInputsAndTargets(mode="mt", load_output=True)
    load_cv = LoadInputsAndTargets(mode="mt", load_output=True)
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    # default collate function converts numpy array to pytorch tensor
    # we used an empty collate function instead which returns list
    train_iter = ChainerDataLoader(
        dataset=TransformDataset(train,
                                 lambda data: converter([load_tr(data)])),
        batch_size=1,
        num_workers=args.n_iter_processes,
        shuffle=not use_sortagrad,
        collate_fn=lambda x: x[0],
    )
    valid_iter = ChainerDataLoader(
        dataset=TransformDataset(valid,
                                 lambda data: converter([load_cv(data)])),
        batch_size=1,
        shuffle=False,
        collate_fn=lambda x: x[0],
        num_workers=args.n_iter_processes,
    )

    # Set up a trainer
    updater = CustomUpdater(
        model,
        args.grad_clip,
        {"main": train_iter},
        optimizer,
        device,
        args.ngpu,
        False,
        args.accum_grad,
        use_apex=use_apex,
    )
    trainer = training.Trainer(updater, (args.epochs, "epoch"),
                               out=args.outdir)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     "epoch"),
        )

    # Resume from a snapshot
    if args.resume:
        logging.info("resumed from %s" % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    if args.save_interval_iters > 0:
        trainer.extend(
            CustomEvaluator(model, {"main": valid_iter}, reporter, device,
                            args.ngpu),
            trigger=(args.save_interval_iters, "iteration"),
        )
    else:
        trainer.extend(
            CustomEvaluator(model, {"main": valid_iter}, reporter, device,
                            args.ngpu))

    # Save attention weight each epoch
    if args.num_save_attention > 0:
        # NOTE: sort it by output lengths
        data = sorted(
            list(valid_json.items())[:args.num_save_attention],
            key=lambda x: int(x[1]["output"][0]["shape"][0]),
            reverse=True,
        )
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        att_reporter = plot_class(
            att_vis_fn,
            data,
            args.outdir + "/att_ws",
            converter=converter,
            transform=load_cv,
            device=device,
            ikey="output",
            iaxis=1,
        )
        trainer.extend(att_reporter, trigger=(1, "epoch"))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport(["main/loss", "validation/main/loss"],
                              "epoch",
                              file_name="loss.png"))
    trainer.extend(
        extensions.PlotReport(["main/acc", "validation/main/acc"],
                              "epoch",
                              file_name="acc.png"))
    trainer.extend(
        extensions.PlotReport(["main/ppl", "validation/main/ppl"],
                              "epoch",
                              file_name="ppl.png"))
    trainer.extend(
        extensions.PlotReport(["main/bleu", "validation/main/bleu"],
                              "epoch",
                              file_name="bleu.png"))

    # Save best models
    trainer.extend(
        snapshot_object(model, "model.loss.best"),
        trigger=training.triggers.MinValueTrigger("validation/main/loss"),
    )
    trainer.extend(
        snapshot_object(model, "model.acc.best"),
        trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
    )

    # save snapshot which contains model and optimizer states
    if args.save_interval_iters > 0:
        trainer.extend(
            torch_snapshot(filename="snapshot.iter.{.updater.iteration}"),
            trigger=(args.save_interval_iters, "iteration"),
        )
    else:
        trainer.extend(torch_snapshot(), trigger=(1, "epoch"))

    # epsilon decay in the optimizer
    if args.opt == "adadelta":
        if args.criterion == "acc":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.acc.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
            trainer.extend(
                adadelta_eps_decay(args.eps_decay),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
        elif args.criterion == "loss":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.loss.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )
            trainer.extend(
                adadelta_eps_decay(args.eps_decay),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )
    elif args.opt == "adam":
        if args.criterion == "acc":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.acc.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
            trainer.extend(
                adam_lr_decay(args.lr_decay),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
        elif args.criterion == "loss":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.loss.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )
            trainer.extend(
                adam_lr_decay(args.lr_decay),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters,
                                      "iteration")))
    report_keys = [
        "epoch",
        "iteration",
        "main/loss",
        "validation/main/loss",
        "main/acc",
        "validation/main/acc",
        "main/ppl",
        "validation/main/ppl",
        "elapsed_time",
    ]
    if args.opt == "adadelta":
        trainer.extend(
            extensions.observe_value(
                "eps",
                lambda trainer: trainer.updater.get_optimizer("main").
                param_groups[0]["eps"],
            ),
            trigger=(args.report_interval_iters, "iteration"),
        )
        report_keys.append("eps")
    elif args.opt in ["adam", "noam"]:
        trainer.extend(
            extensions.observe_value(
                "lr",
                lambda trainer: trainer.updater.get_optimizer("main").
                param_groups[0]["lr"],
            ),
            trigger=(args.report_interval_iters, "iteration"),
        )
        report_keys.append("lr")
    if args.report_bleu:
        report_keys.append("main/bleu")
        report_keys.append("validation/main/bleu")
    trainer.extend(
        extensions.PrintReport(report_keys),
        trigger=(args.report_interval_iters, "iteration"),
    )

    trainer.extend(
        extensions.ProgressBar(update_interval=args.report_interval_iters))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        from torch.utils.tensorboard import SummaryWriter

        trainer.extend(
            TensorboardLogger(SummaryWriter(args.tensorboard_dir),
                              att_reporter),
            trigger=(args.report_interval_iters, "iteration"),
        )
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Beispiel #17
0
def train(args):
    """Train with the given args.

    :param Namespace args: The program arguments
    :param type model_class: LMInterface class for training
    """
    model_class = dynamic_import_lm(args.model_module, args.backend)
    assert issubclass(model_class, LMInterface), "model should implement LMInterface"
    # display torch version
    logging.info('torch version = ' + torch.__version__)

    set_deterministic_pytorch(args)

    # check cuda and cudnn availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get special label ids
    unk = args.char_list_dict['<unk>']
    eos = args.char_list_dict['<eos>']
    # read tokens as a sequence of sentences
    val, n_val_tokens, n_val_oovs = load_dataset(args.valid_label, args.char_list_dict, args.dump_hdf5_path)
    train, n_train_tokens, n_train_oovs = load_dataset(args.train_label, args.char_list_dict, args.dump_hdf5_path)
    logging.info('#vocab = ' + str(args.n_vocab))
    logging.info('#sentences in the training data = ' + str(len(train)))
    logging.info('#tokens in the training data = ' + str(n_train_tokens))
    logging.info('oov rate in the training data = %.2f %%' % (n_train_oovs / n_train_tokens * 100))
    logging.info('#sentences in the validation data = ' + str(len(val)))
    logging.info('#tokens in the validation data = ' + str(n_val_tokens))
    logging.info('oov rate in the validation data = %.2f %%' % (n_val_oovs / n_val_tokens * 100))

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # Create the dataset iterators
    batch_size = args.batchsize * max(args.ngpu, 1)
    if batch_size > args.batchsize:
        logging.info(f'batch size is automatically increased ({args.batchsize} -> {batch_size})')
    train_iter = ParallelSentenceIterator(train, batch_size,
                                          max_length=args.maxlen, sos=eos, eos=eos, shuffle=not use_sortagrad)
    val_iter = ParallelSentenceIterator(val, batch_size,
                                        max_length=args.maxlen, sos=eos, eos=eos, repeat=False)
    logging.info('#iterations per epoch = ' + str(len(train_iter.batch_indices)))
    logging.info('#total iterations = ' + str(args.epoch * len(train_iter.batch_indices)))
    # Prepare an RNNLM model
    model = model_class(args.n_vocab, args)
    reporter = Reporter()
    if args.ngpu > 0:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))).cuda()
        gpu_id = 0
    else:
        gpu_id = -1
    setattr(model, "reporter", reporter)

    # Save model conf to json
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(json.dumps(vars(args), indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8'))

    # Set up an optimizer
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters())

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    updater = BPTTUpdater(train_iter, model, optimizer, gpu_id, gradclip=args.gradclip)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.outdir)
    trainer.extend(LMEvaluator(val_iter, model, reporter, device=gpu_id))
    trainer.extend(extensions.LogReport(postprocess=compute_perplexity,
                                        trigger=(args.report_interval_iters, 'iteration')))
    trainer.extend(extensions.PrintReport(
        ['epoch', 'iteration', 'main/loss', 'perplexity', 'val_perplexity', 'elapsed_time']
    ), trigger=(args.report_interval_iters, 'iteration'))
    trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
    # Save best models
    trainer.extend(torch_snapshot(filename='snapshot.ep.{.updater.epoch}'))
    trainer.extend(snapshot_object(model, 'rnnlm.model.{.updater.epoch}'))
    # T.Hori: MinValueTrigger should be used, but it fails when resuming
    trainer.extend(MakeSymlinkToBestModel('validation/main/loss', 'rnnlm.model'))

    if use_sortagrad:
        trainer.extend(ShufflingEnabler([train_iter]),
                       trigger=(args.sortagrad if args.sortagrad != -1 else args.epoch, 'epoch'))
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    set_early_stop(trainer, args, is_lm=True)
    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        writer = SummaryWriter(args.tensorboard_dir)
        trainer.extend(TensorboardLogger(writer), trigger=(args.report_interval_iters, 'iteration'))

    trainer.run()
    check_early_stop(trainer, args.epoch)

    # compute perplexity for test set
    if args.test_label:
        logging.info('test the best model')
        torch_load(args.outdir + '/rnnlm.model.best', model)
        test = read_tokens(args.test_label, args.char_list_dict)
        n_test_tokens, n_test_oovs = count_tokens(test, unk)
        logging.info('#sentences in the test data = ' + str(len(test)))
        logging.info('#tokens in the test data = ' + str(n_test_tokens))
        logging.info('oov rate in the test data = %.2f %%' % (n_test_oovs / n_test_tokens * 100))
        test_iter = ParallelSentenceIterator(test, batch_size,
                                             max_length=args.maxlen, sos=eos, eos=eos, repeat=False)
        evaluator = LMEvaluator(test_iter, model, reporter, device=gpu_id)
        result = evaluator()
        compute_perplexity(result)
        logging.info(f"test perplexity: {result['perplexity']}")