Example #1
0
def infer(args):
    infer_dir = os.path.join(args.train_dir, 'infer')
    if not os.path.isdir(infer_dir):
        os.makedirs(infer_dir)

    # Subgraph that generates latent vectors
    samp_z_n = tf.placeholder(tf.int32, [], name='samp_z_n')
    samp_z = tf.random_uniform([samp_z_n, args.specgan_latent_dim],
                               -1.0,
                               1.0,
                               dtype=tf.float32,
                               name='samp_z')

    # Input zo
    z = tf.placeholder(tf.float32, [None, args.specgan_latent_dim], name='z')
    ngl = tf.placeholder(tf.int32, [], name='ngl')
    flat_pad = tf.placeholder(tf.int32, [], name='flat_pad')

    # Execute generator
    with tf.variable_scope('G'):
        G_z_norm = SpecGANGenerator(z, train=False, **args.specgan_g_kwargs)
    G_z_norm = tf.identity(G_z_norm, name='G_z_norm')
    G_z = f_to_t(G_z_norm, args.data_moments_mean, args.data_moments_std, ngl)
    G_z = tf.identity(G_z, name='G_z')

    G_z_norm_uint8 = f_to_img(G_z_norm)
    G_z_norm_uint8 = tf.identity(G_z_norm_uint8, name='G_z_norm_uint8')

    # Flatten batch
    nch = int(G_z.get_shape()[-1])
    G_z_padded = tf.pad(G_z, [[0, 0], [0, flat_pad], [0, 0]])
    G_z_flat = tf.reshape(G_z_padded, [-1, nch], name='G_z_flat')

    # Encode to int16
    def float_to_int16(x, name=None):
        x_int16 = x * 32767.
        x_int16 = tf.clip_by_value(x_int16, -32767., 32767.)
        x_int16 = tf.cast(x_int16, tf.int16, name=name)
        return x_int16

    G_z_int16 = float_to_int16(G_z, name='G_z_int16')
    G_z_flat_int16 = float_to_int16(G_z_flat, name='G_z_flat_int16')

    # Create saver
    G_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='G')
    global_step = tf.train.get_or_create_global_step()
    saver = tf.train.Saver(G_vars + [global_step])

    # Export graph
    tf.train.write_graph(tf.get_default_graph(), infer_dir, 'infer.pbtxt')

    # Export MetaGraph
    infer_metagraph_fp = os.path.join(infer_dir, 'infer.meta')
    tf.train.export_meta_graph(filename=infer_metagraph_fp,
                               clear_devices=True,
                               saver_def=saver.as_saver_def())

    # Reset graph (in case training afterwards)
    tf.reset_default_graph()
Example #2
0
def train(fps, args):
    global train_dataset_size

    with tf.name_scope('loader'):
        # This was actually not necessarily good. However, we can keep it as a point for 115 tfrecords
        # train_fps, _ = loader.split_files_test_val(fps, train_data_percentage, 0)
        # fps = train_fps
        # fps = fps[:gan_train_data_size]

        logging.info("Full training datasize = " +
                     str(find_data_size(fps, None)))
        length = len(fps)
        fps = fps[:(int(train_data_percentage / 100.0 * length))]
        logging.info("GAN training datasize (before exclude) = " +
                     str(find_data_size(fps, None)))

        if args.exclude_class is None:
            pass
        elif args.exclude_class != -1:
            train_dataset_size = find_data_size(fps, args.exclude_class)
            logging.info("GAN training datasize (after exclude) = " +
                         str(train_dataset_size))
        elif args.exclude_class == -1:
            fps, _ = loader.split_files_test_val(fps, 0.9, 0)
            train_dataset_size = find_data_size(fps, args.exclude_class)
            logging.info(
                "GAN training datasize (after exclude - random sampling) = " +
                str(train_dataset_size))
        else:  # LOL :P
            raise ValueError(
                "args.exclude_class should be either [0, num_class), None, or -1 for random sampling 90%"
            )

        training_iterator = loader.get_batch(fps,
                                             args.train_batch_size,
                                             _WINDOW_LEN,
                                             args.data_first_window,
                                             repeat=True,
                                             initializable=True,
                                             labels=True,
                                             exclude_class=args.exclude_class)
        x_wav, _ = training_iterator.get_next()  # Important: ignore the labels
        print("x_wav.shape = %s" % str(x_wav.shape))
        x = t_to_f(x_wav, args.data_moments_mean, args.data_moments_std)
        print("x.shape = %s" % str(x.shape))

        logging.info("train_dataset_size = " + str(train_dataset_size))

    # Make z vector
    z = tf.random_uniform([args.train_batch_size, _D_Z],
                          -1.,
                          1.,
                          dtype=tf.float32)

    # Make generator
    with tf.variable_scope('G'):
        G_z = SpecGANGenerator(z, train=True, **args.specgan_g_kwargs)
    G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')

    # Print G summary
    logging.info('-' * 80)
    logging.info('Generator vars')
    nparams = 0
    for v in G_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        logging.info('{} ({}): {}'.format(v.get_shape().as_list(), v_n,
                                          v.name))
    logging.info('Total params: {} ({:.2f} MB)'.format(
        nparams, (float(nparams) * 4) / (1024 * 1024)))

    # Summarize
    x_gl = f_to_t(x, args.data_moments_mean, args.data_moments_std,
                  args.specgan_ngl)
    print("x_gl.shape = %s" % str(x_gl.shape))
    G_z_gl = f_to_t(G_z, args.data_moments_mean, args.data_moments_std,
                    args.specgan_ngl)
    tf.summary.audio('x_wav', x_wav, _FS)
    tf.summary.audio('x', x_gl, _FS)
    tf.summary.audio('G_z', G_z_gl, _FS)
    G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z_gl[:, :, 0]), axis=1))
    x_rms = tf.sqrt(tf.reduce_mean(tf.square(x_gl[:, :, 0]), axis=1))
    tf.summary.histogram('x_rms_batch', x_rms)
    tf.summary.histogram('G_z_rms_batch', G_z_rms)
    tf.summary.scalar('x_rms', tf.reduce_mean(x_rms))
    tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms))
    tf.summary.image('x', f_to_img(x))
    tf.summary.image('G_z', f_to_img(G_z))

    # Make real discriminator
    with tf.name_scope('D_x'), tf.variable_scope('D'):
        D_x = SpecGANDiscriminator(x, **args.specgan_d_kwargs)
    D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')

    # Print D summary
    logging.info('-' * 80)
    logging.info('Discriminator vars')
    nparams = 0
    for v in D_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        logging.info('{} ({}): {}'.format(v.get_shape().as_list(), v_n,
                                          v.name))
    logging.info('Total params: {} ({:.2f} MB)'.format(
        nparams, (float(nparams) * 4) / (1024 * 1024)))
    logging.info('-' * 80)

    # Make fake discriminator
    with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
        D_G_z = SpecGANDiscriminator(G_z, **args.specgan_d_kwargs)

    # Create loss
    D_clip_weights = None
    if args.specgan_loss == 'dcgan':
        fake = tf.zeros([args.train_batch_size], dtype=tf.float32)
        real = tf.ones([args.train_batch_size], dtype=tf.float32)

        G_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=real))

        D_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=fake))
        D_loss += tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x, labels=real))

        D_loss /= 2.
    elif args.specgan_loss == 'lsgan':
        G_loss = tf.reduce_mean((D_G_z - 1.)**2)
        D_loss = tf.reduce_mean((D_x - 1.)**2)
        D_loss += tf.reduce_mean(D_G_z**2)
        D_loss /= 2.
    elif args.specgan_loss == 'wgan':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        with tf.name_scope('D_clip_weights'):
            clip_ops = []
            for var in D_vars:
                clip_bounds = [-.01, .01]
                clip_ops.append(
                    tf.assign(
                        var,
                        tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])))
            D_clip_weights = tf.group(*clip_ops)
    elif args.specgan_loss == 'wgan-gp':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1, 1],
                                  minval=0.,
                                  maxval=1.)
        differences = G_z - x
        interpolates = x + (alpha * differences)
        with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
            D_interp = SpecGANDiscriminator(interpolates,
                                            **args.specgan_d_kwargs)

        LAMBDA = 10
        gradients = tf.gradients(D_interp, [interpolates])[0]
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
        gradient_penalty = tf.reduce_mean((slopes - 1.)**2.)
        D_loss += LAMBDA * gradient_penalty
    else:
        raise NotImplementedError()

    tf.summary.scalar('G_loss', G_loss)
    tf.summary.scalar('D_loss', D_loss)

    # Create (recommended) optimizer
    if args.specgan_loss == 'dcgan':
        G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
        D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
    elif args.specgan_loss == 'lsgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
    elif args.specgan_loss == 'wgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
    elif args.specgan_loss == 'wgan-gp':
        G_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
        D_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
    else:
        raise NotImplementedError()

    # Create training ops
    G_train_op = G_opt.minimize(
        G_loss,
        var_list=G_vars,
        global_step=tf.train.get_or_create_global_step())
    D_train_op = D_opt.minimize(D_loss, var_list=D_vars)

    # Run training
    current_step = -1
    scaffold = tf.train.Scaffold(local_init_op=tf.group(
        tf.local_variables_initializer(), training_iterator.initializer),
                                 saver=tf.train.Saver(max_to_keep=3))
    gpu_options = tf.GPUOptions(allow_growth=True,
                                per_process_gpu_memory_fraction=0.5)
    with tf.train.MonitoredTrainingSession(
            hooks=[SaveAtEnd(os.path.join(args.train_dir, 'model'))],
            config=tf.ConfigProto(gpu_options=gpu_options),
            scaffold=scaffold,
            checkpoint_dir=args.train_dir,
            save_checkpoint_secs=args.train_save_secs,
            save_summaries_secs=args.train_summary_secs,
    ) as sess:
        # sess.run(training_iterator.initializer)
        while True:
            global_step = sess.run(tf.train.get_or_create_global_step())
            logging.info("Global step: " + str(global_step))

            if args.stop_at_global_step != 0 and global_step >= args.stop_at_global_step:
                logging.info(
                    "Stopping because args.stop_at_global_step is set to  " +
                    str(args.stop_at_global_step))
                break
                # last_saver.save(sess, os.path.join(args.train_dir, 'model'), global_step=global_step)

            # Train discriminator
            # for i in range(args.specgan_disc_nupdates):
            #   try:
            #     sess.run(D_train_op)
            #     current_step += 1
            #     # Stop training after x% of training data seen
            #     if current_step * args.train_batch_size > math.ceil(train_dataset_size * train_data_percentage / 100.0):
            #       logging.info("Stopping at batch: " + str(current_step))
            #       current_step = -1
            #       sess.run(training_iterator.initializer)
            #
            #   except tf.errors.OutOfRangeError:
            #     # End of training dataset
            #     if train_data_percentage != 100:
            #       logging.info("ERROR: end of dataset for only part of data! Achieved end of training dataset with train_data_percentage = " + str(train_data_percentage))
            #     else:
            #       current_step = -1
            #       sess.run(training_iterator.initializer)

            # Train discriminator
            try:
                for i in range(args.specgan_disc_nupdates):
                    sess.run(D_train_op)

                # Enforce Lipschitz constraint for WGAN
                if D_clip_weights is not None:
                    sess.run(D_clip_weights)
            except tf.errors.OutOfRangeError:
                sess.run(training_iterator.initializer)

            # Train generator
            sess.run(G_train_op)
Example #3
0
def train(fps, args):
    with tf.name_scope('loader'):
        x_wav = loader.decode_extract_and_batch(
            fps,
            batch_size=args.train_batch_size,
            slice_len=_SLICE_LEN,
            decode_fs=args.data_sample_rate,
            decode_num_channels=1,
            decode_fast_wav=args.data_fast_wav,
            decode_parallel_calls=4,
            slice_randomize_offset=False if args.data_first_slice else True,
            slice_first_only=args.data_first_slice,
            slice_overlap_ratio=0.
            if args.data_first_slice else args.data_overlap_ratio,
            slice_pad_end=True if args.data_first_slice else args.data_pad_end,
            repeat=True,
            shuffle=True,
            shuffle_buffer_size=4096,
            prefetch_size=args.train_batch_size * 4,
            prefetch_gpu_num=args.data_prefetch_gpu_num)[:, :, 0]

        x = t_to_f(x_wav, args.data_moments_mean, args.data_moments_std)

    # Make z vector
    z = tf.random_uniform([args.train_batch_size, args.specgan_latent_dim],
                          -1.,
                          1.,
                          dtype=tf.float32)

    # Make generator
    with tf.variable_scope('G'):
        G_z = SpecGANGenerator(z, train=True, **args.specgan_g_kwargs)
    G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')

    # Print G summary
    print('-' * 80)
    print('Generator vars')
    nparams = 0
    for v in G_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))

    # Summarize
    x_gl = f_to_t(x, args.data_moments_mean, args.data_moments_std,
                  args.specgan_ngl)
    G_z_gl = f_to_t(G_z, args.data_moments_mean, args.data_moments_std,
                    args.specgan_ngl)
    tf.summary.audio('x_wav', x_wav, args.data_sample_rate)
    tf.summary.audio('x', x_gl, args.data_sample_rate)
    tf.summary.audio('G_z', G_z_gl, args.data_sample_rate)
    G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z_gl[:, :, 0]), axis=1))
    x_rms = tf.sqrt(tf.reduce_mean(tf.square(x_gl[:, :, 0]), axis=1))
    tf.summary.histogram('x_rms_batch', x_rms)
    tf.summary.histogram('G_z_rms_batch', G_z_rms)
    tf.summary.scalar('x_rms', tf.reduce_mean(x_rms))
    tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms))
    tf.summary.image('x', f_to_img(x))
    tf.summary.image('G_z', f_to_img(G_z))

    # Make real discriminator
    with tf.name_scope('D_x'), tf.variable_scope('D'):
        D_x = SpecGANDiscriminator(x, **args.specgan_d_kwargs)
    D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')

    # Print D summary
    print('-' * 80)
    print('Discriminator vars')
    nparams = 0
    for v in D_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))
    print('-' * 80)

    # Make fake discriminator
    with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
        D_G_z = SpecGANDiscriminator(G_z, **args.specgan_d_kwargs)

    # Create loss
    D_clip_weights = None
    if args.specgan_loss == 'dcgan':
        fake = tf.zeros([args.train_batch_size], dtype=tf.float32)
        real = tf.ones([args.train_batch_size], dtype=tf.float32)

        G_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=real))

        D_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=fake))
        D_loss += tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x, labels=real))

        D_loss /= 2.
    elif args.specgan_loss == 'lsgan':
        G_loss = tf.reduce_mean((D_G_z - 1.)**2)
        D_loss = tf.reduce_mean((D_x - 1.)**2)
        D_loss += tf.reduce_mean(D_G_z**2)
        D_loss /= 2.
    elif args.specgan_loss == 'wgan':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        with tf.name_scope('D_clip_weights'):
            clip_ops = []
            for var in D_vars:
                clip_bounds = [-.01, .01]
                clip_ops.append(
                    tf.assign(
                        var,
                        tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])))
            D_clip_weights = tf.group(*clip_ops)
    elif args.specgan_loss == 'wgan-gp':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1, 1],
                                  minval=0.,
                                  maxval=1.)
        differences = G_z - x
        interpolates = x + (alpha * differences)
        with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
            D_interp = SpecGANDiscriminator(interpolates,
                                            **args.specgan_d_kwargs)

        LAMBDA = 10
        gradients = tf.gradients(D_interp, [interpolates])[0]
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
        gradient_penalty = tf.reduce_mean((slopes - 1.)**2.)
        D_loss += LAMBDA * gradient_penalty
    else:
        raise NotImplementedError()

    tf.summary.scalar('G_loss', G_loss)
    tf.summary.scalar('D_loss', D_loss)

    # Create (recommended) optimizer
    if args.specgan_loss == 'dcgan':
        G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
        D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
    elif args.specgan_loss == 'lsgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
    elif args.specgan_loss == 'wgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
    elif args.specgan_loss == 'wgan-gp':
        G_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
        D_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
    else:
        raise NotImplementedError()

    # Create training ops
    G_train_op = G_opt.minimize(
        G_loss,
        var_list=G_vars,
        global_step=tf.train.get_or_create_global_step())
    D_train_op = D_opt.minimize(D_loss, var_list=D_vars)

    # Run training
    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=args.train_dir,
            save_checkpoint_secs=args.train_save_secs,
            save_summaries_secs=args.train_summary_secs) as sess:
        print('-' * 80)
        print(
            'Training has started. Please use \'tensorboard --logdir={}\' to monitor.'
            .format(args.train_dir))
        while True:
            # Train discriminator
            for i in xrange(args.specgan_disc_nupdates):
                sess.run(D_train_op)

                # Enforce Lipschitz constraint for WGAN
                if D_clip_weights is not None:
                    sess.run(D_clip_weights)

            # Train generator
            sess.run(G_train_op)
Example #4
0
def train(fps, args):
  with tf.name_scope('loader'):
    right_x_wav = loader.get_batch(fps, args.train_batch_size, _WINDOW_LEN, args.data_first_window)
    right_x = t_to_f(right_x_wav, args.data_moments_mean, args.data_moments_std)
    wrong_x_wav = loader.get_batch(fps, args.train_batch_size, _WINDOW_LEN, args.data_first_window)
    wrong_x = t_to_f(wrong_x_wav, args.data_moments_mean, args.data_moments_std)

  # Make z vector
  z = tf.random_uniform([args.train_batch_size, _D_Z], -1., 1., dtype=tf.float32)
  # static_condition means pitch
  right_static_condition = tf.random_uniform([args.train_batch_size, _STATIC_PITCH_DIM], -1., 1., dtype=tf.float32)
  wrong_static_condition = tf.random_uniform([args.train_batch_size, _STATIC_PITCH_DIM], -1., 1., dtype=tf.float32)


  # Make generator
  with tf.variable_scope('G'):
    # encode the spectrum into a vector
    En_right_x = SpecGANEncoder(right_x)
    En_wrong_x = SpecGANEncoder(wrong_x)
    Condition_z = tf.concat([En_right_x, z, static_condition], 1)

    G_z, G_z_static = SpecGANGenerator(Condition_z, train=True, **args.specgan_g_kwargs)
  G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')

  # Print G summary
  print('-' * 80)
  print('Generator vars')
  nparams = 0
  for v in G_vars:
    v_shape = v.get_shape().as_list()
    v_n = reduce(lambda x, y: x * y, v_shape)
    nparams += v_n
    print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
  print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024)))

  # Summarize
  x_gl = f_to_t(x, args.data_moments_mean, args.data_moments_std, args.specgan_ngl)
  G_z_gl = f_to_t(G_z, args.data_moments_mean, args.data_moments_std, args.specgan_ngl)
  tf.summary.audio('x_wav', x_wav, _FS)
  tf.summary.audio('x', x_gl, _FS)
  tf.summary.audio('G_z', G_z_gl, _FS)
  G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z_gl[:, :, 0]), axis=1))
  x_rms = tf.sqrt(tf.reduce_mean(tf.square(x_gl[:, :, 0]), axis=1))
  tf.summary.histogram('x_rms_batch', x_rms)
  tf.summary.histogram('G_z_rms_batch', G_z_rms)
  tf.summary.scalar('x_rms', tf.reduce_mean(x_rms))
  tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms))
  tf.summary.image('x', f_to_img(x))
  tf.summary.image('G_z', f_to_img(G_z))

  # Real input to discriminator
  dynamic_x = tf.random_uniform([args.train_batch_size, 128, 128, 1], -1., 1., dtype=tf.float32)
  static_x = tf.random_uniform([args.train_batch_size, _STATIC_TRACT_DIM], -1., 1., dtype=tf.float32)

  # Make real-right discriminator
  with tf.name_scope('D_x'), tf.variable_scope('D'):
    real_logits = SpecGANDiscriminator(dynamic_x, static_x, En_right_x, right_static_condition, **args.specgan_d_kwargs)
  D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')

  # Print D summary
  print('-' * 80)
  print('Discriminator vars')
  nparams = 0
  for v in D_vars:
    v_shape = v.get_shape().as_list()
    v_n = reduce(lambda x, y: x * y, v_shape)
    nparams += v_n
    print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
  print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024)))
  print('-' * 80)

  # Make real-wrong discriminator
  with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
    wrong_logits = SpecGANDiscriminator(dynamic_x, static_x, En_wrong_x, wrong_static_condition, **args.specgan_d_kwargs)

  # Make fake-right discriminator
  with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
    fake_logits = SpecGANDiscriminator(G_z, G_z_static, En_right_x, right_static_condition, **args.specgan_d_kwargs)


  # Create loss
  D_clip_weights = None
  if args.specgan_loss == 'dcgan':
    fake = tf.zeros([args.train_batch_size], dtype=tf.float32)
    real = tf.ones([args.train_batch_size], dtype=tf.float32)

    G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
      logits=fake_logits,
      labels=real
    ))

    real_D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
      logits=real_logits,
      labels=real
    ))
    wrong_D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
      logits=wrong_logits,
      labels=fake
    ))
    fake_D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
      logits=fake_logits,
      labels=fake
    ))
    D_loss = real_D_loss + (wrong_D_loss + fake_D_loss)/2.
    
  elif args.specgan_loss == 'lsgan':
    G_loss = tf.reduce_mean((D_G_z - 1.) ** 2)
    D_loss = tf.reduce_mean((D_x - 1.) ** 2)
    D_loss += tf.reduce_mean(D_G_z ** 2)
    D_loss /= 2.
  elif args.specgan_loss == 'wgan':
    G_loss = -tf.reduce_mean(D_G_z)
    D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

    with tf.name_scope('D_clip_weights'):
      clip_ops = []
      for var in D_vars:
        clip_bounds = [-.01, .01]
        clip_ops.append(
          tf.assign(
            var,
            tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])
          )
        )
      D_clip_weights = tf.group(*clip_ops)
  elif args.specgan_loss == 'wgan-gp':
    G_loss = -tf.reduce_mean(D_G_z)
    D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

    alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1, 1], minval=0., maxval=1.)
    differences = G_z - x
    interpolates = x + (alpha * differences)
    with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
      D_interp = SpecGANDiscriminator(interpolates, **args.specgan_d_kwargs)

    LAMBDA = 10
    gradients = tf.gradients(D_interp, [interpolates])[0]
    slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
    gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2.)
    D_loss += LAMBDA * gradient_penalty
  else:
    raise NotImplementedError()

  tf.summary.scalar('G_loss', G_loss)
  tf.summary.scalar('D_loss', D_loss)

  # Create (recommended) optimizer
  if args.specgan_loss == 'dcgan':
    G_opt = tf.train.AdamOptimizer(
        learning_rate=2e-4,
        beta1=0.5)
    D_opt = tf.train.AdamOptimizer(
        learning_rate=2e-4,
        beta1=0.5)
  elif args.specgan_loss == 'lsgan':
    G_opt = tf.train.RMSPropOptimizer(
        learning_rate=1e-4)
    D_opt = tf.train.RMSPropOptimizer(
        learning_rate=1e-4)
  elif args.specgan_loss == 'wgan':
    G_opt = tf.train.RMSPropOptimizer(
        learning_rate=5e-5)
    D_opt = tf.train.RMSPropOptimizer(
        learning_rate=5e-5)
  elif args.specgan_loss == 'wgan-gp':
    G_opt = tf.train.AdamOptimizer(
        learning_rate=1e-4,
        beta1=0.5,
        beta2=0.9)
    D_opt = tf.train.AdamOptimizer(
        learning_rate=1e-4,
        beta1=0.5,
        beta2=0.9)
  else:
    raise NotImplementedError()

  # Create training ops
  G_train_op = G_opt.minimize(G_loss, var_list=G_vars,
      global_step=tf.train.get_or_create_global_step())
  D_train_op = D_opt.minimize(D_loss, var_list=D_vars)

  # Run training
  with tf.train.MonitoredTrainingSession(
      checkpoint_dir=args.train_dir,
      save_checkpoint_secs=args.train_save_secs,
      save_summaries_secs=args.train_summary_secs) as sess:
    while True:
      # Train discriminator
      for i in xrange(args.specgan_disc_nupdates):
        sess.run(D_train_op)

        # Enforce Lipschitz constraint for WGAN
        if D_clip_weights is not None:
          sess.run(D_clip_weights)

      # Train generator
      sess.run(G_train_op)