Beispiel #1
0
    def gpu_decode(feat_list, gpu):
        # set default gpu and do not track gradient
        torch.cuda.set_device(gpu)
        torch.set_grad_enabled(False)

        # define model and load parameters
        if config.use_upsampling_layer:
            upsampling_factor = config.upsampling_factor
        else:
            upsampling_factor = 0
        model = WaveNet(n_quantize=config.n_quantize,
                        n_aux=config.n_aux,
                        n_resch=config.n_resch,
                        n_skipch=config.n_skipch,
                        dilation_depth=config.dilation_depth,
                        dilation_repeat=config.dilation_repeat,
                        kernel_size=config.kernel_size,
                        upsampling_factor=upsampling_factor)
        model.load_state_dict(
            torch.load(args.checkpoint,
                       map_location=lambda storage, loc: storage)["model"])
        model.eval()
        model.cuda()

        # define generator
        generator = decode_generator(
            feat_list,
            batch_size=args.batch_size,
            feature_type=config.feature_type,
            wav_transform=wav_transform,
            feat_transform=feat_transform,
            upsampling_factor=config.upsampling_factor,
            use_upsampling_layer=config.use_upsampling_layer,
            use_speaker_code=config.use_speaker_code)

        # decode
        if args.batch_size > 1:
            for feat_ids, (batch_x, batch_h, n_samples_list) in generator:
                logging.info("decoding start")
                samples_list = model.batch_fast_generate(
                    batch_x, batch_h, n_samples_list, args.intervals)
                for feat_id, samples in zip(feat_ids, samples_list):
                    wav = decode_mu_law(samples, config.n_quantize)
                    sf.write(args.outdir + "/" + feat_id + ".wav", wav,
                             args.fs, "PCM_16")
                    logging.info("wrote %s.wav in %s." %
                                 (feat_id, args.outdir))
        else:
            for feat_id, (x, h, n_samples) in generator:
                logging.info("decoding %s (length = %d)" %
                             (feat_id, n_samples))
                samples = model.fast_generate(x, h, n_samples, args.intervals)
                wav = decode_mu_law(samples, config.n_quantize)
                sf.write(args.outdir + "/" + feat_id + ".wav", wav, args.fs,
                         "PCM_16")
                logging.info("wrote %s.wav in %s." % (feat_id, args.outdir))
def test_assert_different_length_batch_generation():
    # prepare batch
    batch = 4
    length = 32
    x = np.random.randint(0, 256, size=(batch, 1))
    h = np.random.randn(batch, 28, length)
    length_list = sorted(
        list(np.random.randint(length // 2, length - 1, batch)))

    with torch.no_grad():
        net = WaveNet(256, 28, 4, 4, 10, 3, 2)
        net.apply(initialize)
        net.eval()

        # sample-by-sample generation
        gen1_list = []
        for x_, h_, length in zip(x, h, length_list):
            batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long()
            batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float()
            gen1 = net.fast_generate(batch_x, batch_h, length, 1, "argmax")
            gen1_list += [gen1]

        # batch generation
        batch_x = torch.from_numpy(x).long()
        batch_h = torch.from_numpy(h).float()
        gen2_list = net.batch_fast_generate(batch_x, batch_h, length_list, 1,
                                            "argmax")

        # assertion
        for gen1, gen2 in zip(gen1_list, gen2_list):
            np.testing.assert_array_equal(gen1, gen2)
def main():
    """RUN TRAINING."""
    parser = argparse.ArgumentParser()
    # path setting
    parser.add_argument("--waveforms", required=True,
                        type=str, help="directory or list of wav files")
    parser.add_argument("--feats", required=True,
                        type=str, help="directory or list of aux feat files")
    parser.add_argument("--stats", required=True,
                        type=str, help="hdf5 file including statistics")
    parser.add_argument("--expdir", required=True,
                        type=str, help="directory to save the model")
    parser.add_argument("--feature_type", default="world", choices=["world", "melspc"],
                        type=str, help="feature type")
    # network structure setting
    parser.add_argument("--n_quantize", default=256,
                        type=int, help="number of quantization")
    parser.add_argument("--n_aux", default=28,
                        type=int, help="number of dimension of aux feats")
    parser.add_argument("--n_resch", default=512,
                        type=int, help="number of channels of residual output")
    parser.add_argument("--n_skipch", default=256,
                        type=int, help="number of channels of skip output")
    parser.add_argument("--dilation_depth", default=10,
                        type=int, help="depth of dilation")
    parser.add_argument("--dilation_repeat", default=1,
                        type=int, help="number of repeating of dilation")
    parser.add_argument("--kernel_size", default=2,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--upsampling_factor", default=80,
                        type=int, help="upsampling factor of aux features")
    parser.add_argument("--use_upsampling_layer", default=True,
                        type=strtobool, help="flag to use upsampling layer")
    parser.add_argument("--use_speaker_code", default=False,
                        type=strtobool, help="flag to use speaker code")
    # network training setting
    parser.add_argument("--lr", default=1e-4,
                        type=float, help="learning rate")
    parser.add_argument("--weight_decay", default=0.0,
                        type=float, help="weight decay coefficient")
    parser.add_argument("--batch_length", default=20000,
                        type=int, help="batch length (if set 0, utterance batch will be used)")
    parser.add_argument("--batch_size", default=1,
                        type=int, help="batch size (if use utterance batch, batch_size will be 1.")
    parser.add_argument("--iters", default=200000,
                        type=int, help="number of iterations")
    # other setting
    parser.add_argument("--checkpoint_interval", default=10000,
                        type=int, help="how frequent saving model")
    parser.add_argument("--intervals", default=100,
                        type=int, help="log interval")
    parser.add_argument("--seed", default=1,
                        type=int, help="seed number")
    parser.add_argument("--resume", default=None, nargs="?",
                        type=str, help="model path to restart training")
    parser.add_argument("--n_gpus", default=1,
                        type=int, help="number of gpus")
    parser.add_argument("--verbose", default=1,
                        type=int, help="log level")
    args = parser.parse_args()

    # set log level
    if args.verbose == 1:
        logging.basicConfig(level=logging.INFO,
                            format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
                            datefmt='%m/%d/%Y %I:%M:%S')
    elif args.verbose > 1:
        logging.basicConfig(level=logging.DEBUG,
                            format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
                            datefmt='%m/%d/%Y %I:%M:%S')
    else:
        logging.basicConfig(level=logging.WARNING,
                            format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
                            datefmt='%m/%d/%Y %I:%M:%S')
        logging.warning("logging is disabled.")

    # show arguments
    for key, value in vars(args).items():
        logging.info("%s = %s" % (key, str(value)))

    # make experimental directory
    if not os.path.exists(args.expdir):
        os.makedirs(args.expdir)

    # fix seed
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # fix slow computation of dilated conv
    # https://github.com/pytorch/pytorch/issues/15054#issuecomment-450191923
    torch.backends.cudnn.benchmark = True

    # save args as conf
    torch.save(args, args.expdir + "/model.conf")

    # define network
    if args.use_upsampling_layer:
        upsampling_factor = args.upsampling_factor
    else:
        upsampling_factor = 0
    model = WaveNet(
        n_quantize=args.n_quantize,
        n_aux=args.n_aux,
        n_resch=args.n_resch,
        n_skipch=args.n_skipch,
        dilation_depth=args.dilation_depth,
        dilation_repeat=args.dilation_repeat,
        kernel_size=args.kernel_size,
        upsampling_factor=upsampling_factor)
    logging.info(model)
    model.apply(initialize)
    model.train()

    if args.n_gpus > 1:
        device_ids = range(args.n_gpus)
        model = torch.nn.DataParallel(model, device_ids)
        model.receptive_field = model.module.receptive_field
        if args.n_gpus > args.batch_size:
            logging.warning("batch size is less than number of gpus.")

    # define optimizer and loss
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay)
    criterion = nn.CrossEntropyLoss()

    # define transforms
    scaler = StandardScaler()
    scaler.mean_ = read_hdf5(args.stats, "/" + args.feature_type + "/mean")
    scaler.scale_ = read_hdf5(args.stats, "/" + args.feature_type + "/scale")
    wav_transform = transforms.Compose([
        lambda x: encode_mu_law(x, args.n_quantize)])
    feat_transform = transforms.Compose([
        lambda x: scaler.transform(x)])

    # define generator
    if os.path.isdir(args.waveforms):
        filenames = sorted(find_files(args.waveforms, "*.wav", use_dir_name=False))
        wav_list = [args.waveforms + "/" + filename for filename in filenames]
        feat_list = [args.feats + "/" + filename.replace(".wav", ".h5") for filename in filenames]
    elif os.path.isfile(args.waveforms):
        wav_list = read_txt(args.waveforms)
        feat_list = read_txt(args.feats)
    else:
        logging.error("--waveforms should be directory or list.")
        sys.exit(1)
    assert len(wav_list) == len(feat_list)
    logging.info("number of training data = %d." % len(wav_list))
    generator = train_generator(
        wav_list, feat_list,
        receptive_field=model.receptive_field,
        batch_length=args.batch_length,
        batch_size=args.batch_size,
        feature_type=args.feature_type,
        wav_transform=wav_transform,
        feat_transform=feat_transform,
        shuffle=True,
        upsampling_factor=args.upsampling_factor,
        use_upsampling_layer=args.use_upsampling_layer,
        use_speaker_code=args.use_speaker_code)

    # charge minibatch in queue
    while not generator.queue.full():
        time.sleep(0.1)

    # resume model and optimizer
    if args.resume is not None and len(args.resume) != 0:
        checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage)
        iterations = checkpoint["iterations"]
        if args.n_gpus > 1:
            model.module.load_state_dict(checkpoint["model"])
        else:
            model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        logging.info("restored from %d-iter checkpoint." % iterations)
    else:
        iterations = 0

    # check gpu and then send to gpu
    if torch.cuda.is_available():
        model.cuda()
        criterion.cuda()
        for state in optimizer.state.values():
            for key, value in state.items():
                if torch.is_tensor(value):
                    state[key] = value.cuda()
    else:
        logging.error("gpu is not available. please check the setting.")
        sys.exit(1)

    # train
    loss = 0
    total = 0
    for i in six.moves.range(iterations, args.iters):
        start = time.time()
        (batch_x, batch_h), batch_t = generator.next()
        batch_output = model(batch_x, batch_h)
        batch_loss = criterion(
            batch_output[:, model.receptive_field:].contiguous().view(-1, args.n_quantize),
            batch_t[:, model.receptive_field:].contiguous().view(-1))
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()
        loss += batch_loss.item()
        total += time.time() - start
        logging.debug("batch loss = %.3f (%.3f sec / batch)" % (
            batch_loss.item(), time.time() - start))

        # report progress
        if (i + 1) % args.intervals == 0:
            logging.info("(iter:%d) average loss = %.6f (%.3f sec / batch)" % (
                i + 1, loss / args.intervals, total / args.intervals))
            logging.info("estimated required time = "
                         "{0.days:02}:{0.hours:02}:{0.minutes:02}:{0.seconds:02}"
                         .format(relativedelta(
                             seconds=int((args.iters - (i + 1)) * (total / args.intervals)))))
            loss = 0
            total = 0

        # save intermidiate model
        if (i + 1) % args.checkpoint_interval == 0:
            if args.n_gpus > 1:
                save_checkpoint(args.expdir, model.module, optimizer, i + 1)
            else:
                save_checkpoint(args.expdir, model, optimizer, i + 1)

    # save final model
    if args.n_gpus > 1:
        torch.save({"model": model.module.state_dict()}, args.expdir + "/checkpoint-final.pkl")
    else:
        torch.save({"model": model.state_dict()}, args.expdir + "/checkpoint-final.pkl")
    logging.info("final checkpoint created.")
def main():
    """RUN DECODING."""
    parser = argparse.ArgumentParser()
    # decode setting
    parser.add_argument("--seed_waveforms",
                        required=True,
                        type=str,
                        help="directory or list of wav files")
    parser.add_argument("--checkpoint",
                        required=True,
                        type=str,
                        help="model file")
    parser.add_argument("--outdir",
                        required=True,
                        type=str,
                        help="directory to save generated samples")
    parser.add_argument("--config",
                        default=None,
                        type=str,
                        help="configure file")
    parser.add_argument("--fs", default=16000, type=int, help="sampling rate")
    parser.add_argument("--batch_size",
                        default=1,
                        type=int,
                        help="number of batches to decode of batch_length")
    parser.add_argument("--speaker_code", type=int, help="speaker code")
    # other setting
    parser.add_argument("--intervals",
                        default=1000,
                        type=int,
                        help="log interval")
    parser.add_argument("--seed", default=1, type=int, help="seed number")
    parser.add_argument("--verbose", default=1, type=int, help="log level")
    args = parser.parse_args()

    # set log level
    if args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S')
    elif args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S')
    else:
        logging.basicConfig(
            level=logging.WARNING,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S')
        logging.warning("logging is disabled.")

    # show arguments
    for key, value in vars(args).items():
        logging.info("%s = %s" % (key, str(value)))

    # check arguments
    if args.config is None:
        args.config = os.path.dirname(args.checkpoint) + "/model.conf"
    if not os.path.exists(args.config):
        raise FileNotFoundError("config file is missing (%s)." % (args.config))

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    if os.path.isdir(args.seed_waveforms):
        filenames = sorted(
            find_files(args.seed_waveforms, "*.wav", use_dir_name=False))
        wav_list = [
            args.seed_waveforms + "/" + filename for filename in filenames
        ]
    elif os.path.isfile(args.seed_waveforms):
        wav_list = read_txt(args.seed_waveforms)
    else:
        logging.error("--seed_waveforms should be directory or list.")
        sys.exit(1)

    # if speaker code is specified then filter out all other speakers if files conform to wave set pattern
    if not args.speaker_code is None:
        try:
            wav_list = [
                file for file in wav_list
                if parse_wave_file_name(file)[1] == args.speaker_code
            ]
        except ValueError:
            # files not in wave set format, then original list will be kept
            pass

    # fix seed
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # fix slow computation of dilated conv
    # https://github.com/pytorch/pytorch/issues/15054#issuecomment-450191923
    torch.backends.cudnn.benchmark = True

    # load config
    config = torch.load(args.config)

    wav_transform = transforms.Compose(
        [lambda x: encode_mu_law(x, config.n_quantize)])

    # set default gpu and do not track gradient
    torch.set_grad_enabled(False)

    # define model and load parameters
    model = WaveNet(n_quantize=config.n_quantize,
                    n_aux=1,
                    n_resch=config.n_resch,
                    n_skipch=config.n_skipch,
                    dilation_depth=config.dilation_depth,
                    dilation_repeat=config.dilation_repeat,
                    kernel_size=config.kernel_size)

    checkpoint = torch.load(args.checkpoint,
                            map_location=lambda storage, loc: storage)
    model.load_state_dict(checkpoint["model"])
    model.eval()
    if torch.cuda.is_available():
        model.cuda()

    # extract train iterations
    train_iterations = checkpoint["iterations"]

    # choose seed sample
    wav_id = np.random.randint(0, high=len(wav_list))
    wav_file = wav_list[wav_id]
    seed_x, _rate = sf.read(wav_file, dtype=np.float32)
    assert args.fs == _rate, f"expected sample rate is {args.fs}, {wav_file} is {_rate}"
    tot_sample_length = config.batch_length + model.receptive_field
    max_sample_id = seed_x.shape[0] - tot_sample_length
    assert max_sample_id >= 0
    sample_id = np.random.randint(0, max_sample_id)
    seed_sample = seed_x[sample_id:sample_id + tot_sample_length]

    # take code from filename if none
    speaker_code = parse_wave_file_name(
        wav_file)[1] if args.speaker_code is None else args.speaker_code

    # log decoding info
    output_fn_info = f"I{train_iterations}-R{args.seed}-S{speaker_code}-W{Path(wav_file).stem}"
    length_sec = args.batch_size * config.batch_length / args.fs
    logging.info(
        f"generating {args.batch_size} batches of {config.batch_length} samples = {length_sec} seconds"
        f" into {output_fn_info} set")

    # write seed sample
    sf.write(args.outdir + f"/seed_sample-{output_fn_info}.wav", seed_sample,
             args.fs, "PCM_16")
    logging.info("wrote seed_sample.wav in %s." % args.outdir)

    # decode
    new_samples = args.batch_size * config.batch_length
    x, h = decode_from_wav(seed_sample,
                           new_samples,
                           wav_transform=wav_transform,
                           speaker_code=speaker_code)

    def save_samples(samples, file_name):
        wav_data = decode_mu_law(samples, config.n_quantize)
        sf.write(file_name, wav_data, args.fs, "PCM_16")

    def progress_callback(samples, no_samples, elapsed):
        save_samples(samples,
                     args.outdir + "/" + f"decoded.t.all-{output_fn_info}.wav")
        save_samples(samples[-no_samples:],
                     args.outdir + "/" + f"decoded.t.new-{output_fn_info}.wav")

    logging.info("decoding (length = %d)" % h.shape[2])
    samples = model.fast_generate(x,
                                  h,
                                  new_samples,
                                  args.intervals,
                                  callback=progress_callback)
    # samples = model.generate(x, h, new_samples, args.intervals)
    logging.info(f"decoded {len(seed_sample)}")
    save_samples(samples, args.outdir + "/" + f"decoded-{output_fn_info}.wav")
    logging.info("wrote decoded.wav in %s." % args.outdir)
def test_assert_fast_generation():
    # get batch
    batch = 2
    x = np.random.randint(0, 256, size=(batch, 1))
    h = np.random.randn(batch, 28, 32)
    length = h.shape[-1] - 1

    with torch.no_grad():
        # --------------------------------------------------------
        # define model without upsampling and with kernel size = 2
        # --------------------------------------------------------
        net = WaveNet(256, 28, 4, 4, 10, 3, 2)
        net.apply(initialize)
        net.eval()

        # sample-by-sample generation
        gen1_list = []
        gen2_list = []
        for x_, h_ in zip(x, h):
            batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long()
            batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float()
            gen1 = net.generate(batch_x, batch_h, length, 1, "argmax")
            gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax")
            np.testing.assert_array_equal(gen1, gen2)
            gen1_list += [gen1]
            gen2_list += [gen2]
        gen1 = np.stack(gen1_list)
        gen2 = np.stack(gen2_list)
        np.testing.assert_array_equal(gen1, gen2)

        # batch generation
        batch_x = torch.from_numpy(x).long()
        batch_h = torch.from_numpy(h).float()
        gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch,
                                            1, "argmax")
        gen3 = np.stack(gen3_list)
        np.testing.assert_array_equal(gen3, gen2)

        # --------------------------------------------------------
        # define model without upsampling and with kernel size = 3
        # --------------------------------------------------------
        net = WaveNet(256, 28, 4, 4, 10, 3, 3)
        net.apply(initialize)
        net.eval()

        # sample-by-sample generation
        gen1_list = []
        gen2_list = []
        for x_, h_ in zip(x, h):
            batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long()
            batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float()
            gen1 = net.generate(batch_x, batch_h, length, 1, "argmax")
            gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax")
            np.testing.assert_array_equal(gen1, gen2)
            gen1_list += [gen1]
            gen2_list += [gen2]
        gen1 = np.stack(gen1_list)
        gen2 = np.stack(gen2_list)
        np.testing.assert_array_equal(gen1, gen2)

        # batch generation
        batch_x = torch.from_numpy(x).long()
        batch_h = torch.from_numpy(h).float()
        gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch,
                                            1, "argmax")
        gen3 = np.stack(gen3_list)
        np.testing.assert_array_equal(gen3, gen2)

        # get batch
        batch = 2
        upsampling_factor = 10
        x = np.random.randint(0, 256, size=(batch, 1))
        h = np.random.randn(batch, 28, 3)
        length = h.shape[-1] * upsampling_factor - 1

        # -----------------------------------------------------
        # define model with upsampling and with kernel size = 2
        # -----------------------------------------------------
        net = WaveNet(256, 28, 4, 4, 10, 3, 2, upsampling_factor)
        net.apply(initialize)
        net.eval()

        # sample-by-sample generation
        gen1_list = []
        gen2_list = []
        for x_, h_ in zip(x, h):
            batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long()
            batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float()
            gen1 = net.generate(batch_x, batch_h, length, 1, "argmax")
            gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax")
            np.testing.assert_array_equal(gen1, gen2)
            gen1_list += [gen1]
            gen2_list += [gen2]
        gen1 = np.stack(gen1_list)
        gen2 = np.stack(gen2_list)
        np.testing.assert_array_equal(gen1, gen2)

        # batch generation
        batch_x = torch.from_numpy(x).long()
        batch_h = torch.from_numpy(h).float()
        gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch,
                                            1, "argmax")
        gen3 = np.stack(gen3_list)
        np.testing.assert_array_equal(gen3, gen2)

        # -----------------------------------------------------
        # define model with upsampling and with kernel size = 3
        # -----------------------------------------------------
        net = WaveNet(256, 28, 4, 4, 10, 3, 2, upsampling_factor)
        net.apply(initialize)
        net.eval()

        # sample-by-sample generation
        gen1_list = []
        gen2_list = []
        for x_, h_ in zip(x, h):
            batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long()
            batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float()
            gen1 = net.generate(batch_x, batch_h, length, 1, "argmax")
            gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax")
            np.testing.assert_array_equal(gen1, gen2)
            gen1_list += [gen1]
            gen2_list += [gen2]
        gen1 = np.stack(gen1_list)
        gen2 = np.stack(gen2_list)
        np.testing.assert_array_equal(gen1, gen2)

        # batch generation
        batch_x = torch.from_numpy(x).long()
        batch_h = torch.from_numpy(h).float()
        gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch,
                                            1, "argmax")
        gen3 = np.stack(gen3_list)
        np.testing.assert_array_equal(gen3, gen2)
def test_generate():
    batch = 2
    x = np.random.randint(0, 256, size=(batch, 1))
    h = np.random.randn(batch, 28, 10)
    length = h.shape[-1] - 1
    with torch.no_grad():
        net = WaveNet(256, 28, 4, 4, 10, 3, 2)
        net.apply(initialize)
        net.eval()
        for x_, h_ in zip(x, h):
            batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long()
            batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float()
            net.generate(batch_x, batch_h, length, 1, "sampling")
            net.fast_generate(batch_x, batch_h, length, 1, "sampling")
        batch_x = torch.from_numpy(x).long()
        batch_h = torch.from_numpy(h).float()
        net.batch_fast_generate(batch_x, batch_h, [length] * batch, 1,
                                "sampling")
def test_forward():
    # get batch
    generator = sine_generator(100)
    batch = next(generator)
    batch_input = batch.view(1, -1)
    batch_aux = torch.rand(1, 28, batch_input.size(1)).float()

    # define model without upsampling with kernel size = 2
    net = WaveNet(256, 28, 32, 128, 10, 1, 2)
    net.apply(initialize)
    net.eval()
    y = net(batch_input, batch_aux)[0]
    assert y.size(0) == batch_input.size(1)
    assert y.size(1) == 256

    # define model without upsampling with kernel size = 3
    net = WaveNet(256, 28, 32, 128, 10, 1, 2)
    net.apply(initialize)
    net.eval()
    y = net(batch_input, batch_aux)[0]
    assert y.size(0) == batch_input.size(1)
    assert y.size(1) == 256

    batch_input = batch.view(1, -1)
    batch_aux = torch.rand(1, 28, batch_input.size(1) // 10).float()

    # define model with upsampling and kernel size = 2
    net = WaveNet(256, 28, 32, 128, 10, 1, 2, 10)
    net.apply(initialize)
    net.eval()
    y = net(batch_input, batch_aux)[0]
    assert y.size(0) == batch_input.size(1)
    assert y.size(1) == 256

    # define model with upsampling and kernel size = 3
    net = WaveNet(256, 28, 32, 128, 10, 1, 3, 10)
    net.apply(initialize)
    net.eval()
    y = net(batch_input, batch_aux)[0]
    assert y.size(0) == batch_input.size(1)
    assert y.size(1) == 256