예제 #1
0
def test_generate():
    batch = 2
    x = np.random.randint(0, 256, size=(batch, 1))
    h = np.random.randn(batch, 28, 10)
    length = h.shape[-1] - 1
    with torch.no_grad():
        net = WaveNet(256, 28, 4, 4, 10, 3, 2)
        net.apply(initialize)
        net.eval()
        for x_, h_ in zip(x, h):
            batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long()
            batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float()
            net.generate(batch_x, batch_h, length, 1, "sampling")
            net.fast_generate(batch_x, batch_h, length, 1, "sampling")
        batch_x = torch.from_numpy(x).long()
        batch_h = torch.from_numpy(h).float()
        net.batch_fast_generate(batch_x, batch_h, [length] * batch, 1, "sampling")
def test_assert_fast_generation():
    # get batch
    batch = 2
    x = np.random.randint(0, 256, size=(batch, 1))
    h = np.random.randn(batch, 28, 32)
    length = h.shape[-1] - 1

    with torch.no_grad():
        # --------------------------------------------------------
        # define model without upsampling and with kernel size = 2
        # --------------------------------------------------------
        net = WaveNet(256, 28, 4, 4, 10, 3, 2)
        net.apply(initialize)
        net.eval()

        # sample-by-sample generation
        gen1_list = []
        gen2_list = []
        for x_, h_ in zip(x, h):
            batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long()
            batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float()
            gen1 = net.generate(batch_x, batch_h, length, 1, "argmax")
            gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax")
            np.testing.assert_array_equal(gen1, gen2)
            gen1_list += [gen1]
            gen2_list += [gen2]
        gen1 = np.stack(gen1_list)
        gen2 = np.stack(gen2_list)
        np.testing.assert_array_equal(gen1, gen2)

        # batch generation
        batch_x = torch.from_numpy(x).long()
        batch_h = torch.from_numpy(h).float()
        gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch,
                                            1, "argmax")
        gen3 = np.stack(gen3_list)
        np.testing.assert_array_equal(gen3, gen2)

        # --------------------------------------------------------
        # define model without upsampling and with kernel size = 3
        # --------------------------------------------------------
        net = WaveNet(256, 28, 4, 4, 10, 3, 3)
        net.apply(initialize)
        net.eval()

        # sample-by-sample generation
        gen1_list = []
        gen2_list = []
        for x_, h_ in zip(x, h):
            batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long()
            batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float()
            gen1 = net.generate(batch_x, batch_h, length, 1, "argmax")
            gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax")
            np.testing.assert_array_equal(gen1, gen2)
            gen1_list += [gen1]
            gen2_list += [gen2]
        gen1 = np.stack(gen1_list)
        gen2 = np.stack(gen2_list)
        np.testing.assert_array_equal(gen1, gen2)

        # batch generation
        batch_x = torch.from_numpy(x).long()
        batch_h = torch.from_numpy(h).float()
        gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch,
                                            1, "argmax")
        gen3 = np.stack(gen3_list)
        np.testing.assert_array_equal(gen3, gen2)

        # get batch
        batch = 2
        upsampling_factor = 10
        x = np.random.randint(0, 256, size=(batch, 1))
        h = np.random.randn(batch, 28, 3)
        length = h.shape[-1] * upsampling_factor - 1

        # -----------------------------------------------------
        # define model with upsampling and with kernel size = 2
        # -----------------------------------------------------
        net = WaveNet(256, 28, 4, 4, 10, 3, 2, upsampling_factor)
        net.apply(initialize)
        net.eval()

        # sample-by-sample generation
        gen1_list = []
        gen2_list = []
        for x_, h_ in zip(x, h):
            batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long()
            batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float()
            gen1 = net.generate(batch_x, batch_h, length, 1, "argmax")
            gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax")
            np.testing.assert_array_equal(gen1, gen2)
            gen1_list += [gen1]
            gen2_list += [gen2]
        gen1 = np.stack(gen1_list)
        gen2 = np.stack(gen2_list)
        np.testing.assert_array_equal(gen1, gen2)

        # batch generation
        batch_x = torch.from_numpy(x).long()
        batch_h = torch.from_numpy(h).float()
        gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch,
                                            1, "argmax")
        gen3 = np.stack(gen3_list)
        np.testing.assert_array_equal(gen3, gen2)

        # -----------------------------------------------------
        # define model with upsampling and with kernel size = 3
        # -----------------------------------------------------
        net = WaveNet(256, 28, 4, 4, 10, 3, 2, upsampling_factor)
        net.apply(initialize)
        net.eval()

        # sample-by-sample generation
        gen1_list = []
        gen2_list = []
        for x_, h_ in zip(x, h):
            batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long()
            batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float()
            gen1 = net.generate(batch_x, batch_h, length, 1, "argmax")
            gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax")
            np.testing.assert_array_equal(gen1, gen2)
            gen1_list += [gen1]
            gen2_list += [gen2]
        gen1 = np.stack(gen1_list)
        gen2 = np.stack(gen2_list)
        np.testing.assert_array_equal(gen1, gen2)

        # batch generation
        batch_x = torch.from_numpy(x).long()
        batch_h = torch.from_numpy(h).float()
        gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch,
                                            1, "argmax")
        gen3 = np.stack(gen3_list)
        np.testing.assert_array_equal(gen3, gen2)