def moments(fps, args):
    x = loader.get_batch(fps,
                         1,
                         args._WINDOW_LEN,
                         args.data_first_window,
                         repeat=False)[0, :, 0]

    X = tf.contrib.signal.stft(x, 256, 128, pad_end=True)
    X_mag = tf.abs(X)
    X_lmag = tf.log(X_mag + args._LOG_EPS)

    _X_lmags = []
    with tf.Session() as sess:
        while True:
            try:
                _X_lmag = sess.run(X_lmag)
            except:
                break

            _X_lmags.append(_X_lmag)

    _X_lmags = np.concatenate(_X_lmags, axis=0)
    mean, std = np.mean(_X_lmags, axis=0), np.std(_X_lmags, axis=0)

    with open(os.path.join(args.train_dir, args.data_moments_file), 'wb') as f:
        pickle.dump((mean, std), f)
Exemplo n.º 2
0
def moments(fps, args):
    # x = loader.get_batch(fps, 1, _WINDOW_LEN, args.data_first_window, repeat=False)[0, :, 0]
    training_iterator = loader.get_batch(fps,
                                         1,
                                         _WINDOW_LEN,
                                         args.data_first_window,
                                         repeat=False,
                                         initializable=True,
                                         labels=True,
                                         exclude_class=args.exclude_class)
    x, _ = training_iterator.get_next()  # Important: ignore the labels
    x = x[0, :, 0]

    X = tf.contrib.signal.stft(x, 256, 128, pad_end=True)
    X_mag = tf.abs(X)
    X_lmag = tf.log(X_mag + _LOG_EPS)

    _X_lmags = []
    with tf.Session() as sess:
        sess.run(training_iterator.initializer)
        while True:
            try:
                _X_lmag = sess.run(X_lmag)
            except:
                break

            _X_lmags.append(_X_lmag)

    _X_lmags = np.concatenate(_X_lmags, axis=0)
    mean, std = np.mean(_X_lmags, axis=0), np.std(_X_lmags, axis=0)

    with open(args.data_moments_fp, 'wb') as f:
        pickle.dump((mean, std), f)
Exemplo n.º 3
0
def find_data_size(fps, exclude_class):
    # Find out the size of the data
    dummy_it = loader.get_batch(fps,
                                1,
                                _WINDOW_LEN,
                                False,
                                repeat=False,
                                initializable=True,
                                labels=True,
                                exclude_class=exclude_class)
    dummy_x, _ = dummy_it.get_next()
    train_dataset_size = 0

    with tf.device('cpu:0'):
        with tf.Session() as sess:
            sess.run(dummy_it.initializer)
            try:
                while True:
                    sess.run(dummy_x)
                    train_dataset_size += 1
            except tf.errors.OutOfRangeError:
                pass

    return train_dataset_size
Exemplo n.º 4
0
def train(fps, args):
    with tf.name_scope('loader'):
        x, y = loader.get_batch(fps,
                                args.train_batch_size,
                                _WINDOW_LEN,
                                args.data_first_window,
                                labels=True)

    # Make inputs
    y_fill = tf.expand_dims(y, axis=2)
    z = tf.random_uniform([args.train_batch_size, _D_Z],
                          -1.,
                          1.,
                          dtype=tf.float32)

    # concatenate
    x = tf.concat([x, y_fill], 1)
    z = tf.concat([z, y], 1)

    # Make generator
    with tf.variable_scope('G'):
        G_z = WaveGANGenerator(z, train=True, **args.wavegan_g_kwargs)
        if args.wavegan_genr_pp:
            with tf.variable_scope('pp_filt'):
                G_z = tf.layers.conv1d(G_z,
                                       1,
                                       args.wavegan_genr_pp_len,
                                       use_bias=False,
                                       padding='same')
    G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')

    # Print G summary
    print('-' * 80)
    print('Generator vars')
    nparams = 0
    for v in G_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))

    # Summarize
    tf.summary.audio('x', x, _FS)
    tf.summary.audio('G_z', G_z, _FS)
    G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1))
    x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1))
    tf.summary.histogram('x_rms_batch', x_rms)
    tf.summary.histogram('G_z_rms_batch', G_z_rms)
    tf.summary.scalar('x_rms', tf.reduce_mean(x_rms))
    tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms))

    # Make real discriminator
    with tf.name_scope('D_x'), tf.variable_scope('D'):
        D_x = WaveGANDiscriminator(x, **args.wavegan_d_kwargs)
    D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')

    # Print D summary
    print('-' * 80)
    print('Discriminator vars')
    nparams = 0
    for v in D_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))
    print('-' * 80)

    # Make fake discriminator
    with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
        D_G_z = WaveGANDiscriminator(G_z, **args.wavegan_d_kwargs)

    # Create loss
    D_clip_weights = None
    if args.wavegan_loss == 'dcgan':
        fake = tf.zeros([args.train_batch_size], dtype=tf.float32)
        real = tf.ones([args.train_batch_size], dtype=tf.float32)

        G_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=real))

        D_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=fake))
        D_loss += tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x, labels=real))

        D_loss /= 2.
    elif args.wavegan_loss == 'lsgan':
        G_loss = tf.reduce_mean((D_G_z - 1.)**2)
        D_loss = tf.reduce_mean((D_x - 1.)**2)
        D_loss += tf.reduce_mean(D_G_z**2)
        D_loss /= 2.
    elif args.wavegan_loss == 'wgan':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        with tf.name_scope('D_clip_weights'):
            clip_ops = []
            for var in D_vars:
                clip_bounds = [-.01, .01]
                clip_ops.append(
                    tf.assign(
                        var,
                        tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])))
            D_clip_weights = tf.group(*clip_ops)
    elif args.wavegan_loss == 'wgan-gp':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1],
                                  minval=0.,
                                  maxval=1.)
        differences = G_z - x
        interpolates = x + (alpha * differences)
        with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
            D_interp = WaveGANDiscriminator(interpolates,
                                            **args.wavegan_d_kwargs)

        LAMBDA = 10
        gradients = tf.gradients(D_interp, [interpolates])[0]
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
        gradient_penalty = tf.reduce_mean((slopes - 1.)**2.)
        D_loss += LAMBDA * gradient_penalty
    else:
        raise NotImplementedError()

    tf.summary.scalar('G_loss', G_loss)
    tf.summary.scalar('D_loss', D_loss)

    # Create (recommended) optimizer
    if args.wavegan_loss == 'dcgan':
        G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
        D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
    elif args.wavegan_loss == 'lsgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
    elif args.wavegan_loss == 'wgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
    elif args.wavegan_loss == 'wgan-gp':
        G_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
        D_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
    else:
        raise NotImplementedError()

    # Create training ops
    G_train_op = G_opt.minimize(
        G_loss,
        var_list=G_vars,
        global_step=tf.train.get_or_create_global_step())
    D_train_op = D_opt.minimize(D_loss, var_list=D_vars)

    # Run training
    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=args.train_dir,
            save_checkpoint_secs=args.train_save_secs,
            save_summaries_secs=args.train_summary_secs) as sess:
        while True:
            # Train discriminator
            for i in xrange(args.wavegan_disc_nupdates):
                sess.run(D_train_op)

                # Enforce Lipschitz constraint for WGAN
                if D_clip_weights is not None:
                    sess.run(D_clip_weights)

            # Train generator
            sess.run(G_train_op)
Exemplo n.º 5
0
def train(fps, args):
    global train_dataset_size

    with tf.name_scope('loader'):
        # This was actually not necessarily good. However, we can keep it as a point for 115 tfrecords
        # train_fps, _ = loader.split_files_test_val(fps, train_data_percentage, 0)
        # fps = train_fps
        # fps = fps[:gan_train_data_size]

        logging.info("Full training datasize = " +
                     str(find_data_size(fps, None)))
        length = len(fps)
        fps = fps[:(int(train_data_percentage / 100.0 * length))]
        logging.info("GAN training datasize (before exclude) = " +
                     str(find_data_size(fps, None)))

        if args.exclude_class is None:
            pass
        elif args.exclude_class != -1:
            train_dataset_size = find_data_size(fps, args.exclude_class)
            logging.info("GAN training datasize (after exclude) = " +
                         str(train_dataset_size))
        elif args.exclude_class == -1:
            fps, _ = loader.split_files_test_val(fps, 0.9, 0)
            train_dataset_size = find_data_size(fps, args.exclude_class)
            logging.info(
                "GAN training datasize (after exclude - random sampling) = " +
                str(train_dataset_size))
        else:  # LOL :P
            raise ValueError(
                "args.exclude_class should be either [0, num_class), None, or -1 for random sampling 90%"
            )

        training_iterator = loader.get_batch(fps,
                                             args.train_batch_size,
                                             _WINDOW_LEN,
                                             args.data_first_window,
                                             repeat=True,
                                             initializable=True,
                                             labels=True,
                                             exclude_class=args.exclude_class)
        x_wav, _ = training_iterator.get_next()  # Important: ignore the labels
        print("x_wav.shape = %s" % str(x_wav.shape))
        x = t_to_f(x_wav, args.data_moments_mean, args.data_moments_std)
        print("x.shape = %s" % str(x.shape))

        logging.info("train_dataset_size = " + str(train_dataset_size))

    # Make z vector
    z = tf.random_uniform([args.train_batch_size, _D_Z],
                          -1.,
                          1.,
                          dtype=tf.float32)

    # Make generator
    with tf.variable_scope('G'):
        G_z = SpecGANGenerator(z, train=True, **args.specgan_g_kwargs)
    G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')

    # Print G summary
    logging.info('-' * 80)
    logging.info('Generator vars')
    nparams = 0
    for v in G_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        logging.info('{} ({}): {}'.format(v.get_shape().as_list(), v_n,
                                          v.name))
    logging.info('Total params: {} ({:.2f} MB)'.format(
        nparams, (float(nparams) * 4) / (1024 * 1024)))

    # Summarize
    x_gl = f_to_t(x, args.data_moments_mean, args.data_moments_std,
                  args.specgan_ngl)
    print("x_gl.shape = %s" % str(x_gl.shape))
    G_z_gl = f_to_t(G_z, args.data_moments_mean, args.data_moments_std,
                    args.specgan_ngl)
    tf.summary.audio('x_wav', x_wav, _FS)
    tf.summary.audio('x', x_gl, _FS)
    tf.summary.audio('G_z', G_z_gl, _FS)
    G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z_gl[:, :, 0]), axis=1))
    x_rms = tf.sqrt(tf.reduce_mean(tf.square(x_gl[:, :, 0]), axis=1))
    tf.summary.histogram('x_rms_batch', x_rms)
    tf.summary.histogram('G_z_rms_batch', G_z_rms)
    tf.summary.scalar('x_rms', tf.reduce_mean(x_rms))
    tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms))
    tf.summary.image('x', f_to_img(x))
    tf.summary.image('G_z', f_to_img(G_z))

    # Make real discriminator
    with tf.name_scope('D_x'), tf.variable_scope('D'):
        D_x = SpecGANDiscriminator(x, **args.specgan_d_kwargs)
    D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')

    # Print D summary
    logging.info('-' * 80)
    logging.info('Discriminator vars')
    nparams = 0
    for v in D_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        logging.info('{} ({}): {}'.format(v.get_shape().as_list(), v_n,
                                          v.name))
    logging.info('Total params: {} ({:.2f} MB)'.format(
        nparams, (float(nparams) * 4) / (1024 * 1024)))
    logging.info('-' * 80)

    # Make fake discriminator
    with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
        D_G_z = SpecGANDiscriminator(G_z, **args.specgan_d_kwargs)

    # Create loss
    D_clip_weights = None
    if args.specgan_loss == 'dcgan':
        fake = tf.zeros([args.train_batch_size], dtype=tf.float32)
        real = tf.ones([args.train_batch_size], dtype=tf.float32)

        G_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=real))

        D_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=fake))
        D_loss += tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x, labels=real))

        D_loss /= 2.
    elif args.specgan_loss == 'lsgan':
        G_loss = tf.reduce_mean((D_G_z - 1.)**2)
        D_loss = tf.reduce_mean((D_x - 1.)**2)
        D_loss += tf.reduce_mean(D_G_z**2)
        D_loss /= 2.
    elif args.specgan_loss == 'wgan':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        with tf.name_scope('D_clip_weights'):
            clip_ops = []
            for var in D_vars:
                clip_bounds = [-.01, .01]
                clip_ops.append(
                    tf.assign(
                        var,
                        tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])))
            D_clip_weights = tf.group(*clip_ops)
    elif args.specgan_loss == 'wgan-gp':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1, 1],
                                  minval=0.,
                                  maxval=1.)
        differences = G_z - x
        interpolates = x + (alpha * differences)
        with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
            D_interp = SpecGANDiscriminator(interpolates,
                                            **args.specgan_d_kwargs)

        LAMBDA = 10
        gradients = tf.gradients(D_interp, [interpolates])[0]
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
        gradient_penalty = tf.reduce_mean((slopes - 1.)**2.)
        D_loss += LAMBDA * gradient_penalty
    else:
        raise NotImplementedError()

    tf.summary.scalar('G_loss', G_loss)
    tf.summary.scalar('D_loss', D_loss)

    # Create (recommended) optimizer
    if args.specgan_loss == 'dcgan':
        G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
        D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
    elif args.specgan_loss == 'lsgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
    elif args.specgan_loss == 'wgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
    elif args.specgan_loss == 'wgan-gp':
        G_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
        D_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
    else:
        raise NotImplementedError()

    # Create training ops
    G_train_op = G_opt.minimize(
        G_loss,
        var_list=G_vars,
        global_step=tf.train.get_or_create_global_step())
    D_train_op = D_opt.minimize(D_loss, var_list=D_vars)

    # Run training
    current_step = -1
    scaffold = tf.train.Scaffold(local_init_op=tf.group(
        tf.local_variables_initializer(), training_iterator.initializer),
                                 saver=tf.train.Saver(max_to_keep=3))
    gpu_options = tf.GPUOptions(allow_growth=True,
                                per_process_gpu_memory_fraction=0.5)
    with tf.train.MonitoredTrainingSession(
            hooks=[SaveAtEnd(os.path.join(args.train_dir, 'model'))],
            config=tf.ConfigProto(gpu_options=gpu_options),
            scaffold=scaffold,
            checkpoint_dir=args.train_dir,
            save_checkpoint_secs=args.train_save_secs,
            save_summaries_secs=args.train_summary_secs,
    ) as sess:
        # sess.run(training_iterator.initializer)
        while True:
            global_step = sess.run(tf.train.get_or_create_global_step())
            logging.info("Global step: " + str(global_step))

            if args.stop_at_global_step != 0 and global_step >= args.stop_at_global_step:
                logging.info(
                    "Stopping because args.stop_at_global_step is set to  " +
                    str(args.stop_at_global_step))
                break
                # last_saver.save(sess, os.path.join(args.train_dir, 'model'), global_step=global_step)

            # Train discriminator
            # for i in range(args.specgan_disc_nupdates):
            #   try:
            #     sess.run(D_train_op)
            #     current_step += 1
            #     # Stop training after x% of training data seen
            #     if current_step * args.train_batch_size > math.ceil(train_dataset_size * train_data_percentage / 100.0):
            #       logging.info("Stopping at batch: " + str(current_step))
            #       current_step = -1
            #       sess.run(training_iterator.initializer)
            #
            #   except tf.errors.OutOfRangeError:
            #     # End of training dataset
            #     if train_data_percentage != 100:
            #       logging.info("ERROR: end of dataset for only part of data! Achieved end of training dataset with train_data_percentage = " + str(train_data_percentage))
            #     else:
            #       current_step = -1
            #       sess.run(training_iterator.initializer)

            # Train discriminator
            try:
                for i in range(args.specgan_disc_nupdates):
                    sess.run(D_train_op)

                # Enforce Lipschitz constraint for WGAN
                if D_clip_weights is not None:
                    sess.run(D_clip_weights)
            except tf.errors.OutOfRangeError:
                sess.run(training_iterator.initializer)

            # Train generator
            sess.run(G_train_op)
Exemplo n.º 6
0
def train():

    train_data, train_label = loader.get_file(file_dir=train_dir)
    train_batch, train_label_batch = loader.get_batch(train_data, train_label,
                                                      IMG_W, IMG_H,
                                                      TRAIN_BATCH_SIZE,CAPACITY)

    validation, validation_label = loader.get_file(file_dir=logs_validation_dir)
    validation_batch, validation_label_batch = loader.get_batch(validation, validation_label,
                                                                IMG_W, IMG_H,
                                                                VALIDATION_BATCH_SIZE, CAPACITY)

    #tf.Session().run(train_label_batch)
    #train_label_batch = tf.reshape(train_label_batch, [TRAIN_BATCH_SIZE,-1])

    train_logits_op = CNN.inference(input=train_batch, reuse=False)
    validation_logits_op = CNN.inference(input=validation_batch, reuse=False)
    train_losses_op = CNN.losses(logits=train_logits_op, labels=train_label_batch)
    validation_losses_op = CNN.losses(logits=validation_logits_op, labels=validation_label_batch)

    train_op = CNN.training(train_losses_op, learning_rate=LEARNING_RATE)

    train_accuracy_op = CNN.evaluation(logits=train_logits_op, labels=train_label_batch, size=TRAIN_BATCH_SIZE)

    validation_accuracy_op = CNN.evaluation(logits=validation_logits_op, labels=validation_label_batch, size=VALIDATION_BATCH_SIZE)

    #train_cross_entropy_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=train_op, labels=train_label_batch))
    # train_accuracy = [ tf.equal(tf.cast(train_op[i], tf.float32), tf.cast(train_label_batch[i], tf.float32)) for i in range(TRAIN_BATCH_SIZE)]
    # train_accuracy = tf.reduce_mean(tf.cast(train_accuracy, tf.float32))
    #
    # #validation_cross_entropy_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=validation_label_batch,logits=validation_op))
    # val_accuracy = [ tf.equal(tf.cast(train_op[i], tf.float32), tf.cast(validation_label_batch[i], tf.float32)) for i in range(VALIDATION_BATCH_SIZE)]
    # val_accuracy = tf.reduce_mean(tf.cast(val_accuracy, tf.float32))
    #
    #
    # optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE).minimize(train_accuracy)
    # train_correct_prediction = tf.equal(train_op, train_label_batch)
    # train_accuracy = tf.reduce_mean(tf.cast(train_correct_prediction, tf.float32))

    # validation_correct_prediction = tf.equal(validation_op, validation_label_batch)
    # validation_accuracy = tf.reduce_mean(tf.cast(validation_correct_prediction, tf.float32))

    #tf.summary.scalar('train_loss', train_cross_entropy_op)
    #tf.summary.scalar('train_accuracy', train_accuracy)
    #tf.summary.scalar('val_loss', validation_correct_prediction)
    #tf.summary.scalar('val_accuracy', validation_accuracy)

    merge_summary = tf.summary.merge_all()

    with tf.Session() as sess:
        train_writer = tf.summary.FileWriter("./train/summary")
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            for step in range(MAX_STEP):
                if coord.should_stop():
                    break
                #loss = sess.run(train_cross_entropy_op)
                _, train_loss, train_accuracy = sess.run([train_op, train_losses_op, train_accuracy_op])
                if step % 100 ==0:
                    #print('Step %d, training accuracy %.2f, loss %.2f'%(step, acc*100.0, loss))
                    print('Step %d, train loss = %.2f, train accuracy = %.2f' % (step, train_loss, train_accuracy * 100.0))
                    summery_str = sess.run(merge_summary)
                    train_writer.add_summary(summery_str,step)

                if step % 2000 == 0 or (step + 1) == MAX_STEP:
                    checkpoint_path = os.path.join(logs_train_dir, 'EMOTION_CNN.ckpt')
                    saver.save(sess, checkpoint_path, global_step=step)

                if step % 500 == 0 or (step + 1) == MAX_STEP:
                    #val_loss, val_accuracy = sess.run([validation_cross_entropy_op, validation_accuracy])
                    val_loss, val_accuracy = sess.run([validation_losses_op, validation_accuracy_op])
                    print('** step %d, val loss = %.2f, val accuracy = %.2f' % (step, val_loss, val_accuracy * 100.0))
                    summery_str = sess.run(merge_summary)
                    train_writer.add_summary(summery_str,step)

        except tf.errors.OutOfRangeError:
            print("Done training -- epoch limit reached")

        finally:
            coord.request_stop()
Exemplo n.º 7
0
def train(fps, args):
  with tf.name_scope('loader'):
    x = loader.get_batch(fps, args.train_batch_size, _WINDOW_LEN, args.data_first_window)

  # Make z vector
  if args.use_sequence:
    z = tf.random_uniform([args.train_batch_size, 16, args.d_z], -1., 1., dtype=tf.float32)
  else:
    z = tf.random_uniform([args.train_batch_size, args.d_z], -1., 1., dtype=tf.float32)#tf.random_normal([args.train_batch_size, _D_Z])

  # Make generator
  with tf.variable_scope('G'):
    gru_layer = tf.keras.layers.CuDNNGRU(args.d_z, return_sequences=True)
    G_z, gru = WaveGANGenerator(z, gru_layer=gru_layer, train=True, return_gru=True, reuse=False, 
                                use_sequence=args.use_sequence, **args.wavegan_g_kwargs)
    print('G_z.shape:',G_z.get_shape().as_list())
    if args.wavegan_genr_pp:
      with tf.variable_scope('pp_filt'):
        G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same')
  G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')
  G_var_names = [g_var.name for g_var in G_vars]

  # Print G summary
  print('-' * 80)
  print('Generator vars')
  nparams = 0
  for v in G_vars:
    v_shape = v.get_shape().as_list()
    v_n = reduce(lambda x, y: x * y, v_shape)
    nparams += v_n
    print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
  print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024)))


  extra_secs = 1
  if not args.use_sequence:
    z_feed_long = z
  else:
    added_noise = tf.random_uniform([args.train_batch_size, 16*extra_secs, args.d_z], -1., 1., dtype=tf.float32)
    z_feed_long = tf.concat([z, added_noise], axis=1)

  with tf.variable_scope('G', reuse=True):
    #gru_layer.reset_states()
    G_z_long, gru_long = WaveGANGenerator(z_feed_long, gru_layer=gru_layer, train=False, length=16*extra_secs, 
                                          return_gru=True, 
                                          reuse=True, use_sequence=args.use_sequence, **args.wavegan_g_kwargs)
    print('G_z_long.shape:',G_z_long.get_shape().as_list())
    if args.wavegan_genr_pp:
      with tf.variable_scope('pp_filt', reuse=True):
        G_z_long = tf.layers.conv1d(G_z_long, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same')
        
    

  # Summarize
  tf.summary.audio('x', x, _FS)
  tf.summary.audio('G_z', G_z, _FS)
  tf.summary.audio('G_z_long', G_z_long, _FS)
  G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1))
  x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1))
  tf.summary.histogram('x_rms_batch', x_rms)
  tf.summary.histogram('G_z_rms_batch', G_z_rms)
  tf.summary.scalar('x_rms', tf.reduce_mean(x_rms))
  tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms))

  # Make real discriminator
  with tf.name_scope('D_x'), tf.variable_scope('D'):
    D_x = WaveGANDiscriminator(x, **args.wavegan_d_kwargs)
  D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')
  print('D_vars:', D_vars)

  # Print D summary
  print('-' * 80)
  print('Discriminator vars')
  nparams = 0
  for v in D_vars:
    v_shape = v.get_shape().as_list()
    v_n = reduce(lambda x, y: x * y, v_shape)
    nparams += v_n
    print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
  print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024)))
  print('-' * 80)

  # Make fake discriminator
  with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
    D_G_z = WaveGANDiscriminator(G_z, **args.wavegan_d_kwargs)

  # Create loss
  D_clip_weights = None
  if args.wavegan_loss == 'dcgan':
    fake = tf.zeros([args.train_batch_size], dtype=tf.float32)
    real = tf.ones([args.train_batch_size], dtype=tf.float32)

    G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
      logits=D_G_z,
      labels=real
    ))

    D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
      logits=D_G_z,
      labels=fake
    ))
    D_loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
      logits=D_x,
      labels=real
    ))

    D_loss /= 2.
  elif args.wavegan_loss == 'lsgan':
    G_loss = tf.reduce_mean((D_G_z - 1.) ** 2)
    D_loss = tf.reduce_mean((D_x - 1.) ** 2)
    D_loss += tf.reduce_mean(D_G_z ** 2)
    D_loss /= 2.
  elif args.wavegan_loss == 'wgan':
    G_loss = -tf.reduce_mean(D_G_z)
    D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

    with tf.name_scope('D_clip_weights'):
      clip_ops = []
      for var in D_vars:
        clip_bounds = [-.01, .01]
        clip_ops.append(
          tf.assign(
            var,
            tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])
          )
        )
      D_clip_weights = tf.group(*clip_ops)
  elif args.wavegan_loss == 'wgan-gp':
    G_loss = -tf.reduce_mean(D_G_z)# - D_x)#-tf.reduce_mean(D_G_z) + tf.reduce_mean(D_x)
    D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)# - tf.reduce_mean()

    alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1], minval=0., maxval=1.)
    differences = G_z - x
    interpolates = x + (alpha * differences)
    with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True): #
      #stft = tf.log1p(tf.abs(tf.contrib.signal.stft(interpolates[:,:,0], 512,128,fft_length=512)[:,:,:,tf.newaxis]))
    
      #D_interp = WaveGANDiscriminator(interpolates, x_cqt=stft, **args.wavegan_d_kwargs)
      #D_interp = tf.reduce_sum(tf.log1p(tf.abs(tf.contrib.signal.stft(interpolates[:,:,0], 2048,512,fft_length=2048)[:,:,:,tf.newaxis])))
      D_interp = WaveGANDiscriminator(interpolates, **args.wavegan_d_kwargs)
      

    LAMBDA = 10
    gradients = tf.gradients(D_interp, [interpolates])[0]
    print('gradients:', gradients)
    slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
    gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2.)
    D_loss += LAMBDA * gradient_penalty
  else:
    raise NotImplementedError()

  tf.summary.scalar('G_loss', G_loss)
  tf.summary.scalar('D_loss', D_loss)

  # Create (recommended) optimizer
  if args.wavegan_loss == 'dcgan':
    G_opt = tf.train.AdamOptimizer(
        learning_rate=2e-4,
        beta1=0.5)
    D_opt = tf.train.AdamOptimizer(
        learning_rate=2e-4,
        beta1=0.5)
  elif args.wavegan_loss == 'lsgan':
    G_opt = tf.train.RMSPropOptimizer(
        learning_rate=1e-4)
    D_opt = tf.train.RMSPropOptimizer(
        learning_rate=1e-4)
  elif args.wavegan_loss == 'wgan':
    G_opt = tf.train.RMSPropOptimizer(
        learning_rate=5e-5)
    D_opt = tf.train.RMSPropOptimizer(
        learning_rate=5e-5)
  elif args.wavegan_loss == 'wgan-gp':
    my_learning_rate = tf.train.exponential_decay(1e-4, 
                                                  tf.get_collection(tf.GraphKeys.GLOBAL_STEP), 
                                                  decay_steps=100000,
                                                  decay_rate=0.5)

    G_opt = tf.train.AdamOptimizer(
        learning_rate=my_learning_rate,
        beta1=0.5,
        beta2=0.9)
    D_opt = tf.train.AdamOptimizer(
        learning_rate=my_learning_rate,
        beta1=0.5,
        beta2=0.9)
  else:
    raise NotImplementedError()

  # Create training ops
  G_train_op = G_opt.minimize(G_loss, var_list=G_vars,
      global_step=tf.train.get_or_create_global_step())
  D_train_op = D_opt.minimize(D_loss, var_list=D_vars)

  saver = tf.train.Saver(max_to_keep=10)
    
  #tf_max, tf_min = tf.reduce_max(x[:,:,0], axis=-1), tf.reduce_min(x[:,:,0], axis=-1)
  
  global_step = tf.get_collection(tf.GraphKeys.GLOBAL_STEP)
    
  # Run training
  with tf.train.MonitoredTrainingSession(
      scaffold=tf.train.Scaffold(saver=saver),
      checkpoint_dir=args.train_dir,
      save_checkpoint_secs=args.train_save_secs,
      save_summaries_secs=args.train_summary_secs) as sess:
    #saver.restore(sess, tf.train.latest_checkpoint(args.train_dir))
    iterator_count = 0
    while True:
      # Train discriminator
      for i in xrange(args.wavegan_disc_nupdates):
        sess.run(D_train_op)

        # Enforce Lipschitz constraint for WGAN
        if D_clip_weights is not None:
          sess.run(D_clip_weights)

      # Train generator
      #_, g_losses, d_losses, gru_, gru_long_ = sess.run([G_train_op, G_loss, D_loss, gru, gru_long])
      _, g_losses, d_losses, global_step_ = sess.run([G_train_op, G_loss, D_loss, global_step])
      print('i:', global_step_[0], 'G_loss:', g_losses, 'D_loss:', d_losses)
      if iterator_count == 0:
        G_var_dict = {}
        G_vars_np = sess.run(G_vars)
        for g_var_name, g_var in zip(G_var_names, G_vars_np):
            G_var_dict[g_var_name] = g_var
        with open('saved_G_vars_iteration-{}.pkl'.format(global_step_[0]), 'wb') as f:
            pickle.dump(G_var_dict, f)
      #print('maxs:', maxs)
      #print('mins:', mins)
      #print(gru_[0])
      #print(gru_long_[0])
      iterator_count += 1
Exemplo n.º 8
0
def test_model(params,
               training_fps,
               test_fps,
               args,
               processing_specgan=processing_specgan,
               MAX_EPOCHS=MAX_EPOCHS,
               fine_tuning=False,
               predictions_pickle=None):
    batch_size = params['batch_size']

    if not hasattr(args, 'checkpoint_iter'):
        setattr(args, 'checkpoint_iter', None)
    if not hasattr(args, 'load_model_dir'):
        setattr(args, 'load_model_dir', None)
    if not hasattr(args, 'checkpoints_dir'):
        setattr(args, 'checkpoints_dir', None)
    if not hasattr(args, 'load_generator_dir'):
        setattr(args, 'load_generator_dir', None)

    logging.info("Testing configuration %s", params)

    with tf.Graph().as_default() as g:

        with tf.name_scope('loader'):
            training_dataset = loader.get_batch(training_fps,
                                                batch_size,
                                                _WINDOW_LEN,
                                                labels=True,
                                                repeat=False,
                                                return_dataset=True)
            test_dataset = loader.get_batch(test_fps,
                                            batch_size,
                                            _WINDOW_LEN,
                                            labels=True,
                                            repeat=False,
                                            return_dataset=True)

            train_dataset_size = find_data_size(training_fps,
                                                exclude_class=None)
            logging.info("Training datasize = " + str(train_dataset_size))

            test_dataset_size = find_data_size(test_fps, exclude_class=None)
            logging.info("Test datasize = " + str(test_dataset_size))

            iterator = tf.data.Iterator.from_structure(
                training_dataset.output_types, training_dataset.output_shapes)
            training_init_op = iterator.make_initializer(training_dataset)
            test_init_op = iterator.make_initializer(test_dataset)
            x_wav, labels = iterator.get_next()

            x = t_to_f(x_wav, args.data_moments_mean,
                       args.data_moments_std) if processing_specgan else x_wav

        with tf.name_scope('D_x'):
            cnn_output_logits = get_cnn_model(
                params, x, processing_specgan=processing_specgan)

        D_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='D') + \
                 tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='decision_layers')

        # Print D summary
        logging.info('-' * 80)
        logging.info('Discriminator vars')
        nparams = 0
        for v in D_vars:
            v_shape = v.get_shape().as_list()
            v_n = reduce(lambda x, y: x * y, v_shape)
            nparams += v_n
            logging.info('{} ({}): {}'.format(v.get_shape().as_list(), v_n,
                                              v.name))
        logging.info('Total params: {} ({:.2f} MB)'.format(
            nparams, (float(nparams) * 4) / (1024 * 1024)))
        logging.info('-' * 80)

        # Define loss and optimizers
        cnn_loss = tf.nn.softmax_cross_entropy_with_logits(
            logits=cnn_output_logits, labels=tf.one_hot(labels, 10))
        cnn_trainer, d_decision_trainer = get_optimizer(params, cnn_loss)

        predictions = tf.argmax(cnn_output_logits, axis=1)

        # Define accuracy for validation
        acc_op, acc_update_op, acc_reset_op = resettable_metric(
            tf.metrics.accuracy,
            'foo',
            labels=labels,
            predictions=tf.argmax(cnn_output_logits, axis=1))

        saver = tf.train.Saver(var_list=tf.get_collection(
            tf.GraphKeys.GLOBAL_VARIABLES, scope='D') + tf.get_collection(
                tf.GraphKeys.GLOBAL_VARIABLES, scope='decision_layers'))

        load_model_saver = tf.train.Saver(var_list=tf.get_collection(
            tf.GraphKeys.GLOBAL_VARIABLES, scope='D'))
        if args.checkpoint_iter is not None:
            latest_model_ckpt_fp = tf.train.get_checkpoint_state(
                checkpoint_dir=args.train_dir).all_model_checkpoint_paths[
                    args.checkpoint_iter]
        else:
            latest_model_ckpt_fp = tf.train.latest_checkpoint(
                args.load_model_dir
            ) if args.load_model_dir is not None else None

        # saver = tf.train.Saver()
        global_step_op = tf.train.get_or_create_global_step()
        latest_ckpt_fp = tf.train.latest_checkpoint(
            args.checkpoints_dir) if args.checkpoints_dir is not None else None

        def get_accuracy():
            sess.run([acc_reset_op, test_init_op])
            try:
                while True:
                    sess.run(acc_update_op)
            except tf.errors.OutOfRangeError:
                current_accuracy = sess.run(acc_op)
            return current_accuracy

        def get_mean_delta_accuracy(accuracy):
            accuracies_feature_extraction.append(accuracy)
            # Early stopping?
            if len(accuracies_feature_extraction) >= 4:
                mean_delta_accuracy = (
                    accuracies_feature_extraction[-1] -
                    accuracies_feature_extraction[-4]) * 1.0 / 3
                return mean_delta_accuracy
            return 99999

        def run_trainer(trainer, message_prefix):
            for current_epoch in range(MAX_EPOCHS):
                sess.run(training_init_op)
                try:
                    while True:
                        sess.run(trainer)
                except tf.errors.OutOfRangeError:
                    if current_epoch < MAX_EPOCHS / 2.0:  # early stopping, but when?
                        logging.info("%s finished epoch %d" %
                                     (message_prefix, current_epoch))
                        continue
                    accuracy = get_accuracy()
                    logging.info("%s epoch %d accuracy = %f" %
                                 (message_prefix, current_epoch, accuracy))
                    mean_delta_accuracy = get_mean_delta_accuracy(accuracy)

                    if mean_delta_accuracy < mean_delta_accuracy_threshold:
                        logging.info("Early stopping, mean_delta_accuracy = " +
                                     str(mean_delta_accuracy))
                        break

        logging.info("Creating session ... ")
        with tf.Session(graph=g) as sess:

            if args.tensorboard_dir is not None:
                writer = tf.summary.FileWriter(args.tensorboard_dir,
                                               sess.graph)

            sess.run(tf.initialize_all_variables())

            if latest_ckpt_fp is not None:
                saver.restore(sess, latest_ckpt_fp)
            elif latest_model_ckpt_fp is not None:
                load_model_saver.restore(sess, latest_model_ckpt_fp)

            # Feature extraction
            accuracies_feature_extraction = []
            run_trainer(d_decision_trainer, "Feature extraction")

            # Fine tuning
            if fine_tuning is True:
                accuracies_feature_extraction = []
                run_trainer(cnn_trainer, "Fine tuning")

            # Save model
            if args.checkpoints_dir is not None:
                save_path = saver.save(sess,
                                       os.path.join(args.checkpoints_dir,
                                                    "model.ckpt"),
                                       global_step=sess.run(global_step_op))
                print("Model saved in path: %s" % save_path)

            fine_tuning_accuracy = get_accuracy()
            logging.info("Fine tuning accuracy = " + str(fine_tuning_accuracy))

            sess.run(test_init_op)

            if predictions_pickle is not None:
                logging.info("Testing 2")
                numpy_predictions, numpy_labels = np.array([]), np.array([])
                try:
                    while True:
                        tmp_pred, tmp_labels = sess.run([predictions, labels])

                        numpy_predictions = np.append(numpy_predictions,
                                                      tmp_pred)
                        numpy_labels = np.append(numpy_labels, tmp_labels)
                except tf.errors.OutOfRangeError:

                    logging.info("Pickling predictions to " +
                                 predictions_pickle)

                    with open(predictions_pickle, 'wb') as f:
                        pickle.dump((numpy_labels, numpy_predictions), f)

        return fine_tuning_accuracy
Exemplo n.º 9
0
def infer(vid, args):
    print('inferring computational graph...')
    infer_dir = os.path.join(args.train_dir, 'infer')
    if not os.path.isdir(infer_dir):
        os.makedirs(infer_dir)

    # Subgraph that generates latent vectors
    samp_z_n = tf.placeholder(tf.int32, [], name='samp_z_n')
    samp_z = tf.random_uniform([samp_z_n, _D_Z],
                               -1.0,
                               1.0,
                               dtype=tf.float32,
                               name='samp_z')

    # Input z0
    z = tf.placeholder(tf.float32, [None, _D_Z], name='z')
    flat_pad = tf.placeholder(tf.int32, [], name='flat_pad')

    # Input y0
    samp_y_n = tf.placeholder(tf.int32, [], name='samp_y_n')
    samp_y = loader.get_batch(vid, args.train_batch_size, _PRIOR_SIZE,
                              args.data_first_window)
    samp_y = tf.identity(samp_y, name='samp_y')
    y = tf.placeholder(tf.float32, [None, _PRIOR_SIZE, 1], name='y')

    # Execute generator
    with tf.variable_scope('G'):
        G_z = WaveGANGenerator(z, y, train=False, **args.wavegan_g_kwargs)
        if args.wavegan_genr_pp:
            with tf.variable_scope('pp_filt'):
                G_z = tf.layers.conv1d(G_z,
                                       1,
                                       args.wavegan_genr_pp_len,
                                       use_bias=False,
                                       padding='same')
    G_z = tf.identity(G_z, name='G_z')

    # Flatten batch
    nch = int(G_z.get_shape()[-1])
    G_z_padded = tf.pad(G_z, [[0, 0], [0, flat_pad], [0, 0]])
    G_z_flat = tf.reshape(G_z_padded, [-1, nch], name='G_z_flat')

    # Encode to int16
    def float_to_int16(x, name=None):
        x_int16 = x * 32767.
        x_int16 = tf.clip_by_value(x_int16, -32767., 32767.)
        x_int16 = tf.cast(x_int16, tf.int16, name=name)
        return x_int16

    G_z_int16 = float_to_int16(G_z, name='G_z_int16')
    G_z_flat_int16 = float_to_int16(G_z_flat, name='G_z_flat_int16')

    # Create saver
    G_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='G')
    global_step = tf.train.get_or_create_global_step()
    saver = tf.train.Saver(G_vars + [global_step])

    # Export graph
    tf.train.write_graph(tf.get_default_graph(), infer_dir, 'infer.pbtxt')

    # Export MetaGraph
    infer_metagraph_fp = os.path.join(infer_dir, 'infer.meta')
    tf.train.export_meta_graph(filename=infer_metagraph_fp,
                               clear_devices=True,
                               saver_def=saver.as_saver_def())

    # Reset graph (in case training afterwards)
    tf.reset_default_graph()
Exemplo n.º 10
0
def training_procedure(sess,
                       x,
                       gt,
                       raw_fg,
                       train_file_list,
                       test_file_list,
                       pred,
                       train_writer,
                       test_writer,
                       ex_writer,
                       saver,
                       t_str,
                       starting_point=0):
    improvement = 'NaN'
    in_cmp, in_bg = tf.split(value=x, num_or_size_splits=[3, 3], axis=3)
    with tf.variable_scope('loss'):
        alpha_loss = regular_l1(pred, gt, name='alpha_loss')
        pred_cmp = composite(raw_fg, in_bg, pred)
        cmp_loss = regular_l1(pred_cmp, in_cmp, name='compositional_loss')
        s_loss = tf.add(0.5 * alpha_loss, 0.5 * cmp_loss)
        loss = tf.reduce_mean(s_loss, name='loss')
    with tf.variable_scope('resume_training'):
        lr = 1e-5
        print('Training with learning rate of {}'.format(lr))
        optimizer = tf.train.AdamOptimizer(learning_rate=lr,
                                           beta1=0.9,
                                           beta2=0.999,
                                           epsilon=1e-08)
        train_op = optimizer.minimize(loss=loss,
                                      global_step=tf.train.get_global_step())
    with tf.variable_scope('summary'):
        summary_loss = tf.summary.scalar('loss', loss)
        summary_alpha_loss = tf.summary.scalar('alpha_loss',
                                               tf.reduce_mean(alpha_loss))
        summary_cmp_loss = tf.summary.scalar('compositional_loss',
                                             tf.reduce_mean(cmp_loss))
        summary_cmp = tf.summary.image('composite', bgr2rgb(in_cmp))
        summary_gt = tf.summary.image('ground_truth', gt)
        summary_pred = tf.summary.image('prediction', pred)
    train_summaries = [summary_loss, summary_alpha_loss, summary_cmp_loss]
    test_summaries = [summary_loss, summary_alpha_loss, summary_cmp_loss]
    ex_summaries = [summary_cmp, summary_gt, summary_pred]
    train_merged = tf.summary.merge(train_summaries)
    test_merged = tf.summary.merge(test_summaries)
    ex_merged = tf.summary.merge(ex_summaries)
    prev_val_loss = -1.
    iteration = starting_point
    sess.run(tf.global_variables_initializer())

    for epoch in range(params.N_EPOCHS):
        training_list = train_file_list.copy()
        test_list = test_file_list.copy()
        random.shuffle(training_list)
        random.shuffle(test_list)
        # training
        while not loader.epoch_is_over(training_list, params.BATCH_SIZE):
            print('Training model, epoch {}/{}, iteration {}.'.format(
                epoch + 1, params.N_EPOCHS, iteration + 1))
            batch_list = loader.get_batch_list(training_list,
                                               params.BATCH_SIZE)
            inp, lab, rfg = loader.get_batch(batch_list,
                                             params.INPUT_SIZE,
                                             rd_scale=False,
                                             rd_mirror=True)
            feed_dict = {x: inp, gt: lab, raw_fg: rfg}
            summary, _ = sess.run([train_merged, train_op],
                                  feed_dict=feed_dict)
            train_writer.add_summary(summary, iteration)
            iteration += 1
        # validation
        print('Training completed. Computing validation loss...')
        val_loss = 0.
        n_batch = 0
        while not loader.epoch_is_over(test_list, params.BATCH_SIZE):
            batch_list = loader.get_batch_list(test_list, params.BATCH_SIZE)
            inp, lab, rfg = loader.get_batch(batch_list,
                                             params.INPUT_SIZE,
                                             rd_scale=False,
                                             rd_mirror=True)
            feed_dict = {x: inp, gt: lab, raw_fg: rfg}
            ls = sess.run([loss], feed_dict=feed_dict)
            # test_writer.add_summary(summary, iteration)
            val_loss += np.mean(ls)
            n_batch += 1
        val_loss /= n_batch
        if prev_val_loss != -1.:
            improvement = '{:2f}%'.format(
                (prev_val_loss - val_loss) / prev_val_loss)
        print('Validation loss: {:.3f}. Improvement: {}'.format(
            val_loss, improvement))
        print('Saving examples')
        # loads and visualize example prediction of current model
        n_ex = 5
        ex_list = [
            test_file_list[np.random.randint(0, len(test_file_list))]
            for _ in range(n_ex)
        ]
        ex_inp, ex_lab, _ = loader.get_batch(ex_list,
                                             params.INPUT_SIZE,
                                             rd_scale=False,
                                             rd_mirror=True)
        feed_dict = {x: ex_inp, gt: ex_lab}
        summary = sess.run([ex_merged], feed_dict)[0]
        ex_writer.add_summary(summary, iteration)
        print('Saving chekpoint...')
        saver.save(sess,
                   os.path.join(params.LOG_DIR, 'weights_{}'.format(t_str),
                                'model'),
                   global_step=iteration)
Exemplo n.º 11
0
def train(fps, args, cond=False):
	
	with tf.name_scope('loader'):
		if cond:
			#---data flow of a batch of (x, y) training data--@
			x_wav, y_label = loader.get_batch(fps, args.train_batch_size, args._WINDOW_LEN, args.data_first_window, labels=True)
			
			#---parse label from tensor to categorical---#
			y, w = loader.label_to_tensor(label=y_label, args=args, fps=fps)  # (right, wrong) labels

		else:
			x_wav = loader.get_batch(fps, args.train_batch_size, args._WINDOW_LEN, args.data_first_window)
		x = helper.t_to_f(x_wav, args.data_moments_mean, args.data_moments_std)
	
	#---word embedding---#
	if cond:
		with tf.variable_scope('word_embedding'):
			y_emb = word_embedding(word=y, vocab_size=args._VOCAB_SIZE, embedding_dim=args.SpecGAN_word_embedding_dim, train=False)
		with tf.variable_scope('word_embedding', reuse=True):
			w_emb = word_embedding(word=w, vocab_size=args._VOCAB_SIZE, embedding_dim=args.SpecGAN_word_embedding_dim, train=False)		
	
	# Make z vector
	if args.SpecGAN_prior_noise == 'uniform':
		z = tf.random_uniform([args.train_batch_size, args._D_Z], minval=-1., maxval=1., dtype=tf.float32)
	elif args.SpecGAN_prior_noise == 'normal':
		z = tf.random_normal([args.train_batch_size, args._D_Z], mean=0., stddev=1., dtype=tf.float32)
	else:
		raise NotImplementedError()


	# Make generator
	with tf.variable_scope('G'):
		if cond: 
			G_z = Spec_GAN_Generator(z, word=y_emb, train=True, cond=True, **args.SpecGAN_g_kwargs)
		else:
			G_z = Spec_GAN_Generator(z, train=True, **args.SpecGAN_g_kwargs)
	G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')


	# Print G summary
	print('-' * 80)
	print('####### Generator vars #######')
	nparams = 0
	for v in G_vars:
		v_shape = v.get_shape().as_list()
		v_n = reduce(lambda x, y: x * y, v_shape)
		nparams += v_n
		print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
	print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024)))


	# Make real discriminator
	with tf.name_scope('D_x'), tf.variable_scope('D'):
		if cond:
			D_x = Spec_GAN_Discriminator(x, word=y_emb, cond=True, **args.SpecGAN_d_kwargs)
		else:
			D_x = Spec_GAN_Discriminator(x, **args.SpecGAN_d_kwargs)
	D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')


	# Print D summary
	print('-' * 80)
	print('####### Discriminator vars #######')
	nparams = 0
	for v in D_vars:
		v_shape = v.get_shape().as_list()
		v_n = reduce(lambda x, y: x * y, v_shape)
		nparams += v_n
		print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
	print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024)))
	print('-' * 80)


	# Make fake discriminator
	with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
		if cond:
			D_G_z = Spec_GAN_Discriminator(G_z, word=y_emb, cond=True, **args.SpecGAN_d_kwargs)
		else:
			D_G_z = Spec_GAN_Discriminator(G_z, **args.SpecGAN_d_kwargs)

	# Make mismatch discriminator
	with tf.name_scope('D_G_w'), tf.variable_scope('D', reuse=True):
		if cond:
			D_x_w = Spec_GAN_Discriminator(x, word=w_emb, cond=True, **args.SpecGAN_d_kwargs)


	# Create loss
	D_clip_weights = None
	if args.SpecGAN_loss == 'dcgan':
		if cond: raise NotImplementedError()
		fake = tf.zeros([args.train_batch_size], dtype=tf.float32)
		real = tf.ones([args.train_batch_size], dtype=tf.float32)

		G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
			logits=D_G_z,
			labels=real
		))

		D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
			logits=D_G_z,
			labels=fake
		))
		D_loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
			logits=D_x,
			labels=real
		))

		D_loss /= 2.
	elif args.SpecGAN_loss == 'lsgan':
		if cond: raise NotImplementedError()
		G_loss = tf.reduce_mean((D_G_z - 1.) ** 2)
		D_loss = tf.reduce_mean((D_x - 1.) ** 2)
		D_loss += tf.reduce_mean(D_G_z ** 2)
		D_loss /= 2.
	elif args.SpecGAN_loss == 'wgan':
		if cond: raise NotImplementedError()
		G_loss = -tf.reduce_mean(D_G_z)
		D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

		with tf.name_scope('D_clip_weights'):
			clip_ops = []
			for var in D_vars:
				clip_bounds = [-.01, .01]
				clip_ops.append(
					tf.assign(
						var,
						tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])
					)
				)
			D_clip_weights = tf.group(*clip_ops)
	elif args.SpecGAN_loss == 'wgan-gp':
		G_loss = - tf.reduce_mean(D_G_z) # - D fake
		if cond:
			D_loss = - (tf.reduce_mean(D_x) - tf.reduce_mean(D_G_z))
			D_loss += - (tf.reduce_mean(D_x) - tf.reduce_mean(D_x_w))
		else:
			D_loss = - (tf.reduce_mean(D_x) - tf.reduce_mean(D_G_z)) # min (D real - D fake) => D fake - D real

		alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1, 1], minval=0., maxval=1.)
		differences = G_z - x
		interpolates = x + (alpha * differences)
		with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
			D_interp = Spec_GAN_Discriminator(interpolates, **args.SpecGAN_d_kwargs)

		LAMBDA = 10
		gradients = tf.gradients(D_interp, [interpolates])[0]
		slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
		gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2.)
		D_loss += LAMBDA * gradient_penalty
	else:
		raise NotImplementedError()




	# Create (recommended) optimizer
	if args.SpecGAN_loss == 'dcgan':
		G_opt = tf.train.AdamOptimizer(
				learning_rate=2e-4,
				beta1=0.5)
		D_opt = tf.train.AdamOptimizer(
				learning_rate=2e-4,
				beta1=0.5)
	elif args.SpecGAN_loss == 'lsgan':
		G_opt = tf.train.RMSPropOptimizer(
				learning_rate=1e-4)
		D_opt = tf.train.RMSPropOptimizer(
				learning_rate=1e-4)
	elif args.SpecGAN_loss == 'wgan':
		G_opt = tf.train.RMSPropOptimizer(
				learning_rate=5e-5)
		D_opt = tf.train.RMSPropOptimizer(
				learning_rate=5e-5)
	elif args.SpecGAN_loss == 'wgan-gp':
		G_opt = tf.train.AdamOptimizer(
				learning_rate=1e-4,
				beta1=0.5,
				beta2=0.9)
		D_opt = tf.train.AdamOptimizer(
				learning_rate=1e-4,
				beta1=0.5,
				beta2=0.9)
	else:
		raise NotImplementedError()

	# Summarize
	x_gl = helper.f_to_t(x, args.data_moments_mean, args.data_moments_std, args.SpecGAN_ngl)
	G_z_gl = helper.f_to_t(G_z, args.data_moments_mean, args.data_moments_std, args.SpecGAN_ngl)
	tf.summary.audio('x_wav', x_wav, args._FS)
	tf.summary.audio('x', x_gl, args._FS)
	tf.summary.audio('G_z', G_z_gl, args._FS)
	G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z_gl[:, :, 0]), axis=1))
	x_rms = tf.sqrt(tf.reduce_mean(tf.square(x_gl[:, :, 0]), axis=1))
	tf.summary.histogram('x_rms_batch', x_rms)
	tf.summary.histogram('G_z_rms_batch', G_z_rms)
	tf.summary.scalar('x_rms', tf.reduce_mean(x_rms))
	tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms))
	tf.summary.image('x', helper.f_to_img(x))
	tf.summary.image('G_z', helper.f_to_img(G_z))
	try:
		W_distance = tf.reduce_mean(D_x) - 2*tf.reduce_mean(D_G_z)
		tf.summary.scalar('W_distance', W_distance)
	except: pass
	tf.summary.scalar('G_loss', G_loss)
	tf.summary.scalar('D_loss', D_loss)

	# Create training ops
	G_train_op = G_opt.minimize(G_loss, var_list=G_vars, global_step=tf.train.get_or_create_global_step())
	D_train_op = D_opt.minimize(D_loss, var_list=D_vars)

	# Global step defiend for StopAtStepHook
	global_step = tf.train.get_or_create_global_step()

	# Run training
	with tf.train.MonitoredTrainingSession(
					hooks=[tf.train.StopAtStepHook(last_step=args.train_max_step), # hook that stops training at max_step
						   tf.train.NanTensorHook(D_loss), # hook that monitors the loss, terminate if loss is NaN
						   helper.tf_train_LoggerHook(args, losses=[D_loss, G_loss, W_distance])], # user defiend log printing hook
					checkpoint_dir=args.train_dir,
					save_checkpoint_secs=args.train_save_secs,
					save_summaries_secs=args.train_summary_secs) as sess:
		while not sess.should_stop():
			# Train discriminator
			for i in range(args.SpecGAN_disc_nupdates):
				sess.run(D_train_op)

				# Enforce Lipschitz constraint for WGAN
				if D_clip_weights is not None:
					sess.run(D_clip_weights)

			# Train generator
			sess.run(G_train_op)
Exemplo n.º 12
0
def evaluate_model_process(params, seed, q):

    global args

    max_tries = 10
    for current_try in range(max_tries):  # try 10 times
        with tf.Graph().as_default() as g:

            batch_size = params['batch_size']

            # Define input training, validation and test data
            # TODO: think of a deterministic way to do the data split, GAN + Classifier train + valid. Valid out of Train must be seed-able
            logging.info("Preparing data ... ")
            # training_fps = glob.glob(os.path.join(args.data_dir, "train") + '*.tfrecord')# + glob.glob(os.path.join(args.data_dir, "valid") + '*.tfrecord')
            training_fps = args.training_fps
            training_fps, validation_fps = loader.split_files_test_val(
                training_fps, train_size=0.9, seed=seed)

            with tf.name_scope('loader'):
                training_dataset = loader.get_batch(training_fps,
                                                    batch_size,
                                                    _WINDOW_LEN,
                                                    labels=True,
                                                    repeat=False,
                                                    return_dataset=True)
                validation_dataset = loader.get_batch(validation_fps,
                                                      batch_size,
                                                      _WINDOW_LEN,
                                                      labels=True,
                                                      repeat=False,
                                                      return_dataset=True)

                train_dataset_size = find_data_size(training_fps,
                                                    exclude_class=None)
                logging.info("Training datasize = " + str(train_dataset_size))

                valid_dataset_size = find_data_size(validation_fps,
                                                    exclude_class=None)
                logging.info("Validation datasize = " +
                             str(valid_dataset_size))

                iterator = tf.data.Iterator.from_structure(
                    training_dataset.output_types,
                    training_dataset.output_shapes)
                training_init_op = iterator.make_initializer(training_dataset)
                validation_init_op = iterator.make_initializer(
                    validation_dataset)
                x_wav, labels = iterator.get_next()

                x = t_to_f(
                    x_wav, args.data_moments_mean,
                    args.data_moments_std) if processing_specgan else x_wav

            # Get the discriminator and put extra layers
            with tf.name_scope('D_x'):
                cnn_output_logits = get_cnn_model(
                    params, x, processing_specgan=processing_specgan)

            # Define loss and optimizers
            cnn_loss = tf.nn.softmax_cross_entropy_with_logits(
                logits=cnn_output_logits, labels=tf.one_hot(labels, 10))
            cnn_trainer, d_decision_trainer = get_optimizer(params, cnn_loss)

            # Define accuracy performance measure
            acc_op, acc_update_op, acc_reset_op = resettable_metric(
                tf.metrics.accuracy,
                'foo',
                labels=labels,
                predictions=tf.argmax(cnn_output_logits, axis=1))

            # Restore the variables of the discriminator if necessary
            saver = tf.train.Saver(var_list=tf.get_collection(
                tf.GraphKeys.GLOBAL_VARIABLES, scope='D'))
            logging.info("args.train_dir = %s" % args.train_dir)
            if args.train_dir is not None:
                latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir)
            # ckpt_fp = tf.train.get_checkpoint_state(checkpoint_dir=args.train_dir).all_model_checkpoint_paths[args.checkpoint_iter]

            tensorboard_session_name = datetime.datetime.now().strftime(
                "%Y%m%d-%H%M%S") + "_" + str(
                    global_train_data_percentage) + "_" + str(seed)

            def initialize_iterator(sess, skip=False):
                sess.run(training_init_op)

                # if skip:
                #     batches_to_skip = math.ceil(1.0 * train_dataset_size * skip_training_percentage / 100.0 / batch_size)
                #     logging.info("Skipping " + str(batches_to_skip) + " batches.")
                #     for _ in range(batches_to_skip):
                #         sess.run(x)  # equivalent to doing nothing with these training samples

            config = tf.ConfigProto()
            config.gpu_options.allow_growth = False
            config.gpu_options.per_process_gpu_memory_fraction = 0.01
            logging.info("Creating session ... ")
            try:
                # with tf.train.MonitoredTrainingSession(
                #         checkpoint_dir=None,
                #         log_step_count_steps=10,  # don't save checkpoints, not worth for parameter tuning
                #         save_checkpoint_secs=None) as sess:
                with tf.Session(config=config, graph=g) as sess:
                    sess.run(tf.global_variables_initializer())
                    # Don't forget to RESTORE!!!
                    # saver.restore(sess, os.path.join(args.train_dir, "model.ckpt"))
                    if args.train_dir is not None:
                        saver.restore(sess, latest_ckpt_fp)

                    # saver.restore(sess, ckpt_fp)

                    # Main training loop
                    status = STATUS_OK
                    logging.info("Entering main loop ...")

                    logdir = "./tensorboard_wavegan_cnn/" + str(
                        tensorboard_session_name) + "/"
                    # writer = tf.summary.FileWriter(logdir, sess.graph)

                    # nr_training_batches = math.ceil(train_dataset_size / batch_size *
                    #                                 train_data_percentage / 100.0 *
                    #                                 (100 - skip_training_percentage) / 100.0)
                    nr_training_batches = train_dataset_size

                    logging.info("Training batches: " +
                                 str(nr_training_batches))

                    # Step 1: Train decision layer only
                    if perform_feature_extraction:
                        accuracies_feature_exctraction = []
                        for current_epoch in range(MAX_EPOCHS):
                            # sess.run(training_iterator.initializer)
                            # logging.info("Aici 1")
                            initialize_iterator(sess, skip=True)
                            # logging.info("Aici 2")
                            current_step = -1  # this step is the step within an epoch, therefore different from the global step
                            try:
                                while True:
                                    # logging.info("Aici 3")
                                    sess.run(d_decision_trainer)
                                    # logging.info("Aici 4")

                                    current_step += 1

                                    # Stop training after x% of training data seen
                                    if current_step > nr_training_batches:
                                        break

                            except tf.errors.OutOfRangeError:
                                # End of training dataset
                                pass

                            logging.info("Stopped training at epoch step: " +
                                         str(current_step))
                            # Validation
                            sess.run([acc_reset_op, validation_init_op])
                            try:
                                while True:
                                    sess.run(acc_update_op)
                            except tf.errors.OutOfRangeError:
                                # End of dataset
                                current_accuracy = sess.run(acc_op)
                                logging.info("Feature extraction epoch" +
                                             str(current_epoch) +
                                             " accuracy = " +
                                             str(current_accuracy))
                                accuracies_feature_exctraction.append(
                                    current_accuracy)

                            # Early stopping?
                            if len(accuracies_feature_exctraction) >= 4:
                                mean_delta_accuracy = (
                                    accuracies_feature_exctraction[-1] -
                                    accuracies_feature_exctraction[-4]
                                ) * 1.0 / 3

                                if mean_delta_accuracy < mean_delta_accuracy_threshold:
                                    logging.info(
                                        "Early stopping, mean_delta_accuracy = "
                                        + str(mean_delta_accuracy))
                                    break

                            if current_epoch >= MAX_EPOCHS:
                                logging.info("Stopping after " +
                                             str(MAX_EPOCHS) + " epochs!")

                        logging.info(
                            "Result feature extraction: %s %s %s %s %s",
                            params, global_train_data_percentage, seed,
                            global_checkpoint_iter,
                            str(accuracies_feature_exctraction[-1]))

                    # Step 2: Continue training everything
                    if perform_fine_tuning:
                        accuracies_fine_tuning = []
                        for current_epoch in range(MAX_EPOCHS):
                            # sess.run(training_iterator.initializer)
                            initialize_iterator(sess, skip=True)
                            current_step = -1  # this step is the step within an epoch, therefore different from the global step
                            try:
                                while True:
                                    sess.run(cnn_trainer)

                                    current_step += 1

                                    # Stop training after x% of training data seen
                                    if current_step > nr_training_batches:
                                        break

                            except tf.errors.OutOfRangeError:
                                # End of training dataset
                                pass

                            logging.info("Stopped training at epoch step: " +
                                         str(current_step))
                            # Validation
                            sess.run([acc_reset_op, validation_init_op])
                            try:
                                while True:
                                    sess.run(acc_update_op)
                            except tf.errors.OutOfRangeError:
                                # End of dataset
                                current_accuracy = sess.run(acc_op)
                                logging.info("Fine tuning epoch" +
                                             str(current_epoch) +
                                             " accuracy = " +
                                             str(current_accuracy))
                                accuracies_fine_tuning.append(current_accuracy)

                            # Early stopping?
                            if len(accuracies_fine_tuning) >= 4:
                                mean_delta_accuracy = (
                                    accuracies_fine_tuning[-1] -
                                    accuracies_fine_tuning[-4]) * 1.0 / 3

                                if mean_delta_accuracy < mean_delta_accuracy_threshold:
                                    logging.info(
                                        "Early stopping, mean_delta_accuracy = "
                                        + str(mean_delta_accuracy))
                                    break

                            if current_epoch >= MAX_EPOCHS:
                                logging.info("Stopping after " +
                                             str(MAX_EPOCHS) + " epochs!")

                        logging.info("Result fine tuning: %s %s %s %s %s",
                                     params, global_train_data_percentage,
                                     seed, global_checkpoint_iter,
                                     str(accuracies_fine_tuning[-1]))

                    recorded_accuracy = accuracies_feature_exctraction[
                        -1] if record_hyperopt_feature_extraction is True else accuracies_fine_tuning[
                            -1]
                    q.put({
                        'loss':
                        -recorded_accuracy,  # last accuracy; also, return the negative for maximization problem (accuracy)
                        'status': status,
                        # -- store other results like this
                        # 'accuracy_history': accuracies_fine_tuning
                    })
            except tf.errors.ResourceExhaustedError:
                if current_try == max_tries - 1:
                    logging.info(
                        "Got Resources Exhausted Error - Returning FAIL: %s %s %s %s",
                        params, global_train_data_percentage, seed,
                        global_checkpoint_iter)
                    q.put({'loss': -999, 'status': STATUS_FAIL})
                else:
                    logging.info(
                        "Got Resources Exhausted Error - Retrying %d: %s %s %s %s",
                        current_try, params, global_train_data_percentage,
                        seed, global_checkpoint_iter)
            except Exception as e:
                if current_try == max_tries - 1:
                    logging.info(
                        "Got Following Exception - Returning FAIL: %s %s %s %s",
                        params, global_train_data_percentage, seed,
                        global_checkpoint_iter)
                    logging.info(e)
                    q.put({'loss': -998, 'status': STATUS_FAIL})
                else:
                    logging.info(
                        "Got Following Exception - Retrying %d: %s %s %s %s",
                        current_try, params, global_train_data_percentage,
                        seed, global_checkpoint_iter)
                    logging.info(e)
            else:
                # Everything good, success, so don't retry
                break
Exemplo n.º 13
0
def train(fps, args):
    with tf.name_scope('loader'):
        x, cond_text, _ = loader.get_batch(fps,
                                           args.train_batch_size,
                                           _WINDOW_LEN,
                                           args.data_first_window,
                                           conditionals=True,
                                           name='batch')
        wrong_audio = loader.get_batch(fps,
                                       args.train_batch_size,
                                       _WINDOW_LEN,
                                       args.data_first_window,
                                       conditionals=False,
                                       name='wrong_batch')
    # wrong_cond_text, wrong_cond_text_embed = loader.get_batch(fps, args.train_batch_size, _WINDOW_LEN, args.data_first_window, wavs=False, conditionals=True, name='batch')

    # Make z vector
    z = tf.random_normal([args.train_batch_size, _D_Z])

    embed = hub.Module('https://tfhub.dev/google/elmo/2',
                       trainable=False,
                       name='embed')
    cond_text_embed = embed(cond_text)

    # Add conditioning input to the model
    args.wavegan_g_kwargs['context_embedding'] = cond_text_embed
    args.wavegan_d_kwargs['context_embedding'] = args.wavegan_g_kwargs[
        'context_embedding']

    lod = tf.placeholder(tf.float32, shape=[])

    with tf.variable_scope('G'):
        # Make generator
        G_z, c_kl_loss = WaveGANGenerator(z,
                                          lod,
                                          train=True,
                                          **args.wavegan_g_kwargs)
        if args.wavegan_genr_pp:
            with tf.variable_scope('pp_filt'):
                G_z = tf.layers.conv1d(G_z,
                                       1,
                                       args.wavegan_genr_pp_len,
                                       use_bias=False,
                                       padding='same')

    # Summarize
    G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1))
    x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1))
    x_rms_lod_4 = tf.sqrt(
        tf.reduce_mean(tf.square(avg_downsample(x)[:, :, 0]), axis=1))
    x_rms_lod_3 = tf.sqrt(
        tf.reduce_mean(tf.square(avg_downsample(avg_downsample(x))[:, :, 0]),
                       axis=1))
    x_rms_lod_2 = tf.sqrt(
        tf.reduce_mean(tf.square(
            avg_downsample(avg_downsample(avg_downsample(x)))[:, :, 0]),
                       axis=1))
    x_rms_lod_1 = tf.sqrt(
        tf.reduce_mean(tf.square(
            avg_downsample(avg_downsample(avg_downsample(
                avg_downsample(x))))[:, :, 0]),
                       axis=1))
    x_rms_lod_0 = tf.sqrt(
        tf.reduce_mean(tf.square(
            avg_downsample(
                avg_downsample(
                    avg_downsample(avg_downsample(avg_downsample(x)))))[:, :,
                                                                        0]),
                       axis=1))
    tf.summary.histogram('x_rms_batch', x_rms)
    tf.summary.histogram('G_z_rms_batch', G_z_rms)
    tf.summary.scalar('x_rms', tf.reduce_mean(x_rms))
    tf.summary.scalar('x_rms_lod_4', tf.reduce_mean(x_rms_lod_4))
    tf.summary.scalar('x_rms_lod_3', tf.reduce_mean(x_rms_lod_3))
    tf.summary.scalar('x_rms_lod_2', tf.reduce_mean(x_rms_lod_2))
    tf.summary.scalar('x_rms_lod_1', tf.reduce_mean(x_rms_lod_1))
    tf.summary.scalar('x_rms_lod_0', tf.reduce_mean(x_rms_lod_0))
    tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms))
    tf.summary.audio('x', x, _FS, max_outputs=10)
    tf.summary.audio('G_z', G_z, _FS, max_outputs=10)
    tf.summary.text('Conditioning Text', cond_text[:10])

    # with tf.variable_scope('G'):
    #   # Make history buffer
    #   history_buffer = HistoryBuffer(_WINDOW_LEN, args.train_batch_size * 100, args.train_batch_size)

    #   # Select half of batch from history buffer
    #   g_from_history, r_from_history, embeds_from_history = history_buffer.get_from_history_buffer()
    #   new_fake_batch = tf.concat([G_z[:tf.shape(G_z)[0] - tf.shape(g_from_history)[0]], g_from_history], 0) # Use tf.shape to handle case when g_from_history is empty
    #   new_cond_embeds = tf.concat([cond_text_embed[:tf.shape(cond_text_embed)[0] - tf.shape(embeds_from_history)[0]], embeds_from_history], 0)
    #   new_real_batch = tf.concat([x[:tf.shape(x)[0] - tf.shape(r_from_history)[0]], r_from_history], 0)
    #   with tf.control_dependencies([new_fake_batch, new_real_batch, new_cond_embeds]):
    #     with tf.control_dependencies([history_buffer.add_to_history_buffer(G_z, x, cond_text_embed)]):
    #       G_z = tf.identity(new_fake_batch)
    #       x = tf.identity(new_real_batch)
    #       args.wavegan_g_kwargs['context_embedding'] = tf.identity(new_cond_embeds)
    #       args.wavegan_d_kwargs['context_embedding'] = args.wavegan_g_kwargs['context_embedding']
    #   G_z.set_shape([args.train_batch_size, _WINDOW_LEN, 1])
    #   x.set_shape([args.train_batch_size, _WINDOW_LEN, 1])

    G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')

    # Print G summary
    print('-' * 80)
    print('Generator vars')
    nparams = 0
    for v in G_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))

    # Summarize
    # tf.summary.scalar('history_buffer_size', history_buffer.current_size)
    # tf.summary.scalar('g_from_history_size', tf.shape(g_from_history)[0])
    # tf.summary.scalar('r_from_history_size', tf.shape(r_from_history)[0])
    # tf.summary.scalar('embeds_from_history_size', tf.shape(embeds_from_history)[0])
    # tf.summary.audio('G_z_history', g_from_history, _FS, max_outputs=10)
    # tf.summary.audio('x_history', r_from_history, _FS, max_outputs=10)
    tf.summary.audio('wrong_audio', wrong_audio, _FS, max_outputs=10)
    tf.summary.scalar('Conditional Resample - KL-Loss', c_kl_loss)
    # tf.summary.scalar('embed_error_cosine', tf.reduce_sum(tf.multiply(cond_text_embed, expected_embed)) / (tf.norm(cond_text_embed) * tf.norm(expected_embed)))
    # tf.summary.scalar('embed_error_cosine_wrong', tf.reduce_sum(tf.multiply(wrong_cond_text_embed, expected_embed)) / (tf.norm(wrong_cond_text_embed) * tf.norm(expected_embed)))

    # Make real discriminator
    with tf.name_scope('D_x'), tf.variable_scope('D'):
        D_x = WaveGANDiscriminator(x, lod, **args.wavegan_d_kwargs)
    D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')

    # Print D summary
    print('-' * 80)
    print('Discriminator vars')
    nparams = 0
    for v in D_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))
    print('-' * 80)

    # Make fake / wrong discriminator
    with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
        D_G_z = WaveGANDiscriminator(G_z, lod, **args.wavegan_d_kwargs)
    with tf.name_scope('D_w'), tf.variable_scope('D', reuse=True):
        D_w = WaveGANDiscriminator(wrong_audio, lod, **args.wavegan_d_kwargs)

    # Create loss
    D_clip_weights = None
    if args.wavegan_loss == 'dcgan':
        fake = tf.zeros([args.train_batch_size, 1], dtype=tf.float32)
        real = tf.ones([args.train_batch_size, 1], dtype=tf.float32)

        # Conditional G Loss
        G_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z[0],
                                                    labels=real))
        G_loss += c_kl_loss

        # Unconditional G Loss
        if args.use_extra_uncond_loss:
            G_loss += tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z[1],
                                                        labels=real))
            G_loss /= 2

        # Conditional D Losses
        D_loss_fake = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z[0],
                                                    labels=fake))
        D_loss_wrong = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_w[0],
                                                    labels=fake))
        D_loss_real = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x[0],
                                                    labels=real))

        # Unconditional D Losses
        if args.use_extra_uncond_loss:
            D_loss_fake_uncond = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z[1],
                                                        labels=fake))
            D_loss_wrong_uncond = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(logits=D_w[1],
                                                        labels=real))
            D_loss_real_uncond = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x[1],
                                                        labels=real))

            D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake) \
                   + 0.5 * (D_loss_real_uncond + D_loss_wrong_uncond) + D_loss_fake_uncond
            D_loss /= 2
        else:
            D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake)

        # Warmup Conditional Loss
        # D_warmup_loss = D_loss_real + D_loss_wrong
    elif args.wavegan_loss == 'lsgan':
        # Conditional G Loss
        G_loss = tf.reduce_mean((D_G_z[0] - 1.)**2)
        G_loss += c_kl_loss

        # Unconditional G Loss
        if args.use_extra_uncond_loss:
            G_loss += tf.reduce_mean((D_G_z[1] - 1.)**2)
            G_loss /= 2

        # Conditional D Loss
        D_loss_real = tf.reduce_mean((D_x[0] - 1.)**2)
        D_loss_wrong = tf.reduce_mean(D_w[0]**2)
        D_loss_fake = tf.reduce_mean(D_G_z[0]**2)

        # Unconditional D Loss
        if args.use_extra_uncond_loss:
            D_loss_real_uncond = tf.reduce_mean((D_x[1] - 1.)**2)
            D_loss_wrong_uncond = tf.reduce_mean((D_w[1] - 1.)**2)
            D_loss_fake_uncond = tf.reduce_mean(D_G_z[1]**2)

            D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake) \
                   + 0.5 * (D_loss_real_uncond + D_loss_wrong_uncond) + D_loss_fake_uncond
            D_loss /= 2
        else:
            D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake)

        # Warmup Conditional Loss
        # D_warmup_loss = D_loss_real + D_loss_wrong
    elif args.wavegan_loss == 'wgan':
        # Conditional G Loss
        G_loss = -tf.reduce_mean(D_G_z[0])
        G_loss += c_kl_loss

        # Unconditional G Loss
        if args.use_extra_uncond_loss:
            G_loss += -tf.reduce_mean(D_G_z[1])
            G_loss /= 2

        # Conditional D Loss
        D_loss_real = -tf.reduce_mean(D_x[0])
        D_loss_wrong = tf.reduce_mean(D_w[0])
        D_loss_fake = tf.reduce_mean(D_G_z[0])

        # Unconditional D Loss
        if args.use_extra_uncond_loss:
            D_loss_real_uncond = -tf.reduce_mean(D_x[1])
            D_loss_wrong_uncond = -tf.reduce_mean(D_w[1])
            D_loss_fake_uncond = tf.reduce_mean(D_G_z[1])

            D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake) \
                   + 0.5 * (D_loss_real_uncond + D_loss_wrong_uncond) + D_loss_fake_uncond
            D_loss /= 2
        else:
            D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake)

        # Warmup Conditional Loss
        # D_warmup_loss = D_loss_real + D_loss_wrong

        with tf.name_scope('D_clip_weights'):
            clip_ops = []
            for var in D_vars:
                clip_bounds = [-.01, .01]
                clip_ops.append(
                    tf.assign(
                        var,
                        tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])))
            D_clip_weights = tf.group(*clip_ops)
    elif args.wavegan_loss == 'wgan-gp':
        # Conditional G Loss
        G_loss = -tf.reduce_mean(D_G_z[0])
        G_loss += c_kl_loss

        # Unconditional G Loss
        if args.use_extra_uncond_loss:
            G_loss += -tf.reduce_mean(D_G_z[1])
            G_loss /= 2

        # Conditional D Loss
        D_loss_real = -tf.reduce_mean(D_x[0])
        D_loss_wrong = tf.reduce_mean(D_w[0])
        D_loss_fake = tf.reduce_mean(D_G_z[0])

        # Unconditional D Loss
        if args.use_extra_uncond_loss:
            D_loss_real_uncond = -tf.reduce_mean(D_x[1])
            D_loss_wrong_uncond = -tf.reduce_mean(D_w[1])
            D_loss_fake_uncond = tf.reduce_mean(D_G_z[1])

            D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake) \
                   + 0.5 * (D_loss_real_uncond + D_loss_wrong_uncond) + D_loss_fake_uncond
            D_loss /= 2
        else:
            D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake)

        # Warmup Conditional Loss
        # D_warmup_loss = D_loss_real + D_loss_wrong

        # Conditional Gradient Penalty
        alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1],
                                  minval=0.,
                                  maxval=1.)
        real = x
        fake = tf.concat([
            G_z[:args.train_batch_size // 2],
            wrong_audio[:args.train_batch_size // 2]
        ], 0)
        differences = fake - real
        interpolates = real + (alpha * differences)
        with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
            D_interp = WaveGANDiscriminator(
                interpolates, lod,
                **args.wavegan_d_kwargs)[0]  # Only want conditional output
        gradients = tf.gradients(D_interp, [interpolates])[0]
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
        cond_gradient_penalty = tf.reduce_mean((slopes - 1.)**2.)

        # Unconditional Gradient Penalty
        alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1],
                                  minval=0.,
                                  maxval=1.)
        real = tf.concat([
            x[:args.train_batch_size // 2],
            wrong_audio[:args.train_batch_size // 2]
        ], 0)
        fake = G_z
        differences = fake - real
        interpolates = real + (alpha * differences)
        with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
            D_interp = WaveGANDiscriminator(
                interpolates, lod,
                **args.wavegan_d_kwargs)[1]  # Only want unconditional output
        gradients = tf.gradients(D_interp, [interpolates])[0]
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
        uncond_gradient_penalty = tf.reduce_mean((slopes - 1.)**2.)

        # Warmup Gradient Penalty
        # alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1], minval=0., maxval=1.)
        # real = x
        # fake = wrong_audio
        # differences = fake - real
        # interpolates = real + (alpha * differences)
        # with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
        #   D_interp = WaveGANDiscriminator(interpolates, lod, **args.wavegan_d_kwargs)[0] # Only want conditional output
        # gradients = tf.gradients(D_interp, [interpolates])[0]
        # slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
        # warmup_gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2.)

        gradient_penalty = (cond_gradient_penalty +
                            uncond_gradient_penalty) / 2

        LAMBDA = 10
        D_loss += LAMBDA * gradient_penalty
        # D_warmup_loss += LAMBDA * warmup_gradient_penalty
    else:
        raise NotImplementedError()

    tf.summary.scalar('G_loss', G_loss)
    if (args.wavegan_loss == 'wgan-gp'):
        tf.summary.scalar('Gradient Penalty', LAMBDA * gradient_penalty)
    if (args.wavegan_loss == 'wgan' or args.wavegan_loss == 'wgan-gp'):
        if args.use_extra_uncond_loss:
            tf.summary.scalar('Critic Score - Real Data - Condition Match',
                              -D_loss_real)
            tf.summary.scalar('Critic Score - Fake Data - Condition Match',
                              D_loss_fake)
            tf.summary.scalar('Critic Score - Wrong Data - Condition Match',
                              D_loss_wrong)
            tf.summary.scalar('Critic Score - Real Data', -D_loss_real_uncond)
            tf.summary.scalar('Critic Score - Wrong Data',
                              -D_loss_wrong_uncond)
            tf.summary.scalar('Critic Score - Fake Data', D_loss_fake_uncond)
            tf.summary.scalar('Wasserstein Distance - No Regularization Term',
                              -((D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake) \
                               + 0.5 * (D_loss_real_uncond + D_loss_wrong_uncond) + D_loss_fake_uncond) / 2))
            tf.summary.scalar('Wasserstein Distance - Real-Wrong Only',
                              -(D_loss_real + D_loss_wrong))
            tf.summary.scalar('Wasserstein Distance - Real-Fake Only',
                              -((D_loss_real + D_loss_fake \
                               + D_loss_real_uncond + D_loss_fake_uncond) / 2))
        else:
            tf.summary.scalar('Critic Score - Real Data', -D_loss_real)
            tf.summary.scalar('Critic Score - Wrong Data', D_loss_wrong)
            tf.summary.scalar('Critic Score - Fake Data', D_loss_fake)
            tf.summary.scalar(
                'Wasserstein Distance - No Regularization Term',
                -(D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake)))
        tf.summary.scalar('Wasserstein Distance - With Regularization Term',
                          -D_loss)
    else:
        if args.use_extra_uncond_loss:
            tf.summary.scalar('D_acc_uncond', 0.5 * ((0.5 * (tf.reduce_mean(tf.sigmoid(D_x[1])) + tf.reduce_mean(tf.sigmoid(D_w[1])))) \
                                                   + tf.reduce_mean(1 - tf.sigmoid(D_G_z[1]))))
            tf.summary.scalar('D_acc', 0.5 * (tf.reduce_mean(tf.sigmoid(D_x[0])) \
                                            + 0.5 * (tf.reduce_mean(1 - tf.sigmoid(D_w[0])) + tf.reduce_mean(1 - tf.sigmoid(D_G_z[0])))))
            tf.summary.scalar('D_acc_real_wrong_only', 0.5 * (tf.reduce_mean(tf.sigmoid(D_x[0])) \
                                                            + tf.reduce_mean(1 - tf.sigmoid(D_w[0]))))
            tf.summary.scalar('D_loss_cond_real', D_loss_real)
            tf.summary.scalar('D_loss_uncond_real', D_loss_real_uncond)
            tf.summary.scalar('D_loss_cond_wrong', D_loss_wrong)
            tf.summary.scalar('D_loss_uncond_wrong', D_loss_wrong_uncond)
            tf.summary.scalar('D_loss_cond_fake', D_loss_fake)
            tf.summary.scalar('D_loss_uncond_fake', D_loss_fake_uncond)
            tf.summary.scalar('D_loss_unregularized',
                               (D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake) \
                              + 0.5 * (D_loss_real_uncond + D_loss_wrong_uncond) + D_loss_fake_uncond) / 2)
        else:
            tf.summary.scalar('D_acc', 0.5 * (tf.reduce_mean(tf.sigmoid(D_x[0])) \
                                            + 0.5 * (tf.reduce_mean(1 - tf.sigmoid(D_w[0])) + tf.reduce_mean(1 - tf.sigmoid(D_G_z[0])))))
            tf.summary.scalar('D_loss_real', D_loss_real)
            tf.summary.scalar('D_loss_wrong', D_loss_wrong)
            tf.summary.scalar('D_loss_fake', D_loss_fake)
            tf.summary.scalar('D_loss_unregularized',
                              D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake))
        tf.summary.scalar('D_loss', D_loss)

    # Create (recommended) optimizer
    if args.wavegan_loss == 'dcgan':
        G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
        D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
    elif args.wavegan_loss == 'lsgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
    elif args.wavegan_loss == 'wgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
    elif args.wavegan_loss == 'wgan-gp':
        G_opt = tf.train.AdamOptimizer(learning_rate=4e-4,
                                       beta1=0.0,
                                       beta2=0.9)
        D_opt = tf.train.AdamOptimizer(learning_rate=4e-4,
                                       beta1=0.0,
                                       beta2=0.9)
    else:
        raise NotImplementedError()

    # Optimizer internal state reset ops
    reset_G_opt_op = tf.variables_initializer(G_opt.variables())
    reset_D_opt_op = tf.variables_initializer(D_opt.variables())

    # Create training ops
    G_train_op = G_opt.minimize(
        G_loss,
        var_list=G_vars,
        global_step=tf.train.get_or_create_global_step())
    D_train_op = D_opt.minimize(D_loss, var_list=D_vars)

    def smoothstep(x, mi, mx):
        return mi + (mx - mi) * (lambda t: np.where(
            t < 0, 0, np.where(t <= 1, 3 * t**2 - 2 * t**3, 1)))(x)

    def np_lerp_clip(t, a, b):
        return a + (b - a) * np.clip(t, 0.0, 1.0)

    def get_lod_at_step(step):
        return np.piecewise(float(step), [
            step < 10000, 10000 <= step < 20000, 20000 <= step < 30000,
            30000 <= step < 40000, 40000 <= step < 50000,
            50000 <= step < 60000, 60000 <= step < 70000,
            70000 <= step < 80000, 80000 <= step < 90000,
            90000 <= step < 100000
        ], [
            0, lambda x: np_lerp_clip((x - 10000) / 10000, 0, 1), 1,
            lambda x: np_lerp_clip(
                (x - 30000) / 10000, 1, 2), 2, lambda x: np_lerp_clip(
                    (x - 50000) / 10000, 2, 3), 3, lambda x: np_lerp_clip(
                        (x - 70000) / 10000, 3, 4), 4, lambda x: np_lerp_clip(
                            (x - 90000) / 10000, 4, 5), 5
        ])

    def my_filter_callable(datum, tensor):
        if (not isinstance(tensor, debug_data.InconvertibleTensorProto)) and (
                tensor.dtype == np.float32 or tensor.dtype == np.float64):
            return np.any([
                np.any(np.greater_equal(tensor, 50.0)),
                np.any(np.less_equal(tensor, -50.0))
            ])
        else:
            return False

    # Create a LocalCLIDebugHook and use it as a monitor
    # debug_hook = tf_debug.LocalCLIDebugHook(dump_root='C:/d/t/')
    # debug_hook.add_tensor_filter('large_values', my_filter_callable)
    # hooks = [debug_hook]

    # Run training
    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=args.train_dir,
            save_checkpoint_secs=args.train_save_secs,
            save_summaries_secs=args.train_summary_secs) as sess:
        # Get the summary writer for writing extra summary statistics
        summary_writer = SummaryWriterCache.get(args.train_dir)

        cur_lod = 0
        while True:
            # Calculate Maximum LOD to train
            step = sess.run(tf.train.get_or_create_global_step(),
                            feed_dict={lod: cur_lod})
            cur_lod = get_lod_at_step(step)
            prev_lod = get_lod_at_step(step - 1)

            # Reset optimizer internal state when new layers are introduced
            if np.floor(cur_lod) != np.floor(prev_lod) or np.ceil(
                    cur_lod) != np.ceil(prev_lod):
                print(
                    "Resetting optimizers' internal states at step {}".format(
                        step))
                sess.run([reset_G_opt_op, reset_D_opt_op],
                         feed_dict={lod: cur_lod})

            # Output current LOD and 'steps at currrent LOD' to tensorboard
            step = float(
                sess.run(tf.train.get_or_create_global_step(),
                         feed_dict={lod: cur_lod}))
            lod_summary = tf.Summary(value=[
                tf.Summary.Value(tag="current_lod",
                                 simple_value=float(cur_lod)),
            ])
            summary_writer.add_summary(lod_summary, step)

            # Train discriminator
            for i in xrange(args.wavegan_disc_nupdates):
                sess.run(D_train_op, feed_dict={lod: cur_lod})

                # Enforce Lipschitz constraint for WGAN
                if D_clip_weights is not None:
                    sess.run(D_clip_weights, feed_dict={lod: cur_lod})

            # Train generator
            sess.run(G_train_op, feed_dict={lod: cur_lod})
Exemplo n.º 14
0
def train(fps, args):
  with tf.name_scope('loader'):
    right_x_wav = loader.get_batch(fps, args.train_batch_size, _WINDOW_LEN, args.data_first_window)
    right_x = t_to_f(right_x_wav, args.data_moments_mean, args.data_moments_std)
    wrong_x_wav = loader.get_batch(fps, args.train_batch_size, _WINDOW_LEN, args.data_first_window)
    wrong_x = t_to_f(wrong_x_wav, args.data_moments_mean, args.data_moments_std)

  # Make z vector
  z = tf.random_uniform([args.train_batch_size, _D_Z], -1., 1., dtype=tf.float32)
  # static_condition means pitch
  right_static_condition = tf.random_uniform([args.train_batch_size, _STATIC_PITCH_DIM], -1., 1., dtype=tf.float32)
  wrong_static_condition = tf.random_uniform([args.train_batch_size, _STATIC_PITCH_DIM], -1., 1., dtype=tf.float32)


  # Make generator
  with tf.variable_scope('G'):
    # encode the spectrum into a vector
    En_right_x = SpecGANEncoder(right_x)
    En_wrong_x = SpecGANEncoder(wrong_x)
    Condition_z = tf.concat([En_right_x, z, static_condition], 1)

    G_z, G_z_static = SpecGANGenerator(Condition_z, train=True, **args.specgan_g_kwargs)
  G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')

  # Print G summary
  print('-' * 80)
  print('Generator vars')
  nparams = 0
  for v in G_vars:
    v_shape = v.get_shape().as_list()
    v_n = reduce(lambda x, y: x * y, v_shape)
    nparams += v_n
    print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
  print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024)))

  # Summarize
  x_gl = f_to_t(x, args.data_moments_mean, args.data_moments_std, args.specgan_ngl)
  G_z_gl = f_to_t(G_z, args.data_moments_mean, args.data_moments_std, args.specgan_ngl)
  tf.summary.audio('x_wav', x_wav, _FS)
  tf.summary.audio('x', x_gl, _FS)
  tf.summary.audio('G_z', G_z_gl, _FS)
  G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z_gl[:, :, 0]), axis=1))
  x_rms = tf.sqrt(tf.reduce_mean(tf.square(x_gl[:, :, 0]), axis=1))
  tf.summary.histogram('x_rms_batch', x_rms)
  tf.summary.histogram('G_z_rms_batch', G_z_rms)
  tf.summary.scalar('x_rms', tf.reduce_mean(x_rms))
  tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms))
  tf.summary.image('x', f_to_img(x))
  tf.summary.image('G_z', f_to_img(G_z))

  # Real input to discriminator
  dynamic_x = tf.random_uniform([args.train_batch_size, 128, 128, 1], -1., 1., dtype=tf.float32)
  static_x = tf.random_uniform([args.train_batch_size, _STATIC_TRACT_DIM], -1., 1., dtype=tf.float32)

  # Make real-right discriminator
  with tf.name_scope('D_x'), tf.variable_scope('D'):
    real_logits = SpecGANDiscriminator(dynamic_x, static_x, En_right_x, right_static_condition, **args.specgan_d_kwargs)
  D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')

  # Print D summary
  print('-' * 80)
  print('Discriminator vars')
  nparams = 0
  for v in D_vars:
    v_shape = v.get_shape().as_list()
    v_n = reduce(lambda x, y: x * y, v_shape)
    nparams += v_n
    print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
  print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024)))
  print('-' * 80)

  # Make real-wrong discriminator
  with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
    wrong_logits = SpecGANDiscriminator(dynamic_x, static_x, En_wrong_x, wrong_static_condition, **args.specgan_d_kwargs)

  # Make fake-right discriminator
  with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
    fake_logits = SpecGANDiscriminator(G_z, G_z_static, En_right_x, right_static_condition, **args.specgan_d_kwargs)


  # Create loss
  D_clip_weights = None
  if args.specgan_loss == 'dcgan':
    fake = tf.zeros([args.train_batch_size], dtype=tf.float32)
    real = tf.ones([args.train_batch_size], dtype=tf.float32)

    G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
      logits=fake_logits,
      labels=real
    ))

    real_D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
      logits=real_logits,
      labels=real
    ))
    wrong_D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
      logits=wrong_logits,
      labels=fake
    ))
    fake_D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
      logits=fake_logits,
      labels=fake
    ))
    D_loss = real_D_loss + (wrong_D_loss + fake_D_loss)/2.
    
  elif args.specgan_loss == 'lsgan':
    G_loss = tf.reduce_mean((D_G_z - 1.) ** 2)
    D_loss = tf.reduce_mean((D_x - 1.) ** 2)
    D_loss += tf.reduce_mean(D_G_z ** 2)
    D_loss /= 2.
  elif args.specgan_loss == 'wgan':
    G_loss = -tf.reduce_mean(D_G_z)
    D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

    with tf.name_scope('D_clip_weights'):
      clip_ops = []
      for var in D_vars:
        clip_bounds = [-.01, .01]
        clip_ops.append(
          tf.assign(
            var,
            tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])
          )
        )
      D_clip_weights = tf.group(*clip_ops)
  elif args.specgan_loss == 'wgan-gp':
    G_loss = -tf.reduce_mean(D_G_z)
    D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

    alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1, 1], minval=0., maxval=1.)
    differences = G_z - x
    interpolates = x + (alpha * differences)
    with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
      D_interp = SpecGANDiscriminator(interpolates, **args.specgan_d_kwargs)

    LAMBDA = 10
    gradients = tf.gradients(D_interp, [interpolates])[0]
    slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
    gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2.)
    D_loss += LAMBDA * gradient_penalty
  else:
    raise NotImplementedError()

  tf.summary.scalar('G_loss', G_loss)
  tf.summary.scalar('D_loss', D_loss)

  # Create (recommended) optimizer
  if args.specgan_loss == 'dcgan':
    G_opt = tf.train.AdamOptimizer(
        learning_rate=2e-4,
        beta1=0.5)
    D_opt = tf.train.AdamOptimizer(
        learning_rate=2e-4,
        beta1=0.5)
  elif args.specgan_loss == 'lsgan':
    G_opt = tf.train.RMSPropOptimizer(
        learning_rate=1e-4)
    D_opt = tf.train.RMSPropOptimizer(
        learning_rate=1e-4)
  elif args.specgan_loss == 'wgan':
    G_opt = tf.train.RMSPropOptimizer(
        learning_rate=5e-5)
    D_opt = tf.train.RMSPropOptimizer(
        learning_rate=5e-5)
  elif args.specgan_loss == 'wgan-gp':
    G_opt = tf.train.AdamOptimizer(
        learning_rate=1e-4,
        beta1=0.5,
        beta2=0.9)
    D_opt = tf.train.AdamOptimizer(
        learning_rate=1e-4,
        beta1=0.5,
        beta2=0.9)
  else:
    raise NotImplementedError()

  # Create training ops
  G_train_op = G_opt.minimize(G_loss, var_list=G_vars,
      global_step=tf.train.get_or_create_global_step())
  D_train_op = D_opt.minimize(D_loss, var_list=D_vars)

  # Run training
  with tf.train.MonitoredTrainingSession(
      checkpoint_dir=args.train_dir,
      save_checkpoint_secs=args.train_save_secs,
      save_summaries_secs=args.train_summary_secs) as sess:
    while True:
      # Train discriminator
      for i in xrange(args.specgan_disc_nupdates):
        sess.run(D_train_op)

        # Enforce Lipschitz constraint for WGAN
        if D_clip_weights is not None:
          sess.run(D_clip_weights)

      # Train generator
      sess.run(G_train_op)