Ejemplo n.º 1
0
 def setUp(self):
     self.net = WaveNetModel(batch_size=1,
                             dilations=[1, 2, 4, 8, 16, 32, 64, 128, 256],
                             filter_width=2,
                             residual_channels=16,
                             dilation_channels=16,
                             quantization_channels=128,
                             skip_channels=32)
Ejemplo n.º 2
0
class TestNet(tf.test.TestCase):

    def setUp(self):
        self.net = WaveNetModel(batch_size=1,
                                dilations=[1, 2, 4, 8, 16, 32, 64, 128, 256,
                                           1, 2, 4, 8, 16, 32, 64, 128, 256],
                                filter_width=2,
                                residual_channels=16,
                                dilation_channels=16,
                                quantization_channels=256,
                                skip_channels=32)

    # Train a net on a short clip of 3 sine waves superimposed
    # (an e-flat chord).
    #
    # Presumably it can overfit to such a simple signal. This test serves
    # as a smoke test where we just check that it runs end-to-end during
    # training, and learns this waveform.

    def testEndToEndTraining(self):
        audio = MakeSineWaves()
        np.random.seed(42)

        audio_tensor = tf.convert_to_tensor(audio, dtype=tf.float32)
        loss = self.net.loss(audio_tensor)
        optimizer = tf.train.AdamOptimizer(learning_rate=0.02)
        trainable = tf.trainable_variables()
        optim = optimizer.minimize(loss, var_list=trainable)
        init = tf.initialize_all_variables()

        max_allowed_loss = 0.1
        loss_val = max_allowed_loss
        initial_loss = None
        with self.test_session() as sess:
            sess.run(init)
            initial_loss = sess.run(loss)
            for i in range(50):
                loss_val, _ = sess.run([loss, optim])
                # print("i: %d loss: %f" % (i, loss_val))

        # Sanity check the initial loss was larger.
        self.assertGreater(initial_loss, max_allowed_loss)

        # Loss after training should be small.
        self.assertLess(loss_val, max_allowed_loss)

        # Loss should be at least two orders of magnitude better
        # than before training.
        self.assertLess(loss_val / initial_loss, 0.01)
Ejemplo n.º 3
0
    def setUp(self):
        print('TestNet setup.')
        sys.stdout.flush()

        self.optimizer_type = 'sgd'
        self.learning_rate = 0.02
        self.generate = False
        self.momentum = MOMENTUM
        self.global_conditioning = False
        self.train_iters = TRAIN_ITERATIONS
        self.net = WaveNetModel(batch_size=1,
                                dilations=[1, 2, 4, 8, 16, 32, 64,
                                           1, 2, 4, 8, 16, 32, 64],
                                filter_width=2,
                                residual_channels=32,
                                dilation_channels=32,
                                quantization_channels=QUANTIZATION_CHANNELS,
                                skip_channels=32,
                                global_condition_channels=None,
                                global_condition_cardinality=None)
Ejemplo n.º 4
0
def main(checkpoint=None):

    #print("\n\nGenerating.\nPlease wait.\n\n")

    title_BOOL = True
    #title=""

    args = get_arguments()

    # create folder in which to put txt files of generated poems
    directory = "GENERATED/" + args.start_time
    if not os.path.exists(directory):
        os.makedirs(directory)

    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'])

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples)
    else:
        next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    powr = int((len(wavenet_params['dilations']) / 2) - 1)
    md = ''.join(
        args.checkpoint.split("-")[-1:]
    )  #map(str.lstrip("[").rstrip("]").strip(",")  , args.checkpoint.split("-")[-1:])

    #STORAGE
    words = "\n\n"

    if checkpoint == None:
        intro = """DIR: {}\tMODEL: {}\t\tLOSS: {}\ndilations: {}\t\t\t\tfilter_width: {}\t\tresidual_channels: {}\ndilation_channels: {}\t\t\tskip_channels: {}\tquantization_channels: {}\n_______________________________________________________________________________________________""".format(
            args.checkpoint.split("/")[-2], md, args.loss, "2^" + str(powr),
            wavenet_params['filter_width'],
            wavenet_params['residual_channels'],
            wavenet_params['dilation_channels'],
            wavenet_params['skip_channels'],
            wavenet_params['quantization_channels'])

        saver.restore(sess, args.checkpoint)
    else:
        print('Restoring model from PARAMETER {}'.format(checkpoint))
        saver.restore(sess, args.checkpoint)

    decode = samples

    quantization_channels = wavenet_params['quantization_channels']
    waveform = [32.]

    last_sample_timestamp = datetime.now()
    limit = args.samples - 1

    print("")

    for step in range(args.samples):

        # COUNTDOWN
        #print(step,args.samples,int(args.samples)-int(step), end="\r")
        #print("")
        print("Generating:", step, "/", args.samples, end="\r")

        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]
        sample = np.random.choice(np.arange(quantization_channels),
                                  p=prediction)
        waveform.append(sample)

        # CAPITALIZE TITLE
        if title_BOOL:
            # STORAGE
            words += chr(sample).title()

            #check for newline
            if sample == 10:
                #print("GOT IT___________")
                title_BOOL = False
                words += "\n\n"
        else:
            # STORAGE
            words += chr(sample)

        #TYPEWRITER
        #sys.stdout.write(words[-1])

        if args.text_out_path == None:
            args.text_out_path = "GENERATED/{}/{}_DIR-{}_Model-{}_Loss-{}_Chars-{}.txt".format(
                args.start_time,
                datetime.strftime(datetime.now(), '%Y-%m-%d_%H-%M'),
                args.checkpoint.split("/")[-2],
                args.checkpoint.split("-")[-1], args.loss, args.samples)

        # If we have partial writing, save the result so far.
        if (args.text_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            #write_text(out, args.text_out_path,intro,words)
            #print (step, end="\r")

    # Introduce a newline to clear the carriage return from the progress.
    #print()
    ml = "Model: {}  |  Loss: {}  |  {}".format(
        args.checkpoint.split("-")[-1], args.loss,
        args.checkpoint.split("/")[-2])

    # Save the result as a wav file.
    if args.text_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        print("                                          ", end="\r")
        write_text(out, args.text_out_path, intro, words, ml)
Ejemplo n.º 5
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'])

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples)
    else:
        next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed, wavenet_params['sample_rate'],
                           quantization_channels)
        waveform = sess.run(seed).tolist()
    else:
        waveform = np.random.randint(quantization_channels,
                                     size=(1, )).tolist()

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[:-(args.window + 1)]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]

        # Scale prediction distribution using temperature.
        np.seterr(divide='ignore')
        scaled_prediction = np.log(prediction) / args.temperature
        scaled_prediction = scaled_prediction - np.logaddexp.reduce(
            scaled_prediction)
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after scaling.
        if args.temperature == 1.0:
            np.testing.assert_allclose(
                prediction,
                scaled_prediction,
                atol=1e-5,
                err_msg=
                'Prediction scaling at temperature=1.0 is not working as intended.'
            )

        sample = np.random.choice(np.arange(quantization_channels),
                                  p=scaled_prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(logdir)
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
Ejemplo n.º 6
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw data.
    with tf.name_scope('create_inputs'):
        reader = CSVReader(
            args.data_dir,
            coord,
            sample_size=args.sample_size)
        data_batch = reader.dequeue(args.batch_size)

    # Create network.
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"])
    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    loss = net.loss(data_batch, args.l2_regularization_strength)
    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.initialize_all_variables()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver()

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)

    try:
        last_saved_step = saved_global_step
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            loss_value, _ = sess.run([loss, optim])
            print("fin step", step)
            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'
                  .format(step, loss_value, duration))

            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
Ejemplo n.º 7
0
    ]

    args = get_arguments()

    # Load parameters from wavenet params json file
    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    quantization_channels = wavenet_params['quantization_channels']

    # Intialize generator WaveNet
    G = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"],
        initial_filter_width=wavenet_params["initial_filter_width"])

    gi_sampler = get_generator_input_sampler()

    # White noise generator params
    white_mean = 0
    white_sigma = 1
    white_length = 27117

    Z = tf.placeholder(tf.float32, shape=[None, white_length], name='Z')

    # initialize generator
Ejemplo n.º 8
0
class TestGeneration(tf.test.TestCase):

    def setUp(self):
        self.net = WaveNetModel(batch_size=1,
                                dilations=[1, 2, 4, 8, 16, 32, 64, 128, 256],
                                filter_width=2,
                                residual_channels=16,
                                dilation_channels=16,
                                quantization_channels=128,
                                skip_channels=32)

    def testGenerateSimple(self):
        '''Generate a few samples using the naive method and
        perform sanity checks on the output.'''
        waveform = tf.placeholder(tf.int32)
        np.random.seed(0)
        data = np.random.randint(128, size=1000)
        proba = self.net.predict_proba(waveform)

        with self.test_session() as sess:
            sess.run(tf.initialize_all_variables())
            proba = sess.run(proba, feed_dict={waveform: data})

        self.assertAllEqual(proba.shape, [128])
        self.assertTrue(np.all((proba >= 0) & (proba <= (128 - 1))))

    def testGenerateFast(self):
        '''Generate a few samples using the fast method and
        perform sanity checks on the output.'''
        waveform = tf.placeholder(tf.int32)
        np.random.seed(0)
        data = np.random.randint(128)
        proba = self.net.predict_proba_incremental(waveform)

        with self.test_session() as sess:
            sess.run(tf.initialize_all_variables())
            sess.run(self.net.init_ops)
            proba = sess.run(proba, feed_dict={waveform: data})

        self.assertAllEqual(proba.shape, [128])
        self.assertTrue(np.all((proba >= 0) & (proba <= (128 - 1))))

    def testCompareSimpleFast(self):
        waveform = tf.placeholder(tf.int32)
        np.random.seed(0)
        data = np.random.randint(128, size=1)
        proba = self.net.predict_proba(waveform)
        proba_fast = self.net.predict_proba_incremental(waveform)
        with self.test_session() as sess:
            sess.run(tf.initialize_all_variables())
            sess.run(self.net.init_ops)
            # Prime the incremental generation with all samples
            # except the last one
            for x in data[:-1]:
                proba_fast_ = sess.run(
                    [proba_fast, self.net.push_ops],
                    feed_dict={waveform: x})

            # Get the last sample from the incremental generator
            proba_fast_ = sess.run(
                proba_fast,
                feed_dict={waveform: data[-1]})
            # Get the sample from the simple generator
            proba_ = sess.run(proba, feed_dict={waveform: data})
            self.assertAllClose(proba_, proba_fast_)
Ejemplo n.º 9
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = args.silence_threshold if args.silence_threshold > \
                                                      EPSILON else None
        gc_enabled = args.gc_channels is not None
        reader = AudioReader(
            args.data_dir,
            coord,
            sample_rate=wavenet_params['sample_rate'],
            gc_enabled=gc_enabled,
            receptive_field=WaveNetModel.calculate_receptive_field(
                wavenet_params["filter_width"], wavenet_params["dilations"],
                wavenet_params["scalar_input"],
                wavenet_params["initial_filter_width"]),
            sample_size=args.sample_size,
            silence_threshold=silence_threshold,
            normalize_peak=args.normalize_peak,
            queue_size=32 * max(args.num_gpus, 1))
        if gc_enabled:
            gc_id_batch = reader.dequeue_gc(args.batch_size)
        else:
            gc_id_batch = None

    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None

    if args.num_gpus <= 1:
        print("Falling back to single computation unit.")
        audio_batch = reader.dequeue(args.batch_size)
        net = make_model(args, wavenet_params, reader)
        loss = net.loss(
            input_batch=audio_batch,
            global_condition_batch=gc_id_batch,
            l2_regularization_strength=args.l2_regularization_strength)
        optimizer = optimizer_factory[args.optimizer](
            learning_rate=args.learning_rate, momentum=args.momentum)
        trainable = tf.trainable_variables()
        gradients = optimizer.compute_gradients(loss, var_list=trainable)
        for gradient, variable in gradients:
            if gradient is not None:
                tf.summary.scalar(variable.name + '/gradient',
                                  tf.norm(gradient))
        optim = optimizer.apply_gradients(gradients)
    else:
        print("Using {} GPUs for compuation.".format(args.num_gpus))
        with tf.device('/gpu:0'), tf.name_scope('tower_0'):
            optimizer = optimizer_factory[args.optimizer](
                learning_rate=args.learning_rate, momentum=args.momentum)
        losses = []
        gradients = []
        with tf.variable_scope(tf.get_variable_scope()) as scope:
            for i in range(args.num_gpus):
                with tf.device('/gpu:%d' % i), tf.name_scope('tower_%d' % i):
                    audio_batch = reader.dequeue(args.batch_size)
                    net = make_model(args, wavenet_params, reader, i)
                    loss = net.loss(input_batch=audio_batch,
                                    global_condition_batch=gc_id_batch,
                                    l2_regularization_strength=args.
                                    l2_regularization_strength)
                    trainable = tf.trainable_variables()
                    gradient = optimizer.compute_gradients(loss,
                                                           var_list=trainable)
                    losses.append(loss)
                    gradients.append(gradient)
                    scope.reuse_variables()

        with tf.device('/gpu:0'), tf.name_scope('tower_0'):
            loss = tf.reduce_mean(losses)
            tf.summary.scalar('mean_total_loss', loss)
            average_gradients = []
            for grouped_gradients in zip(*gradients):
                expanded_gradients = []
                for gradient, _ in grouped_gradients:
                    if gradient is not None:
                        expanded_gradients.append(tf.expand_dims(gradient, 0))

                # Since all GPUs share the same variable we can just the the one from gpu:0
                _, variable = grouped_gradients[0]
                if len(expanded_gradients) == 0:
                    print('No gradient for %s' % variable.name)
                    average_gradients.append((None, variable))
                    continue

                merged_gradients = tf.concat(expanded_gradients, 0)
                average_gradient = tf.reduce_mean(merged_gradients, 0)
                average_gradients.append((average_gradient, variable))

                tf.summary.scalar(variable.name + '/gradient',
                                  tf.norm(average_gradient))
            optim = optimizer.apply_gradients(average_gradients)

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()

    # Set up session
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True)
    # Workaround for avoiding allocating memory on all GPUs due to tensorflow#8021
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(),
                           max_to_keep=args.max_checkpoints)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)

    step = None
    last_saved_step = saved_global_step
    try:
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run([summaries, loss, optim],
                                                  options=run_options,
                                                  run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run([summaries, loss, optim])
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(
                step, loss_value, duration))

            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step is not None and step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
Ejemplo n.º 10
0
def main():
    args = get_arguments()

    with open(args.model_params, 'r') as f:
        model_params = json.load(f)

    with open(args.training_params, 'r') as f:
        train_params = json.load(f)

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print('Some arguments are wrong:')
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    receptive_field = WaveNetModel.calculate_receptive_field(
        model_params['filter_width'],
        model_params['dilations'],
        model_params['initial_filter_width'])
    # Save arguments and model params into file
    save_run_config(args, receptive_field, STARTED_DATESTRING, logdir)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Create data loader.
    with tf.name_scope('create_inputs'):
        reader = WavMidReader(data_dir=args.data_dir_train,
                              coord=coord,
                              audio_sample_rate=model_params['audio_sr'],
                              receptive_field=receptive_field,
                              velocity=args.velocity,
                              sample_size=args.sample_size,
                              queues_size=(10, 10*args.batch_size))
        data_batch = reader.dequeue(args.batch_size)

    # Create model.
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=model_params['dilations'],
        filter_width=model_params['filter_width'],
        residual_channels=model_params['residual_channels'],
        dilation_channels=model_params['dilation_channels'],
        skip_channels=model_params['skip_channels'],
        output_channels=model_params['output_channels'],
        use_biases=model_params['use_biases'],
        initial_filter_width=model_params['initial_filter_width'])

    input_data = tf.placeholder(dtype=tf.float32,
                                shape=(args.batch_size, None, 1))
    input_labels = tf.placeholder(dtype=tf.float32,
                                  shape=(args.batch_size, None,
                                         model_params['output_channels']))

    loss, probs = net.loss(input_data=input_data,
                           input_labels=input_labels,
                           pos_weight=train_params['pos_weight'],
                           l2_reg_str=train_params['l2_reg_str'])
    optimizer = optimizer_factory[args.optimizer](
                    learning_rate=train_params['learning_rate'],
                    momentum=train_params['momentum'])
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()
    histograms = tf.summary.merge_all(key=HKEY)

    # Separate summary ops for validation, since they are
    # calculated only once per evaluation cycle.
    with tf.name_scope('validation_summaries'):

        metric_summaries = metrics_empty_dict()
        metric_value = tf.placeholder(tf.float32)
        for name in metric_summaries.keys():
            metric_summaries[name] = tf.summary.scalar(name, metric_value)

        images_buffer = tf.placeholder(tf.string)
        images_batch = tf.stack(
            [tf.image.decode_png(images_buffer[0], channels=4),
             tf.image.decode_png(images_buffer[1], channels=4),
             tf.image.decode_png(images_buffer[2], channels=4)])
        images_summary = tf.summary.image('estim', images_batch)

        audio_data = tf.placeholder(tf.float32)
        audio_summary = tf.summary.audio('input', audio_data,
                                         model_params['audio_sr'])

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.global_variables_initializer()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(),
                           max_to_keep=args.max_checkpoints)

    # Trainer for keeping best validation-performing model
    # and optional early stopping.
    trainer = Trainer(sess, logdir, train_params['early_stop_limit'], 0.999)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print('Something went wrong while restoring checkpoint. '
              'Training will be terminated to avoid accidentally '
              'overwriting the previous model.')
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)


    step = None
    last_saved_step = saved_global_step
    try:
        for step in range(saved_global_step + 1, train_params['num_steps']):
            waveform, pianoroll = sess.run([data_batch[0], data_batch[1]])
            feed_dict = {input_data : waveform, input_labels : pianoroll}
            # Reload switches from file on each step
            with open(RUNTIME_SWITCHES, 'r') as f:
                switch = json.load(f)

            start_time = time.time()
            if switch['store_meta'] and step % switch['store_every'] == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run(
                    [summaries, loss, optim],
                    feed_dict=feed_dict,
                    options=run_options,
                    run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run([summaries, loss, optim],
                                                  feed_dict=feed_dict)
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'
                  .format(step, loss_value, duration))

            if step % switch['checkpoint_every'] == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

            # Evaluate model performance on validation data
            if step % switch['evaluate_every'] == 0:
                if switch['histograms']:
                    hist_summary = sess.run(histograms)
                    writer.add_summary(hist_summary, step)
                print('evaluating...')
                stats = 0, 0, 0, 0, 0, 0
                est = np.empty([0, model_params['output_channels']])
                ref = np.empty([0, model_params['output_channels']])

                b_data, b_labels, b_cntr = (
                    np.empty((0, args.sample_size + receptive_field - 1, 1)),
                    np.empty((0, model_params['output_channels'])),
                    args.batch_size)

                # if (batch_size * sample_size > valid_data) single_pass() again
                while est.size == 0: # and ref.size == 0 and sum(stats) == 0 ...

                    for data, labels in reader.single_pass(
                        sess, args.data_dir_valid):

                        # cumulate batch
                        if b_cntr > 1:
                            b_data, b_labels, decr = cumulateBatch(
                                data, labels, b_data, b_labels)
                            b_cntr -= decr
                            continue
                        elif args.batch_size > 1:
                            b_data, b_labels, decr = cumulateBatch(
                                data, labels, b_data, b_labels)
                            if not decr:
                                continue
                            data = b_data
                            labels = b_labels
                            # reset batch cumulation variables
                            b_data, b_labels, b_cntr = (
                                np.empty((
                                    0, args.sample_size + receptive_field - 1, 1
                                )),
                                np.empty((0, model_params['output_channels'])),
                                args.batch_size)

                        predictions = sess.run(
                            probs, feed_dict={input_data : data})
                        # Aggregate sums for metrics calculation
                        stats_chunk = calc_stats(
                            predictions, labels, args.threshold)
                        stats = tuple([sum(x) for x in zip(stats, stats_chunk)])
                        est = np.append(est, predictions, axis=0)
                        ref = np.append(ref, labels, axis=0)

                metrics = calc_metrics(None, None, None, stats=stats)
                write_metrics(metrics, metric_summaries, metric_value,
                              writer, step, sess)
                trainer.check(metrics['f1_measure'])

                # Render evaluation results
                if switch['log_image'] or switch['log_sound']:
                    sub_fac = int(model_params['audio_sr']/switch['midi_sr'])
                    est = roll_subsample(est.T, sub_fac)
                    ref = roll_subsample(ref.T, sub_fac)
                if switch['log_image']:
                    write_images(est, ref, switch['midi_sr'], args.threshold,
                                 (8, 6), images_summary, images_buffer,
                                 writer, step, sess)
                if switch['log_sound']:
                    write_audio(est, ref, switch['midi_sr'],
                                model_params['audio_sr'], 0.007,
                                audio_summary, audio_data,
                                writer, step, sess)

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
        flush_n_close(writer, sess)
Ejemplo n.º 11
0
    def testGenerateSimple(self):
        # Reader config
        with open(TEST_DATA + "/config.json") as json_file:
            self.reader_config = json.load(json_file)

        # Initialize the reader
        receptive_field_size = WaveNetModel.calculate_receptive_field(2, LAYERS, False, 8)

        self.reader = CsvReader(
            [TEST_DATA + "/test.dat", TEST_DATA + "/test.emo", TEST_DATA + "/test.pho"],
            batch_size=1,
            receptive_field=receptive_field_size,
            sample_size=SAMPLE_SIZE,
            config=self.reader_config
        )

        # WaveNet model
        self.net = WaveNetModel(batch_size=1,
                                dilations=LAYERS,
                                filter_width=2,
                                residual_channels=8,
                                dilation_channels=8,
                                skip_channels=8,
                                quantization_channels=2,
                                use_biases=True,
                                scalar_input=False,
                                initial_filter_width=8,
                                histograms=False,
                                global_channels=GC_CHANNELS,
                                local_channels=LC_CHANNELS)

        loss = self.net.loss(input_batch=self.reader.data_batch,
                             global_condition=self.reader.gc_batch,
                             local_condition=self.reader.lc_batch,
                             l2_regularization_strength=L2)

        optimizer = optimizer_factory['adam'](learning_rate=0.003, momentum=0.9)
        trainable = tf.trainable_variables()
        train_op = optimizer.minimize(loss, var_list=trainable)

        samples = tf.placeholder(tf.float32, shape=(receptive_field_size, self.reader.data_dim), name="samples")
        gc = tf.placeholder(tf.int32, shape=(receptive_field_size), name="gc")
        lc = tf.placeholder(tf.int32, shape=(receptive_field_size), name="lc")

        gc = tf.one_hot(gc, GC_CHANNELS)
        lc = tf.one_hot(lc, LC_CHANNELS)

        predict = self.net.predict_proba(samples, gc, lc)

        '''does nothing'''
        with self.test_session() as session:
            session.run([
                tf.local_variables_initializer(),
                tf.global_variables_initializer(),
                tf.tables_initializer(),
            ])
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=session, coord=coord)

            for ITER in range(1):

                for i in range(1000):
                    _, loss_val = session.run([train_op, loss])
                    print("step %d loss %.4f" % (i, loss_val), end='\r')
                    sys.stdout.flush()
                print()

                data_samples = np.random.random((receptive_field_size, self.reader.data_dim))
                gc_samples = np.zeros((receptive_field_size))
                lc_samples = np.zeros((receptive_field_size))

                output = []

                for EMO in range(3):
                    for PHO in range(3):
                        for _ in range(100):
                            prediction = session.run(predict, feed_dict={'samples:0': data_samples, 'gc:0': gc_samples, 'lc:0': lc_samples})
                            data_samples = data_samples[1:, :]
                            data_samples = np.append(data_samples, prediction, axis=0)

                            gc_samples = gc_samples[1:]
                            gc_samples = np.append(gc_samples, [EMO], axis=0)
                            lc_samples = lc_samples[1:]
                            lc_samples = np.append(lc_samples, [PHO], axis=0)

                            output.append(prediction[0])

                output = np.array(output)
                print("ITER %d" % ITER)
                plt.imsave("./test/SINE_test_%d.png" % ITER, np.kron(output[:, :], np.ones([1, 500])), vmin=0.0, vmax=1.0)
Ejemplo n.º 12
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=args.gc_cardinality)

    gi_sampler = get_generator_input_sampler()

    # White noise generator params
    white_mean = 0
    white_sigma = 1
    white_length = 20234

    white_noise = gi_sampler(white_mean, white_sigma, white_length)

    loss = net.loss(input_batch=tf.convert_to_tensor(white_noise,
                                                     dtype=np.float32),
                    name='generator')

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples, args.gc_id)
    else:
        next_sample = net.predict_proba(samples, args.gc_id)

    if args.fast_generation:
        sess.run(tf.global_variables_initializer())
        sess.run(net.init_ops)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    '''
    # Silence with a single random sample at the end.
    waveform = [quantization_channels / 2] * (net.receptive_field - 1)
    waveform.append(np.random.randint(quantization_channels))
    '''
    waveform = [0]

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]

        # Scale prediction distribution using temperature.
        np.seterr(divide='ignore')
        scaled_prediction = np.log(prediction) / args.temperature
        scaled_prediction = (scaled_prediction -
                             np.logaddexp.reduce(scaled_prediction))
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after
        # scaling.
        if args.temperature == 1.0:
            np.testing.assert_allclose(
                prediction,
                scaled_prediction,
                atol=1e-5,
                err_msg='Prediction scaling at temperature=1.0 '
                'is not working as intended.')

        sample = np.random.choice(np.arange(quantization_channels),
                                  p=scaled_prediction)
        waveform.append(sample)

    # Introduce a newline to clear the carriage return from the progress.
    print()
    del waveform[0]
    print(waveform)
    print()

    print('Finished generating. The result can be viewed in TensorBoard.')
def main(checkpoint=None):

    title_BOOL = True
    title = ""

    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'])

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples)
    else:
        next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    powr = int((len(wavenet_params['dilations']) / 2) - 1)
    md = args.checkpoint.split(
        "-"
    )[-1:]  #map(str.lstrip("[").rstrip("]").strip(",")  , args.checkpoint.split("-")[-1:])
    if checkpoint == None:
        intro = """\n_______________________________________________________________________________________________________________\n\nDIR: {}\tMODEL: {}\t\tLOSS: {}\n
dilations={}\t        filter_width={}\t                residual_channels={}
dilation_channels={}\tquantization_channels={}\tskip_channels={}\n_______________________________________________________________________________________________________________\n\n""".format(
            args.checkpoint.split("/")[-2], md, args.loss, "2^" + str(powr),
            wavenet_params['filter_width'],
            wavenet_params['residual_channels'],
            wavenet_params['dilation_channels'],
            wavenet_params['quantization_channels'],
            wavenet_params['skip_channels'])
        print(intro)
        saver.restore(sess, args.checkpoint)
    else:
        print('Restoring model from PARAMETER {}'.format(checkpoint))
        saver.restore(sess, args.checkpoint)

    decode = samples

    quantization_channels = wavenet_params['quantization_channels']
    waveform = [32.]

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]
        sample = np.random.choice(np.arange(quantization_channels),
                                  p=prediction)
        waveform.append(sample)

        # CAPITALIZE TITLE
        if title_BOOL:
            # SHOW character by character in terminal
            sys.stdout.write(chr(sample).capitalize())
            title += chr(sample).capitalize()
            #check for newline
            if sample == 10:
                title_BOOL = False
                sys.stdout.write("\n\n")
        else:
            # SHOW character by character in terminal
            sys.stdout.write(chr(sample))

        if args.text_out_path == None:
            args.text_out_path = "GENERATED/{}_DIR-{}_Model-{}_Loss-{}_Chars-{}.txt".format(
                datetime.strftime(datetime.now(), '%Y-%m-%d_%H:%M'),
                args.checkpoint.split("/")[-2],
                args.checkpoint.split("-")[-1], args.loss, args.samples)

        # If we have partial writing, save the result so far.
        if (args.text_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_text(out, args.text_out_path, intro)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as a wav file.
    if args.text_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_text(out, args.text_out_path, intro)
Ejemplo n.º 14
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
        histograms=False,
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=args.gc_cardinality,
        MFSC_channels=wavenet_params["MFSC_channels"],
        AP_channels=wavenet_params["AP_channels"],
        F0_channels=wavenet_params["F0_channels"],
        phone_channels=wavenet_params["phones_channels"],
        phone_pos_channels=wavenet_params["phone_pos_channels"])

    samples = tf.placeholder(tf.float32)
    lc = tf.placeholder(tf.float32)

    AP_channels = wavenet_params["AP_channels"]
    MFSC_channels = wavenet_params["MFSC_channels"]

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(
            samples, args.gc_id)  ########### understand shape of next_sample
    else:
        outputs = net.predict_proba(samples, AP_channels, lc, args.gc_id)
        outputs = tf.reshape(outputs, [1, AP_channels])

    if args.fast_generation:
        sess.run(tf.global_variables_initializer())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    if args.wav_seed:
        # seed = create_seed(args.wav_seed,
        #                    wavenet_params['sample_rate'],
        #                    quantization_channels,
        #                    net.receptive_field)
        # waveform = sess.run(seed).tolist()
        pass
    else:

        # Silence with a single random sample at the end.
        waveform = np.zeros(
            (net.receptive_field - 1, AP_channels + MFSC_channels))
        waveform = np.append(waveform,
                             np.random.randn(1, AP_channels + MFSC_channels),
                             axis=0)

        lc_array, mfsc_array = load_lc(
            clip_id=args.clip_id,
            scramble=args.scramble)  # clip_id:[003, 004, 007, 010, 012 ...]
        lc_array = np.pad(lc_array, ((net.receptive_field, 0), (0, 0)),
                          'constant',
                          constant_values=((0, 0), (0, 0)))
        mfsc_array = np.pad(mfsc_array, ((net.receptive_field, 0), (0, 0)),
                            'constant',
                            constant_values=((0, 0), (0, 0)))

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(
            net.push_ops
        )  ################# understand net.push_ops, understand shape of outputs

        print('Priming generation...')
        for i, x in enumerate(waveform[-net.receptive_field:-1]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(args.frames):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:

            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:, :]
            else:
                window = waveform

        # Run the WaveNet to predict the next sample.
        window = window.reshape(1, window.shape[-2], window.shape[-1])

        prediction = sess.run(outputs,
                              feed_dict={
                                  samples:
                                  window,
                                  lc:
                                  lc_array[step + 1:step + 1 +
                                           net.receptive_field, :].reshape(
                                               1, net.receptive_field, -1)
                              })

        prediction = np.concatenate(
            (prediction, mfsc_array[step + net.receptive_field].reshape(1,
                                                                        -1)),
            axis=-1)
        waveform = np.append(waveform, prediction, axis=0)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Frame {:3<d}/{:3<d}'.format(step + 1, args.frames),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            np.save(args.wav_out_path, waveform)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    if args.wav_out_path:

        np.save(
            args.wav_out_path,
            waveform[:, :4])  # only take the first four columns, which are aps

    print('Finished generating. The result was saved as .npy file.')
Ejemplo n.º 15
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    # open wavenet_params file
    if args.wavenet_params.startswith('wavenet_params/'):
        with open(args.wavenet_params, 'r') as config_file:
            wavenet_params = json.load(config_file)
    elif args.wavenet_params.startswith('wavenet_params'):
        with open('wavenet_params/' + args.wavenet_params, 'r') as config_file:
            wavenet_params = json.load(config_file)
    else:
        with open('wavenet_params/wavenet_params_' + args.wavenet_params,
                  'r') as config_file:
            wavenet_params = json.load(config_file)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = args.silence_threshold if args.silence_threshold > \
                                                      EPSILON else None
        gc_enabled = args.gc_channels is not None
        reader = AudioReader(
            args.data_dir,
            coord,
            sample_rate=wavenet_params['sample_rate'],
            gc_enabled=gc_enabled,
            receptive_field=WaveNetModel.calculate_receptive_field(
                wavenet_params["filter_width"], wavenet_params["dilations"],
                wavenet_params["scalar_input"],
                wavenet_params["initial_filter_width"]),
            sample_size=args.sample_size,
            silence_threshold=silence_threshold)
        audio_batch = reader.dequeue(args.batch_size)
        if gc_enabled:
            gc_id_batch = reader.dequeue_gc(args.batch_size)
        else:
            gc_id_batch = None

    # Create network.
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
        histograms=args.histograms,
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=reader.gc_category_cardinality)

    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    loss = net.loss(input_batch=audio_batch,
                    global_condition_batch=gc_id_batch,
                    l2_regularization_strength=args.l2_regularization_strength)
    optimizer = optimizer_factory[args.optimizer](
        learning_rate=args.learning_rate, momentum=args.momentum)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.global_variables_initializer()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(),
                           max_to_keep=args.max_checkpoints)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)

    step = None
    last_saved_step = saved_global_step
    try:
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run([summaries, loss, optim],
                                                  options=run_options,
                                                  run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run([summaries, loss, optim])
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(
                step, loss_value, duration))

            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
Ejemplo n.º 16
0
def main():
    # Get default and command line parameters
    args = get_arguments()

    # Get logdir
    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    logdir_init = directories['logdir_init']
    restore_from = directories['restore_from']
    restore_from_init = directories['restore_from_init']

    # Lambda for white noise sampler
    gi_sampler = get_generator_input_sampler()

    # Some TensorFlow setup variables
    sess = tf.Session()
    coord = tf.train.Coordinator()

    # White noise generation and verification

    # White noise generator params
    white_mean = 0
    white_sigma = 1
    white_length = 20234

    white_noise = gi_sampler(white_mean, white_sigma, white_length)
    if args.view_initial_white:
        plt.plot(white_noise)
        plt.ylabel('Amplitude')
        plt.xlabel('Time')
        plt.show()

    # Load parameters from wavenet params json file
    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)  

    # Initialize generator WaveNet
    G = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"],
        initial_filter_width=wavenet_params["initial_filter_width"])

    # Calculate loss for white noise input
    # loss = G.loss(input_batch=tf.convert_to_tensor(white_noise, dtype=np.float32), name='generator')
    result = G.loss(input_batch=tf.convert_to_tensor(white_noise, dtype=np.float32), name='generator')
    loss = result['loss']
    output = result['output']
    optimizer = optimizer_factory[args.optimizer](
                    learning_rate=args.learning_rate,
                    momentum=args.momentum)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=args.max_checkpoints)

    init = tf.global_variables_initializer()
    sess.run(init)

    try:
        init_step = load(saver, sess, restore_from_init)
        if init_step is None:
            init_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    if init_step == -1:
        print('--------- Begin dummy weight setup ---------')
        start_time = time.time()
        loss_value, _, output_value = sess.run([loss, optim, output])
        duration = time.time() - start_time
        print('loss = {:.3f}, ({:.3f} sec)'.format(loss_value, duration))

    else: 
        print('---------- Loading initial weight ----------')
        print('... Done')
Ejemplo n.º 17
0
class TestGeneration(tf.test.TestCase):

    def setUp(self):
        self.net = WaveNetModel(batch_size=1,
                                dilations=[1, 2, 4, 8, 16, 32, 64, 128, 256],
                                filter_width=2,
                                residual_channels=16,
                                dilation_channels=16,
                                quantization_channels=128,
                                skip_channels=32)

    def testGenerateSimple(self):
        '''Generate a few samples using the naive method and
        perform sanity checks on the output.'''
        waveform = tf.placeholder(tf.int32)
        np.random.seed(0)
        data = np.random.randint(128, size=1000)
        proba = self.net.predict_proba(waveform)

        with self.test_session() as sess:
            sess.run(tf.initialize_all_variables())
            proba = sess.run(proba, feed_dict={waveform: data})

        self.assertAllEqual(proba.shape, [128])
        self.assertTrue(np.all((proba >= 0) & (proba <= (128 - 1))))

    def testGenerateFast(self):
        '''Generate a few samples using the fast method and
        perform sanity checks on the output.'''
        waveform = tf.placeholder(tf.int32)
        np.random.seed(0)
        data = np.random.randint(128)
        proba = self.net.predict_proba_incremental(waveform)

        with self.test_session() as sess:
            sess.run(tf.initialize_all_variables())
            sess.run(self.net.init_ops)
            proba = sess.run(proba, feed_dict={waveform: data})

        self.assertAllEqual(proba.shape, [128])
        self.assertTrue(np.all((proba >= 0) & (proba <= (128 - 1))))

    def testCompareSimpleFast(self):
        waveform = tf.placeholder(tf.int32)
        np.random.seed(0)
        data = np.random.randint(128, size=1000)
        proba = self.net.predict_proba(waveform)
        proba_fast = self.net.predict_proba_incremental(waveform)
        with self.test_session() as sess:
            sess.run(tf.initialize_all_variables())
            sess.run(self.net.init_ops)
            # Prime the incremental generation with all samples
            # except the last one
            for x in data[:-1]:
                proba_fast_ = sess.run(
                    [proba_fast, self.net.push_ops],
                    feed_dict={waveform: x})

            # Get the last sample from the incremental generator
            proba_fast_ = sess.run(
                proba_fast,
                feed_dict={waveform: data[-1]})
            # Get the sample from the simple generator
            proba_ = sess.run(proba, feed_dict={waveform: data})
            self.assertAllClose(proba_, proba_fast_)
Ejemplo n.º 18
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'])

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples)
    else:
        next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed,
                           wavenet_params['sample_rate'],
                           quantization_channels)
        waveform = sess.run(seed).tolist()
    else:
        waveform = np.random.randint(quantization_channels, size=(1,)).tolist()

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[:-(args.window + 1)]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]
        sample = np.random.choice(
            np.arange(quantization_channels), p=prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every and
                (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(logdir)
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
Ejemplo n.º 19
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Read TFRecords and create network.
    tf.reset_default_graph()

    data_train = get_tfrecord(name='train',
                              sample_size=args.sample_size,
                              batch_size=args.batch_size,
                              seed=None,
                              repeat=None,
                              data_path=args.data_path)
    data_test = get_tfrecord(name='test',
                             sample_size=args.sample_size,
                             batch_size=args.batch_size,
                             seed=None,
                             repeat=None,
                             data_path=args.data_path)

    train_itr = data_train.make_one_shot_iterator()
    test_itr = data_test.make_one_shot_iterator()

    train_batch, train_label = train_itr.get_next()
    test_batch, test_label = test_itr.get_next()

    train_batch = tf.reshape(train_batch, [-1, train_batch.shape[1], 1])
    test_batch = tf.reshape(test_batch, [-1, test_batch.shape[1], 1])

    # Create network.
    net = WaveNetModel(sample_size=args.sample_size,
                       batch_size=args.batch_size,
                       dilations=wavenet_params["dilations"],
                       filter_width=wavenet_params["filter_width"],
                       residual_channels=wavenet_params["residual_channels"],
                       dilation_channels=wavenet_params["dilation_channels"],
                       skip_channels=wavenet_params["skip_channels"],
                       histograms=args.histograms)

    train_loss = net.loss(train_batch, train_label)
    test_loss = net.loss(test_batch, test_label)

    # Optimizer
    # Temporarily set to momentum optimizer
    optimizer = tf.train.MomentumOptimizer(learning_rate=args.learning_rate,
                                           momentum=args.momentum)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(train_loss, var_list=trainable)

    # Accuracy of test data
    pred_test = net.predict_proba(test_batch, audio_only=True)
    equals = tf.equal(tf.squeeze(test_label), tf.round(pred_test))
    acc = tf.reduce_mean(tf.cast(equals, tf.float32))

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.global_variables_initializer()
    init2 = tf.local_variables_initializer()
    sess.run([init, init2])

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(),
                           max_to_keep=args.max_checkpoints)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1
    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    step = None
    last_saved_step = saved_global_step
    try:
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            if step == saved_global_step + 1:
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary_, train_loss_, test_loss_, acc_, _ = sess.run(
                    [summaries, train_loss, test_loss, acc, optim],
                    options=run_options,
                    run_metadata=run_metadata)
                writer.add_summary(summary_, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary_, train_loss_, test_loss_, acc_, _ = sess.run(
                    [summaries, train_loss, test_loss, acc, optim],
                    options=run_options,
                    run_metadata=run_metadata)
                writer.add_summary(summary_, step)

            duration = time.time() - start_time
            print("step {:d}:  trainloss = {:.3f}, "
                  "testloss = {:.3f}, acc = {:.3f}, ({:.3f} sec/step)".format(
                      step, train_loss_, test_loss_, acc_, duration))

            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step and step > last_saved_step:
            save(saver, sess, logdir, step)
        elif not step:
            print("No training performed during session.")
        else:
            pass
Ejemplo n.º 20
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'])

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples)
    else:
        next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    # decode = samples
    # Creating a copy of samples
    decode = tf.identity(samples)

    quantization_channels = wavenet_params['quantization_channels']

    with open(fname) as f:
        content = f.readlines()
    waveform = [float(x.replace('\r', '').strip()) for x in content]

    dataframe = pd.read_csv(fname, engine='python')
    dataset = dataframe.values

    # split into train and test sets
    train_size = int(len(dataset) * SPLIT)
    test_size = len(dataset) - train_size

    testX, testY = create_test_dataset(dataset, train_size)

    print(str(len(testX)) + " steps to go...")

    testP = []

    # train, test = waveform[0:train_size,:], waveform[train_size:len(dataset),:]
    # waveform = [144.048970239,143.889691754,143.68922135,143.644718903,143.698498762,143.710396703,143.756327831,143.843187531,143.975287002,143.811912129]

    last_sample_timestamp = datetime.now()
    for step in range(len(testX)):
        # for step in range(2):
        # if len(waveform) > args.window:
        #     window = waveform[-args.window:]
        # else:

        window = testX[step]
        outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]
        # sample = np.random.choice(np.arange(quantization_channels), p=prediction)
        sample = np.arange(quantization_channels)[np.argmax(prediction)]
        # waveform.append(sample)
        testP.append(sample)
        print(step, sample)

    # Introduce a newline to clear the carriage return from the progress.
    testPredict = np.array(testP).reshape((-1, 1))

    # Save the result as a wav file.
    # if args.text_out_path:
    #     out = sess.run(decode, feed_dict={samples: waveform})
    #     write_text(out, args.text_out_path)
    testScore = math.sqrt(mean_squared_error(testY, testPredict[:, 0]))
    print('Test Score: %.2f RMSE' % (testScore))

    testPredictPlot = np.empty_like(dataset, dtype=float)
    testPredictPlot[:, :] = np.nan
    testPredictPlot[train_size + 1:len(dataset) - 1, :] = testPredict

    # plot baseline and predictions
    plt.plot(dataset)
    plt.plot(testPredictPlot)
    plt.show()

    print('Finished generating.')
Ejemplo n.º 21
0
def main():
    def _str_to_bool(s):
        """Convert string to bool (in argparse context)."""
        if s.lower() not in ['true', 'false']:
            raise ValueError(
                'Argument needs to be a boolean, got {}'.format(s))
        return {'true': True, 'false': False}[s.lower()]

    parser = argparse.ArgumentParser(description='WaveNet example network')

    DATA_DIRECTORY = '/home/jeon/Desktop/Speech_project/Tacotron-Wavenet-Vocoder/data/TY,/home/jeon/Desktop/Speech_project/Tacotron-Wavenet-Vocoder/data/kss'
    parser.add_argument('--data_dir',
                        type=str,
                        default=DATA_DIRECTORY,
                        help='The directory containing the VCTK corpus.')

    LOGDIR = '/home/jeon/Desktop/Speech_project/Tacotron-Wavenet-Vocoder/logdir-wavenet/train/2019-01-09T16-27-25'
    #LOGDIR = './/logdir-wavenet//train//2018-12-21T22-58-10'

    parser.add_argument(
        '--logdir',
        type=str,
        default=LOGDIR,
        help=
        'Directory in which to store the logging information for TensorBoard. If the model already exists, it will restore the state and will continue training. Cannot use with --logdir_root and --restore_from.'
    )

    parser.add_argument(
        '--logdir_root',
        type=str,
        default=None,
        help=
        'Root directory to place the logging output and generated model. These are stored under the dated subdirectory of --logdir_root. Cannot use with --logdir.'
    )
    parser.add_argument(
        '--restore_from',
        type=str,
        default=None,
        help=
        'Directory in which to restore the model from. This creates the new model under the dated directory in --logdir_root. Cannot use with --logdir.'
    )

    CHECKPOINT_EVERY = 5000  # checkpoint 저장 주기
    parser.add_argument(
        '--checkpoint_every',
        type=int,
        default=CHECKPOINT_EVERY,
        help='How many steps to save each checkpoint after. Default: ' +
        str(CHECKPOINT_EVERY) + '.')

    config = parser.parse_args()  # command 창에서 입력받을 수 있는 조건
    config.data_dir = config.data_dir.split(",")

    try:
        directories = validate_directories(config, hparams)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    log_path = os.path.join(logdir, 'train.log')
    infolog.init(log_path, logdir)

    global_step = tf.Variable(0, name='global_step', trainable=False)

    # Create coordinator.
    coord = tf.train.Coordinator()
    num_speakers = len(config.data_dir)
    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = hparams.silence_threshold if hparams.silence_threshold > EPSILON else None
        gc_enable = num_speakers > 1

        # AudioReader에서 wav 파일을 잘라 input값을 만든다. receptive_field길이만큼을 앞부분에 pad하거나 앞조각에서 가져온다. (receptive_field+ sample_size)크기로 자른다.
        reader = DataFeederWavenet(
            coord,
            config.data_dir,
            batch_size=hparams.wavenet_batch_size,
            receptive_field=WaveNetModel.calculate_receptive_field(
                hparams.filter_width, hparams.dilations, hparams.scalar_input,
                hparams.initial_filter_width),
            gc_enable=gc_enable)
        if gc_enable:
            audio_batch, lc_batch, gc_id_batch = reader.inputs_wav, reader.local_condition, reader.speaker_id
        else:
            audio_batch, lc_batch = reader.inputs_wav, reader.local_condition

    # Create network.
    net = WaveNetModel(
        batch_size=hparams.wavenet_batch_size,
        dilations=hparams.dilations,
        filter_width=hparams.filter_width,  #2
        residual_channels=hparams.residual_channels,
        dilation_channels=hparams.dilation_channels,
        quantization_channels=hparams.quantization_channels,
        out_channels=hparams.out_channels,
        skip_channels=hparams.skip_channels,
        use_biases=hparams.use_biases,  #  True
        scalar_input=hparams.scalar_input,
        initial_filter_width=hparams.initial_filter_width,  #32
        global_condition_channels=hparams.gc_channels,
        global_condition_cardinality=num_speakers,
        local_condition_channels=hparams.num_mels,
        upsample_factor=hparams.upsample_factor,
        train_mode=True)

    if hparams.l2_regularization_strength == 0:
        hparams.l2_regularization_strength = None

    net.add_loss(input_batch=audio_batch,
                 local_condition=lc_batch,
                 global_condition_batch=gc_id_batch,
                 l2_regularization_strength=hparams.l2_regularization_strength)
    net.add_optimizer(hparams, global_step)

    run_metadata = tf.RunMetadata()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False)
                      )  # log_device_placement=False --> cpu/gpu 자동 배치.
    init = tf.global_variables_initializer()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(
        var_list=tf.global_variables(),
        max_to_keep=hparams.max_checkpoints)  # 최대 checkpoint 저장 갯수 지정

    try:
        start_step = load(saver, sess, restore_from)  # checkpoint load
        if is_overwritten_training or start_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            zero_step_assign = tf.assign(global_step, 0)
            sess.run(zero_step_assign)

    except:
        print(
            "Something went wrong while restoring checkpoint. We will terminate training to avoid accidentally overwriting the previous model."
        )
        raise

    ###########

    start_step = sess.run(global_step)
    last_saved_step = start_step
    try:
        reader.start_in_session(sess, start_step)
        while not coord.should_stop():

            start_time = time.time()
            if hparams.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                log('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                step, loss_value, _ = sess.run(
                    [global_step, net.loss, net.optimize],
                    options=run_options,
                    run_metadata=run_metadata)

                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                step, loss_value, _ = sess.run(
                    [global_step, net.loss, net.optimize])

            duration = time.time() - start_time
            log('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(
                step, loss_value, duration))

            if step % config.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

            if step >= hparams.num_steps:
                # error message가 나오지만, 여기서 멈춘 것은 맞다.
                raise Exception('End xxx~~~yyy')

    except Exception as e:
        print('finally')
        #if step > last_saved_step:
        #    save(saver, sess, logdir, step)

        coord.request_stop(e)
Ejemplo n.º 22
0
class TestNet(tf.test.TestCase):
    def setUp(self):
        print('TestNet setup.')
        sys.stdout.flush()

        self.optimizer_type = 'sgd'
        self.learning_rate = 0.02
        self.generate = False
        self.momentum = MOMENTUM
        self.global_conditioning = False
        self.local_conditioning = False
        self.train_iters = TRAIN_ITERATIONS
        self.net = WaveNetModel(
            batch_size=1,
            dilations=[1, 2, 4, 8, 16, 32, 64, 1, 2, 4, 8, 16, 32, 64],
            filter_width=2,
            residual_channels=32,
            dilation_channels=32,
            quantization_channels=QUANTIZATION_CHANNELS,
            skip_channels=32,
            global_condition_channels=None,
            global_condition_cardinality=None)

    def _save_net(sess):
        saver = tf.train.Saver(var_list=tf.trainable_variables())
        saver.save(sess, os.path.join('tmp', 'test.ckpt'))

    # Train a net on a short clip of 3 sine waves superimposed
    # (an e-flat chord).
    #
    # Presumably it can overfit to such a simple signal. This test serves
    # as a smoke test where we just check that it runs end-to-end during
    # training, and learns this waveform.

    def testEndToEndTraining(self):
        def shuffle_row(audio, gc, lc):
            from copy import deepcopy
            for i in range(10):
                index1 = random.randint(0, audio.shape[0] - 1)
                index2 = random.randint(0, audio.shape[0] - 1)
                audio1 = deepcopy(audio[index1, :])
                audio2 = deepcopy(audio[index2, :])
                audio[index1, :] = audio2
                audio[index2, :] = audio1
                lc1 = deepcopy(lc[index1, :])
                lc2 = deepcopy(lc[index2, :])
                lc[index1, :] = lc2
                lc[index2, :] = lc1
                gc1 = deepcopy(gc[index1])
                gc2 = deepcopy(gc[index2])
                gc[index1] = gc2
                gc[index2] = gc1
            return audio, gc, lc

        def CreateTrainingFeedDict(audio, gc, lc, audio_placeholder,
                                   gc_placeholder, lc_placeholder, i):
            speaker_index = 0

            i = i % int(audio.shape[0] / self.net.batch_size)
            if i == 0:
                audio, gc, lc = shuffle_row(audio, gc, lc)
            _audio = audio[i * self.net.batch_size:(i + 1) *
                           self.net.batch_size]
            _gc = gc[i * self.net.batch_size:(i + 1) * self.net.batch_size]
            _lc = lc[i * self.net.batch_size:(i + 1) * self.net.batch_size]
            print("training audio length")
            print(_audio.shape)
            exit()

            if gc is None:
                # No global conditioning.
                feed_dict = {audio_placeholder: _audio}
            elif self.global_conditioning and not self.local_conditioning:
                feed_dict = {audio_placeholder: _audio, gc_placeholder: _gc}
            elif not self.global_conditioning and self.local_conditioning:
                feed_dict = {audio_placeholder: _audio, lc_placeholder: _lc}
            elif self.global_conditioning and self.local_conditioning:
                feed_dict = {
                    audio_placeholder: _audio,
                    gc_placeholder: _gc,
                    lc_placeholder: _lc
                }
            return feed_dict, speaker_index, audio, gc, lc

        np.random.seed(42)

        receptive_field = self.net.receptive_field
        audio, gc, lc, duration_lists = make_sine_waves(
            self.global_conditioning, self.local_conditioning, True)
        waveform_size = audio.shape[1]

        print("shape check 1")
        print(audio.shape)
        print(gc.shape)
        print(lc.shape)
        # Pad with 0s (silence) times size of the receptive field minus one,
        # because the first sample of the training data is 0 and if the network
        # learns to predict silence based on silence, it will generate only
        # silence.
        # if self.global_conditioning:
        #     # print(audio.shape)
        #     audio = np.pad(audio, ((0, 0), (self.net.receptive_field - 1, 0)), 'constant')
        #     # lc = np.pad(lc, ((0,0), (self.net.receptive_field - 1, 0)), 'maximum')
        #     # to set lc=0 for the initial silence
        #     lc = np.pad(lc, ((0, 0), (self.net.receptive_field - 1, 0)), 'constant')
        #     # print(audio.shape)
        #     # exit()
        # else:
        #     # print(audio.shape)
        #     audio = np.pad(audio, (self.net.receptive_field - 1, 0),
        #                    'constant')
        # print(audio.shape)
        # exit()

        audio_placeholder = tf.placeholder(dtype=tf.float32)
        gc_placeholder = tf.placeholder(dtype=tf.int32)  \
            if self.global_conditioning else None
        lc_placeholder = tf.placeholder(dtype=tf.int32) \
            if self.local_conditioning else None

        loss = self.net.loss(input_batch=audio_placeholder,
                             global_condition_batch=gc_placeholder,
                             local_condition_batch=lc_placeholder)
        self.net.batch_size = 1
        validation = self.net.loss(input_batch=audio_placeholder,
                                   global_condition_batch=gc_placeholder,
                                   local_condition_batch=lc_placeholder)
        self.net.batch_size = 3
        optimizer = optimizer_factory[self.optimizer_type](
            learning_rate=self.learning_rate, momentum=self.momentum)
        trainable = tf.trainable_variables()
        optim = optimizer.minimize(loss, var_list=trainable)
        init = tf.global_variables_initializer()

        generated_waveform = None
        max_allowed_loss = 0.1
        loss_val = max_allowed_loss
        initial_loss = None
        operations = [loss, optim]
        with self.test_session() as sess:
            # feed_dict, speaker_index, audio, gc, lc  = CreateTrainingFeedDict(
            #     audio, gc, lc, audio_placeholder, gc_placeholder, lc_placeholder, 0)
            sess.run(init)
            # print("shape check 2")
            # print(audio.shape)
            # print(lc.shape)
            # print(gc.shape)
            # print(feed_dict[audio_placeholder].shape)
            # print(feed_dict[gc_placeholder].shape)
            # print(feed_dict[lc_placeholder].shape)
            # initial_loss = sess.run(loss, feed_dict=feed_dict)

            _gc = np.zeros(3)
            """validation data"""
            lc_1 = np.full(900, 1)
            lc_2 = np.full(900, 2)
            lc_3 = np.full(900, 3)
            val_lc = np.zeros((1, 900))
            val_lc[0, :300] = lc_1[:300]
            val_lc[0, 300:600] = lc_2[300:600]
            val_lc[0, 600:] = lc_3[600:900]
            val_lc = np.pad(val_lc, ((0, 0), (receptive_field - 1, 0)),
                            'constant')
            sample_period = 1.0 / SAMPLE_RATE_HZ
            times = np.arange(0.0, SAMPLE_DURATION, sample_period)
            note1 = 0.6 * np.sin(times * 2.0 * np.pi * F1)
            note2 = 0.5 * np.sin(times * 2.0 * np.pi * F2)
            note3 = 0.4 * np.sin(times * 2.0 * np.pi * F3)
            val_audio = np.zeros((1, 900))
            val_audio[0, :300] = note1[:300]
            val_audio[0, 300:600] = note2[300:600]
            val_audio[0, 600:] = note3[600:900]
            val_audio = np.pad(val_audio, ((0, 0), (receptive_field - 1, 0)),
                               'constant')
            val_list = []
            error_list = []

            for i in range(self.train_iters):
                # for lc_index in range(3):
                #     current_audio = audio[:, int(lc_index * (waveform_size / 3)): int(
                #         (lc_index + 1) * (waveform_size / 3) + self.net.receptive_field)]
                #     # print(current_audio.shape)
                #     current_lc = lc[:, int(lc_index * (waveform_size / 3)): int(
                #         (lc_index + 1) * (waveform_size / 3) + self.net.receptive_field)]
                #
                #     [results] = sess.run([operations],
                #                          feed_dict={audio_placeholder: current_audio, lc_placeholder: current_lc,
                #                                     gc_placeholder: gc})
                self.net.batch_size = 3
                a = 0
                current_audio = audio[i % 10]
                current_lc = lc[i % 10]
                duration_list = duration_lists[i % 10]
                start_time = 0
                error_total = 0
                for duration in duration_list:
                    _audio = current_audio[:, start_time:duration +
                                           receptive_field]
                    _lc = current_lc[:, start_time:duration + receptive_field]
                    start_time = duration

                    [results] = sess.run(
                        [operations],
                        feed_dict={
                            audio_placeholder: _audio,
                            lc_placeholder: _lc,
                            gc_placeholder: _gc
                        })
                    error_total += results[0]
                # feed_dict, speaker_index, audio, gc, lc = CreateTrainingFeedDict(
                #     audio, gc, lc, audio_placeholder, gc_placeholder, lc_placeholder, i)
                # [results] = sess.run([operations], feed_dict=feed_dict)
                if i % 10 == 0:
                    print("i: %d loss: %f" % (i, results[0]))
                    error_list.append(error_total / len(duration_list))

                if i % 10 == 0:
                    self.net.batch_size = 1
                    validation_score = 0
                    for i in range(3):

                        validation_score += sess.run(
                            validation,
                            feed_dict={
                                audio_placeholder:
                                val_audio[:, 300 * i:300 * (i + 1) +
                                          receptive_field],
                                lc_placeholder:
                                val_lc[:, 300 * i:300 * (i + 1) +
                                       receptive_field],
                                gc_placeholder:
                                _gc[0]
                            })
                    val_list.append(validation_score / 3)
                    print("i: %d validation: %f" % (i, validation_score / 3))

            with open('complicated_error.pkl', 'wb') as f1:
                pickle.dump(error_list, f1)

            with open('complicated_validation.pkl', 'wb') as f2:
                pickle.dump(val_list, f2)
            loss_val = results[0]

            # Sanity check the initial loss was larger.
            # self.assertGreater(initial_loss, max_allowed_loss)

            # Loss after training should be small.
            # self.assertLess(loss_val, max_allowed_loss)

            # Loss should be at least two orders of magnitude better
            # than before training.
            # self.assertLess(loss_val / initial_loss, 0.02)

            if self.generate:
                # self._save_net(sess)
                if self.global_conditioning and not self.local_conditioning:
                    # Check non-fast-generated waveform.
                    generated_waveforms, ids = generate_waveforms(
                        # sess, self.net, True, speaker_ids)
                        sess,
                        self.net,
                        True,
                        np.array((0, )))
                    for (waveform, id) in zip(generated_waveforms, ids):
                        # check_waveform(self.assertGreater, waveform, id[0])
                        check_waveform(self.assertGreater, waveform, id)

                elif self.global_conditioning and self.local_conditioning:
                    lc_0 = np.full(int(GENERATE_SAMPLES / 3), 1)
                    lc_1 = np.full(int(GENERATE_SAMPLES / 3), 2)
                    lc_2 = np.full(int(GENERATE_SAMPLES / 3), 3)
                    lc = np.concatenate((lc_0, lc_1, lc_2))
                    lc = lc.reshape((lc.shape[0], 1))
                    print(lc.shape)
                    """ * test * """
                    test = False
                    if test:
                        # compare_logits(sess, self.net, np.array((0,)), lc)
                        logits_fast, logits_slow = check_logits(
                            sess, self.net, np.array((0, )), lc)
                        np.save("../data/logits_fast", logits_fast)
                        np.save("../data/logits_slow", logits_slow)
                        # np.save("../data/proba_fast", proba_fast)
                        # np.save("../data/proba_slow", proba_slow)
                        exit()
                    # Check non-fast-generated waveform.
                    if self.generate_two_waves:
                        generated_waveforms, ids = generate_waveforms(
                            sess, self.net, True, np.array((0, 1)), lc)
                    else:
                        generated_waveforms, ids = generate_waveforms(
                            sess, self.net, True, np.array((0, )), lc)
                    for (waveform, id) in zip(generated_waveforms, ids):
                        # check_waveform(self.assertGreater, waveform, id[0])
                        if id == 0:
                            np.save("../data/wave_fast", waveform)
                            np.save("../data/lc_fast", lc)
                            # plot_waveform(waveform)
                        else:
                            np.save("../data/wave_t", waveform)

                    generated_waveforms, ids = generate_waveforms(
                        sess, self.net, False, np.array((0, )), lc[:, 0])

                    for (waveform, id) in zip(generated_waveforms, ids):
                        # check_waveform(self.assertGreater, waveform, id[0])
                        if id == 0:
                            np.save("../data/wave_slow", waveform)
                            np.save("../data/lc_slow", lc)
                            # plot_waveform(waveform)
                        else:
                            np.save("../data/wave_t", waveform)
                            # np.save("../data/lc", lc)
                        # plot_waveform4eachLC(waveform, lc)
                        # check_waveform(self.assertGreater, waveform, id)

                    # Check fast-generated wveform.
                    # generated_waveforms, ids = generate_waveforms(sess,
                    #     self.net, True, speaker_ids)
                    # for (waveform, id) in zip(generated_waveforms, ids):
                    #     print("Checking fast wf for id{}".format(id[0]))
                    #     check_waveform( self.assertGreater, waveform, id[0])

                else:
                    # Check non-incremental generation
                    generated_waveforms, _ = generate_waveforms(
                        sess, self.net, False, None)
                    check_waveform(self.assertGreater, generated_waveforms[0],
                                   None)
                    # Check incremental generation
                    generated_waveform = generate_waveforms(
                        sess, self.net, True, None)
                    check_waveform(self.assertGreater, generated_waveforms[0],
                                   None)
Ejemplo n.º 23
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    #logdir is where logging file is saved. different from where generated mat is saved.
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    #skeleton_channels = wavenet_params['skeleton_channels']
    input_channels = wavenet_params['input_channels']
    output_channels = wavenet_params['output_channels']
    #gt, cut_index, gc_id = create_seed(os.path.join(args.motion_seed, os.path.basename(args.motion_seed)), args.window)
    gt, cut_index = create_seed(
        os.path.join(args.motion_seed, os.path.basename(args.motion_seed)),
        args.window)
    if np.isnan(np.sum(gt)):
        print('nan detected')
        raise ValueError('NAN detected in seed file')
    #if skeleton_channels == 45 or skeleton_channels == 42:
    #    seed = tf.constant(gt[:cut_index, 45 - skeleton_channels:])
    #else:
    #    seed = tf.constant(gt[:cut_index, :])
    seed = tf.constant(gt)

    net = WaveNetModel(batch_size=1,
                       dilations=wavenet_params['dilations'],
                       filter_width=wavenet_params['filter_width'],
                       residual_channels=wavenet_params['residual_channels'],
                       dilation_channels=wavenet_params['dilation_channels'],
                       skip_channels=wavenet_params['skip_channels'],
                       use_biases=wavenet_params['use_biases'],
                       input_channels=input_channels,
                       output_channels=output_channels,
                       global_condition_channels=args.gc_channels)

    samples = tf.placeholder(dtype=tf.float32)
    #todo: Q: how does samples represent T x 42 data? does predict_proba_incremental memorize? A: samples can store multiple frames. T x 42 dim

    if args.fast_generation:
        #next_sample = net.predict_proba_incremental(samples, args.gc_id)
        #next_sample = net.predict_proba_incremental(samples, gc_id)
        next_sample = net.predict_proba_incremental(samples)
    else:
        #next_sample = net.predict_proba(samples, args.gc_id)
        #next_sample = net.predict_proba(samples, gc_id)
        next_sample = net.predict_proba_incremental(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)
    if args.motion_seed:
        pass
    else:
        raise ValueError('motion seed not specified!')

    # seed: T x 42 tensor
    # tolist() converts a tf tensor to a list
    gt_list = sess.run(seed).tolist()
    motion = gt_list[:cut_index]
    #motion[i]: ith frame, list of 42 features
    if args.fast_generation and args.motion_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(motion[-net.receptive_field:-1]):
            if i % 10 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs,
                     feed_dict={samples: np.reshape(x, (1, input_channels))})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = motion[-1]
        else:
            if len(motion) > net.receptive_field:
                window = motion[-net.receptive_field:]
            else:
                window = motion
            outputs = [next_sample]

        #TODO: why motion[-1] fed into network twice?
        # Run the WaveNet to predict the next sample.
        prediction = sess.run(
            outputs,
            feed_dict={samples: np.reshape(window, (1, input_channels))})[0]
        # prediction = sess.run(outputs, feed_dict={samples: window})[0]
        #TODO: next_input = np.concatenate((prediction, gt(4:9)), axis=1). motion.append(next_input)
        motion.append(
            np.concatenate(
                (prediction, gt_list[cut_index + step][input_channels:]),
                axis=1))
        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

    print()

    # save result in .mat file
    if args.skeleton_out_path:
        #TODO: save according to Hanbyul rules
        # outdir = os.path.join('logdir','skeleton_generate', os.path.basename(os.path.dirname(args.checkpoint)) + os.path.basename(args.checkpoint)+'window'+str(args.window)+'sample'+str(args.samples))
        outdir = os.path.join(
            args.skeleton_out_path,
            os.path.basename(os.path.dirname(args.checkpoint)))
        if not os.path.exists(outdir):
            os.makedirs(outdir)
        filedir = os.path.join(
            outdir,
            str(os.path.basename(args.motion_seed)) + '.mat')
        # filedir = os.path.join(outdir, (sub+args.skeleton_out_path))
        sio.savemat(filedir, {'wavenet_predict': motion, 'gt': gt})
        # out = sess.run(decode, feed_dict={samples: motion})
        # todo: write skeleton writer
        # write_skeleton(motion, args.wav_out_path)
        print(len(motion))
        print('generated filedir:{0}'.format(filedir))
    print('Finished generating. The result can be viewed in Matlab.')
Ejemplo n.º 24
0
class TestNet(tf.test.TestCase):
    def setUp(self):
        self.net = WaveNetModel(batch_size=1,
                                dilations=[1, 2, 4, 8, 16, 32, 64,
                                           1, 2, 4, 8, 16, 32, 64],
                                filter_width=2,
                                residual_channels=32,
                                dilation_channels=32,
                                quantization_channels=256,
                                skip_channels=32)
        self.optimizer_type = 'sgd'
        self.learning_rate = 0.02
        self.generate = False
        self.momentum = MOMENTUM

    # Train a net on a short clip of 3 sine waves superimposed
    # (an e-flat chord).
    #
    # Presumably it can overfit to such a simple signal. This test serves
    # as a smoke test where we just check that it runs end-to-end during
    # training, and learns this waveform.

    def testEndToEndTraining(self):
        audio = make_sine_waves()
        np.random.seed(42)

        # if self.generate:
        #    librosa.output.write_wav('/tmp/sine_train.wav', audio,
        #                             SAMPLE_RATE_HZ)
        #    power_spectrum = np.abs(np.fft.fft(audio))**2
        #    freqs = np.fft.fftfreq(audio.size, SAMPLE_PERIOD_SECS)
        #    indices = np.argsort(freqs)
        #    indices = [index for index in indices if freqs[index] >= 0 and
        #                                             freqs[index] <= 500.0]
        #    plt.plot(freqs[indices], power_spectrum[indices])
        #    plt.show()

        audio_tensor = tf.convert_to_tensor(audio, dtype=tf.float32)
        loss = self.net.loss(audio_tensor)
        optimizer = optimizer_factory[self.optimizer_type](
                      learning_rate=self.learning_rate, momentum=self.momentum)
        trainable = tf.trainable_variables()
        optim = optimizer.minimize(loss, var_list=trainable)
        init = tf.initialize_all_variables()

        generated_waveform = None
        max_allowed_loss = 0.1
        loss_val = max_allowed_loss
        initial_loss = None
        with self.test_session() as sess:
            sess.run(init)
            initial_loss = sess.run(loss)
            for i in range(TRAIN_ITERATIONS):
                loss_val, _ = sess.run([loss, optim])
                # if i % 10 == 0:
                #     print("i: %d loss: %f" % (i, loss_val))

            # Sanity check the initial loss was larger.
            self.assertGreater(initial_loss, max_allowed_loss)

            # Loss after training should be small.
            self.assertLess(loss_val, max_allowed_loss)

            # Loss should be at least two orders of magnitude better
            # than before training.
            self.assertLess(loss_val / initial_loss, 0.01)

            # saver = tf.train.Saver(var_list=tf.trainable_variables())
            # saver.save(sess, '/tmp/sine_test_model.ckpt', global_step=i)
            if self.generate:
                # Check non-incremental generation
                generated_waveform = generate_waveform(sess, self.net, False)
                check_waveform(self.assertGreater, generated_waveform)

                # Check incremental generation
                generated_waveform = generate_waveform(sess, self.net, True)
                check_waveform(self.assertGreater, generated_waveform)
Ejemplo n.º 25
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return
    logdir = directories['logdir']
    logdir_root = directories['logdir_root']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.

    coord = tf.train.Coordinator()

    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        gc_enabled = args.gc_channels is not None
        reader = SkeletonReader(
            args.data_dir,
            coord,
            gc_enabled=gc_enabled,
            receptive_field=WaveNetModel.calculate_receptive_field(
                wavenet_params["filter_width"], wavenet_params["dilations"]),
            input_channels=wavenet_params["input_channels"],
            sample_size=args.sample_size)
        #print ('batch_size:{0}'.format(args.batch_size))
        #skeleton_batch: batch_sizes x (receptive field+sample_size) x skeleton_channels
        skeleton_batch = reader.dequeue(args.batch_size)
        #print('skeleton_batch shape:')
        print_node = tf.Print(skeleton_batch, [tf.size(skeleton_batch)])
        if gc_enabled:
            gc_id_batch = reader.dequeue_gc(args.batch_size)
        else:
            gc_id_batch = None
        #sess.run(print_node)
    # Create network.
    net = WaveNetModel(batch_size=args.batch_size,
                       dilations=wavenet_params["dilations"],
                       filter_width=wavenet_params["filter_width"],
                       residual_channels=wavenet_params["residual_channels"],
                       input_channels=wavenet_params["input_channels"],
                       output_channels=wavenet_params["output_channels"],
                       dilation_channels=wavenet_params["dilation_channels"],
                       skip_channels=wavenet_params["skip_channels"],
                       use_biases=wavenet_params["use_biases"],
                       histograms=args.histograms,
                       global_condition_channels=args.gc_channels)

    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    loss = net.loss(input_batch=skeleton_batch,
                    global_condition_batch=gc_id_batch,
                    l2_regularization_strength=args.l2_regularization_strength)
    optimizer = optimizer_factory[args.optimizer](
        learning_rate=args.learning_rate,
        momentum=args.momentum,
        epsilon=args.epsilon)
    trainable = tf.trainable_variables()
    total_parameters = 0
    for variable in trainable:
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print('total number of parameters: {0}'.format(total_parameters))
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up logging for TensorBoard.
    writer = tf.train.SummaryWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.merge_all_summaries()

    # Set up session
    #sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.initialize_all_variables()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables())

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)

    step = None
    try:
        last_saved_step = saved_global_step
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run([summaries, loss, optim],
                                                  options=run_options,
                                                  run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run([summaries, loss, optim])
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(
                step, loss_value, duration))

            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
Ejemplo n.º 26
0
def main():
    config = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(config.logdir, 'generate', started_datestring)

    if not os.path.exists(logdir):
        os.makedirs(logdir)

    load_hparams(hparams, config.checkpoint_dir)

    with tf.device('/cpu:0'):

        sess = tf.Session()
        scalar_input = hparams.scalar_input
        net = WaveNetModel(
            batch_size=config.batch_size,
            dilations=hparams.dilations,
            filter_width=hparams.filter_width,
            residual_channels=hparams.residual_channels,
            dilation_channels=hparams.dilation_channels,
            quantization_channels=hparams.quantization_channels,
            out_channels=hparams.out_channels,
            skip_channels=hparams.skip_channels,
            use_biases=hparams.use_biases,
            scalar_input=hparams.scalar_input,
            initial_filter_width=hparams.initial_filter_width,
            global_condition_channels=hparams.gc_channels,
            global_condition_cardinality=config.gc_cardinality,
            local_condition_channels=hparams.num_mels,
            upsample_factor=hparams.upsample_factor,
            train_mode=False
        )  # train 단계에서는 global_condition_cardinality를 AudioReader에서 파악했지만, 여기서는 넣어주어야 함

        if scalar_input:
            samples = tf.placeholder(tf.float32, shape=[net.batch_size, None])
        else:
            samples = tf.placeholder(
                tf.int32, shape=[net.batch_size, None]
            )  # samples: mu_law_encode로 변환된 것. one-hot으로 변환되기 전. (batch_size, 길이)

        # local condition이 (N,T,num_mels) 여야 하지만, 길이 1까지로 들어가야하기 때무넹, (N,1,num_mels) --> squeeze하면 (N,num_mels)
        upsampled_local_condition = tf.placeholder(
            tf.float32, shape=[net.batch_size, hparams.num_mels])

        next_sample = net.predict_proba_incremental(
            samples, upsampled_local_condition, [config.gc_id] * net.batch_size
        )  # Fast Wavenet Generation Algorithm-1611.09482 algorithm 적용

        # making local condition data. placeholder - upsampled_local_condition 넣어줄 upsampled local condition data를 만들어 보자.

        mel_input = np.load(config.mel)
        sample_size = mel_input.shape[0] * hparams.hop_size
        mel_input = np.tile(mel_input, (config.batch_size, 1, 1))
        with tf.variable_scope('wavenet', reuse=tf.AUTO_REUSE):
            upsampled_local_condition_data = net.create_upsample(mel_input)

        var_list = [
            var for var in tf.global_variables() if 'queue' not in var.name
        ]
        saver = tf.train.Saver(var_list)
        print('Restoring model from {}'.format(config.checkpoint_dir))

        load(saver, sess, config.checkpoint_dir)

        sess.run(
            net.queue_initializer)  # 이 부분이 없으면, checkpoint에서 복원된 값들이 들어 있다.

        quantization_channels = hparams.quantization_channels
        if config.wav_seed:
            # wav_seed의 길이가 receptive_field보다 작으면, padding이라도 해야 되는 거 아닌가? 그냥 짧으면 짧은 대로 return함  --> 그래서 너무 짧으면 error
            seed = create_seed(config.wav_seed, hparams.sample_rate,
                               quantization_channels, net.receptive_field,
                               scalar_input)  # --> mu_law encode 된 것.
            if scalar_input:
                waveform = seed.tolist()
            else:
                waveform = sess.run(
                    seed).tolist()  # [116, 114, 120, 121, 127, ...]

            print('Priming generation...')
            for i, x in enumerate(waveform[-net.receptive_field:-1]
                                  ):  # 제일 마지막 1개는 아래의 for loop의 첫 loop에서 넣어준다.
                if i % 100 == 0:
                    print('Priming sample {}/{}'.format(
                        i, net.receptive_field),
                          end='\r')
                sess.run(next_sample,
                         feed_dict={
                             samples:
                             np.array([x] * net.batch_size).reshape(
                                 net.batch_size, 1),
                             upsampled_local_condition:
                             np.zeros([net.batch_size, hparams.num_mels])
                         })
            print('Done.')
            waveform = np.array([waveform[-net.receptive_field:]] *
                                net.batch_size)
        else:
            # Silence with a single random sample at the end.
            if scalar_input:
                waveform = [0.0] * (net.receptive_field - 1)
                waveform = np.array(waveform * net.batch_size).reshape(
                    net.batch_size, -1)
                waveform = np.concatenate(
                    [
                        waveform, 2 * np.random.rand(net.batch_size).reshape(
                            net.batch_size, -1) - 1
                    ],
                    axis=-1)  # -1~1사이의 random number를 만들어 끝에 붙힌다.
            else:
                waveform = [quantization_channels / 2] * (
                    net.receptive_field - 1
                )  # 필요한 receptive_field 크기보다 1개 작게 만든 후, 아래에서 random하게 1개를 덧붙힌다.
                waveform = np.array(waveform * net.batch_size).reshape(
                    net.batch_size, -1)
                waveform = np.concatenate(
                    [
                        waveform,
                        np.random.randint(quantization_channels,
                                          size=net.batch_size).reshape(
                                              net.batch_size, -1)
                    ],
                    axis=-1)  # one hot 변환 전. (batch_size, 5117)

        start_time = time.time()
        upsampled_local_condition_data = sess.run(
            upsampled_local_condition_data)
        last_sample_timestamp = datetime.now()
        for step in range(sample_size):  # 원하는 길이를 구하기 위해 loop

            window = waveform[:, -1:]  # 제일 끝에 있는 1개만 samples에 넣어 준다.

            # Run the WaveNet to predict the next sample.

            # fast가 아닌경우. window: [128.0, 128.0, ..., 128.0, 178, 185]
            # fast인 경우, window는 숫자 1개.
            prediction = sess.run(
                next_sample,
                feed_dict={
                    samples:
                    window,
                    upsampled_local_condition:
                    upsampled_local_condition_data[:, step, :]
                }
            )  # samples는 mu law encoding된 것. 계산 과정에서 one hot으로 변환된다.  --> (batch_size,256)

            if scalar_input:
                sample = prediction  # logistic distribution으로부터 sampling 되었기 때문에, randomness가 있다.
            else:
                # Scale prediction distribution using temperature.
                # 다음 과정은 config.temperature==1이면 각 원소를 합으로 나누어주는 것에 불과. 이미 softmax를 적용한 겂이므로, 합이 1이된다. 그래서 값의 변화가 없다.
                # config.temperature가 1이 아니며, 각 원소의 log취한 값을 나눈 후, 합이 1이 되도록 rescaling하는 것이 된다.
                np.seterr(divide='ignore')
                scaled_prediction = np.log(
                    prediction
                ) / config.temperature  # config.temperature인 경우는 값의 변화가 없다.
                scaled_prediction = (
                    scaled_prediction - np.logaddexp.reduce(
                        scaled_prediction, axis=-1, keepdims=True)
                )  # np.log(np.sum(np.exp(scaled_prediction)))
                scaled_prediction = np.exp(scaled_prediction)
                np.seterr(divide='warn')

                # Prediction distribution at temperature=1.0 should be unchanged after
                # scaling.
                if config.temperature == 1.0:
                    np.testing.assert_allclose(
                        prediction,
                        scaled_prediction,
                        atol=1e-5,
                        err_msg=
                        'Prediction scaling at temperature=1.0 is not working as intended.'
                    )

                # argmax로 선택하지 않기 때문에, 같은 입력이 들어가도 달라질 수 있다.
                sample = [[
                    np.random.choice(np.arange(quantization_channels), p=p)
                ] for p in scaled_prediction]  # choose one sample per batch

            waveform = np.concatenate([waveform, sample], axis=-1)

            # Show progress only once per second.
            current_sample_timestamp = datetime.now()
            time_since_print = current_sample_timestamp - last_sample_timestamp
            if time_since_print.total_seconds() > 1.:
                duration = time.time() - start_time
                print('Sample {:3<d}/{:3<d}, ({:.3f} sec/step)'.format(
                    step + 1, sample_size, duration),
                      end='\r')
                last_sample_timestamp = current_sample_timestamp

        # Introduce a newline to clear the carriage return from the progress.
        print()

        # Save the result as a wav file.
        if scalar_input:
            out = waveform[:, net.receptive_field:]
        else:
            decode = mu_law_decode(samples, quantization_channels)
            out = sess.run(
                decode, feed_dict={samples: waveform[:, net.receptive_field:]})

        # save wav
        for i in range(net.batch_size):
            config.wav_out_path = logdir + '/test-{}.wav'.format(i)
            audio.save_wav(
                out[i],
                config.wav_out_path,
                hparams.sample_rate,
            )

        print('Finished generating.')
Ejemplo n.º 27
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    
    if args.wavenet_params.startswith('wavenet_params/'):
        with open(args.wavenet_params, 'r') as config_file:
            wavenet_params = json.load(config_file)
    elif args.wavenet_params.startswith('wavenet_params'):
        with open('wavenet_params/'+args.wavenet_params, 'r') as config_file:
            wavenet_params = json.load(config_file)
    else:
        with open('wavenet_params/wavenet_params_'+args.wavenet_params, 'r') as config_file:
            wavenet_params = json.load(config_file)    

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=args.gc_cardinality)

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples, args.gc_id)
    else:
        next_sample = net.predict_proba(samples, args.gc_id)

    if args.fast_generation:
        sess.run(tf.global_variables_initializer())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed,
                           wavenet_params['sample_rate'],
                           quantization_channels,
                           net.receptive_field,
                           args.silence_threshold)
        waveform = sess.run(seed).tolist()
    else:
        # Silence with a single random sample at the end.
        waveform = [quantization_channels / 2] * (net.receptive_field - 1)
        waveform.append(np.random.randint(quantization_channels))

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[-net.receptive_field: -1]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()

        #frequency to period change
    if args.tfrequency is not None:
        if args.tperiod is not None:
            raise ValueError("Frequency and Period both assigned. Assign only one of them.")
        else:
            PERIOD = dynamic.frequency_to_period(args.tfrequency)
    else:
        if args.tperiod is not None:
            PERIOD = args.tperiod
        else:
            PERIOD = 1

    # generate an array of temperature for each step when called "dynamic" in tempurature-change variable
    if args.temperature_change == "dynamic" and args.tform is not None:
        temp_array = dynamic.generate_value(0, wavenet_params['sample_rate'], args.tform, args.tmin, args.tmax, PERIOD, args.samples, args.tphaseshift)


    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]

        # Scale prediction distribution using temperature.
        np.seterr(divide='ignore')

        # temperature change
        if args.temperature_change == None: #static
            _temp_temperature = args.temperature
        elif args.temperature_change == "dynamic":
            if args.tform == None: #random
                if step % int(args.samples/5) == 0:
                    _temp_temperature = args.temperature * np.random.rand()
            else:
                _temp_temperature = temp_array[1][step]  
        else:
                raise Exception("wrong temperature_change value")

        scaled_prediction = np.log(prediction) / _temp_temperature
        scaled_prediction = (scaled_prediction -
                             np.logaddexp.reduce(scaled_prediction))
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after
        # scaling.
        if args.temperature == 1.0 and args.temperature_change == None:
            np.testing.assert_allclose(
                    prediction, scaled_prediction, atol=1e-5,
                    err_msg = 'Prediction scaling at temperature=1.0 is not working as intended.')

        sample = np.random.choice(
            np.arange(quantization_channels), p=scaled_prediction)

        
        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}, temperature {:3<f}'.format(step + 1, args.samples, _temp_temperature),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every and
                (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.summary.FileWriter(logdir)
    tf.summary.audio('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.summary.merge_all()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
Ejemplo n.º 28
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
        histograms=False,
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=args.gc_cardinality,
        MFSC_channels=wavenet_params["MFSC_channels"],
        AP_channels=wavenet_params["AP_channels"],
        F0_channels=wavenet_params["F0_channels"],
        phone_channels=wavenet_params["phones_channels"],
        phone_pos_channels=wavenet_params["phone_pos_channels"])

    samples = tf.placeholder(tf.float32)
    lc = tf.placeholder(tf.float32)
    
    AP_channels = wavenet_params["AP_channels"]
    MFSC_channels = wavenet_params["MFSC_channels"]

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples,  args.gc_id)         ########### understand shape of next_sample
    else:
        # samples = tf.reshape(samples, [1, samples.shape[-2], samples.shape[-1]])
        # lc = tf.reshape(lc, [1, lc.shape[-2], lc.shape[-1]])
        # mvn1, mvn2 , mvn3 ,mvn4, w1, w2, w3, w4 = net.predict_proba(samples, lc, args.gc_id)
        outputs = net.predict_proba(samples, AP_channels, lc, args.gc_id)
        outputs = tf.reshape(outputs, [1, AP_channels])

    if args.fast_generation:
        sess.run(tf.global_variables_initializer())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    # decode = mu_law_decode(samples, wavenet_params['quantization_channels'])      

    # quantization_channels = wavenet_params['quantization_channels']
    # MFSC_channels = wavenet_params["MFSC_channels"]
    if args.wav_seed:
        # seed = create_seed(args.wav_seed,
        #                    wavenet_params['sample_rate'],
        #                    quantization_channels,
        #                    net.receptive_field)
        # waveform = sess.run(seed).tolist()
        pass
    else:

        # Silence with a single random sample at the end.
        # waveform = [quantization_channels / 2] * (net.receptive_field - 1)
        # waveform.append(np.random.randint(quantization_channels))

        waveform = np.zeros((net.receptive_field - 1, AP_channels + MFSC_channels))
        waveform = np.append(waveform, np.random.randn(1, AP_channels + MFSC_channels), axis=0)

        lc_array, mfsc_array = load_lc(scramble=True) # clip_id:[003, 004, 007, 010, 012 ...]
        # print ("before pading:", lc_array.shape)
        # lc_array = lc_array.reshape(1, lc_array.shape[-2], lc_array.shape[-1])
        lc_array = np.pad(lc_array, ((net.receptive_field, 0), (0, 0)), 'constant', constant_values=((0, 0),(0,0)))
        mfsc_array = np.pad(mfsc_array, ((net.receptive_field, 0), (0, 0)), 'constant', constant_values=((0, 0),(0,0)))
        # print ("after pading:", lc_array.shape)
    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)                                   ################# understand net.push_ops, understand shape of outputs

        print('Priming generation...')
        for i, x in enumerate(waveform[-net.receptive_field: -1]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:

            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:, :]
            else:
                window = waveform
            # outputs = [next_sample]
            ############################################## TODO ############################################
            # Modified code to get one output of the GMM
            # outputs = estimate_output(mvn1, mvn2, mvn3, mvn4, w1, w2, w3, w4, MFSC_channels)
            # outputs = tf.reshape(outputs, [1,60])
            
            # outputs = w1 * mvn1.sample([1]) + w2 * mvn2.sample([1]) + w3 * mvn3.sample([1]) + w4 * mvn4.sample([1])

        # Run the WaveNet to predict the next sample.
        window = window.reshape(1, window.shape[-2], window.shape[-1])
        # print ("lc_batch_shape:", lc_array[step+1: step+1+net.receptive_field, :].reshape(1, net.receptive_field, -1).shape)
        prediction = sess.run(outputs, feed_dict={samples: window, 
                                                lc: lc_array[step+1: step+1+net.receptive_field, :].reshape(1, net.receptive_field, -1)})
        # print(prediction.shape)
        prediction = np.concatenate((prediction, mfsc_array[step + net.receptive_field].reshape(1,-1)), axis=-1)
        waveform = np.append(waveform, prediction, axis=0)

        ########### temp control is in model.py currently ############

        # # Scale prediction distribution using temperature.
        # np.seterr(divide='ignore')
        # scaled_prediction = np.log(prediction) / args.temperature           ############# understand this
        # scaled_prediction = (scaled_prediction -
        #                      np.logaddexp.reduce(scaled_prediction))
        # scaled_prediction = np.exp(scaled_prediction)
        # np.seterr(divide='warn')

        # # Prediction distribution at temperature=1.0 should be unchanged after
        # # scaling.
        # if args.temperature == 1.0:
        #     np.testing.assert_allclose(
        #             prediction, scaled_prediction, atol=1e-5,
        #             err_msg='Prediction scaling at temperature=1.0 '
        #                     'is not working as intended.')

        # sample = np.random.choice(                                      ############ replace with sampling multivariate gaussian
        #     np.arange(quantization_channels), p=scaled_prediction)
        # waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Frame {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every and
                (step + 1) % args.save_every == 0):
            # out = sess.run(decode, feed_dict={samples: waveform})
            # write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)
            np.save(args.wav_out_path, waveform)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    # datestring = str(datetime.now()).replace(' ', 'T')
    # writer = tf.summary.FileWriter(logdir)
    # tf.summary.audio('generated', decode, wavenet_params['sample_rate'])
    # summaries = tf.summary.merge_all()
    # summary_out = sess.run(summaries,
    #                        feed_dict={samples: np.reshape(waveform, [-1, 1])})
    # writer.add_summary(summary_out)

    # Save the result as a numpy file.
    if args.wav_out_path:
        # out = sess.run(decode, feed_dict={samples: waveform})
        # write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)
        np.save(args.wav_out_path, waveform[:,:4])    # only take the first four columns, which are aps

    print('Finished generating. The result can be viewed in TensorBoard.')
Ejemplo n.º 29
0
class TestMoveNet(tf.test.TestCase):
    def generate_waveform(self, sess):
        samples = tf.placeholder(tf.int32)
        next_sample_probs = self.net.predict_proba_all(samples)
        operations = [next_sample_probs]

        waveform = []
        seed = create_seed("sine_train.wav",
                           SAMPLE_RATE_HZ,
                           QUANTIZATION_CHANNELS,
                           window_size=WINDOW_SIZE,
                           silence_threshold=0)
        input_waveform = sess.run(seed).tolist()
        decode = mu_law_decode(samples, QUANTIZATION_CHANNELS)
        slide_windows = 256
        for slide_start in range(0, len(input_waveform), slide_windows):
            if slide_start + slide_windows >= len(input_waveform):
                break
            input_audio_window = input_waveform[slide_start:slide_start +
                                                slide_windows]

            # Run the WaveNet to predict the next sample.
            all_prediction = sess.run(operations,
                                      feed_dict={samples:
                                                 input_audio_window})[0]
            all_prediction = np.asarray(all_prediction)
            output_waveform = get_all_output_from_predictions(all_prediction)
            print("Prediction {}".format(output_waveform))
            waveform.extend(output_waveform)

        waveform = np.array(waveform[:])
        decoded_waveform = sess.run(decode, feed_dict={samples: waveform})
        return decoded_waveform

    def setUp(self):
        self.net = WaveNetModel(
            batch_size=1,
            dilations=[1, 2, 4, 8, 16, 32, 64, 1, 2, 4, 8, 16, 32, 64],
            filter_width=2,
            residual_channels=32,
            dilation_channels=32,
            quantization_channels=256,
            use_biases=True,
            skip_channels=32)
        self.optimizer_type = 'sgd'
        self.learning_rate = 0.02
        self.generate = True
        self.momentum = MOMENTUM

    def testEndToEndTraining(self):
        audio, output_audio = make_sine_waves()
        np.random.seed(42)
        librosa.output.write_wav('sine_train.wav', audio, int(SAMPLE_RATE_HZ))
        librosa.output.write_wav('sine_expected_answered.wav', output_audio,
                                 int(SAMPLE_RATE_HZ))

        input_samples = tf.placeholder(tf.float32)
        output_samples = tf.placeholder(tf.float32)

        loss = self.net.loss(input_samples, output_samples)
        optimizer = optimizer_factory[self.optimizer_type](
            learning_rate=self.learning_rate, momentum=self.momentum)
        trainable = tf.trainable_variables()
        optim = optimizer.minimize(loss, var_list=trainable)
        init = tf.initialize_all_variables()

        generated_waveform = None
        max_allowed_loss = 0.1
        slide_windows = 256
        slide_start = 0
        with self.test_session() as sess:
            sess.run(init)
            for i in range(TRAIN_ITERATIONS):
                if slide_start + slide_windows >= min(len(audio),
                                                      len(output_audio)):
                    slide_start = 0
                    print("slide from beginning...")
                input_audio_window = audio[slide_start:slide_start +
                                           slide_windows]
                output_audio_window = output_audio[slide_start:slide_start +
                                                   slide_windows]
                slide_start += 1
                loss_val, _ = sess.run(
                    [loss, optim],
                    feed_dict={
                        input_samples: input_audio_window,
                        output_samples: output_audio_window
                    })
                if i % 10 == 0:
                    print("i: %d loss: %f" % (i, loss_val))
            # saver.save(sess, '/tmp/sine_test_model.ckpt', global_step=i)
            if self.generate:
                # Check non-incremental generation
                generated_waveform = self.generate_waveform(sess)
                check_waveform(self.assertGreater, generated_waveform)
Ejemplo n.º 30
0
class TestNet(tf.test.TestCase):
    def setUp(self):
        print('TestNet setup.')
        sys.stdout.flush()

        self.optimizer_type = 'sgd'
        self.learning_rate = 0.02
        self.generate = False
        self.momentum = MOMENTUM
        self.global_conditioning = False
        self.train_iters = TRAIN_ITERATIONS
        self.net = WaveNetModel(
            batch_size=1,
            dilations=[1, 2, 4, 8, 16, 32, 64, 1, 2, 4, 8, 16, 32, 64],
            filter_width=2,
            residual_channels=32,
            dilation_channels=32,
            quantization_channels=QUANTIZATION_CHANNELS,
            skip_channels=32,
            global_condition_channels=None,
            global_condition_cardinality=None)

    def _save_net(sess):
        saver = tf.train.Saver(var_list=tf.trainable_variables())
        saver.save(sess, '\tmp\test.ckpt')

    # Train a net on a short clip of 3 sine waves superimposed
    # (an e-flat chord).
    #
    # Presumably it can overfit to such a simple signal. This test serves
    # as a smoke test where we just check that it runs end-to-end during
    # training, and learns this waveform.

    def testEndToEndTraining(self):
        def CreateTrainingFeedDict(audio, speaker_ids, audio_placeholder,
                                   gc_placeholder, i):
            speaker_index = 0
            if speaker_ids is None:
                # No global conditioning.
                feed_dict = {audio_placeholder: audio}
            else:
                feed_dict = {
                    audio_placeholder: audio,
                    gc_placeholder: speaker_ids
                }
            return feed_dict, speaker_index

        np.random.seed(42)
        audio, speaker_ids = make_sine_waves(self.global_conditioning)

        # if self.generate:
        #     if len(audio.shape) == 2:
        #       for i in range(audio.shape[0]):
        #            librosa.output.write_wav(
        #                  '/tmp/sine_train{}.wav'.format(i), audio[i,:],
        #                  SAMPLE_RATE_HZ)
        #            power_spectrum = np.abs(np.fft.fft(audio[i,:]))**2
        #            freqs = np.fft.fftfreq(audio[i,:].size,
        #                                   SAMPLE_PERIOD_SECS)
        #            indices = np.argsort(freqs)
        #            indices = [index for index in indices if
        #                         freqs[index] >= 0 and
        #                         freqs[index] <= 500.0]
        #            plt.plot(freqs[indices], power_spectrum[indices])
        #            plt.show()

        audio_placeholder = tf.placeholder(dtype=tf.float32)
        gc_placeholder = tf.placeholder(dtype=tf.int32)  \
            if self.global_conditioning else None

        loss = self.net.loss(input_batch=audio_placeholder,
                             global_condition_batch=gc_placeholder)
        optimizer = optimizer_factory[self.optimizer_type](
            learning_rate=self.learning_rate, momentum=self.momentum)
        trainable = tf.trainable_variables()
        optim = optimizer.minimize(loss, var_list=trainable)
        init = tf.initialize_all_variables()

        generated_waveform = None
        max_allowed_loss = 0.1
        loss_val = max_allowed_loss
        initial_loss = None
        operations = [loss, optim]
        with self.test_session() as sess:
            feed_dict, speaker_index = CreateTrainingFeedDict(
                audio, speaker_ids, audio_placeholder, gc_placeholder, 0)
            sess.run(init)
            initial_loss = sess.run(loss, feed_dict=feed_dict)
            for i in range(self.train_iters):
                feed_dict, speaker_index = CreateTrainingFeedDict(
                    audio, speaker_ids, audio_placeholder, gc_placeholder, i)
                [results] = sess.run([operations], feed_dict=feed_dict)
                if i % 10 == 0:
                    print("i: %d loss: %f" % (i, results[0]))

            loss_val = results[0]

            # Sanity check the initial loss was larger.
            self.assertGreater(initial_loss, max_allowed_loss)

            # Loss after training should be small.
            self.assertLess(loss_val, max_allowed_loss)

            # Loss should be at least two orders of magnitude better
            # than before training.
            self.assertLess(loss_val / initial_loss, 0.02)

            if self.generate:
                # self._save_net(sess)
                if self.global_conditioning:
                    # Check non-fast-generated waveform.
                    generated_waveforms, ids = generate_waveforms(
                        sess, self.net, False, speaker_ids)
                    for (waveform, id) in zip(generated_waveforms, ids):
                        check_waveform(self.assertGreater, waveform, id[0])

                    # Check fast-generated wveform.
                    # generated_waveforms, ids = generate_waveforms(sess,
                    #     self.net, True, speaker_ids)
                    # for (waveform, id) in zip(generated_waveforms, ids):
                    #     print("Checking fast wf for id{}".format(id[0]))
                    #     check_waveform( self.assertGreater, waveform, id[0])

                else:
                    # Check non-incremental generation
                    generated_waveforms, _ = generate_waveforms(
                        sess, self.net, False, None)
                    check_waveform(self.assertGreater, generated_waveforms[0],
                                   None)
                    if not self.net.scalar_input:
                        # Check incremental generation
                        generated_waveform = generate_waveforms(
                            sess, self.net, True, None)
                        check_waveform(self.assertGreater,
                                       generated_waveforms[0], None)
Ejemplo n.º 31
0
class TestNet(tf.test.TestCase):
    def setUp(self):
        print('TestNet setup.')
        sys.stdout.flush()

        self.optimizer_type = 'sgd'
        self.learning_rate = 0.02
        self.generate = False
        self.momentum = MOMENTUM
        self.global_conditioning = False
        self.train_iters = TRAIN_ITERATIONS

        with tf.variable_scope('test_net', reuse=tf.AUTO_REUSE):
            self.net = WaveNetModel(
                batch_size=1,
                dilations=[1, 2, 4, 8, 16, 32, 64, 1, 2, 4, 8, 16, 32, 64],
                filter_width=2,
                residual_channels=32,
                dilation_channels=32,
                quantization_channels=QC,
                skip_channels=32,
                global_condition_channels=None,
                global_condition_cardinality=None)

    def _save_net(sess):
        saver = tf.train.Saver(var_list=tf.trainable_variables())
        saver.save(sess, os.path.join('tmp', 'test.ckpt'))

    # Train a net on a short clip of 3 sine waves superimposed
    # (an e-flat chord).
    #
    # Presumably it can overfit to such a simple signal. This test serves
    # as a smoke test where we just check that it runs end-to-end during
    # training, and learns this waveform.

    def testEndToEndTraining(self):
        def CreateTrainingFeedDict(audio, speaker_ids, audio_placeholder,
                                   gc_placeholder, i):
            speaker_index = 0
            if speaker_ids is None:
                # No global conditioning.
                feed_dict = {audio_placeholder: audio}
            else:
                feed_dict = {
                    audio_placeholder: audio,
                    gc_placeholder: speaker_ids
                }
            return feed_dict, speaker_index

        np.random.seed(42)
        audio, speaker_ids = make_sine_waves(self.global_conditioning)
        # Pad with 0s (silence) times size of the receptive field minus one,
        # because the first sample of the training data is 0 and if the network
        # learns to predict silence based on silence, it will generate only
        # silence.
        if self.global_conditioning:
            audio = np.pad(audio, ((0, 0), (self.net.receptive_field - 1, 0)),
                           'constant')
        else:
            audio = np.pad(audio, (self.net.receptive_field - 1, 0),
                           'constant')

        audio_placeholder = tf.placeholder(dtype=tf.float32)
        gc_placeholder = tf.placeholder(dtype=tf.int32)  \
            if self.global_conditioning else None

        loss = self.net.loss(input_batch=audio_placeholder,
                             global_condition_batch=gc_placeholder)
        optimizer = optimizer_factory[self.optimizer_type](
            learning_rate=self.learning_rate, momentum=self.momentum)
        trainable = tf.trainable_variables()
        optim = optimizer.minimize(loss, var_list=trainable)
        init = tf.global_variables_initializer()

        generated_waveform = None
        max_allowed_loss = 0.1
        loss_val = max_allowed_loss
        initial_loss = None
        operations = [loss, optim]
        with self.test_session() as sess:
            feed_dict, speaker_index = CreateTrainingFeedDict(
                audio, speaker_ids, audio_placeholder, gc_placeholder, 0)
            sess.run(init)
            initial_loss = sess.run(loss, feed_dict=feed_dict)
            for i in range(self.train_iters):
                feed_dict, speaker_index = CreateTrainingFeedDict(
                    audio, speaker_ids, audio_placeholder, gc_placeholder, i)
                [results] = sess.run([operations], feed_dict=feed_dict)
                if i % 100 == 0:
                    print("i: %d loss: %f" % (i, results[0]))

            loss_val = results[0]

            # Sanity check the initial loss was larger.
            self.assertGreater(initial_loss, max_allowed_loss)

            # Loss after training should be small.
            self.assertLess(loss_val, max_allowed_loss)

            # Loss should be at least two orders of magnitude better
            # than before training.
            self.assertLess(loss_val / initial_loss, 0.02)

            if self.generate:
                # self._save_net(sess)
                if self.global_conditioning:
                    # Check non-fast-generated waveform.
                    generated_waveforms, ids = generate_waveforms(
                        sess, self.net, False, speaker_ids)
                    for (waveform, id) in zip(generated_waveforms, ids):
                        check_waveform(self.assertGreater, waveform, id[0])

                    # Check fast-generated wveform.
                    # generated_waveforms, ids = generate_waveforms(sess,
                    #     self.net, True, speaker_ids)
                    # for (waveform, id) in zip(generated_waveforms, ids):
                    #     print("Checking fast wf for id{}".format(id[0]))
                    #     check_waveform( self.assertGreater, waveform, id[0])

                else:
                    # Check non-incremental generation
                    generated_waveforms, _ = generate_waveforms(
                        sess, self.net, False, None)
                    check_waveform(self.assertGreater, generated_waveforms[0],
                                   None)
                    # Check incremental generation
                    generated_waveform = generate_waveforms(
                        sess, self.net, True, None)
                    check_waveform(self.assertGreater, generated_waveforms[0],
                                   None)
Ejemplo n.º 32
0
def main():
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'])

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples)
    else:
        next_sample = net.predict_proba(samples)

    if args.fast_generation:
        sess.run(tf.initialize_all_variables())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.all_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = samples

    quantization_channels = wavenet_params['quantization_channels']
    waveform = [32.]

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > args.window:
                window = waveform[-args.window:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]
        sample = np.random.choice(np.arange(quantization_channels),
                                  p=prediction)
        waveform.append(sample)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.text_out_path and args.save_every
                and (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_text(out, args.text_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as a wav file.
    if args.text_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_text(out, args.text_out_path)

    print('Finished generating.')
Ejemplo n.º 33
0
def main():
    args = get_arguments()
    data_dir = 'midi-Corpus/' + args.data_set + '/'
    logdir = data_dir + 'max_dilation=%d_reps=%d/' % (args.max_dilation_pow,
                                                      args.expansion_reps)
    print('*************************************************')
    print(logdir)
    print('*************************************************')
    sys.stdout.flush()
    restore_from = logdir
    if not os.path.exists(logdir):
        os.makedirs(logdir)

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    wavenet_params = loadParams(args.max_dilation_pow, args.expansion_reps,
                                args.dil_chan, args.res_chan, args.skip_chan)

    with open(logdir + 'wavenet_params.json', 'w') as outfile:
        json.dump(wavenet_params, outfile)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        gc_enabled = False
        # data queue for the training set
        train_dir = data_dir + 'train/'
        train_reader = MidiReader(
            train_dir,
            coord,
            sample_rate=wavenet_params['sample_rate'],
            gc_enabled=gc_enabled,
            receptive_field=WaveNetModel.calculate_receptive_field(
                wavenet_params["filter_width"], wavenet_params["dilations"],
                wavenet_params["scalar_input"],
                wavenet_params["initial_filter_width"]),
            sample_size=args.sample_size)
        train_batch = train_reader.dequeue(args.batch_size)
        if gc_enabled:
            gc_id_batch = reader.dequeue_gc(args.batch_size)
        else:
            gc_id_batch = None

    # Create network.
    net = WaveNetModel(
        batch_size=BATCH_SIZE,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
        histograms=False,
        global_condition_channels=None,
        global_condition_cardinality=train_reader.gc_category_cardinality)
    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    print('constructing training loss')
    sys.stdout.flush()
    train_loss, target_output, prediction = net.loss(
        input_batch=train_batch,
        global_condition_batch=gc_id_batch,
        l2_regularization_strength=args.l2_regularization_strength)
    print('constructing validation loss')
    sys.stdout.flush()

    print('making optimizer')
    sys.stdout.flush()
    optimizer = optimizer_factory['adam'](learning_rate=args.learning_rate,
                                          momentum=args.momentum)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(train_loss, var_list=trainable)

    print('setting up tensorboard')
    sys.stdout.flush()
    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()

    test_input = tf.placeholder(dtype=tf.float32, shape=(1, None, 88))
    test_loss, test_target_output, test_prediction = net.loss(
        input_batch=test_input,
        global_condition_batch=gc_id_batch,
        l2_regularization_strength=args.l2_regularization_strength)
    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.global_variables_initializer()
    sess.run(init)

    print('saver')
    sys.stdout.flush()
    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=5)

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    test_audio = load_all_audio(data_dir + 'test/')
    num_test_files = len(test_audio)
    test_losses = np.zeros((num_test_files, ))
    for i in range(num_test_files):
        test_i = np.expand_dims(test_audio[i], 0)
        test_losses[i] = sess.run(test_loss, {test_input: test_i})
    test_loss_value = np.mean(test_losses)
    np.savez(logdir + 'test.npz', test_loss=test_loss_value)
Ejemplo n.º 34
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    logdir_root = directories['logdir_root']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = args.silence_threshold if args.silence_threshold > \
                                                      EPSILON else None
        reader = AudioReader(
            args.data_dir,
            coord,
            sample_rate=wavenet_params['sample_rate'],
            sample_size=args.sample_size,
            silence_threshold=args.silence_threshold)
        audio_batch = reader.dequeue(args.batch_size)

    # Create network.
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"])
    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    loss = net.loss(audio_batch, args.l2_regularization_strength)
    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up logging for TensorBoard.
    writer = tf.train.SummaryWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.merge_all_summaries()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.initialize_all_variables()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver()

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)

    try:
        last_saved_step = saved_global_step
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run(
                    [summaries, loss, optim],
                    options=run_options,
                    run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run([summaries, loss, optim])
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'
                  .format(step, loss_value, duration))

            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
Ejemplo n.º 35
0
def main(waveform, num_predictions):
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

    #args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    #logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(WAVENET_PARAMS, 'r') as config_file:
        wavenet_params = json.load(config_file)
    tf.reset_default_graph()
    config = tf.ConfigProto(
        device_count={'GPU': 0}  # todo this is only to generate test stuff on nonGPU
    )
    sess = tf.Session(config=config)

    # sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        global_condition_channels=None,
        global_condition_cardinality=None)

    samples = tf.placeholder(tf.int32)


    # next_sample = net.predict_proba_incremental(samples, None) #fastgen
    next_sample = net.predict_proba(samples, None) #regulargen


    #sess.run(tf.global_variables_initializer())
    #sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    #print('Restoring model from {}'.format(args.checkpoint))
    #global INITIALIZED
    #if not INITIALIZED:
    #saver.restore(sess, "logdir/train/2017-06-05T19-03-11/model.ckpt-107")
    saver.restore(sess, "logdir/train/last144kwl2/model.ckpt-27050")
    #    INITIALIZED = True

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']

    seed = create_seed(waveform,
                       wavenet_params['sample_rate'],
                       quantization_channels,
                       net.receptive_field)
    waveform = sess.run(seed).tolist()
    #print("priming size is " + str(net.receptive_field))  # TODO debug so I can figure out how long a sequence to feed


    # When using the incremental generation, we need to
    # feed in all priming samples one by one before starting the
    # actual generation.
    # TODO This could be done much more efficiently by passing the waveform
    # to the incremental generator as an optional argument, which would be
    # used to fill the queues initially.
    # outputs = [next_sample] begin fastgen commented section
    # outputs.extend(net.push_ops)
    #
    # # print('Priming generation...') # todo remove these prints for going into production
    # for i, x in enumerate(waveform[-net.receptive_field: -1]):
    #     # if i % 1000 == 0:
    #         # print('Priming sample {}'.format(i))
    #     sess.run(outputs, feed_dict={samples: x})

    last_sample_timestamp = datetime.now()
    for step in range(num_predictions):
        if False:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
        else:
            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:]
            else:
                window = waveform
            outputs = [next_sample]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]

        # Scale prediction distribution using temperature.
        np.seterr(divide='ignore')
        scaled_prediction = np.log(prediction)
        scaled_prediction = (scaled_prediction -
                             np.logaddexp.reduce(scaled_prediction))
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')
        # plt.plot(range(len(scaled_prediction)), scaled_prediction) # todo debug tool, sharp peaks means operating correctly
        # plt.show()  #
        # Prediction distribution at temperature=1.0 should be unchanged after
        # scaling.
        # if args.temperature == 1.0:
        #     np.testing.assert_allclose(
        #             prediction, scaled_prediction, atol=1e-5,
        #             err_msg='Prediction scaling at temperature=1.0 '
        #                     'is not working as intended.')
        # TODO consider grabbing only top probability instead of this range
        sample = np.random.choice(np.arange(quantization_channels), p=scaled_prediction)
        waveform.append(sample)

        # Show progress only once per second.
        # current_sample_timestamp = datetime.now()
        # time_since_print = current_sample_timestamp - last_sample_timestamp
        # if time_since_print.total_seconds() > 1.:
        #     print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
        #           end='\r')
        #     last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        # if (args.wav_out_path and args.save_every and
        #         (step + 1) % args.save_every == 0):
        #     out = sess.run(decode, feed_dict={samples: waveform})
        #     write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    # Introduce a newline to clear the carriage return from the progress.
    # print()

    # Save the result as an audio summary.
    # datestring = str(datetime.now()).replace(' ', 'T')
    # writer = tf.summary.FileWriter(logdir)
    # tf.summary.audio('generated', decode, wavenet_params['sample_rate'])
    # summaries = tf.summary.merge_all()
    # summary_out = sess.run(summaries,
    #                        feed_dict={samples: np.reshape(waveform, [-1, 1])})
    # writer.add_summary(summary_out)

    # Save the result as a wav file.
    # if args.wav_out_path:
    #     out = sess.run(decode, feed_dict={samples: waveform})
    #     write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)
    out = sess.run(decode, feed_dict={samples: waveform})
    sess.close()
    return out
Ejemplo n.º 36
0
def main():
    midi_dims = 88
    args = get_arguments()
    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    checkpoint = tf.train.latest_checkpoint(args.resdir)
    print('checkpoint: ', checkpoint)
    wavenet_params_fname = args.resdir + 'wavenet_params.json'
    print('wavenet params fname', wavenet_params_fname)
    with open(wavenet_params_fname, 'r') as config_file:
        wavenet_params = json.load(config_file)
        wavenet_params['midi_dims'] = midi_dims

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        midi_dims=wavenet_params['midi_dims'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=args.gc_cardinality)

    samples = tf.placeholder(tf.float32, shape=(None, midi_dims))

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples, args.gc_id)
    else:
        next_sample = net.predict_proba(samples, args.gc_id)

    if args.fast_generation:
        sess.run(tf.global_variables_initializer())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)
    print('vars to restore')
    for key in variables_to_restore.keys():
        print(key, variables_to_restore[key])

    print('Restoring model from {}'.format(checkpoint))
    saver.restore(sess, checkpoint)

    if args.wav_seed:
        seed = create_seed(args.wav_seed, wavenet_params['sample_rate'],
                           midi_dims, net.receptive_field)
        waveform = sess.run(seed).tolist()
    else:
        # Silence with a single random sample at the end.
        #waveform = [quantization_channels / 2] * (net.receptive_field - 1)
        #waveform.append(np.random.randint(quantization_channels))
        random_note = np.zeros((1, midi_dims))
        random_note[0, np.random.randint(0, midi_dims - 1)] = 1.0
        waveform = np.concatenate((np.zeros(
            (net.receptive_field - 1, midi_dims)), random_note),
                                  axis=0)

    if args.fast_generation and args.wav_seed:
        print('fast gen')
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[-net.receptive_field:-1]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    print('receptive field is %d' % net.receptive_field)
    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = np.expand_dims(waveform[-1, :], 0)
        else:
            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:]
            else:
                window = waveform
            outputs = [next_sample]

        print(step, 'wave shape', waveform.shape, 'window shape', window.shape)
        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs, feed_dict={samples: window})[0]

        # Scale prediction distribution using temperature.
        #np.seterr(divide='ignore')
        #scaled_prediction = np.log(prediction) / args.temperature
        #scaled_prediction = (scaled_prediction -
        #                     np.logaddexp.reduce(scaled_prediction))
        #scaled_prediction = np.exp(scaled_prediction)
        #np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after
        # scaling.
        #if args.temperature == 1.0:
        #    np.testing.assert_allclose(
        #            prediction, scaled_prediction, atol=1e-5,
        #           err_msg='Prediction scaling at temperature=1.0 '
        #                    'is not working as intended.')
        sample = 1 * (prediction > 0.5)
        print('num notes', np.count_nonzero(sample))
        #sample = np.random.choice(
        #    np.arange(quantization_channels), p=scaled_prediction)
        waveform = np.concatenate((waveform, np.expand_dims(sample, 0)),
                                  axis=0)

        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples),
                  end='\r')
            last_sample_timestamp = current_sample_timestamp

    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    #writer = tf.summary.FileWriter(logdir)
    #tf.summary.audio('generated', decode, wavenet_params['sample_rate'])
    #summaries = tf.summary.merge_all()
    #print('waveform', waveform);
    #summary_out = sess.run(summaries, feed_dict={samples: waveform})
    #writer.add_summary(summary_out)

    # Save the result as a wav file.
    if args.wav_out_path is None:
        args.wav_out_path = args.resdir

    #out = sess.run(decode, feed_dict={samples: waveform})
    print(args.wav_out_path)
    filename = args.wav_out_path + ('sample_%d.mid' % int(args.gen_num))
    midiwrite(filename, waveform)
Ejemplo n.º 37
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    logdir_root = directories['logdir_root']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = args.silence_threshold if args.silence_threshold > \
                                                      EPSILON else None
        reader = AudioReader(args.data_dir,
                             coord,
                             sample_rate=wavenet_params['sample_rate'],
                             sample_size=args.sample_size,
                             silence_threshold=args.silence_threshold)
        #audio_batch, input_IDs = reader.dequeue(args.batch_size)#单GPu转成下面的多GPU

    # Create network.
    batch_size_single_GPU = int(1.0 * args.batch_size / args.num_gpus)
    net = WaveNetModel(
        batch_size=batch_size_single_GPU,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        ID_channels=wavenet_params["ID_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],  #标量输入与矢量输入?
        initial_filter_width=wavenet_params["initial_filter_width"])
    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None

    optimizer = optimizer_factory[args.optimizer](
        learning_rate=args.learning_rate, momentum=args.momentum)
    trainable = tf.trainable_variables()

    tower_grads = []
    #for i in range(args.num_gpus):
    with tf.device('/gpu:0'):
        with tf.name_scope('losstower_0') as scope:
            audio_batch, input_IDs = reader.dequeue(batch_size_single_GPU)
            all_loss = net.loss(audio_batch, input_IDs,
                                args.l2_regularization_strength)
            loss, L1 = all_loss  #total loss
            tf.get_variable_scope().reuse_variables()
            grads_vars = optimizer.compute_gradients(loss, var_list=trainable)
            tower_grads.append(grads_vars)  #
    update_wei_op = []
    with tf.device('/cpu:0'):  ###
        for gv in tower_grads:
            app_grad = optimizer.apply_gradients(gv)
            update_wei_op.append(app_grad)

    with tf.control_dependencies(update_wei_op):
        train_op = tf.no_op()

    # Set up logging for TensorBoard.
    writer = tf.train.SummaryWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.merge_all_summaries()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    init = tf.initialize_all_variables()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables())

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess, N_THREADS)

    step = None
    try:
        last_saved_step = saved_global_step
        avg_loss_value = 0.0
        avg_L1_value = 0.0
        start_time = time.time()
        for step in range(saved_global_step + 1, args.num_steps):
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, all_loss_value, _ = sess.run(
                    [summaries, all_loss, train_op],
                    options=run_options,
                    run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                all_loss_value, _ = sess.run([all_loss, train_op])
                #writer.add_summary(summary, step)
            loss_value, L1_value = all_loss_value
            avg_loss_value += loss_value
            avg_L1_value += L1_value

            if step % args.checkloss_every == 0:
                avg_loss_value = avg_loss_value / args.checkloss_every
                avg_L1_value = avg_L1_value / args.checkloss_every
                duration = (time.time() -
                            start_time) * 1.0 / args.checkloss_every
                print(
                    'step {:d} - avg_loss = {:.3f}, avg_L1 = {:.3f}, ({:.3f} sec/step)'
                    .format(step, loss_value, L1_value, duration))
                sys.stdout.flush()
                avg_loss_value = 0.0
                avg_L1_value = 0.0
                start_time = time.time()

            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
Ejemplo n.º 38
0
def main():
    args = get_arguments()

    if (args.logdir is not None and os.path.isdir(args.logdir)):
        logdir = args.logdir
    else:
        print('Argument --logdir=\'{}\' is not (but should be) '
              'a path to valid directory.'.format(args.logdir))
        return

    with open(args.model_params, 'r') as f:
        model_params = json.load(f)
    with open(RUNTIME_SWITCHES, 'r') as f:
        switch = json.load(f)

    receptive_field = WaveNetModel.calculate_receptive_field(
        model_params['filter_width'],
        model_params['dilations'],
        model_params['initial_filter_width'])

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Create data loader.
    with tf.name_scope('create_inputs'):
        reader = WavMidReader(data_dir=args.data_dir_test,
                              coord=coord,
                              audio_sample_rate=model_params['audio_sr'],
                              receptive_field=receptive_field,
                              velocity=args.velocity,
                              sample_size=args.sample_size,
                              queues_size=(100, 100*BATCH_SIZE))

    # Create model.
    net = WaveNetModel(
        batch_size=BATCH_SIZE,
        dilations=model_params['dilations'],
        filter_width=model_params['filter_width'],
        residual_channels=model_params['residual_channels'],
        dilation_channels=model_params['dilation_channels'],
        skip_channels=model_params['skip_channels'],
        output_channels=model_params['output_channels'],
        use_biases=model_params['use_biases'],
        initial_filter_width=model_params['initial_filter_width'])

    input_data = tf.placeholder(dtype=tf.float32,
                                shape=(BATCH_SIZE, None, 1))
    input_labels = tf.placeholder(dtype=tf.float32,
                                  shape=(BATCH_SIZE, None,
                                         model_params['output_channels']))

    _, probs = net.loss(input_data=input_data,
                        input_labels=input_labels,
                        pos_weight=1.0,
                        l2_reg_str=None)

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.global_variables_initializer()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables())

    try:
        load(saver, sess, logdir)

    except:
        print('Something went wrong while restoring checkpoint.')
        raise

    try:
        stats = 0, 0, 0, 0, 0, 0
        est = np.empty([model_params['output_channels'], 0])
        ref = np.empty([model_params['output_channels'], 0])
        sub_fac = int(model_params['audio_sr']/switch['midi_sr'])
        for data, labels in reader.single_pass(sess,
                                               args.data_dir_test):

            predictions = sess.run(probs, feed_dict={input_data : data})
            # Aggregate sums for metrics calculation
            stats_chunk = calc_stats(predictions, labels, args.threshold)
            stats = tuple([sum(x) for x in zip(stats, stats_chunk)])
            est = np.append(est, roll_subsample(predictions.T, sub_fac), axis=1)
            ref = np.append(ref, roll_subsample(labels.T, sub_fac, b=True),
                            axis=1)

        metrics = calc_metrics(None, None, None, stats=stats)
        write_metrics(metrics, None, None, None, None, None, logdir=logdir)

        # Save subsampled data for further arbitrary evaluation
        np.save(logdir+'/est.npy', est)
        np.save(logdir+'/ref.npy', ref)

        # Render evaluation results
        figsize=(int(args.plot_scale*est.shape[1]/switch['midi_sr']),
                 int(args.plot_scale*model_params['output_channels']/12))
        if args.media:
            write_images(est, ref, switch['midi_sr'],
                         args.threshold, figsize,
                         None, None, None, 0, None,
                         noterange=(21, 109),
                         legend=args.plot_legend,
                         logdir=logdir)
            write_audio(est, ref, switch['midi_sr'],
                        model_params['audio_sr'], 0.007,
                        None, None, None, 0, None, logdir=logdir)

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        coord.request_stop()
Ejemplo n.º 39
0
class TestNet(tf.test.TestCase):
    def setUp(self):
        print('TestNet setup.')
        sys.stdout.flush()

        self.optimizer_type = 'sgd'
        self.learning_rate = 0.02
        self.generate = False
        self.momentum = MOMENTUM
        self.global_conditioning = False
        self.train_iters = TRAIN_ITERATIONS
        self.net = WaveNetModel(batch_size=1,
                                dilations=[1, 2, 4, 8, 16, 32, 64,
                                           1, 2, 4, 8, 16, 32, 64],
                                filter_width=2,
                                residual_channels=32,
                                dilation_channels=32,
                                quantization_channels=QUANTIZATION_CHANNELS,
                                skip_channels=32,
                                global_condition_channels=None,
                                global_condition_cardinality=None)

    def _save_net(sess):
        saver = tf.train.Saver(var_list=tf.trainable_variables())
        saver.save(sess, os.path.join('tmp', 'test.ckpt'))

    # Train a net on a short clip of 3 sine waves superimposed
    # (an e-flat chord).
    #
    # Presumably it can overfit to such a simple signal. This test serves
    # as a smoke test where we just check that it runs end-to-end during
    # training, and learns this waveform.

    def testEndToEndTraining(self):
        def CreateTrainingFeedDict(audio, speaker_ids, audio_placeholder,
                                   gc_placeholder, i):
            speaker_index = 0
            if speaker_ids is None:
                # No global conditioning.
                feed_dict = {audio_placeholder: audio}
            else:
                feed_dict = {audio_placeholder: audio,
                             gc_placeholder: speaker_ids}
            return feed_dict, speaker_index

        np.random.seed(42)
        audio, speaker_ids = make_sine_waves(self.global_conditioning)
        # Pad with 0s (silence) times size of the receptive field minus one,
        # because the first sample of the training data is 0 and if the network
        # learns to predict silence based on silence, it will generate only
        # silence.
        if self.global_conditioning:
            audio = np.pad(audio, ((0, 0), (self.net.receptive_field - 1, 0)),
                           'constant')
        else:
            audio = np.pad(audio, (self.net.receptive_field - 1, 0),
                           'constant')

        audio_placeholder = tf.placeholder(dtype=tf.float32)
        gc_placeholder = tf.placeholder(dtype=tf.int32)  \
            if self.global_conditioning else None

        loss = self.net.loss(input_batch=audio_placeholder,
                             global_condition_batch=gc_placeholder)
        optimizer = optimizer_factory[self.optimizer_type](
                      learning_rate=self.learning_rate, momentum=self.momentum)
        trainable = tf.trainable_variables()
        optim = optimizer.minimize(loss, var_list=trainable)
        init = tf.global_variables_initializer()

        generated_waveform = None
        max_allowed_loss = 0.1
        loss_val = max_allowed_loss
        initial_loss = None
        operations = [loss, optim]
        with self.test_session() as sess:
            feed_dict, speaker_index = CreateTrainingFeedDict(
                audio, speaker_ids, audio_placeholder, gc_placeholder, 0)
            sess.run(init)
            initial_loss = sess.run(loss, feed_dict=feed_dict)
            for i in range(self.train_iters):
                feed_dict, speaker_index = CreateTrainingFeedDict(
                    audio, speaker_ids, audio_placeholder, gc_placeholder, i)
                [results] = sess.run([operations], feed_dict=feed_dict)
                if i % 100 == 0:
                    print("i: %d loss: %f" % (i, results[0]))

            loss_val = results[0]

            # Sanity check the initial loss was larger.
            self.assertGreater(initial_loss, max_allowed_loss)

            # Loss after training should be small.
            self.assertLess(loss_val, max_allowed_loss)

            # Loss should be at least two orders of magnitude better
            # than before training.
            self.assertLess(loss_val / initial_loss, 0.02)

            if self.generate:
                # self._save_net(sess)
                if self.global_conditioning:
                    # Check non-fast-generated waveform.
                    generated_waveforms, ids = generate_waveforms(
                        sess, self.net, False, speaker_ids)
                    for (waveform, id) in zip(generated_waveforms, ids):
                        check_waveform(self.assertGreater, waveform, id[0])

                    # Check fast-generated wveform.
                    # generated_waveforms, ids = generate_waveforms(sess,
                    #     self.net, True, speaker_ids)
                    # for (waveform, id) in zip(generated_waveforms, ids):
                    #     print("Checking fast wf for id{}".format(id[0]))
                    #     check_waveform( self.assertGreater, waveform, id[0])

                else:
                    # Check non-incremental generation
                    generated_waveforms, _ = generate_waveforms(
                        sess, self.net, False, None)
                    check_waveform(
                        self.assertGreater, generated_waveforms[0], None)
                    # Check incremental generation
                    generated_waveform = generate_waveforms(
                        sess, self.net, True, None)
                    check_waveform(
                        self.assertGreater, generated_waveforms[0], None)