with result.epoch("main", train=True) as train_result:
            n_iter_fun = min if args.use_smaller_data_size else max
            for i in range(n_iter_fun(n_supervised, n_unsupervised)):
                # re-shuffle and repeat smaller one
                if i % n_supervised == 0:
                    supervised_train_batch = np.random.permutation(
                        supervised_train_batch)
                if i % n_unsupervised == 0:
                    unsupervised_train_batch = shuffle_pair(
                        unsupervised_train_batch)
                sbatch = supervised_train_batch[i % n_supervised]
                ubatch = unsupervised_train_batch[i % n_unsupervised]

                # supervised forward
                with open_kaldi_feat(sbatch, train_reader) as sx:
                    loss_ctc, loss_att, acc_supervised = model.predictor(
                        sx, supervised=True)
                loss_supervised = args.mtlalpha * loss_ctc + (
                    1.0 - args.mtlalpha) * loss_att

                # unsupervised forward
                with open_kaldi_feat(ubatch, unsupervised_reader) as ux:
                    loss_text, loss_hidden, acc_text = model.predictor(
                        ux, supervised=False, discriminator=discriminator)

                loss_unsupervised = args.speech_text_ratio * loss_hidden + (
                    1.0 - args.speech_text_ratio) * loss_text

                if discriminator:
                    loss_discriminator = -loss_hidden
                    d_optimizer.zero_grad()
                    loss_discriminator.backward(retain_variables=True)
Пример #2
0
    valid_reader = lazy_io.read_dict_scp(args.valid_feat)

    best = dict(loss=float("inf"), acc=-float("inf"))
    opt_key = "eps" if args.opt == "adadelta" else "lr"
    def get_opt_param():
        return optimizer.param_groups[0][opt_key]

    # training loop
    result = GlobalResult(args.epochs, args.outdir)
    for epoch in range(args.epochs):
        model.train()
        with result.epoch("main", train=True) as train_result:
            for batch in np.random.permutation(train_batch):
                with open_kaldi_feat(batch, train_reader) as x:
                    # forward
                    loss_ctc, loss_att, acc = model.predictor(x)
                    loss = args.mtlalpha * loss_ctc + (1 - args.mtlalpha) * loss_att
                    # backward
                    optimizer.zero_grad()  # Clear the parameter gradients
                    loss.backward()  # Backprop
                    loss.detach()  # Truncate the graph
                    # compute the gradient norm to check if it is normal or not
                    grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)
                    logging.info('grad norm={}'.format(grad_norm))
                    if math.isnan(grad_norm):
                        logging.warning('grad norm is nan. Do not update model.')
                    else:
                        optimizer.step()
                    # print/plot stats to args.outdir/results
                    train_result.report({
                        "loss": loss,