class TestGeneration(tf.test.TestCase):

    def setUp(self):
        self.net = WaveNetModel(batch_size=1,
                                dilations=[1, 2, 4, 8, 16, 32, 64, 128, 256],
                                filter_width=2,
                                residual_channels=16,
                                dilation_channels=16,
                                quantization_channels=128,
                                skip_channels=32)

    def testGenerateSimple(self):
        '''Generate a few samples using the naive method and
        perform sanity checks on the output.'''
        waveform = tf.placeholder(tf.int32)
        np.random.seed(0)
        data = np.random.randint(128, size=1000)
        proba = self.net.predict_proba(waveform)

        with self.test_session() as sess:
            sess.run(tf.initialize_all_variables())
            proba = sess.run(proba, feed_dict={waveform: data})

        self.assertAllEqual(proba.shape, [128])
        self.assertTrue(np.all((proba >= 0) & (proba <= (128 - 1))))

    def testGenerateFast(self):
        '''Generate a few samples using the fast method and
        perform sanity checks on the output.'''
        waveform = tf.placeholder(tf.int32)
        np.random.seed(0)
        data = np.random.randint(128)
        proba = self.net.predict_proba_incremental(waveform)

        with self.test_session() as sess:
            sess.run(tf.initialize_all_variables())
            sess.run(self.net.init_ops)
            proba = sess.run(proba, feed_dict={waveform: data})

        self.assertAllEqual(proba.shape, [128])
        self.assertTrue(np.all((proba >= 0) & (proba <= (128 - 1))))

    def testCompareSimpleFast(self):
        waveform = tf.placeholder(tf.int32)
        np.random.seed(0)
        data = np.random.randint(128, size=1)
        proba = self.net.predict_proba(waveform)
        proba_fast = self.net.predict_proba_incremental(waveform)
        with self.test_session() as sess:
            sess.run(tf.initialize_all_variables())
            sess.run(self.net.init_ops)
            # Prime the incremental generation with all samples
            # except the last one
            for x in data[:-1]:
                proba_fast_ = sess.run(
                    [proba_fast, self.net.push_ops],
                    feed_dict={waveform: x})

            # Get the last sample from the incremental generator
            proba_fast_ = sess.run(
                proba_fast,
                feed_dict={waveform: data[-1]})
            # Get the sample from the simple generator
            proba_ = sess.run(proba, feed_dict={waveform: data})
            self.assertAllClose(proba_, proba_fast_)
Beispiel #2
0
class TestGeneration(tf.test.TestCase):
    def testGenerateSimple(self):
        # Reader config
        with open(TEST_DATA + "/config.json") as json_file:
            self.reader_config = json.load(json_file)

        # Initialize the reader
        receptive_field_size = WaveNetModel.calculate_receptive_field(
            2, [1, 1], False, 8)

        self.reader = CsvReader([
            TEST_DATA + "/test.dat", TEST_DATA + "/test.emo",
            TEST_DATA + "/test.pho"
        ],
                                batch_size=1,
                                receptive_field=receptive_field_size,
                                sample_size=SAMPLE_SIZE,
                                config=self.reader_config)

        # WaveNet model
        self.net = WaveNetModel(batch_size=1,
                                dilations=[1, 1],
                                filter_width=2,
                                residual_channels=8,
                                dilation_channels=8,
                                skip_channels=8,
                                quantization_channels=2,
                                use_biases=True,
                                scalar_input=False,
                                initial_filter_width=8,
                                histograms=False,
                                global_channels=GC_CHANNELS,
                                local_channels=LC_CHANNELS)

        loss = self.net.loss(input_batch=self.reader.data_batch,
                             global_condition=self.reader.gc_batch,
                             local_condition=self.reader.lc_batch,
                             l2_regularization_strength=0)

        optimizer = optimizer_factory['adam'](learning_rate=0.003,
                                              momentum=0.9)
        trainable = tf.trainable_variables()
        train_op = optimizer.minimize(loss, var_list=trainable)

        samples = tf.placeholder(tf.float32,
                                 shape=(receptive_field_size,
                                        self.reader.data_dim),
                                 name="samples")
        gc = tf.placeholder(tf.int32, shape=(receptive_field_size), name="gc")
        lc = tf.placeholder(tf.int32,
                            shape=(receptive_field_size, 4),
                            name="lc")

        gc = tf.one_hot(gc, GC_CHANNELS)
        lc = tf.one_hot(lc, int(LC_CHANNELS / 4))

        predict = self.net.predict_proba(samples, gc, lc)
        '''does nothing'''
        with self.test_session() as session:
            session.run([
                tf.local_variables_initializer(),
                tf.global_variables_initializer(),
                tf.tables_initializer(),
            ])
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=session, coord=coord)

            for i in range(5000):
                _, loss_val = session.run([train_op, loss])
                print("step %d loss %.4f" % (i, loss_val), end='\r')
                sys.stdout.flush()
            print()

            data_samples = np.random.random(
                (receptive_field_size, self.reader.data_dim))
            gc_samples = np.zeros((receptive_field_size))
            lc_samples = np.zeros((receptive_field_size, 4))

            # WITH CONDITIONING.
            error = 0.0
            i = 0.0
            for p in range(3):
                for q in range(3):
                    gc_samples[:] = p
                    lc_samples[:, :] = q
                    for _ in range(64):
                        prediction = session.run(predict,
                                                 feed_dict={
                                                     'samples:0': data_samples,
                                                     'gc:0': gc_samples,
                                                     'lc:0': lc_samples
                                                 })
                        data_samples = data_samples[1:, :]
                        data_samples = np.append(data_samples,
                                                 prediction,
                                                 axis=0)
                    print("G%d L%d - %.2f vs %.2f ERR %.2f" %
                          (p, q, i, np.average(prediction),
                           np.abs(i - np.average(prediction))))
                    error += np.abs(i - np.average(prediction))
                    data_samples = np.random.random(
                        (receptive_field_size, self.reader.data_dim))
                    i += 0.1

            print("TOTAL ERROR CONDITIONING: %.5f" % error)
            # WITHOUT CONDITIONING.

            data_samples = np.random.random(
                (receptive_field_size, self.reader.data_dim))

            errorNo = 0.0
            i = 0.0
            for p in range(3):
                for q in range(3):
                    gc_samples[:] = 0
                    lc_samples[:, :] = 0
                    for _ in range(64):
                        prediction = session.run(predict,
                                                 feed_dict={
                                                     'samples:0': data_samples,
                                                     'gc:0': gc_samples,
                                                     'lc:0': lc_samples
                                                 })
                        data_samples = data_samples[1:, :]
                        data_samples = np.append(data_samples,
                                                 prediction,
                                                 axis=0)
                    print("G%d L%d - %.2f vs %.2f ERR %.2f" %
                          (p, q, i, np.average(prediction),
                           (i - np.average(prediction))))
                    errorNo += np.abs(i - np.average(prediction))
                    data_samples = np.random.random(
                        (receptive_field_size, self.reader.data_dim))
                    i += 0.1

            print("TOTAL ERROR NO CONDITIONING: %.5f" % errorNo)
            self.assertTrue(error < 0.5)
            self.assertTrue(errorNo > 0.05)
Beispiel #3
0
def main():

    with tf.Graph().as_default():
        coord = tf.train.Coordinator()
        sess = tf.Session()

        batch_size = 1
        hidden1_units = 5202
        hidden2_units = 2601
        hidden3_units = 1300
        hidden4_units = 650
        hidden5_units = 325
        max_training_steps = 1

        global_step = tf.Variable(0, name='global_step', trainable=False)
        initial_training_learning_rate = 3e-2
        training_learning_rate = tf.train.exponential_decay(
            initial_training_learning_rate,
            global_step,
            100,
            0.9,
            staircase=True)

        inputs_placeholder, labels_placeholder = placeholder_inputs(batch_size)

        logits = ffnn.inference(inputs_placeholder, hidden1_units,
                                hidden2_units, hidden3_units, hidden4_units,
                                hidden5_units)
        loss = ffnn.loss(logits, labels_placeholder)
        train_op = ffnn.training(loss, training_learning_rate, global_step)
        eval_correct = ffnn.evaluation(logits, labels_placeholder)

        summary = tf.summary.merge_all()
        init = tf.global_variables_initializer()
        saver = tf.train.Saver()

        summary_writer = tf.summary.FileWriter('./logdir', sess.graph)

        sess.run(init)

        args = get_arguments()

        # Load parameters from wavenet params json file
        with open(args.wavenet_params, 'r') as f:
            wavenet_params = json.load(f)

        quantization_channels = wavenet_params['quantization_channels']

        if args.restore_from != None:
            restore_from = args.restore_from
            print("Restoring from: ")
            print(restore_from)

        else:
            restore_from = ""

        try:
            saved_global_step = load(saver, sess, restore_from)
            if saved_global_step is None:
                # The first training step will be saved_global_step + 1,
                # therefore we put -1 here for new or overwritten trainings.
                saved_global_step = -1
            else:
                counter = saved_global_step % label_batch_size

        except:
            print(
                "Something went wrong while restoring checkpoint. "
                "We will terminate training to avoid accidentally overwriting "
                "the previous model.")
            raise

        # TODO: Find a more robust way to find different data sets

        # Training data
        directory = './sampleTrue'
        reader = AudioReader(directory,
                             coord,
                             sample_rate=16000,
                             gc_enabled=False,
                             receptive_field=5117,
                             sample_size=15117,
                             silence_threshold=0.05)
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        reader.start_threads(sess)

        directory = './sampleFalse'
        reader2 = AudioReader(directory,
                              coord,
                              sample_rate=16000,
                              gc_enabled=False,
                              receptive_field=5117,
                              sample_size=15117,
                              silence_threshold=0.05)
        threads2 = tf.train.start_queue_runners(sess=sess, coord=coord)
        reader2.start_threads(sess)

        total_loss = 0
        for step in range(saved_global_step + 1, max_training_steps):
            start_time = time.time()

            batch_data = []
            label_data = []

            if (step % 100 == 0):
                print('Current learning rate: %6f' %
                      (sess.run(training_learning_rate)))

            for b in range(batch_size):
                label = randint(0, 1)

                if label == 1:
                    data = sess.run(reader.dequeue(1))
                    while (len(data[0]) < ffnn.INPUT_SIZE):
                        data = sess.run(reader.dequeue(1))
                else:
                    data = sess.run(reader2.dequeue(1))
                    while (len(data[0]) < ffnn.INPUT_SIZE):
                        data = sess.run(reader2.dequeue(1))
                data = np.array(data[0])

                cut = []
                for i in range(ffnn.INPUT_SIZE):
                    cut.append(data[i])

                data = cut

                # processing
                samples = process(data, quantization_channels, 1)

                batch_data.append(samples)
                label_data.append(label)

            feed_dict = fill_feed_dict(batch_data, label_data,
                                       inputs_placeholder, labels_placeholder)

            _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

            duration = time.time() - start_time
            total_loss = total_loss + loss_value

            print('Step %d: loss = %.7f (%.3f sec)' %
                  (step, loss_value, duration))
            '''
            if step % 100 == 0 or (step + 1) == max_training_steps:
                average = total_loss / (step + 1)
                print('Cumulative average loss: %6f' % (average))
                # TODO: Update train script to add data to new directory
                checkpoint_file = os.path.join('./logdir/init-train/', 'model.ckpt')
                print("Generating checkpoint file...")
                saver.save(sess, checkpoint_file, global_step=step)
            '''

        # Lambda for white noise sampler
        gi_sampler = get_generator_input_sampler()

        # Intialize generator WaveNet
        G = WaveNetModel(
            batch_size=1,
            dilations=wavenet_params["dilations"],
            filter_width=wavenet_params["filter_width"],
            residual_channels=wavenet_params["residual_channels"],
            dilation_channels=wavenet_params["dilation_channels"],
            skip_channels=wavenet_params["skip_channels"],
            quantization_channels=wavenet_params["quantization_channels"],
            use_biases=wavenet_params["use_biases"],
            initial_filter_width=wavenet_params["initial_filter_width"])

        # White noise generator params
        white_mean = 0
        white_sigma = 1
        white_length = ffnn.INPUT_SIZE

        white_noise = gi_sampler(white_mean, white_sigma, white_length)
        white_noise = process(white_noise, quantization_channels, 1)
        white_noise_t = tf.convert_to_tensor(white_noise)

        # initialize generator
        w_loss, w_prediction = G.loss(input_batch=white_noise_t,
                                      name='generator')

        G_variables = tf.trainable_variables(scope='wavenet')
        optimizer = optimizer_factory[args.optimizer](learning_rate=3e-2,
                                                      momentum=args.momentum)
        optim = optimizer.minimize(w_loss, var_list=G_variables)

        init = tf.global_variables_initializer()
        sess.run(init)

        print(sess.run(tf.shape(w_prediction)))

        # main GAN training loop
        for step in range(NUM_EPOCHS):
            batch_data = []
            label_data = []

            # train D on real
            for d_index in range(batch_size):
                data = sess.run(reader.dequeue(1))
                data = data[0]

                d_real_data = process(data, quantization_channels, 1)

                batch_data.append(d_real_data)
                label_data.append(1)

            feed_dict = fill_feed_dict(batch_data, label_data,
                                       inputs_placeholder, labels_placeholder)

            _, d_real_loss = sess.run([train_op, loss], feed_dict=feed_dict)

            print("Real loss")
            print(d_real_loss)

            batch_data = []
            label_data = []

            # train D on fake
            for d_index in range(batch_size):
                samples = tf.placeholder(tf.int32)

                if args.fast_generation:
                    next_sample = G.predict_proba_incremental(
                        samples, args.gc_id)
                else:
                    next_sample = G.predict_proba(samples, args.gc_id)

                if args.fast_generation:
                    sess.run(tf.global_variables_initializer())
                    sess.run(G.init_ops)

                waveform = [0]

                for step in range(ffnn.INPUT_SIZE):
                    if args.fast_generation:
                        outputs = [next_sample]
                        outputs.extend(G.push_ops)
                        window = waveform[-1]
                    else:
                        if len(waveform) > G.receptive_field:
                            window = waveform[-G.receptive_field:]
                        else:
                            window = waveform
                        outputs = [next_sample]

                    # Run the WaveNet to predict the next sample.
                    prediction = sess.run(outputs, feed_dict={samples:
                                                              window})[0]

                    # Scale prediction distribution using temperature.
                    np.seterr(divide='ignore')
                    scaled_prediction = np.log(prediction) / 1
                    scaled_prediction = (
                        scaled_prediction -
                        np.logaddexp.reduce(scaled_prediction))
                    scaled_prediction = np.exp(scaled_prediction)
                    np.seterr(divide='warn')

                    sample = np.random.choice(np.arange(quantization_channels),
                                              p=scaled_prediction)
                    waveform.append(sample)

                del waveform[0]

                d_fake_data = process(waveform, quantization_channels, 0)

                batch_data.append(d_fake_data)
                label_data.append(0)

            feed_dict = fill_feed_dict(batch_data, label_data,
                                       inputs_placeholder, labels_placeholder)

            _, d_fake_loss = sess.run([train_op, loss], feed_dict=feed_dict)

            print("Fake loss")
            print(d_fake_loss)

            batch_data = []
            label_data = []

            # train G, but don't train D
            for g_index in range(batch_size):
                samples = tf.placeholder(tf.int32)

                if args.fast_generation:
                    next_sample = G.predict_proba_incremental(
                        samples, args.gc_id)
                else:
                    next_sample = G.predict_proba(samples, args.gc_id)

                if args.fast_generation:
                    sess.run(tf.global_variables_initializer())
                    sess.run(G.init_ops)

                waveform = [0]

                for step in range(ffnn.INPUT_SIZE):
                    if args.fast_generation:
                        outputs = [next_sample]
                        outputs.extend(G.push_ops)
                        window = waveform[-1]
                    else:
                        if len(waveform) > G.receptive_field:
                            window = waveform[-G.receptive_field:]
                        else:
                            window = waveform
                        outputs = [next_sample]

                    # Run the WaveNet to predict the next sample.
                    prediction = sess.run(outputs, feed_dict={samples:
                                                              window})[0]

                    # Scale prediction distribution using temperature.
                    np.seterr(divide='ignore')
                    scaled_prediction = np.log(prediction) / 1
                    scaled_prediction = (
                        scaled_prediction -
                        np.logaddexp.reduce(scaled_prediction))
                    scaled_prediction = np.exp(scaled_prediction)
                    np.seterr(divide='warn')

                    sample = np.random.choice(np.arange(quantization_channels),
                                              p=scaled_prediction)
                    waveform.append(sample)

                del waveform[0]

                g_data = process(waveform, quantization_channels, 0)

                batch_data.append(g_data)
                label_data.append(1)

            feed_dict = fill_feed_dict(batch_data, label_data,
                                       inputs_placeholder, labels_placeholder)

            _, g_loss = sess.run([optim, loss], feed_dict=feed_dict)

            print("Generator loss")
            print(g_loss)
        '''
Beispiel #4
0
def run(target,
        is_chief,
        train_steps,
        job_dir,
        train_files,
        reader_config,
        batch_size,
        learning_rate,
        residual_channels,
        dilation_channels,
        skip_channels,
        dilations,
        use_biases,
        gc_channels,
        lc_channels,
        filter_width,
        sample_size,
        initial_filter_width,
        l2_regularization_strength,
        momentum,
        optimizer):

    # Run the training and evaluation graph.

    # If the server is chief which is `master`
    # In between graph replication Chief is one node in
    # the cluster with extra responsibility and by default
    # is worker task zero. We have assigned master as the chief.
    #
    # See https://youtu.be/la_M6bCV91M?t=1203 for details on
    # distributed TensorFlow and motivation about chief.
    # TODO: hooks
    hooks = []

    # Create a new graph and specify that as default
    with tf.Graph().as_default():
        # Placement of ops on devices using replica device setter
        # which automatically places the parameters on the `ps` server
        # and the `ops` on the workers
        #
        # See:
        # https://www.tensorflow.org/api_docs/python/tf/train/replica_device_setter
        with tf.device(tf.train.replica_device_setter()):

            with open(reader_config) as json_file:
                reader_config = json.load(json_file)

            # Reader
            receptive_field_size = WaveNetModel.calculate_receptive_field(filter_width,
                                                                          dilations,
                                                                          False,
                                                                          initial_filter_width)

            reader = CsvReader(
                train_files,
                batch_size=batch_size,
                receptive_field=receptive_field_size,
                sample_size=sample_size,
                config=reader_config
            )

            # Create network.
            net = WaveNetModel(
                batch_size=batch_size,
                dilations=dilations,
                filter_width=filter_width,
                residual_channels=residual_channels,
                dilation_channels=dilation_channels,
                skip_channels=skip_channels,
                quantization_channels=reader.data_dim,
                use_biases=use_biases,
                scalar_input=False,
                initial_filter_width=initial_filter_width,
                histograms=False,
                global_channels=gc_channels,
                local_channels=lc_channels)

            global_step_tensor = tf.contrib.framework.get_or_create_global_step()

            if l2_regularization_strength == 0:
                l2_regularization_strength = None

            loss = net.loss(input_batch=reader.data_batch,
                            global_condition=reader.gc_batch,
                            local_condition=reader.lc_batch,
                            l2_regularization_strength=l2_regularization_strength)

            optimizer = optimizer_factory[optimizer](learning_rate=learning_rate, momentum=momentum)

            trainable = tf.trainable_variables()

            train_op = optimizer.minimize(loss, var_list=trainable, global_step=global_step_tensor)

            # Add Generation operator to graph for later use in generate.py
            tf.add_to_collection("config", tf.constant(reader.data_dim, name='data_dim'))
            tf.add_to_collection("config", tf.constant(receptive_field_size, name='receptive_field_size'))
            tf.add_to_collection("config", tf.constant(sample_size, name='sample_size'))

            samples = tf.placeholder(tf.float32, shape=(receptive_field_size, reader.data_dim), name="samples")
            gc = tf.placeholder(tf.int32, shape=(receptive_field_size), name="gc")
            lc = tf.placeholder(tf.int32, shape=(receptive_field_size), name="lc")  # TODO set to one

            gc = tf.one_hot(gc, gc_channels)
            lc = tf.one_hot(lc, lc_channels / 1)  # TODO set to one...

            tf.add_to_collection("predict_proba", net.predict_proba(samples, gc, lc))

            # TODO: Implement fast generation
            """
            if filter_width <= 2:
                samples_fast = tf.placeholder(tf.float32, shape=(1, reader.data_dim), name="samples_fast")
                gc_fast = tf.placeholder(tf.int32, shape=(1), name="gc_fast")
                lc_fast = tf.placeholder(tf.int32, shape=(1), name="lc_fast")

                gc_fast = tf.one_hot(gc_fast, gc_channels)
                lc_fast = tf.one_hot(lc_fast, lc_channels)

                tf.add_to_collection("predict_proba_incremental", net.predict_proba_incremental(samples_fast, gc_fast, lc_fast))
                tf.add_to_collection("push_ops", net.push_ops)
            """

        # Creates a MonitoredSession for training
        # MonitoredSession is a Session-like object that handles
        # initialization, recovery and hooks
        # https://www.tensorflow.org/api_docs/python/tf/train/MonitoredTrainingSession
        with tf.train.MonitoredTrainingSession(master=target,
                                               is_chief=is_chief,
                                               checkpoint_dir=job_dir,
                                               hooks=hooks,
                                               save_checkpoint_secs=120,
                                               save_summaries_steps=20) as session:  # TODO: SUMMARIES HERE

            # Global step to keep track of global number of steps particularly in
            # distributed setting
            step = global_step_tensor.eval(session=session)
            # Run the training graph which returns the step number as tracked by
            # the global step tensor.
            # When train epochs is reached, session.should_stop() will be true.
            try:
                while (train_steps is None or
                       step < train_steps) and not session.should_stop():

                    step, _, loss_val = session.run([global_step_tensor, train_op, loss])
                    print("step %d loss %.4f" % (step, loss_val), end='\r')
                    sys.stdout.flush()

                    # For debugging
                    # dat, gc, lc = session.run([reader.data_batch, reader.gc_batch, reader.lc_batch])
                    # print(colored(str(dat.shape), 'red', 'on_grey'))
                    # for field in dat:
                    #     print(colored(str(field), 'red'))
                    # print(colored(str(lc.shape), 'red', 'on_grey'))
                    # for i in lc[0, -1, :]:
                    #     print("%1d" % i, end='')
                    #     sys.stdout.flush()
                    # print(colored(str(gc.shape), 'red', 'on_grey'))
                    # for i in gc[0, -1, :]:
                    #     print("%1d" % i, end='')
                    #     sys.stdout.flush()
                    # for field in lc:
                    #     print(colored(str(field), 'blue'))
                    # print(colored(str(gc.shape), 'red', 'on_grey'))
                    # for field in gc:
                    #     print(colored(str(field), 'green'))

            except KeyboardInterrupt:
                pass
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'])

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples)
    else:
        next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed,
                           wavenet_params['sample_rate'],
                           quantization_channels)
        waveform = sess.run(seed).tolist()
    else:
        waveform = np.random.randint(quantization_channels, size=(1,)).tolist()

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[:-(args.window + 1)]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]
        sample = np.random.choice(
            np.arange(quantization_channels), p=prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every and
                (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(logdir)
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    lc_enabled = args.lc_channels is not None
    lc_channels = args.lc_channels
    lc_duration = args.lc_duration

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=args.gc_cardinality,
        local_condition_channels=args.lc_channels)

    samples = tf.placeholder(tf.int32)
    lc_piece = tf.placeholder(tf.float32, [1, lc_channels])
    local_condition = load_lc(os.path.join(os.getcwd(), args.lc_path))

    args.samples = args.lc_duration * len(local_condition)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples, args.gc_id,
                                                    lc_piece)
    else:
        next_sample = net.predict_proba(samples, args.gc_id, lc_piece)

    if args.fast_generation:
        sess.run(tf.global_variables_initializer())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed, wavenet_params['sample_rate'],
                           quantization_channels, net.receptive_field)
        waveform = sess.run(seed).tolist()
    else:
        # Silence with a single random sample at the end.
        waveform = [quantization_channels / 2] * (net.receptive_field - 1)
        waveform.append(np.random.randint(quantization_channels))

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[-net.receptive_field:-1]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:]
            else:
                window = waveform
            outputs = [next_sample]

        if lc_enabled:
            if (step % lc_duration == 0):
                lc_window = local_condition[:1, :]
                local_condition[1:, :]
            prediction = sess.run(outputs,
                                  feed_dict={
                                      samples: window,
                                      lc_piece: lc_window
                                  })[0]
        else:
            # Run the WaveNet to predict the next sample.
            prediction = sess.run(outputs, feed_dict={samples: window})[0]

        # Scale prediction distribution using temperature.
        np.seterr(divide='ignore')
        scaled_prediction = np.log(prediction) / args.temperature
        scaled_prediction = (scaled_prediction -
                             np.logaddexp.reduce(scaled_prediction))
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after
        # scaling.
        if args.temperature == 1.0:
            np.testing.assert_allclose(
                prediction,
                scaled_prediction,
                atol=1e-5,
                err_msg='Prediction scaling at temperature=1.0 '
                'is not working as intended.')

        sample = np.random.choice(np.arange(quantization_channels),
                                  p=scaled_prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.summary.FileWriter(logdir)
    tf.summary.audio('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.summary.merge_all()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
def main(waveform, num_predictions):
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

    #args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    #logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(WAVENET_PARAMS, 'r') as config_file:
        wavenet_params = json.load(config_file)
    tf.reset_default_graph()
    config = tf.ConfigProto(
        device_count={'GPU': 0}  # todo this is only to generate test stuff on nonGPU
    )
    sess = tf.Session(config=config)

    # sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        global_condition_channels=None,
        global_condition_cardinality=None)

    samples = tf.placeholder(tf.int32)


    # next_sample = net.predict_proba_incremental(samples, None) #fastgen
    next_sample = net.predict_proba(samples, None) #regulargen


    #sess.run(tf.global_variables_initializer())
    #sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    #print('Restoring model from {}'.format(args.checkpoint))
    #global INITIALIZED
    #if not INITIALIZED:
    #saver.restore(sess, "logdir/train/2017-06-05T19-03-11/model.ckpt-107")
    saver.restore(sess, "logdir/train/last144kwl2/model.ckpt-27050")
    #    INITIALIZED = True

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']

    seed = create_seed(waveform,
                       wavenet_params['sample_rate'],
                       quantization_channels,
                       net.receptive_field)
    waveform = sess.run(seed).tolist()
    #print("priming size is " + str(net.receptive_field))  # TODO debug so I can figure out how long a sequence to feed


    # When using the incremental generation, we need to
    # feed in all priming samples one by one before starting the
    # actual generation.
    # TODO This could be done much more efficiently by passing the waveform
    # to the incremental generator as an optional argument, which would be
    # used to fill the queues initially.
    # outputs = [next_sample] begin fastgen commented section
    # outputs.extend(net.push_ops)
    #
    # # print('Priming generation...') # todo remove these prints for going into production
    # for i, x in enumerate(waveform[-net.receptive_field: -1]):
    #     # if i % 1000 == 0:
    #         # print('Priming sample {}'.format(i))
    #     sess.run(outputs, feed_dict={samples: x})

    last_sample_timestamp = datetime.now()
    for step in range(num_predictions):
        if False:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]

        # Scale prediction distribution using temperature.
        np.seterr(divide='ignore')
        scaled_prediction = np.log(prediction)
        scaled_prediction = (scaled_prediction -
                             np.logaddexp.reduce(scaled_prediction))
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')
        # plt.plot(range(len(scaled_prediction)), scaled_prediction) # todo debug tool, sharp peaks means operating correctly
        # plt.show()  #
        # Prediction distribution at temperature=1.0 should be unchanged after
        # scaling.
        # if args.temperature == 1.0:
        #     np.testing.assert_allclose(
        #             prediction, scaled_prediction, atol=1e-5,
        #             err_msg='Prediction scaling at temperature=1.0 '
        #                     'is not working as intended.')
        # TODO consider grabbing only top probability instead of this range
        sample = np.random.choice(np.arange(quantization_channels), p=scaled_prediction)
        waveform.append(sample)

        # Show progress only once per second.
        # current_sample_timestamp = datetime.now()
        # time_since_print = current_sample_timestamp - last_sample_timestamp
        # if time_since_print.total_seconds() > 1.:
        #     print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
        #           end='\r')
        #     last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        # if (args.wav_out_path and args.save_every and
        #         (step + 1) % args.save_every == 0):
        #     out = sess.run(decode, feed_dict={samples: waveform})
        #     write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    # print()

    # Save the result as an audio summary.
    # datestring = str(datetime.now()).replace(' ', 'T')
    # writer = tf.summary.FileWriter(logdir)
    # tf.summary.audio('generated', decode, wavenet_params['sample_rate'])
    # summaries = tf.summary.merge_all()
    # summary_out = sess.run(summaries,
    #                        feed_dict={samples: np.reshape(waveform, [-1, 1])})
    # writer.add_summary(summary_out)

    # Save the result as a wav file.
    # if args.wav_out_path:
    #     out = sess.run(decode, feed_dict={samples: waveform})
    #     write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)
    out = sess.run(decode, feed_dict={samples: waveform})
    sess.close()
    return out
Beispiel #8
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
        histograms=False,
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=args.gc_cardinality,
        MFSC_channels=wavenet_params["MFSC_channels"],
        AP_channels=wavenet_params["AP_channels"],
        F0_channels=wavenet_params["F0_channels"],
        phone_channels=wavenet_params["phones_channels"],
        phone_pos_channels=wavenet_params["phone_pos_channels"])

    samples = tf.placeholder(tf.float32)
    lc = tf.placeholder(tf.float32)

    AP_channels = wavenet_params["AP_channels"]
    MFSC_channels = wavenet_params["MFSC_channels"]

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(
            samples, args.gc_id)  ########### understand shape of next_sample
    else:
        outputs = net.predict_proba(samples, AP_channels, lc, args.gc_id)
        outputs = tf.reshape(outputs, [1, AP_channels])

    if args.fast_generation:
        sess.run(tf.global_variables_initializer())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    if args.wav_seed:
        # seed = create_seed(args.wav_seed,
        #                    wavenet_params['sample_rate'],
        #                    quantization_channels,
        #                    net.receptive_field)
        # waveform = sess.run(seed).tolist()
        pass
    else:

        # Silence with a single random sample at the end.
        waveform = np.zeros(
            (net.receptive_field - 1, AP_channels + MFSC_channels))
        waveform = np.append(waveform,
                             np.random.randn(1, AP_channels + MFSC_channels),
                             axis=0)

        lc_array, mfsc_array = load_lc(
            scramble=True)  # clip_id:[003, 004, 007, 010, 012 ...]
        lc_array = np.pad(lc_array, ((net.receptive_field, 0), (0, 0)),
                          'constant',
                          constant_values=((0, 0), (0, 0)))
        mfsc_array = np.pad(mfsc_array, ((net.receptive_field, 0), (0, 0)),
                            'constant',
                            constant_values=((0, 0), (0, 0)))

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(
            net.push_ops
        )  ################# understand net.push_ops, understand shape of outputs

        print('Priming generation...')
        for i, x in enumerate(waveform[-net.receptive_field:-1]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:

            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:, :]
            else:
                window = waveform

        # Run the WaveNet to predict the next sample.
        window = window.reshape(1, window.shape[-2], window.shape[-1])

        prediction = sess.run(outputs,
                              feed_dict={
                                  samples:
                                  window,
                                  lc:
                                  lc_array[step + 1:step + 1 +
                                           net.receptive_field, :].reshape(
                                               1, net.receptive_field, -1)
                              })

        prediction = np.concatenate(
            (prediction, mfsc_array[step + net.receptive_field].reshape(1,
                                                                        -1)),
            axis=-1)
        waveform = np.append(waveform, prediction, axis=0)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Frame {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            np.save(args.wav_out_path, waveform)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    if args.wav_out_path:

        np.save(
            args.wav_out_path,
            waveform[:, :4])  # only take the first four columns, which are aps

    print('Finished generating. The result was saved as .npy file.')
Beispiel #9
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Read TFRecords and create network.
    tf.reset_default_graph()

    data_train = get_tfrecord(name='train',
                              sample_size=args.sample_size,
                              batch_size=args.batch_size,
                              seed=None,
                              repeat=None,
                              data_path=args.data_path)
    data_test = get_tfrecord(name='test',
                             sample_size=args.sample_size,
                             batch_size=args.batch_size,
                             seed=None,
                             repeat=None,
                             data_path=args.data_path)

    train_itr = data_train.make_one_shot_iterator()
    test_itr = data_test.make_one_shot_iterator()

    train_batch, train_label = train_itr.get_next()
    test_batch, test_label = test_itr.get_next()

    train_batch = tf.reshape(train_batch, [-1, train_batch.shape[1], 1])
    test_batch = tf.reshape(test_batch, [-1, test_batch.shape[1], 1])

    # Create network.
    net = WaveNetModel(sample_size=args.sample_size,
                       batch_size=args.batch_size,
                       dilations=wavenet_params["dilations"],
                       filter_width=wavenet_params["filter_width"],
                       residual_channels=wavenet_params["residual_channels"],
                       dilation_channels=wavenet_params["dilation_channels"],
                       skip_channels=wavenet_params["skip_channels"],
                       histograms=args.histograms)

    train_loss = net.loss(train_batch, train_label)
    test_loss = net.loss(test_batch, test_label)

    # Optimizer
    # Temporarily set to momentum optimizer
    optimizer = tf.train.MomentumOptimizer(learning_rate=args.learning_rate,
                                           momentum=args.momentum)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(train_loss, var_list=trainable)

    # Accuracy of test data
    pred_test = net.predict_proba(test_batch, audio_only=True)
    equals = tf.equal(tf.squeeze(test_label), tf.round(pred_test))
    acc = tf.reduce_mean(tf.cast(equals, tf.float32))

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.global_variables_initializer()
    init2 = tf.local_variables_initializer()
    sess.run([init, init2])

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(),
                           max_to_keep=args.max_checkpoints)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1
    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    step = None
    last_saved_step = saved_global_step
    try:
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            if step == saved_global_step + 1:
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary_, train_loss_, test_loss_, acc_, _ = sess.run(
                    [summaries, train_loss, test_loss, acc, optim],
                    options=run_options,
                    run_metadata=run_metadata)
                writer.add_summary(summary_, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary_, train_loss_, test_loss_, acc_, _ = sess.run(
                    [summaries, train_loss, test_loss, acc, optim],
                    options=run_options,
                    run_metadata=run_metadata)
                writer.add_summary(summary_, step)

            duration = time.time() - start_time
            print("step {:d}:  trainloss = {:.3f}, "
                  "testloss = {:.3f}, acc = {:.3f}, ({:.3f} sec/step)".format(
                      step, train_loss_, test_loss_, acc_, duration))

            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step and step > last_saved_step:
            save(saver, sess, logdir, step)
        elif not step:
            print("No training performed during session.")
        else:
            pass
Beispiel #10
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    
    if args.wavenet_params.startswith('wavenet_params/'):
        with open(args.wavenet_params, 'r') as config_file:
            wavenet_params = json.load(config_file)
    elif args.wavenet_params.startswith('wavenet_params'):
        with open('wavenet_params/'+args.wavenet_params, 'r') as config_file:
            wavenet_params = json.load(config_file)
    else:
        with open('wavenet_params/wavenet_params_'+args.wavenet_params, 'r') as config_file:
            wavenet_params = json.load(config_file)    

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=args.gc_cardinality)

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples, args.gc_id)
    else:
        next_sample = net.predict_proba(samples, args.gc_id)

    if args.fast_generation:
        sess.run(tf.global_variables_initializer())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed,
                           wavenet_params['sample_rate'],
                           quantization_channels,
                           net.receptive_field,
                           args.silence_threshold)
        waveform = sess.run(seed).tolist()
    else:
        # Silence with a single random sample at the end.
        waveform = [quantization_channels / 2] * (net.receptive_field - 1)
        waveform.append(np.random.randint(quantization_channels))

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[-net.receptive_field: -1]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()

        #frequency to period change
    if args.tfrequency is not None:
        if args.tperiod is not None:
            raise ValueError("Frequency and Period both assigned. Assign only one of them.")
        else:
            PERIOD = dynamic.frequency_to_period(args.tfrequency)
    else:
        if args.tperiod is not None:
            PERIOD = args.tperiod
        else:
            PERIOD = 1

    # generate an array of temperature for each step when called "dynamic" in tempurature-change variable
    if args.temperature_change == "dynamic" and args.tform is not None:
        temp_array = dynamic.generate_value(0, wavenet_params['sample_rate'], args.tform, args.tmin, args.tmax, PERIOD, args.samples, args.tphaseshift)


    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]

        # Scale prediction distribution using temperature.
        np.seterr(divide='ignore')

        # temperature change
        if args.temperature_change == None: #static
            _temp_temperature = args.temperature
        elif args.temperature_change == "dynamic":
            if args.tform == None: #random
                if step % int(args.samples/5) == 0:
                    _temp_temperature = args.temperature * np.random.rand()
            else:
                _temp_temperature = temp_array[1][step]  
        else:
                raise Exception("wrong temperature_change value")

        scaled_prediction = np.log(prediction) / _temp_temperature
        scaled_prediction = (scaled_prediction -
                             np.logaddexp.reduce(scaled_prediction))
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after
        # scaling.
        if args.temperature == 1.0 and args.temperature_change == None:
            np.testing.assert_allclose(
                    prediction, scaled_prediction, atol=1e-5,
                    err_msg = 'Prediction scaling at temperature=1.0 is not working as intended.')

        sample = np.random.choice(
            np.arange(quantization_channels), p=scaled_prediction)

        
        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}, temperature {:3<f}'.format(step + 1, args.samples, _temp_temperature),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every and
                (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.summary.FileWriter(logdir)
    tf.summary.audio('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.summary.merge_all()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
Beispiel #11
0
class TestGeneration(tf.test.TestCase):

    def testGenerateSimple(self):
        # Reader config
        with open(TEST_DATA + "/config.json") as json_file:
            self.reader_config = json.load(json_file)

        # Initialize the reader
        receptive_field_size = WaveNetModel.calculate_receptive_field(2, LAYERS, False, 8)

        self.reader = CsvReader(
            [TEST_DATA + "/test.dat", TEST_DATA + "/test.emo", TEST_DATA + "/test.pho"],
            batch_size=1,
            receptive_field=receptive_field_size,
            sample_size=SAMPLE_SIZE,
            config=self.reader_config
        )

        # WaveNet model
        self.net = WaveNetModel(batch_size=1,
                                dilations=LAYERS,
                                filter_width=2,
                                residual_channels=8,
                                dilation_channels=8,
                                skip_channels=8,
                                quantization_channels=2,
                                use_biases=True,
                                scalar_input=False,
                                initial_filter_width=8,
                                histograms=False,
                                global_channels=GC_CHANNELS,
                                local_channels=LC_CHANNELS)

        loss = self.net.loss(input_batch=self.reader.data_batch,
                             global_condition=self.reader.gc_batch,
                             local_condition=self.reader.lc_batch,
                             l2_regularization_strength=L2)

        optimizer = optimizer_factory['adam'](learning_rate=0.003, momentum=0.9)
        trainable = tf.trainable_variables()
        train_op = optimizer.minimize(loss, var_list=trainable)

        samples = tf.placeholder(tf.float32, shape=(receptive_field_size, self.reader.data_dim), name="samples")
        gc = tf.placeholder(tf.int32, shape=(receptive_field_size), name="gc")
        lc = tf.placeholder(tf.int32, shape=(receptive_field_size), name="lc")

        gc = tf.one_hot(gc, GC_CHANNELS)
        lc = tf.one_hot(lc, LC_CHANNELS)

        predict = self.net.predict_proba(samples, gc, lc)

        '''does nothing'''
        with self.test_session() as session:
            session.run([
                tf.local_variables_initializer(),
                tf.global_variables_initializer(),
                tf.tables_initializer(),
            ])
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=session, coord=coord)

            for ITER in range(1):

                for i in range(1000):
                    _, loss_val = session.run([train_op, loss])
                    print("step %d loss %.4f" % (i, loss_val), end='\r')
                    sys.stdout.flush()
                print()

                data_samples = np.random.random((receptive_field_size, self.reader.data_dim))
                gc_samples = np.zeros((receptive_field_size))
                lc_samples = np.zeros((receptive_field_size))

                output = []

                for EMO in range(3):
                    for PHO in range(3):
                        for _ in range(100):
                            prediction = session.run(predict, feed_dict={'samples:0': data_samples, 'gc:0': gc_samples, 'lc:0': lc_samples})
                            data_samples = data_samples[1:, :]
                            data_samples = np.append(data_samples, prediction, axis=0)

                            gc_samples = gc_samples[1:]
                            gc_samples = np.append(gc_samples, [EMO], axis=0)
                            lc_samples = lc_samples[1:]
                            lc_samples = np.append(lc_samples, [PHO], axis=0)

                            output.append(prediction[0])

                output = np.array(output)
                print("ITER %d" % ITER)
                plt.imsave("./test/SINE_test_%d.png" % ITER, np.kron(output[:, :], np.ones([1, 500])), vmin=0.0, vmax=1.0)
Beispiel #12
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=args.gc_cardinality)

    gi_sampler = get_generator_input_sampler()

    # White noise generator params
    white_mean = 0
    white_sigma = 1
    white_length = 20234

    white_noise = gi_sampler(white_mean, white_sigma, white_length)

    loss = net.loss(input_batch=tf.convert_to_tensor(white_noise,
                                                     dtype=np.float32),
                    name='generator')

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples, args.gc_id)
    else:
        next_sample = net.predict_proba(samples, args.gc_id)

    if args.fast_generation:
        sess.run(tf.global_variables_initializer())
        sess.run(net.init_ops)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    '''
    # Silence with a single random sample at the end.
    waveform = [quantization_channels / 2] * (net.receptive_field - 1)
    waveform.append(np.random.randint(quantization_channels))
    '''
    waveform = [0]

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]

        # Scale prediction distribution using temperature.
        np.seterr(divide='ignore')
        scaled_prediction = np.log(prediction) / args.temperature
        scaled_prediction = (scaled_prediction -
                             np.logaddexp.reduce(scaled_prediction))
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after
        # scaling.
        if args.temperature == 1.0:
            np.testing.assert_allclose(
                prediction,
                scaled_prediction,
                atol=1e-5,
                err_msg='Prediction scaling at temperature=1.0 '
                'is not working as intended.')

        sample = np.random.choice(np.arange(quantization_channels),
                                  p=scaled_prediction)
        waveform.append(sample)

    # Introduce a newline to clear the carriage return from the progress.
    print()
    del waveform[0]
    print(waveform)
    print()

    print('Finished generating. The result can be viewed in TensorBoard.')
Beispiel #13
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    if (args.using_magna):
        styles = wavenet_params['styles']
        header = get_data(N_CLASSES)[0]
        conditions = np.zeros(N_CLASSES)

        for style in styles:
            conditions[header.index(style)] = 1
        cardinality = N_CLASSES
    else:
        conditions = args.gc_id
        cardinality = args.gc_cardinality

    sess = tf.Session()
    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=cardinality,
        glove_channels=args.glove_channels,
        residual_postproc=wavenet_params["residual_postproc"],
        use_magna=args.using_magna)

    samples = tf.placeholder(tf.int32)

    if args.text is not None:
        nlp = spacy.load('en', vectors='en_glove_cc_300_1m_vectors')
        word_vec = nlp(u'%s' % args.text).vector
    else:
        word_vec = None

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(
            samples, conditions, word_vec)
    else:
        next_sample = net.predict_proba(samples, conditions, word_vec)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed,
                           wavenet_params['sample_rate'],
                           quantization_channels)
        waveform = sess.run(seed).tolist()
    else:
        waveform = np.random.randint(quantization_channels, size=(1,)).tolist()

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[:-(args.window + 1)]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]

        # Scale prediction distribution using temperature.
        np.seterr(divide='ignore')
        scaled_prediction = np.log(prediction) / args.temperature
        scaled_prediction = scaled_prediction - \
            np.logaddexp.reduce(scaled_prediction)
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after
        # scaling.
        if args.temperature == 1.0:
            np.testing.assert_allclose(
                prediction, scaled_prediction, atol=1e-5,
                err_msg='Prediction scaling at temperature=1.0 is not working as intended.')

        sample = np.random.choice(
            np.arange(quantization_channels), p=scaled_prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every and
                (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(logdir)
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
Beispiel #14
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'])

    samples = tf.placeholder(tf.int32)  #待初始化的张量占位符
    input_ID = tf.placeholder(tf.int32)
    startime_fastgeration = time.clock()
    if args.fast_generation:
        print("#########using_fast_generation")
        next_sample = net.predict_proba_incremental(samples, input_ID)

        #print next_sample

    else:
        next_sample = net.predict_proba(samples, input_ID)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)
    endtime_fastgernation = time.clock()
    #print ('fast_generation time {}'.format(endtime_fastgernation - endtime_fastgernation))
    time_of_fast = endtime_fastgernation - startime_fastgeration  #1

    start_vari_saver = time.clock()  #变量save
    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)
    end_vari_saver = time.clock()
    print('variables_to_restore{}'.format(end_vari_saver - start_vari_saver))

    starttime_restore = time.clock()  #恢复从checkpoint
    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)
    endtime_restore = time.clock()
    #print ('restore model time{}'.format(endtime_restore - endtime_restore))
    time_of_restore = endtime_restore - starttime_restore  #2
    print('%%%%%%%%%%%%{}'.format(time_of_restore))
    #return 0
    decode = mu_law_decode(
        samples, wavenet_params['quantization_channels'])  #namescope(encode)

    quantization_channels = wavenet_params['quantization_channels']
    time_of_seed = 0
    if args.wav_seed:
        start_using_create_seed = time.clock()
        print('#######using_create_seed')
        seed = create_seed(args.wav_seed, wavenet_params['sample_rate'],
                           quantization_channels)
        waveform = sess.run(seed).tolist()  #
        end_using_create_seed = time.clock()
        time_of_seed = end_using_create_seed - start_using_create_seed
        #print ('using create_seed time{}'.format(end_using_create_seed - start_using_create_seed))
    else:
        print('#######not_using_create_seed')
        waveform = np.random.randint(quantization_channels,
                                     size=(1, )).tolist()
    predict_of_fast_seed = 0
    if args.fast_generation and args.wav_seed:
        starttime_fast_and_seed = time.clock()
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)  #push_ops是一个列表

        print('Priming generation...')
        for i, x in enumerate(waveform[:-1]):
            if i % 1600 == 0:
                print('Priming sample {}'.format(i), end='\r')
                sys.stdout.flush()
            sess.run(outputs,
                     feed_dict={
                         samples: x,
                         input_ID: args.ID_generation
                     })
        print('Done.')
        endtime_fast_seed = time.clock()
        #print('fast_generation and create_seed time{}'.format(endtime_fast_seed - starttime_fast_and_seed))
        predict_of_fast_seed = predict_of_fast_seed + (endtime_fast_seed -
                                                       starttime_fast_and_seed)

    #return 0
    last_sample_timestamp = datetime.now()
    predict = 0
    index_begin_generate = 0 if (False
                                 == args.fast_generation) else len(waveform)
    startime_total_predict_sample = time.clock()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:  #波形大于窗口?
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        starttime_run_net_predict = time.clock()
        prediction = sess.run(outputs,
                              feed_dict={
                                  samples: window,
                                  input_ID: args.ID_generation
                              })[0]
        endtime_run_net_predict = time.clock()
        print('run net to predict samples per step{}'.format(
            endtime_run_net_predict - starttime_run_net_predict))
        predict = predict + (endtime_run_net_predict -
                             starttime_run_net_predict)

        # Scale prediction distribution using temperature.尺度化预测!!

        np.seterr(divide='ignore')  #某种数据规约尺度化
        scaled_prediction = np.log(prediction) / args.temperature
        scaled_prediction = scaled_prediction - np.logaddexp.reduce(
            scaled_prediction)
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after scaling.
        if args.temperature == 1.0:
            np.testing.assert_allclose(
                prediction,
                scaled_prediction,
                atol=1e-5,
                err_msg=
                'Prediction scaling at temperature=1.0 is not working as intended.'
            )

        sample = np.random.choice(  #以概率返回某层通道采样
            np.arange(quantization_channels),
            p=scaled_prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:  #以1??
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            sys.stdout.flush()
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            print(
                '$$$$$$$$$$If we have partial writing, save the result so far')
            out = sess.run(decode, feed_dict={samples:
                                              waveform})  #有输入要求的tensor
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)
    endtime_total_predicttime = time.clock()
    print('total predic time{}'.format(endtime_total_predicttime -
                                       startime_total_predict_sample))
    print('run net predict time{}'.format(predict))

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    '''
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(logdir)
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform[index_begin_generate:], [-1, 1]), input_ID: args.ID_generation})
    writer.add_summary(summary_out)
    '''

    # Save the result as a wav file.
    if args.wav_out_path:
        start_save_wav_time = time.clock()
        out = sess.run(decode,
                       feed_dict={samples: waveform[index_begin_generate:]})

        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)
        end_save_wave_time = time.clock()
        print('write wave time{}'.format(end_save_wave_time -
                                         start_save_wav_time))

    print('time_of_fast_initops{}'.format(time_of_fast))
    print('time_of_restore'.format(time_of_restore))
    print('time_of_fast_and_seed{}'.format(predict_of_fast_seed))
    print('time_of_seed'.format(time_of_seed))
    print('Finished generating. The result can be viewed in TensorBoard.')
Beispiel #15
0
def main():
    midi_dims = 88
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    checkpoint = tf.train.latest_checkpoint(args.resdir)
    print('checkpoint: ', checkpoint)
    wavenet_params_fname = args.resdir + 'wavenet_params.json'
    print('wavenet params fname', wavenet_params_fname)
    with open(wavenet_params_fname, 'r') as config_file:
        wavenet_params = json.load(config_file)
        wavenet_params['midi_dims'] = midi_dims

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        midi_dims=wavenet_params['midi_dims'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=args.gc_cardinality)

    samples = tf.placeholder(tf.float32, shape=(None, midi_dims))

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples, args.gc_id)
    else:
        next_sample = net.predict_proba(samples, args.gc_id)

    if args.fast_generation:
        sess.run(tf.global_variables_initializer())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)
    print('vars to restore')
    for key in variables_to_restore.keys():
        print(key, variables_to_restore[key])

    print('Restoring model from {}'.format(checkpoint))
    saver.restore(sess, checkpoint)

    if args.wav_seed:
        seed = create_seed(args.wav_seed, wavenet_params['sample_rate'],
                           midi_dims, net.receptive_field)
        waveform = sess.run(seed).tolist()
    else:
        # Silence with a single random sample at the end.
        #waveform = [quantization_channels / 2] * (net.receptive_field - 1)
        #waveform.append(np.random.randint(quantization_channels))
        random_note = np.zeros((1, midi_dims))
        random_note[0, np.random.randint(0, midi_dims - 1)] = 1.0
        waveform = np.concatenate((np.zeros(
            (net.receptive_field - 1, midi_dims)), random_note),
                                  axis=0)

    if args.fast_generation and args.wav_seed:
        print('fast gen')
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[-net.receptive_field:-1]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    print('receptive field is %d' % net.receptive_field)
    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = np.expand_dims(waveform[-1, :], 0)
        else:
            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:]
            else:
                window = waveform
            outputs = [next_sample]

        print(step, 'wave shape', waveform.shape, 'window shape', window.shape)
        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]

        # Scale prediction distribution using temperature.
        #np.seterr(divide='ignore')
        #scaled_prediction = np.log(prediction) / args.temperature
        #scaled_prediction = (scaled_prediction -
        #                     np.logaddexp.reduce(scaled_prediction))
        #scaled_prediction = np.exp(scaled_prediction)
        #np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after
        # scaling.
        #if args.temperature == 1.0:
        #    np.testing.assert_allclose(
        #            prediction, scaled_prediction, atol=1e-5,
        #           err_msg='Prediction scaling at temperature=1.0 '
        #                    'is not working as intended.')
        sample = 1 * (prediction > 0.5)
        print('num notes', np.count_nonzero(sample))
        #sample = np.random.choice(
        #    np.arange(quantization_channels), p=scaled_prediction)
        waveform = np.concatenate((waveform, np.expand_dims(sample, 0)),
                                  axis=0)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    #writer = tf.summary.FileWriter(logdir)
    #tf.summary.audio('generated', decode, wavenet_params['sample_rate'])
    #summaries = tf.summary.merge_all()
    #print('waveform', waveform);
    #summary_out = sess.run(summaries, feed_dict={samples: waveform})
    #writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path is None:
        args.wav_out_path = args.resdir

    #out = sess.run(decode, feed_dict={samples: waveform})
    print(args.wav_out_path)
    filename = args.wav_out_path + ('sample_%d.mid' % int(args.gen_num))
    midiwrite(filename, waveform)
Beispiel #16
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'])

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples)
    else:
        next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed,
                           wavenet_params['sample_rate'],
                           quantization_channels)
        waveform = sess.run(seed).tolist()
    else:
        waveform = np.random.randint(quantization_channels, size=(1,)).tolist()

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[:-(args.window + 1)]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]
        sample = np.random.choice(
            np.arange(quantization_channels), p=prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every and
                (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(logdir)
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
Beispiel #17
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'])

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples)
    else:
        next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    # decode = samples
    # Creating a copy of samples
    decode = tf.identity(samples)

    quantization_channels = wavenet_params['quantization_channels']

    with open(fname) as f:
        content = f.readlines()
    waveform = [float(x.replace('\r', '').strip()) for x in content]

    dataframe = pd.read_csv(fname, engine='python')
    dataset = dataframe.values

    # split into train and test sets
    train_size = int(len(dataset) * SPLIT)
    test_size = len(dataset) - train_size

    testX, testY = create_test_dataset(dataset, train_size)

    print(str(len(testX)) + " steps to go...")

    testP = []

    # train, test = waveform[0:train_size,:], waveform[train_size:len(dataset),:]
    # waveform = [144.048970239,143.889691754,143.68922135,143.644718903,143.698498762,143.710396703,143.756327831,143.843187531,143.975287002,143.811912129]

    last_sample_timestamp = datetime.now()
    for step in range(len(testX)):
        # for step in range(2):
        # if len(waveform) > args.window:
        #     window = waveform[-args.window:]
        # else:

        window = testX[step]
        outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]
        # sample = np.random.choice(np.arange(quantization_channels), p=prediction)
        sample = np.arange(quantization_channels)[np.argmax(prediction)]
        # waveform.append(sample)
        testP.append(sample)
        print(step, sample)

    # Introduce a newline to clear the carriage return from the progress.
    testPredict = np.array(testP).reshape((-1, 1))

    # Save the result as a wav file.
    # if args.text_out_path:
    #     out = sess.run(decode, feed_dict={samples: waveform})
    #     write_text(out, args.text_out_path)
    testScore = math.sqrt(mean_squared_error(testY, testPredict[:, 0]))
    print('Test Score: %.2f RMSE' % (testScore))

    testPredictPlot = np.empty_like(dataset, dtype=float)
    testPredictPlot[:, :] = np.nan
    testPredictPlot[train_size + 1:len(dataset) - 1, :] = testPredict

    # plot baseline and predictions
    plt.plot(dataset)
    plt.plot(testPredictPlot)
    plt.show()

    print('Finished generating.')
Beispiel #18
0
class TestGeneration(tf.test.TestCase):
    def setUp(self):
        self.net = WaveNetModel(batch_size=1,
                                dilations=[1, 2, 4, 8, 16, 32, 64, 128, 256],
                                filter_width=2,
                                residual_channels=16,
                                dilation_channels=16,
                                quantization_channels=128,
                                skip_channels=32)

    def testGenerateSimple(self):
        '''Generate a few samples using the naive method and
        perform sanity checks on the output.'''
        waveform = tf.placeholder(tf.int32)
        np.random.seed(0)
        data = np.random.randint(128, size=1000)
        proba = self.net.predict_proba(waveform)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            proba = sess.run(proba, feed_dict={waveform: data})

        self.assertAllEqual(proba.shape, [128])
        self.assertTrue(np.all((proba >= 0) & (proba <= (128 - 1))))

    def testGenerateFast(self):
        '''Generate a few samples using the fast method and
        perform sanity checks on the output.'''
        waveform = tf.placeholder(tf.int32)
        np.random.seed(0)
        data = np.random.randint(128)
        proba = self.net.predict_proba_incremental(waveform)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(self.net.init_ops)
            proba = sess.run(proba, feed_dict={waveform: data})

        self.assertAllEqual(proba.shape, [128])
        self.assertTrue(np.all((proba >= 0) & (proba <= (128 - 1))))

    def testCompareSimpleFast(self):
        waveform = tf.placeholder(tf.int32)
        np.random.seed(0)
        data = np.random.randint(128, size=1000)
        proba = self.net.predict_proba(waveform)
        proba_fast = self.net.predict_proba_incremental(waveform)
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(self.net.init_ops)
            # Prime the incremental generation with all samples
            # except the last one
            for x in data[:-1]:
                proba_fast_ = sess.run([proba_fast, self.net.push_ops],
                                       feed_dict={waveform: x})

            # Get the last sample from the incremental generator
            proba_fast_ = sess.run(proba_fast, feed_dict={waveform: data[-1]})
            # Get the sample from the simple generator
            proba_ = sess.run(proba, feed_dict={waveform: data})
            self.assertAllClose(proba_, proba_fast_)
Beispiel #19
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
        histograms=False,
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=args.gc_cardinality,
        MFSC_channels=wavenet_params["MFSC_channels"],
        AP_channels=wavenet_params["AP_channels"],
        F0_channels=wavenet_params["F0_channels"],
        phone_channels=wavenet_params["phones_channels"],
        phone_pos_channels=wavenet_params["phone_pos_channels"])

    samples = tf.placeholder(tf.float32)
    lc = tf.placeholder(tf.float32)
    
    AP_channels = wavenet_params["AP_channels"]
    MFSC_channels = wavenet_params["MFSC_channels"]

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples,  args.gc_id)         ########### understand shape of next_sample
    else:
        # samples = tf.reshape(samples, [1, samples.shape[-2], samples.shape[-1]])
        # lc = tf.reshape(lc, [1, lc.shape[-2], lc.shape[-1]])
        # mvn1, mvn2 , mvn3 ,mvn4, w1, w2, w3, w4 = net.predict_proba(samples, lc, args.gc_id)
        outputs = net.predict_proba(samples, AP_channels, lc, args.gc_id)
        outputs = tf.reshape(outputs, [1, AP_channels])

    if args.fast_generation:
        sess.run(tf.global_variables_initializer())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    # decode = mu_law_decode(samples, wavenet_params['quantization_channels'])      

    # quantization_channels = wavenet_params['quantization_channels']
    # MFSC_channels = wavenet_params["MFSC_channels"]
    if args.wav_seed:
        # seed = create_seed(args.wav_seed,
        #                    wavenet_params['sample_rate'],
        #                    quantization_channels,
        #                    net.receptive_field)
        # waveform = sess.run(seed).tolist()
        pass
    else:

        # Silence with a single random sample at the end.
        # waveform = [quantization_channels / 2] * (net.receptive_field - 1)
        # waveform.append(np.random.randint(quantization_channels))

        waveform = np.zeros((net.receptive_field - 1, AP_channels + MFSC_channels))
        waveform = np.append(waveform, np.random.randn(1, AP_channels + MFSC_channels), axis=0)

        lc_array, mfsc_array = load_lc(scramble=True) # clip_id:[003, 004, 007, 010, 012 ...]
        # print ("before pading:", lc_array.shape)
        # lc_array = lc_array.reshape(1, lc_array.shape[-2], lc_array.shape[-1])
        lc_array = np.pad(lc_array, ((net.receptive_field, 0), (0, 0)), 'constant', constant_values=((0, 0),(0,0)))
        mfsc_array = np.pad(mfsc_array, ((net.receptive_field, 0), (0, 0)), 'constant', constant_values=((0, 0),(0,0)))
        # print ("after pading:", lc_array.shape)
    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)                                   ################# understand net.push_ops, understand shape of outputs

        print('Priming generation...')
        for i, x in enumerate(waveform[-net.receptive_field: -1]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:

            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:, :]
            else:
                window = waveform
            # outputs = [next_sample]
            ############################################## TODO ############################################
            # Modified code to get one output of the GMM
            # outputs = estimate_output(mvn1, mvn2, mvn3, mvn4, w1, w2, w3, w4, MFSC_channels)
            # outputs = tf.reshape(outputs, [1,60])
            
            # outputs = w1 * mvn1.sample([1]) + w2 * mvn2.sample([1]) + w3 * mvn3.sample([1]) + w4 * mvn4.sample([1])

        # Run the WaveNet to predict the next sample.
        window = window.reshape(1, window.shape[-2], window.shape[-1])
        # print ("lc_batch_shape:", lc_array[step+1: step+1+net.receptive_field, :].reshape(1, net.receptive_field, -1).shape)
        prediction = sess.run(outputs, feed_dict={samples: window, 
                                                lc: lc_array[step+1: step+1+net.receptive_field, :].reshape(1, net.receptive_field, -1)})
        # print(prediction.shape)
        prediction = np.concatenate((prediction, mfsc_array[step + net.receptive_field].reshape(1,-1)), axis=-1)
        waveform = np.append(waveform, prediction, axis=0)

        ########### temp control is in model.py currently ############

        # # Scale prediction distribution using temperature.
        # np.seterr(divide='ignore')
        # scaled_prediction = np.log(prediction) / args.temperature           ############# understand this
        # scaled_prediction = (scaled_prediction -
        #                      np.logaddexp.reduce(scaled_prediction))
        # scaled_prediction = np.exp(scaled_prediction)
        # np.seterr(divide='warn')

        # # Prediction distribution at temperature=1.0 should be unchanged after
        # # scaling.
        # if args.temperature == 1.0:
        #     np.testing.assert_allclose(
        #             prediction, scaled_prediction, atol=1e-5,
        #             err_msg='Prediction scaling at temperature=1.0 '
        #                     'is not working as intended.')

        # sample = np.random.choice(                                      ############ replace with sampling multivariate gaussian
        #     np.arange(quantization_channels), p=scaled_prediction)
        # waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Frame {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every and
                (step + 1) % args.save_every == 0):
            # out = sess.run(decode, feed_dict={samples: waveform})
            # write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)
            np.save(args.wav_out_path, waveform)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    # datestring = str(datetime.now()).replace(' ', 'T')
    # writer = tf.summary.FileWriter(logdir)
    # tf.summary.audio('generated', decode, wavenet_params['sample_rate'])
    # summaries = tf.summary.merge_all()
    # summary_out = sess.run(summaries,
    #                        feed_dict={samples: np.reshape(waveform, [-1, 1])})
    # writer.add_summary(summary_out)

    # Save the result as a numpy file.
    if args.wav_out_path:
        # out = sess.run(decode, feed_dict={samples: waveform})
        # write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)
        np.save(args.wav_out_path, waveform[:,:4])    # only take the first four columns, which are aps

    print('Finished generating. The result can be viewed in TensorBoard.')
Beispiel #20
0
def main(checkpoint=None):

    title_BOOL = True
    title = ""

    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'])

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples)
    else:
        next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    powr = int((len(wavenet_params['dilations']) / 2) - 1)
    md = args.checkpoint.split(
        "-"
    )[-1:]  #map(str.lstrip("[").rstrip("]").strip(",")  , args.checkpoint.split("-")[-1:])
    if checkpoint == None:
        intro = """\n________________________________________________________________________________________________\n\n\t\t~~~~~~******~~~~~~\n________________________________________________________________________________________________\n\n\tDIR: {}\n\tMODEL: {}\n\tLOSS: {}\n\n
dilations={}\t        filter_width={}\t                residual_channels={}
dilation_channels={}\tquantization_channels={}\tskip_channels={}\n________________________________________________________________________________________________\n\n""".format(
            args.checkpoint.split("/")[-2], md, args.loss, "2^" + str(powr),
            wavenet_params['filter_width'],
            wavenet_params['residual_channels'],
            wavenet_params['dilation_channels'],
            wavenet_params['quantization_channels'],
            wavenet_params['skip_channels'])
        print(intro)
        saver.restore(sess, args.checkpoint)
    else:
        print('Restoring model from PARAMETER {}'.format(checkpoint))
        saver.restore(sess, args.checkpoint)

    decode = samples

    quantization_channels = wavenet_params['quantization_channels']
    waveform = [32.]

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]
        sample = np.random.choice(np.arange(quantization_channels),
                                  p=prediction)
        waveform.append(sample)

        # CAPITALIZE TITLE
        if title_BOOL:
            # SHOW character by character in terminal
            sys.stdout.write(chr(sample).capitalize())
            title += chr(sample).capitalize()
            #check for newline
            if sample == 10:
                title_BOOL = False
                sys.stdout.write("\n\n")
        else:
            # SHOW character by character in terminal
            sys.stdout.write(chr(sample))

        if args.text_out_path == None:
            args.text_out_path = "GENERATED/{}_DIR-{}_Model-{}_Loss-{}_Chars-{}.txt".format(
                datetime.strftime(datetime.now(), '%Y-%m-%d_%H:%M'),
                args.checkpoint.split("/")[-2],
                args.checkpoint.split("-")[-1], args.loss, args.samples)

        # If we have partial writing, save the result so far.
        if (args.text_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_text(out, args.text_out_path, intro)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as a wav file.
    if args.text_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_text(out, args.text_out_path, intro)
Beispiel #21
0
def main(checkpoint=None):

    #print("\n\nGenerating.\nPlease wait.\n\n")

    title_BOOL = True
    #title=""

    args = get_arguments()

    # create folder in which to put txt files of generated poems
    directory = "GENERATED/" + args.start_time
    if not os.path.exists(directory):
        os.makedirs(directory)

    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'])

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples)
    else:
        next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    powr = int((len(wavenet_params['dilations']) / 2) - 1)
    md = ''.join(
        args.checkpoint.split("-")[-1:]
    )  #map(str.lstrip("[").rstrip("]").strip(",")  , args.checkpoint.split("-")[-1:])

    #STORAGE
    words = "\n\n"

    if checkpoint == None:
        intro = """DIR: {}\tMODEL: {}\t\tLOSS: {}\ndilations: {}\t\t\t\tfilter_width: {}\t\tresidual_channels: {}\ndilation_channels: {}\t\t\tskip_channels: {}\tquantization_channels: {}\n_______________________________________________________________________________________________""".format(
            args.checkpoint.split("/")[-2], md, args.loss, "2^" + str(powr),
            wavenet_params['filter_width'],
            wavenet_params['residual_channels'],
            wavenet_params['dilation_channels'],
            wavenet_params['skip_channels'],
            wavenet_params['quantization_channels'])

        saver.restore(sess, args.checkpoint)
    else:
        print('Restoring model from PARAMETER {}'.format(checkpoint))
        saver.restore(sess, args.checkpoint)

    decode = samples

    quantization_channels = wavenet_params['quantization_channels']
    waveform = [32.]

    last_sample_timestamp = datetime.now()
    limit = args.samples - 1

    print("")

    for step in range(args.samples):

        # COUNTDOWN
        #print(step,args.samples,int(args.samples)-int(step), end="\r")
        #print("")
        print("Generating:", step, "/", args.samples, end="\r")

        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]
        sample = np.random.choice(np.arange(quantization_channels),
                                  p=prediction)
        waveform.append(sample)

        # CAPITALIZE TITLE
        if title_BOOL:
            # STORAGE
            words += chr(sample).title()

            #check for newline
            if sample == 10:
                #print("GOT IT___________")
                title_BOOL = False
                words += "\n\n"
        else:
            # STORAGE
            words += chr(sample)

        #TYPEWRITER
        #sys.stdout.write(words[-1])

        if args.text_out_path == None:
            args.text_out_path = "GENERATED/{}/{}_DIR-{}_Model-{}_Loss-{}_Chars-{}.txt".format(
                args.start_time,
                datetime.strftime(datetime.now(), '%Y-%m-%d_%H-%M'),
                args.checkpoint.split("/")[-2],
                args.checkpoint.split("-")[-1], args.loss, args.samples)

        # If we have partial writing, save the result so far.
        if (args.text_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            #write_text(out, args.text_out_path,intro,words)
            #print (step, end="\r")

    # Introduce a newline to clear the carriage return from the progress.
    #print()
    ml = "Model: {}  |  Loss: {}  |  {}".format(
        args.checkpoint.split("-")[-1], args.loss,
        args.checkpoint.split("/")[-2])

    # Save the result as a wav file.
    if args.text_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        print("                                          ", end="\r")
        write_text(out, args.text_out_path, intro, words, ml)
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'])

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples)
    else:
        next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = samples

    quantization_channels = wavenet_params['quantization_channels']
    waveform = [32.]

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]
        sample = np.random.choice(np.arange(quantization_channels),
                                  p=prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.text_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_text(out, args.text_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as a wav file.
    if args.text_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_text(out, args.text_out_path)

    print('Finished generating.')
Beispiel #23
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    labels = tf.placeholder(tf.float32)

    data_dir = DATA_DIRECTORY
    file_list = FILE_LIST
    label_dir = data_dir + 'binary_label_norm/'
    audio_dir = data_dir + 'wav/'
    label_dim = 425
    n_out = 1

    iterator = audio_reader.load_generic_audio_label(file_list, audio_dir,
                                                     label_dir, label_dim)
    audio_test, labels_test, filename = iterator.next()
    n_samples_read = audio_test.shape[0]
    labels_test = labels_test.reshape(
        (1, labels_test.shape[0], labels_test.shape[1]))

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        label_dim=label_dim,
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        skip_channels=wavenet_params['skip_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        histograms=False)

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples, labels)
    else:
        next_sample = net.predict_proba(samples, labels)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed, wavenet_params['sample_rate'],
                           quantization_channels)
        waveform = sess.run(seed).tolist()
    else:
        waveform = np.random.randint(quantization_channels,
                                     size=(1, )).tolist()

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[-args.window:-1]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs,
                     feed_dict={
                         samples: x,
                         labels: labels_test[:, i:i + 1, :]
                     })
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(n_samples_read):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
            labels_window = labels_test[:, step:step + 1, :]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]
            labels_window = labels_test[:, step:step + min(
                len(window), args.window
            ), :]  # Here there might be a problem with out of index error.

        # Run the WaveNet to predict the next sample.
        #if (step%100 == 0):
        #    print('step = ', step, ' , ')
        prediction = sess.run(outputs,
                              feed_dict={
                                  samples: window,
                                  labels: labels_window
                              })[0]

        # Scale prediction distribution using temperature.
        np.seterr(divide='ignore')
        scaled_prediction = np.log(prediction) / args.temperature
        scaled_prediction = scaled_prediction - np.logaddexp.reduce(
            scaled_prediction)
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after scaling.
        if args.temperature == 1.0:
            np.testing.assert_allclose(
                prediction,
                scaled_prediction,
                atol=1e-5,
                err_msg=
                'Prediction scaling at temperature=1.0 is not working as intended.'
            )

        sample = np.random.choice(np.arange(quantization_channels),
                                  p=scaled_prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(logdir)
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')