예제 #1
0
def infer(args):
    infer_dir = os.path.join(args.train_dir, 'infer')
    if not os.path.isdir(infer_dir):
        os.makedirs(infer_dir)

    # Subgraph that generates latent vectors
    samp_z_n = tf.placeholder(tf.int32, [], name='samp_z_n')
    samp_z = tf.random_uniform([samp_z_n, _D_Z],
                               -1.0,
                               1.0,
                               dtype=tf.float32,
                               name='samp_z')

    # Input zo
    z = tf.placeholder(tf.float32, [None, _D_Z + _D_Y], name='z')
    flat_pad = tf.placeholder(tf.int32, [], name='flat_pad')

    # Execute generator
    with tf.variable_scope('G'):
        G_z = WaveGANGenerator(z, train=False, **args.wavegan_g_kwargs)
        if args.wavegan_genr_pp:
            with tf.variable_scope('pp_filt'):
                G_z = tf.layers.conv1d(G_z,
                                       1,
                                       args.wavegan_genr_pp_len,
                                       use_bias=False,
                                       padding='same')
    G_z = tf.identity(G_z, name='G_z')

    # Flatten batch
    nch = int(G_z.get_shape()[-1])
    G_z_padded = tf.pad(G_z, [[0, 0], [0, flat_pad], [0, 0]])
    G_z_flat = tf.reshape(G_z_padded, [-1, nch], name='G_z_flat')

    # Encode to int16
    def float_to_int16(x, name=None):
        x_int16 = x * 32767.
        x_int16 = tf.clip_by_value(x_int16, -32767., 32767.)
        x_int16 = tf.cast(x_int16, tf.int16, name=name)
        return x_int16

    G_z_int16 = float_to_int16(G_z, name='G_z_int16')
    G_z_flat_int16 = float_to_int16(G_z_flat, name='G_z_flat_int16')

    # Create saver
    G_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='G')
    global_step = tf.train.get_or_create_global_step()
    saver = tf.train.Saver(G_vars + [global_step])

    # Export graph
    tf.train.write_graph(tf.get_default_graph(), infer_dir, 'infer.pbtxt')

    # Export MetaGraph
    infer_metagraph_fp = os.path.join(infer_dir, 'infer.meta')
    tf.train.export_meta_graph(filename=infer_metagraph_fp,
                               clear_devices=True,
                               saver_def=saver.as_saver_def())

    # Reset graph (in case training afterwards)
    tf.reset_default_graph()
예제 #2
0
 def load_wavegan(self, slice_len=16384, model_size=32):
     path_to_model = os.path.join(self.wavegan_path)
     self.generator = WaveGANGenerator(slice_len=slice_len,
                                       model_size=model_size,
                                       use_batch_norm=False,
                                       num_channels=1)
     checkpoint = torch.load(path_to_model, map_location=self.device)
     self.generator.load_state_dict(checkpoint['generator'])
예제 #3
0
def train(fps, args):
    with tf.name_scope('loader'):
        x, y = loader.get_batch(fps,
                                args.train_batch_size,
                                _WINDOW_LEN,
                                args.data_first_window,
                                labels=True)

    # Make inputs
    y_fill = tf.expand_dims(y, axis=2)
    z = tf.random_uniform([args.train_batch_size, _D_Z],
                          -1.,
                          1.,
                          dtype=tf.float32)

    # concatenate
    x = tf.concat([x, y_fill], 1)
    z = tf.concat([z, y], 1)

    # Make generator
    with tf.variable_scope('G'):
        G_z = WaveGANGenerator(z, train=True, **args.wavegan_g_kwargs)
        if args.wavegan_genr_pp:
            with tf.variable_scope('pp_filt'):
                G_z = tf.layers.conv1d(G_z,
                                       1,
                                       args.wavegan_genr_pp_len,
                                       use_bias=False,
                                       padding='same')
    G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')

    # Print G summary
    print('-' * 80)
    print('Generator vars')
    nparams = 0
    for v in G_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))

    # Summarize
    tf.summary.audio('x', x, _FS)
    tf.summary.audio('G_z', G_z, _FS)
    G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1))
    x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1))
    tf.summary.histogram('x_rms_batch', x_rms)
    tf.summary.histogram('G_z_rms_batch', G_z_rms)
    tf.summary.scalar('x_rms', tf.reduce_mean(x_rms))
    tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms))

    # Make real discriminator
    with tf.name_scope('D_x'), tf.variable_scope('D'):
        D_x = WaveGANDiscriminator(x, **args.wavegan_d_kwargs)
    D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')

    # Print D summary
    print('-' * 80)
    print('Discriminator vars')
    nparams = 0
    for v in D_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))
    print('-' * 80)

    # Make fake discriminator
    with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
        D_G_z = WaveGANDiscriminator(G_z, **args.wavegan_d_kwargs)

    # Create loss
    D_clip_weights = None
    if args.wavegan_loss == 'dcgan':
        fake = tf.zeros([args.train_batch_size], dtype=tf.float32)
        real = tf.ones([args.train_batch_size], dtype=tf.float32)

        G_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=real))

        D_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=fake))
        D_loss += tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x, labels=real))

        D_loss /= 2.
    elif args.wavegan_loss == 'lsgan':
        G_loss = tf.reduce_mean((D_G_z - 1.)**2)
        D_loss = tf.reduce_mean((D_x - 1.)**2)
        D_loss += tf.reduce_mean(D_G_z**2)
        D_loss /= 2.
    elif args.wavegan_loss == 'wgan':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        with tf.name_scope('D_clip_weights'):
            clip_ops = []
            for var in D_vars:
                clip_bounds = [-.01, .01]
                clip_ops.append(
                    tf.assign(
                        var,
                        tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])))
            D_clip_weights = tf.group(*clip_ops)
    elif args.wavegan_loss == 'wgan-gp':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1],
                                  minval=0.,
                                  maxval=1.)
        differences = G_z - x
        interpolates = x + (alpha * differences)
        with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
            D_interp = WaveGANDiscriminator(interpolates,
                                            **args.wavegan_d_kwargs)

        LAMBDA = 10
        gradients = tf.gradients(D_interp, [interpolates])[0]
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
        gradient_penalty = tf.reduce_mean((slopes - 1.)**2.)
        D_loss += LAMBDA * gradient_penalty
    else:
        raise NotImplementedError()

    tf.summary.scalar('G_loss', G_loss)
    tf.summary.scalar('D_loss', D_loss)

    # Create (recommended) optimizer
    if args.wavegan_loss == 'dcgan':
        G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
        D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
    elif args.wavegan_loss == 'lsgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
    elif args.wavegan_loss == 'wgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
    elif args.wavegan_loss == 'wgan-gp':
        G_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
        D_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
    else:
        raise NotImplementedError()

    # Create training ops
    G_train_op = G_opt.minimize(
        G_loss,
        var_list=G_vars,
        global_step=tf.train.get_or_create_global_step())
    D_train_op = D_opt.minimize(D_loss, var_list=D_vars)

    # Run training
    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=args.train_dir,
            save_checkpoint_secs=args.train_save_secs,
            save_summaries_secs=args.train_summary_secs) as sess:
        while True:
            # Train discriminator
            for i in xrange(args.wavegan_disc_nupdates):
                sess.run(D_train_op)

                # Enforce Lipschitz constraint for WGAN
                if D_clip_weights is not None:
                    sess.run(D_clip_weights)

            # Train generator
            sess.run(G_train_op)
예제 #4
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = WaveGANGenerator(
            model_size=opt.model_size,
            ngpus=opt.ngpus,
            latent_dim=opt.latent_dim,
            alpha=opt.alpha,
            post_proc_filt_len=opt.post_proc_filt_len)
        self.netG_B = WaveGANGenerator(
            model_size=opt.model_size,
            ngpus=opt.ngpus,
            latent_dim=opt.latent_dim,
            alpha=opt.alpha,
            post_proc_filt_len=opt.post_proc_filt_len)

        if self.isTrain:
            use_sigmoid = opt.gan_loss != 'lsgan'
            self.netD_A = WaveGANDiscriminator(model_size=opt.model_size,
                                               ngpus=opt.ngpus,
                                               shift_factor=opt.shift_factor,
                                               alpha=opt.alpha,
                                               batch_shuffle=opt.batch_shuffle)
            self.netD_B = WaveGANDiscriminator(model_size=opt.model_size,
                                               ngpus=opt.ngpus,
                                               shift_factor=opt.shift_factor,
                                               alpha=opt.alpha,
                                               batch_shuffle=opt.batch_shuffle)

        if self.isTrain:
            self.fake_A_pool = AudioPool(opt.pool_size)
            self.fake_B_pool = AudioPool(opt.pool_size)
            # define loss functions
            self.criterionGAN = GANLoss(loss_type=opt.gan_loss,
                                        tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, opt.beta2))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, opt.beta2))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, opt.beta2))
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                self.schedulers.append(get_scheduler(optimizer, opt))

        LOGGER.info('---------- Networks initialized -------------')
        print_network(self.netG_A)
        print_network(self.netG_B)
        if self.isTrain:
            print_network(self.netD_A)
            print_network(self.netD_B)
        LOGGER.info('-----------------------------------------------')
예제 #5
0
class CycleGANModel(BaseModel):
    def name(self):
        return 'CycleGANModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = WaveGANGenerator(
            model_size=opt.model_size,
            ngpus=opt.ngpus,
            latent_dim=opt.latent_dim,
            alpha=opt.alpha,
            post_proc_filt_len=opt.post_proc_filt_len)
        self.netG_B = WaveGANGenerator(
            model_size=opt.model_size,
            ngpus=opt.ngpus,
            latent_dim=opt.latent_dim,
            alpha=opt.alpha,
            post_proc_filt_len=opt.post_proc_filt_len)

        if self.isTrain:
            use_sigmoid = opt.gan_loss != 'lsgan'
            self.netD_A = WaveGANDiscriminator(model_size=opt.model_size,
                                               ngpus=opt.ngpus,
                                               shift_factor=opt.shift_factor,
                                               alpha=opt.alpha,
                                               batch_shuffle=opt.batch_shuffle)
            self.netD_B = WaveGANDiscriminator(model_size=opt.model_size,
                                               ngpus=opt.ngpus,
                                               shift_factor=opt.shift_factor,
                                               alpha=opt.alpha,
                                               batch_shuffle=opt.batch_shuffle)

        if self.isTrain:
            self.fake_A_pool = AudioPool(opt.pool_size)
            self.fake_B_pool = AudioPool(opt.pool_size)
            # define loss functions
            self.criterionGAN = GANLoss(loss_type=opt.gan_loss,
                                        tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, opt.beta2))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, opt.beta2))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, opt.beta2))
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                self.schedulers.append(get_scheduler(optimizer, opt))

        LOGGER.info('---------- Networks initialized -------------')
        print_network(self.netG_A)
        print_network(self.netG_B)
        if self.isTrain:
            print_network(self.netD_A)
            print_network(self.netD_B)
        LOGGER.info('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        if len(self.gpu_ids) > 0:
            input_A = input_A.cuda(self.gpu_ids[0], async=True)
            input_B = input_B.cuda(self.gpu_ids[0], async=True)
        self.input_A = input_A
        self.input_B = input_B
        self.audio_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)

    def test(self):
        real_A = Variable(self.input_A, volatile=True)
        fake_B = self.netG_A(real_A)
        self.rec_A = self.netG_B(fake_B).data
        self.fake_B = fake_B.data

        real_B = Variable(self.input_B, volatile=True)
        fake_A = self.netG_B(real_B)
        self.rec_B = self.netG_A(fake_A).data
        self.fake_A = fake_A.data

    # get audio paths
    def get_audio_paths(self):
        return self.audio_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5

        # backward
        loss_D.backward()
        return loss_D

    def backward_D_wp(self, netD, real, fake):
        # Gradient penalty loss for WGAN-WP
        loss_D_wp = calc_gradient_penalty(netD,
                                          real,
                                          fake,
                                          self.opt.batchSize,
                                          self.opt.lambda_wp,
                                          use_cuda=len(self.gpu_ids) > 0)

        loss_D_wp.backward()
        return loss_D_wp

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
        self.loss_D_A = loss_D_A.data[0]
        if self.opt.gan_loss == 'wgan-wp':
            loss_D_wp_A = self.backward_D_wp(self.netD_A, self.real_B, fake_B)
            self.loss_D_wp_A = loss_D_wp_A.data[0]
        else:
            self.loss_D_wp_A = 0

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
        self.loss_D_B = loss_D_B.data[0]
        if self.opt.gan_loss == 'wgan-wp':
            loss_D_wp_B = self.backward_D_wp(self.netD_B, self.real_A, fake_A)
            self.loss_D_wp_B = loss_D_wp_B.data[0]
        else:
            self.loss_D_wp_B = 0

    def backward_G(self):
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            idt_A = self.netG_A(self.real_B)
            loss_idt_A = self.criterionIdt(idt_A,
                                           self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            idt_B = self.netG_B(self.real_A)
            loss_idt_B = self.criterionIdt(idt_B,
                                           self.real_A) * lambda_A * lambda_idt

            self.idt_A = idt_A.data
            self.idt_B = idt_B.data
            self.loss_idt_A = loss_idt_A.data[0]
            self.loss_idt_B = loss_idt_B.data[0]
        else:
            loss_idt_A = 0
            loss_idt_B = 0
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        fake_B = self.netG_A(self.real_A)
        pred_fake = self.netD_A(fake_B)
        loss_G_A = self.criterionGAN(pred_fake, True)

        # GAN loss D_B(G_B(B))
        fake_A = self.netG_B(self.real_B)
        pred_fake = self.netD_B(fake_A)
        loss_G_B = self.criterionGAN(pred_fake, True)

        # Forward cycle loss
        rec_A = self.netG_B(fake_B)
        loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A

        # Backward cycle loss
        rec_B = self.netG_A(fake_A)
        loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B
        # combined loss
        loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B
        loss_G.backward()

        self.fake_B = fake_B.data
        self.fake_A = fake_A.data
        self.rec_A = rec_A.data
        self.rec_B = rec_B.data

        self.loss_G_A = loss_G_A.data[0]
        self.loss_G_B = loss_G_B.data[0]
        self.loss_cycle_A = loss_cycle_A.data[0]
        self.loss_cycle_B = loss_cycle_B.data[0]

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        ret_errors = OrderedDict([('D_A', self.loss_D_A),
                                  ('G_A', self.loss_G_A),
                                  ('Cyc_A', self.loss_cycle_A),
                                  ('D_B', self.loss_D_B),
                                  ('G_B', self.loss_G_B),
                                  ('Cyc_B', self.loss_cycle_B)])
        if self.opt.lambda_identity > 0.0:
            ret_errors['idt_A'] = self.loss_idt_A
            ret_errors['idt_B'] = self.loss_idt_B

        if self.opt.gan_loss == 'wgan-wp':
            ret_errors['WP_A'] = self.loss_D_wp_A
            ret_errors['WP_B'] = self.loss_D_wp_B

        return ret_errors

    def get_current_audibles(self):
        real_A = tensor2audio(self.input_A)
        fake_B = tensor2audio(self.fake_B)
        rec_A = tensor2audio(self.rec_A)
        real_B = tensor2audio(self.input_B)
        fake_A = tensor2audio(self.fake_A)
        rec_B = tensor2audio(self.rec_B)
        ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B),
                                   ('rec_A', rec_A), ('real_B', real_B),
                                   ('fake_A', fake_A), ('rec_B', rec_B)])
        if self.opt.isTrain and self.opt.lambda_identity > 0.0:
            ret_visuals['idt_A'] = tensor2audio(self.idt_A)
            ret_visuals['idt_B'] = tensor2audio(self.idt_B)
        return ret_visuals

    def save(self, label):
        self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
        self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)
        self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
        self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)
예제 #6
0
def train(fps, args):
    with tf.name_scope('loader'):
        x = loader.decode_extract_and_batch(
            fps,
            batch_size=args.train_batch_size,
            slice_len=args.data_slice_len,
            decode_fs=args.data_sample_rate,
            decode_num_channels=args.data_num_channels,
            decode_fast_wav=args.data_fast_wav,
            decode_parallel_calls=4,
            slice_randomize_offset=False if args.data_first_slice else True,
            slice_first_only=args.data_first_slice,
            slice_overlap_ratio=0.
            if args.data_first_slice else args.data_overlap_ratio,
            slice_pad_end=True if args.data_first_slice else args.data_pad_end,
            repeat=True,
            shuffle=True,
            shuffle_buffer_size=4096,
            prefetch_size=args.train_batch_size * 4,
            prefetch_gpu_num=args.data_prefetch_gpu_num)[:, :, 0]

    # Make z vector
    z = tf.random_uniform([args.train_batch_size, args.wavegan_latent_dim],
                          -1.,
                          1.,
                          dtype=tf.float32)

    # Make generator
    with tf.variable_scope('G'):
        G_z = WaveGANGenerator(z, train=True, **args.wavegan_g_kwargs)
        if args.wavegan_genr_pp:
            with tf.variable_scope('pp_filt'):
                G_z = tf.layers.conv1d(G_z,
                                       1,
                                       args.wavegan_genr_pp_len,
                                       use_bias=False,
                                       padding='same')
    G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')

    # Print G summary
    print('-' * 80)
    print('Generator vars')
    nparams = 0
    for v in G_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))

    # Summarize
    tf.summary.audio('x', x, args.data_sample_rate)
    tf.summary.audio('G_z', G_z, args.data_sample_rate)
    G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1))
    x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1))
    tf.summary.histogram('x_rms_batch', x_rms)
    tf.summary.histogram('G_z_rms_batch', G_z_rms)
    tf.summary.scalar('x_rms', tf.reduce_mean(x_rms))
    tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms))

    # Make real discriminator
    with tf.name_scope('D_x'), tf.variable_scope('D'):
        D_x = WaveGANDiscriminator(x, **args.wavegan_d_kwargs)
    D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')

    # Print D summary
    print('-' * 80)
    print('Discriminator vars')
    nparams = 0
    for v in D_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))
    print('-' * 80)

    # Make fake discriminator
    with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
        D_G_z = WaveGANDiscriminator(G_z, **args.wavegan_d_kwargs)

    # Create loss
    D_clip_weights = None
    if args.wavegan_loss == 'dcgan':
        fake = tf.zeros([args.train_batch_size], dtype=tf.float32)
        real = tf.ones([args.train_batch_size], dtype=tf.float32)

        G_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=real))

        D_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=fake))
        D_loss += tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x, labels=real))

        D_loss /= 2.
    elif args.wavegan_loss == 'lsgan':
        G_loss = tf.reduce_mean((D_G_z - 1.)**2)
        D_loss = tf.reduce_mean((D_x - 1.)**2)
        D_loss += tf.reduce_mean(D_G_z**2)
        D_loss /= 2.
    elif args.wavegan_loss == 'wgan':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        with tf.name_scope('D_clip_weights'):
            clip_ops = []
            for var in D_vars:
                clip_bounds = [-.01, .01]
                clip_ops.append(
                    tf.assign(
                        var,
                        tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])))
            D_clip_weights = tf.group(*clip_ops)
    elif args.wavegan_loss == 'wgan-gp':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1],
                                  minval=0.,
                                  maxval=1.)
        differences = G_z - x
        interpolates = x + (alpha * differences)
        with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
            D_interp = WaveGANDiscriminator(interpolates,
                                            **args.wavegan_d_kwargs)

        LAMBDA = 10
        gradients = tf.gradients(D_interp, [interpolates])[0]
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
        gradient_penalty = tf.reduce_mean((slopes - 1.)**2.)
        D_loss += LAMBDA * gradient_penalty
    else:
        raise NotImplementedError()

    # Load adversarial input
    fs, audio = wavread(args.adv_input)
    assert fs == args.data_sample_rate
    assert audio.dtype == np.float32
    assert len(audio.shape) == 1

    # Synthesis
    if audio.shape[0] < args.data_slice_len:
        audio = np.pad(audio, (0, args.data_slice_len - audio.shape[0]),
                       'constant')
    adv_input = tf.constant(
        audio[:args.data_slice_len], dtype=np.float32
    ) + args.adv_magnitude * tf.reshape(G_z,
                                        G_z.get_shape().as_list()[:-1])

    # Calculate MFCCs
    spectrograms = tf.abs(
        tf.signal.stft(adv_input, frame_length=320, frame_step=160))
    linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
        40, spectrograms.shape[-1].value, fs, 20, 4000)
    mel_spectrograms = tf.tensordot(spectrograms, linear_to_mel_weight_matrix,
                                    1)
    mel_spectrograms.set_shape(spectrograms.shape[:-1].concatenate(
        linear_to_mel_weight_matrix.shape[-1:]))
    log_mel_spectrograms = tf.math.log(mel_spectrograms + 1e-6)
    mfccs = tf.expand_dims(
        tf.signal.mfccs_from_log_mel_spectrograms(log_mel_spectrograms)[
            -1, :99, :40], -1)

    # Load a model for speech command classification
    with tf.gfile.FastGFile(args.adv_model, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        with tf.variable_scope('Speech'):
            adv_logits, = tf.import_graph_def(graph_def,
                                              input_map={'Mfcc:0': mfccs},
                                              return_elements=['add_2:0'])

    # Load labels for speech command classification
    adv_labels = [line.rstrip() for line in tf.gfile.GFile(args.adv_label)]
    adv_index = adv_labels.index(args.adv_target)

    # Make adversarial loss
    # Came from: https://github.com/carlini/nn_robust_attacks/blob/master/l2_attack.py
    adv_targets = tf.one_hot(
        tf.constant([adv_index] * args.train_batch_size, dtype=tf.int32),
        len(adv_labels))
    adv_target_logit = tf.reduce_sum(adv_targets * adv_logits, 1)
    adv_others_logit = tf.reduce_max(
        (1 - adv_targets) * adv_logits - (adv_targets * 10000), 1)

    adv_loss = tf.reduce_mean(
        tf.maximum(0.0,
                   adv_others_logit - adv_target_logit + args.adv_confidence))

    # Summarize audios
    tf.summary.audio('adv_input',
                     adv_input,
                     fs,
                     max_outputs=args.adv_max_outputs)
    tf.summary.scalar('adv_loss', adv_loss)
    tf.summary.histogram('adv_classes', tf.argmax(adv_logits, axis=1))

    # Create (recommended) optimizer
    if args.wavegan_loss == 'dcgan':
        G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
        D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
    elif args.wavegan_loss == 'lsgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
    elif args.wavegan_loss == 'wgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
    elif args.wavegan_loss == 'wgan-gp':
        G_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
        D_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
    else:
        raise NotImplementedError()

    # Create training ops
    G_train_op = G_opt.minimize(
        G_loss + args.adv_lambda * adv_loss,
        var_list=G_vars,
        global_step=tf.train.get_or_create_global_step())
    D_train_op = D_opt.minimize(D_loss, var_list=D_vars)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    # Run training
    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=args.train_dir,
            config=config,
            save_checkpoint_secs=args.train_save_secs,
            save_summaries_secs=args.train_summary_secs) as sess:
        print('-' * 80)
        print(
            'Training has started. Please use \'tensorboard --logdir={}\' to monitor.'
            .format(args.train_dir))
        while True:
            # Train discriminator
            for i in xrange(args.wavegan_disc_nupdates):
                sess.run(D_train_op)

                # Enforce Lipschitz constraint for WGAN
                if D_clip_weights is not None:
                    sess.run(D_clip_weights)

            # Train generator
            sess.run(G_train_op)
예제 #7
0
    print('Saving configurations...')
    config_path = os.path.join(model_dir, 'config.json')
    with open(config_path, 'w') as f:
        json.dump(args, f)

    # Try on some training data
    print('Loading audio data...')
    audio_filepaths = get_all_audio_filepaths(args['audio_dir'])
    train_gen, valid_data, test_data \
        = create_data_split(audio_filepaths, args['valid_ratio'], args['test_ratio'],
                            batch_size, batch_size, batch_size)

    print('Creating models...')
    model_gen = WaveGANGenerator(model_size=model_size,
                                 ngpus=ngpus,
                                 latent_dim=latent_dim,
                                 post_proc_filt_len=args['post_proc_filt_len'],
                                 upsample=True)
    model_dis = WaveGANDiscriminator(model_size=model_size,
                                     ngpus=ngpus,
                                     alpha=args['alpha'],
                                     shift_factor=args['shift_factor'],
                                     batch_shuffle=args['batch_shuffle'])

    print('Starting training...')
    model_gen, model_dis, history, final_discr_metrics, samples = train_wgan(
        model_gen=model_gen,
        model_dis=model_dis,
        train_gen=train_gen,
        valid_data=valid_data,
        test_data=test_data,
예제 #8
0
def train(fps, args):
  with tf.name_scope('loader'):
    x = loader.get_batch(fps, args.train_batch_size, _WINDOW_LEN, args.data_first_window)

  # Make z vector
  if args.use_sequence:
    z = tf.random_uniform([args.train_batch_size, 16, args.d_z], -1., 1., dtype=tf.float32)
  else:
    z = tf.random_uniform([args.train_batch_size, args.d_z], -1., 1., dtype=tf.float32)#tf.random_normal([args.train_batch_size, _D_Z])

  # Make generator
  with tf.variable_scope('G'):
    gru_layer = tf.keras.layers.CuDNNGRU(args.d_z, return_sequences=True)
    G_z, gru = WaveGANGenerator(z, gru_layer=gru_layer, train=True, return_gru=True, reuse=False, 
                                use_sequence=args.use_sequence, **args.wavegan_g_kwargs)
    print('G_z.shape:',G_z.get_shape().as_list())
    if args.wavegan_genr_pp:
      with tf.variable_scope('pp_filt'):
        G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same')
  G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')
  G_var_names = [g_var.name for g_var in G_vars]

  # Print G summary
  print('-' * 80)
  print('Generator vars')
  nparams = 0
  for v in G_vars:
    v_shape = v.get_shape().as_list()
    v_n = reduce(lambda x, y: x * y, v_shape)
    nparams += v_n
    print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
  print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024)))


  extra_secs = 1
  if not args.use_sequence:
    z_feed_long = z
  else:
    added_noise = tf.random_uniform([args.train_batch_size, 16*extra_secs, args.d_z], -1., 1., dtype=tf.float32)
    z_feed_long = tf.concat([z, added_noise], axis=1)

  with tf.variable_scope('G', reuse=True):
    #gru_layer.reset_states()
    G_z_long, gru_long = WaveGANGenerator(z_feed_long, gru_layer=gru_layer, train=False, length=16*extra_secs, 
                                          return_gru=True, 
                                          reuse=True, use_sequence=args.use_sequence, **args.wavegan_g_kwargs)
    print('G_z_long.shape:',G_z_long.get_shape().as_list())
    if args.wavegan_genr_pp:
      with tf.variable_scope('pp_filt', reuse=True):
        G_z_long = tf.layers.conv1d(G_z_long, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same')
        
    

  # Summarize
  tf.summary.audio('x', x, _FS)
  tf.summary.audio('G_z', G_z, _FS)
  tf.summary.audio('G_z_long', G_z_long, _FS)
  G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1))
  x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1))
  tf.summary.histogram('x_rms_batch', x_rms)
  tf.summary.histogram('G_z_rms_batch', G_z_rms)
  tf.summary.scalar('x_rms', tf.reduce_mean(x_rms))
  tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms))

  # Make real discriminator
  with tf.name_scope('D_x'), tf.variable_scope('D'):
    D_x = WaveGANDiscriminator(x, **args.wavegan_d_kwargs)
  D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')
  print('D_vars:', D_vars)

  # Print D summary
  print('-' * 80)
  print('Discriminator vars')
  nparams = 0
  for v in D_vars:
    v_shape = v.get_shape().as_list()
    v_n = reduce(lambda x, y: x * y, v_shape)
    nparams += v_n
    print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
  print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024)))
  print('-' * 80)

  # Make fake discriminator
  with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
    D_G_z = WaveGANDiscriminator(G_z, **args.wavegan_d_kwargs)

  # Create loss
  D_clip_weights = None
  if args.wavegan_loss == 'dcgan':
    fake = tf.zeros([args.train_batch_size], dtype=tf.float32)
    real = tf.ones([args.train_batch_size], dtype=tf.float32)

    G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
      logits=D_G_z,
      labels=real
    ))

    D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
      logits=D_G_z,
      labels=fake
    ))
    D_loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
      logits=D_x,
      labels=real
    ))

    D_loss /= 2.
  elif args.wavegan_loss == 'lsgan':
    G_loss = tf.reduce_mean((D_G_z - 1.) ** 2)
    D_loss = tf.reduce_mean((D_x - 1.) ** 2)
    D_loss += tf.reduce_mean(D_G_z ** 2)
    D_loss /= 2.
  elif args.wavegan_loss == 'wgan':
    G_loss = -tf.reduce_mean(D_G_z)
    D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

    with tf.name_scope('D_clip_weights'):
      clip_ops = []
      for var in D_vars:
        clip_bounds = [-.01, .01]
        clip_ops.append(
          tf.assign(
            var,
            tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])
          )
        )
      D_clip_weights = tf.group(*clip_ops)
  elif args.wavegan_loss == 'wgan-gp':
    G_loss = -tf.reduce_mean(D_G_z)# - D_x)#-tf.reduce_mean(D_G_z) + tf.reduce_mean(D_x)
    D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)# - tf.reduce_mean()

    alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1], minval=0., maxval=1.)
    differences = G_z - x
    interpolates = x + (alpha * differences)
    with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True): #
      #stft = tf.log1p(tf.abs(tf.contrib.signal.stft(interpolates[:,:,0], 512,128,fft_length=512)[:,:,:,tf.newaxis]))
    
      #D_interp = WaveGANDiscriminator(interpolates, x_cqt=stft, **args.wavegan_d_kwargs)
      #D_interp = tf.reduce_sum(tf.log1p(tf.abs(tf.contrib.signal.stft(interpolates[:,:,0], 2048,512,fft_length=2048)[:,:,:,tf.newaxis])))
      D_interp = WaveGANDiscriminator(interpolates, **args.wavegan_d_kwargs)
      

    LAMBDA = 10
    gradients = tf.gradients(D_interp, [interpolates])[0]
    print('gradients:', gradients)
    slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
    gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2.)
    D_loss += LAMBDA * gradient_penalty
  else:
    raise NotImplementedError()

  tf.summary.scalar('G_loss', G_loss)
  tf.summary.scalar('D_loss', D_loss)

  # Create (recommended) optimizer
  if args.wavegan_loss == 'dcgan':
    G_opt = tf.train.AdamOptimizer(
        learning_rate=2e-4,
        beta1=0.5)
    D_opt = tf.train.AdamOptimizer(
        learning_rate=2e-4,
        beta1=0.5)
  elif args.wavegan_loss == 'lsgan':
    G_opt = tf.train.RMSPropOptimizer(
        learning_rate=1e-4)
    D_opt = tf.train.RMSPropOptimizer(
        learning_rate=1e-4)
  elif args.wavegan_loss == 'wgan':
    G_opt = tf.train.RMSPropOptimizer(
        learning_rate=5e-5)
    D_opt = tf.train.RMSPropOptimizer(
        learning_rate=5e-5)
  elif args.wavegan_loss == 'wgan-gp':
    my_learning_rate = tf.train.exponential_decay(1e-4, 
                                                  tf.get_collection(tf.GraphKeys.GLOBAL_STEP), 
                                                  decay_steps=100000,
                                                  decay_rate=0.5)

    G_opt = tf.train.AdamOptimizer(
        learning_rate=my_learning_rate,
        beta1=0.5,
        beta2=0.9)
    D_opt = tf.train.AdamOptimizer(
        learning_rate=my_learning_rate,
        beta1=0.5,
        beta2=0.9)
  else:
    raise NotImplementedError()

  # Create training ops
  G_train_op = G_opt.minimize(G_loss, var_list=G_vars,
      global_step=tf.train.get_or_create_global_step())
  D_train_op = D_opt.minimize(D_loss, var_list=D_vars)

  saver = tf.train.Saver(max_to_keep=10)
    
  #tf_max, tf_min = tf.reduce_max(x[:,:,0], axis=-1), tf.reduce_min(x[:,:,0], axis=-1)
  
  global_step = tf.get_collection(tf.GraphKeys.GLOBAL_STEP)
    
  # Run training
  with tf.train.MonitoredTrainingSession(
      scaffold=tf.train.Scaffold(saver=saver),
      checkpoint_dir=args.train_dir,
      save_checkpoint_secs=args.train_save_secs,
      save_summaries_secs=args.train_summary_secs) as sess:
    #saver.restore(sess, tf.train.latest_checkpoint(args.train_dir))
    iterator_count = 0
    while True:
      # Train discriminator
      for i in xrange(args.wavegan_disc_nupdates):
        sess.run(D_train_op)

        # Enforce Lipschitz constraint for WGAN
        if D_clip_weights is not None:
          sess.run(D_clip_weights)

      # Train generator
      #_, g_losses, d_losses, gru_, gru_long_ = sess.run([G_train_op, G_loss, D_loss, gru, gru_long])
      _, g_losses, d_losses, global_step_ = sess.run([G_train_op, G_loss, D_loss, global_step])
      print('i:', global_step_[0], 'G_loss:', g_losses, 'D_loss:', d_losses)
      if iterator_count == 0:
        G_var_dict = {}
        G_vars_np = sess.run(G_vars)
        for g_var_name, g_var in zip(G_var_names, G_vars_np):
            G_var_dict[g_var_name] = g_var
        with open('saved_G_vars_iteration-{}.pkl'.format(global_step_[0]), 'wb') as f:
            pickle.dump(G_var_dict, f)
      #print('maxs:', maxs)
      #print('mins:', mins)
      #print(gru_[0])
      #print(gru_long_[0])
      iterator_count += 1
예제 #9
0
def train(fps, args):
    with tf.name_scope('loader'):
        x = loader.decode_extract_and_batch(
            fps,
            batch_size=args.train_batch_size,
            slice_len=args.data_slice_len,
            decode_fs=args.data_sample_rate,
            decode_num_channels=args.data_num_channels,
            decode_fast_wav=args.data_fast_wav,
            decode_parallel_calls=4,
            slice_randomize_offset=False if args.data_first_slice else True,
            slice_first_only=args.data_first_slice,
            slice_overlap_ratio=0.
            if args.data_first_slice else args.data_overlap_ratio,
            slice_pad_end=True if args.data_first_slice else args.data_pad_end,
            repeat=True,
            shuffle=True,
            shuffle_buffer_size=4096,
            prefetch_size=args.train_batch_size * 4,
            prefetch_gpu_num=args.data_prefetch_gpu_num)[:, :, 0]

    # Make z vector
    z = tf.random_uniform([args.train_batch_size, args.wavegan_latent_dim],
                          -1.,
                          1.,
                          dtype=tf.float32)

    # Make generator
    with tf.variable_scope('G'):
        G_z = WaveGANGenerator(z, train=True, **args.wavegan_g_kwargs)
        if args.wavegan_genr_pp:
            with tf.variable_scope('pp_filt'):
                G_z = tf.layers.conv1d(G_z,
                                       1,
                                       args.wavegan_genr_pp_len,
                                       use_bias=False,
                                       padding='same')
    G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')

    # Print G summary
    print('-' * 80)
    print('Generator vars')
    nparams = 0
    for v in G_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))

    # Summarize
    tf.summary.audio('x', x, args.data_sample_rate)
    tf.summary.audio('G_z', G_z, args.data_sample_rate)
    G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1))
    x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1))
    tf.summary.histogram('x_rms_batch', x_rms)
    tf.summary.histogram('G_z_rms_batch', G_z_rms)
    tf.summary.scalar('x_rms', tf.reduce_mean(x_rms))
    tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms))

    # Make real discriminator
    with tf.name_scope('D_x'), tf.variable_scope('D'):
        D_x = WaveGANDiscriminator(x, **args.wavegan_d_kwargs)
    D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')

    # Print D summary
    print('-' * 80)
    print('Discriminator vars')
    nparams = 0
    for v in D_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))
    print('-' * 80)

    # Make fake discriminator
    with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
        D_G_z = WaveGANDiscriminator(G_z, **args.wavegan_d_kwargs)

    # Create loss
    D_clip_weights = None
    if args.wavegan_loss == 'dcgan':
        fake = tf.zeros([args.train_batch_size], dtype=tf.float32)
        real = tf.ones([args.train_batch_size], dtype=tf.float32)

        G_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=real))

        D_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=fake))
        D_loss += tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x, labels=real))

        D_loss /= 2.
    elif args.wavegan_loss == 'lsgan':
        G_loss = tf.reduce_mean((D_G_z - 1.)**2)
        D_loss = tf.reduce_mean((D_x - 1.)**2)
        D_loss += tf.reduce_mean(D_G_z**2)
        D_loss /= 2.
    elif args.wavegan_loss == 'wgan':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        with tf.name_scope('D_clip_weights'):
            clip_ops = []
            for var in D_vars:
                clip_bounds = [-.01, .01]
                clip_ops.append(
                    tf.assign(
                        var,
                        tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])))
            D_clip_weights = tf.group(*clip_ops)
    elif args.wavegan_loss == 'wgan-gp':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1],
                                  minval=0.,
                                  maxval=1.)
        differences = G_z - x
        interpolates = x + (alpha * differences)
        with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
            D_interp = WaveGANDiscriminator(interpolates,
                                            **args.wavegan_d_kwargs)

        LAMBDA = 10
        gradients = tf.gradients(D_interp, [interpolates])[0]
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
        gradient_penalty = tf.reduce_mean((slopes - 1.)**2.)
        D_loss += LAMBDA * gradient_penalty
    else:
        raise NotImplementedError()

    tf.summary.scalar('G_loss', G_loss)
    tf.summary.scalar('D_loss', D_loss)

    # Create (recommended) optimizer
    if args.wavegan_loss == 'dcgan':
        G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
        D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
    elif args.wavegan_loss == 'lsgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
    elif args.wavegan_loss == 'wgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
    elif args.wavegan_loss == 'wgan-gp':
        G_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
        D_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
    else:
        raise NotImplementedError()

    # Create training ops
    G_train_op = G_opt.minimize(
        G_loss,
        var_list=G_vars,
        global_step=tf.train.get_or_create_global_step())
    D_train_op = D_opt.minimize(D_loss, var_list=D_vars)

    # Run training
    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=args.train_dir,
            save_checkpoint_secs=args.train_save_secs,
            save_summaries_secs=args.train_summary_secs) as sess:
        print('-' * 80)
        print(
            'Training has started. Please use \'tensorboard --logdir={}\' to monitor.'
            .format(args.train_dir))
        counter = 0
        while True:
            # Train discriminator
            for i in xrange(args.wavegan_disc_nupdates):
                sess.run(D_train_op)
                print("Ran D_train_op")

                # Enforce Lipschitz constraint for WGAN
                if D_clip_weights is not None:
                    sess.run(D_clip_weights)
                    print("Ran D_clip_weights")

            # Train generator
            sess.run(G_train_op)
            print("Ran G_train_op")
            counter += 1
            print("Iteration: " + str(counter))
        print("DONE TRAINING")
예제 #10
0
def train(fps, args):
    # global train_dataset_size
    # global train_data_percentage
    train_data_percentage = args.train_data_percentage
    with tf.name_scope('loader'):
        # This was actually not necessarily good. However, we can keep it as a point for 115 tfrecords
        # train_fps, _ = loader.split_files_test_val(fps, train_data_percentage, 0)
        # fps = train_fps
        # fps = fps[:gan_train_data_size]

        logging.info("Full training datasize = " +
                     str(find_data_size(fps, None)))
        length = len(fps)
        fps = fps[:(int(train_data_percentage / 100.0 * length))]
        logging.info("GAN training datasize (before exclude) = " +
                     str(find_data_size(fps, None)))

        if args.exclude_class is None:
            pass
        elif args.exclude_class != -1:
            train_dataset_size = find_data_size(fps, args.exclude_class)
            logging.info("GAN training datasize (after exclude) = " +
                         str(train_dataset_size))
        elif args.exclude_class == -1:
            fps, _ = loader.split_files_test_val(fps, 0.9, 0)
            train_dataset_size = find_data_size(fps, args.exclude_class)
            logging.info(
                "GAN training datasize (after exclude - random sampling) = " +
                str(train_dataset_size))
        else:  # LOL :P
            raise ValueError(
                "args.exclude_class should be either [0, num_class), None, or -1 for random sampling 90%"
            )

        training_iterator = loader.get_batch(fps,
                                             args.train_batch_size,
                                             _WINDOW_LEN,
                                             args.data_first_window,
                                             repeat=True,
                                             initializable=True,
                                             labels=True,
                                             exclude_class=args.exclude_class)
        x, _ = training_iterator.get_next()  # Important: ignore the labels
        print("x_wav.shape = %s" % str(x.shape))

        logging.info("train_dataset_size = " + str(train_dataset_size))
    # Make z vector
    z = tf.random_uniform([args.train_batch_size, _D_Z],
                          -1.,
                          1.,
                          dtype=tf.float32)

    # Make generator
    with tf.variable_scope('G'):
        G_z = WaveGANGenerator(z, train=True, **args.wavegan_g_kwargs)
        if args.wavegan_genr_pp:
            with tf.variable_scope('pp_filt'):
                G_z = tf.layers.conv1d(G_z,
                                       1,
                                       args.wavegan_genr_pp_len,
                                       use_bias=False,
                                       padding='same')
    G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')

    # Print G summary
    print '-' * 80
    print 'Generator vars'
    nparams = 0
    for v in G_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print '{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)
    print 'Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024))

    # Summarize
    tf.summary.audio('x', x, _FS)
    tf.summary.audio('G_z', G_z, _FS)
    G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1))
    x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1))
    tf.summary.histogram('x_rms_batch', x_rms)
    tf.summary.histogram('G_z_rms_batch', G_z_rms)
    tf.summary.scalar('x_rms', tf.reduce_mean(x_rms))
    tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms))

    # Make real discriminator
    with tf.name_scope('D_x'), tf.variable_scope('D'):
        D_x = WaveGANDiscriminator(x, **args.wavegan_d_kwargs)
    D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')

    # Print D summary
    print '-' * 80
    print 'Discriminator vars'
    nparams = 0
    for v in D_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print '{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)
    print 'Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024))
    print '-' * 80

    # Make fake discriminator
    with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
        D_G_z = WaveGANDiscriminator(G_z, **args.wavegan_d_kwargs)

    # Create loss
    D_clip_weights = None
    if args.wavegan_loss == 'dcgan':
        fake = tf.zeros([args.train_batch_size], dtype=tf.float32)
        real = tf.ones([args.train_batch_size], dtype=tf.float32)

        G_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=real))

        D_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z, labels=fake))
        D_loss += tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x, labels=real))

        D_loss /= 2.
    elif args.wavegan_loss == 'lsgan':
        G_loss = tf.reduce_mean((D_G_z - 1.)**2)
        D_loss = tf.reduce_mean((D_x - 1.)**2)
        D_loss += tf.reduce_mean(D_G_z**2)
        D_loss /= 2.
    elif args.wavegan_loss == 'wgan':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        with tf.name_scope('D_clip_weights'):
            clip_ops = []
            for var in D_vars:
                clip_bounds = [-.01, .01]
                clip_ops.append(
                    tf.assign(
                        var,
                        tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])))
            D_clip_weights = tf.group(*clip_ops)
    elif args.wavegan_loss == 'wgan-gp':
        G_loss = -tf.reduce_mean(D_G_z)
        D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)

        alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1],
                                  minval=0.,
                                  maxval=1.)
        differences = G_z - x
        interpolates = x + (alpha * differences)
        with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
            D_interp = WaveGANDiscriminator(interpolates,
                                            **args.wavegan_d_kwargs)

        LAMBDA = 10
        gradients = tf.gradients(D_interp, [interpolates])[0]
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
        gradient_penalty = tf.reduce_mean((slopes - 1.)**2.)
        D_loss += LAMBDA * gradient_penalty
    else:
        raise NotImplementedError()

    tf.summary.scalar('G_loss', G_loss)
    tf.summary.scalar('D_loss', D_loss)

    # Create (recommended) optimizer
    if args.wavegan_loss == 'dcgan':
        G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
        D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
    elif args.wavegan_loss == 'lsgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
    elif args.wavegan_loss == 'wgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
    elif args.wavegan_loss == 'wgan-gp':
        G_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
        D_opt = tf.train.AdamOptimizer(learning_rate=1e-4,
                                       beta1=0.5,
                                       beta2=0.9)
    else:
        raise NotImplementedError()

    # Create training ops
    G_train_op = G_opt.minimize(
        G_loss,
        var_list=G_vars,
        global_step=tf.train.get_or_create_global_step())
    D_train_op = D_opt.minimize(D_loss, var_list=D_vars)

    # Run training
    current_step = -1
    scaffold = tf.train.Scaffold(local_init_op=tf.group(
        tf.local_variables_initializer(), training_iterator.initializer),
                                 saver=tf.train.Saver(max_to_keep=5))

    with tf.train.MonitoredTrainingSession(
            hooks=[SaveAtEnd(os.path.join(args.train_dir, 'model'))],
            scaffold=scaffold,
            checkpoint_dir=args.train_dir,
            save_checkpoint_secs=args.train_save_secs,
            save_summaries_secs=args.train_summary_secs) as sess:
        while True:
            global_step = sess.run(tf.train.get_or_create_global_step())
            logging.info("Global step: " + str(global_step))

            if args.stop_at_global_step != 0 and global_step >= args.stop_at_global_step:
                logging.info(
                    "Stopping because args.stop_at_global_step is set to  " +
                    str(args.stop_at_global_step))
                break
            # Train discriminator
            # for i in range(args.wavegan_disc_nupdates):
            #   try:
            #     sess.run(D_train_op)
            #     current_step += 1
            #
            #     # Stop training after x% of training data seen
            #     if current_step * args.train_batch_size > math.ceil(train_dataset_size * train_data_percentage / 100.0):
            #       print("Stopping at batch: " + str(current_step))
            #       current_step = -1
            #       sess.run(training_iterator.initializer)
            #
            #   except tf.errors.OutOfRangeError:
            #     # End of training dataset
            #     if train_data_percentage != 100:
            #       print(
            #         "ERROR: end of dataset for only part of data! Achieved end of training dataset with train_data_percentage = " + str(
            #           train_data_percentage))
            #     else:
            #       current_step = -1
            #       sess.run(training_iterator.initializer)

            try:
                for i in range(args.wavegan_disc_nupdates):
                    sess.run(D_train_op)

                # Enforce Lipschitz constraint for WGAN
                if D_clip_weights is not None:
                    sess.run(D_clip_weights)
            except tf.errors.OutOfRangeError:
                sess.run(training_iterator.initializer)

            # Train generator
            sess.run(G_train_op)
예제 #11
0
args['model_dir'] = model_dir
# save samples for every N epochs.
epochs_per_sample = args['epochs_per_sample']
# gradient penalty regularization factor.
lmbda = args['lmbda']


# Dir
audio_dir = args['audio_dir']
output_dir = args['output_dir']



# =============Network===============
netG = WaveGANGenerator(model_size=model_size, ngpus=ngpus, latent_dim=latent_dim, upsample=True)
netD = WaveGANDiscriminator(model_size=model_size, ngpus=ngpus)

if cuda:
    netG = torch.nn.DataParallel(netG).cuda()
    netD = torch.nn.DataParallel(netD).cuda()

# "Two time-scale update rule"(TTUR) to update netD 4x faster than netG.
optimizerG = optim.Adam(netG.parameters(), lr=args['learning_rate'], betas=(args['beta1'], args['beta2']))
optimizerD = optim.Adam(netD.parameters(), lr=args['learning_rate'], betas=(args['beta1'], args['beta2']))

# Sample noise used for generated output.
sample_noise = torch.randn(args['sample_size'], latent_dim)
if cuda:
    sample_noise = sample_noise.cuda()
sample_noise_Var = autograd.Variable(sample_noise, requires_grad=False)
예제 #12
0
class Word2Wave(nn.Module):
    def __init__(self, args):
        super(Word2Wave, self).__init__()
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.coala_model_name = args.coala_model_name
        self.wavegan_path = args.wavegan_path

        self.load_wavegan()
        self.load_coala()
        self.init_latents()

    def load_wavegan(self, slice_len=16384, model_size=32):
        path_to_model = os.path.join(self.wavegan_path)
        self.generator = WaveGANGenerator(slice_len=slice_len,
                                          model_size=model_size,
                                          use_batch_norm=False,
                                          num_channels=1)
        checkpoint = torch.load(path_to_model, map_location=self.device)
        self.generator.load_state_dict(checkpoint['generator'])

    def load_coala(self):
        coala_path = os.path.join("coala/models", self.coala_model_name)
        tag_encoder_url = "https://github.com/xavierfav/coala/blob/master/saved_models/{}/tag_encoder_epoch_200.pt".format(
            self.coala_model_name)
        audio_encoder_url = "https://github.com/xavierfav/coala/blob/master/saved_models/{}/audio_encoder_epoch_200.pt".format(
            self.coala_model_name)
        tag_encoder_path = os.path.join(coala_path,
                                        os.path.basename(tag_encoder_url))
        audio_encoder_path = os.path.join(coala_path,
                                          os.path.basename(audio_encoder_url))
        # TODO below does not work due to corrupted download - download manually instead
        if not os.path.exists(coala_path):
            os.mkdir(coala_path)
            logging.info("Downloading COALA model weights from {}".format(
                audio_encoder_url))
            urlretrieve(tag_encoder_url, tag_encoder_path)
            urlretrieve(audio_encoder_url, audio_encoder_path)

        self.tag_encoder = TagEncoder()
        self.tag_encoder.load_state_dict(torch.load(tag_encoder_path))
        self.tag_encoder.eval()

        self.audio_encoder = AudioEncoder()
        self.audio_encoder.load_state_dict(torch.load(audio_encoder_path))
        self.audio_encoder.eval()

        id2tag = json.load(open('coala/id2token_top_1000.json', 'rb'))
        self.tag2id = {tag: id for id, tag in id2tag.items()}

    def init_latents(self, size=1, latent_dim=100):
        noise = torch.FloatTensor(size, latent_dim)
        noise.data.normal_()
        self.latents = torch.nn.Parameter(noise)

    def tokenize_text(self, text_prompt):
        words_not_in_dict = [
            word for word in text_prompt.split(" ")
            if word not in self.tag2id.keys()
        ]
        words_in_dict = [
            word for word in text_prompt.split(" ")
            if word in self.tag2id.keys()
        ]
        tokenized_text = [int(self.tag2id[word]) for word in words_in_dict]
        return tokenized_text, words_in_dict, words_not_in_dict

    def encode_text(self, text_prompt):
        word_ids, _, _ = self.tokenize_text(text_prompt)
        sentence_embedding = torch.zeros(1152).to(self.device)

        tag_vector = torch.zeros(len(word_ids), 1000).to(self.device)
        for index, word in enumerate(word_ids):
            tag_vector[index, word] = 1

        embedding, embedding_d = self.tag_encoder(tag_vector)
        sentence_embedding = embedding_d.mean(dim=0)
        return sentence_embedding

    def encode_audio(self, audio):
        x = preprocess_audio(audio).to(self.device)
        scaler = pickle.load(open('coala/scaler_top_1000.pkl', 'rb'))
        x *= torch.tensor(scaler.scale_).to(self.device)
        x += torch.tensor(scaler.min_).to(self.device)
        x = torch.clamp(x, scaler.feature_range[0], scaler.feature_range[1])
        embedding, embedding_d = self.audio_encoder(
            x.unsqueeze(0).unsqueeze(0))
        return embedding_d

    def latent_space_interpolation(self, latents=None, n_samples=1):
        if latents is None:
            z_test = sample_noise(2)
        else:
            z_test = latents
        interpolates = []
        for alpha in np.linspace(0, 1, n_samples):
            interpolate_vec = alpha * z_test[0] + ((1 - alpha) * z_test[1])
            interpolates.append(interpolate_vec)
        interpolates = torch.stack(interpolates)
        generated_audio = self.generator(interpolates)
        return generated_audio

    def synthesise_audio(self, noise):
        generated_audio = self.generator(noise).view(-1)
        return generated_audio

    def coala_loss(self, audio, text):
        text_embedding = self.encode_text(text)
        audio_embedding = self.encode_audio(audio)

        text_embedding = text_embedding / text_embedding.norm()
        audio_embedding = audio_embedding / audio_embedding.norm()

        cos_dist = (1 - audio_embedding @ text_embedding.t()) / 2

        return cos_dist

    def forward(self, text):
        audio = self.generator(self.latents).view(-1)
        loss = self.coala_loss(audio, text)
        return audio, loss
예제 #13
0
def train(fps, args):
    with tf.name_scope('loader'):
        x, cond_text, _ = loader.get_batch(fps,
                                           args.train_batch_size,
                                           _WINDOW_LEN,
                                           args.data_first_window,
                                           conditionals=True,
                                           name='batch')
        wrong_audio = loader.get_batch(fps,
                                       args.train_batch_size,
                                       _WINDOW_LEN,
                                       args.data_first_window,
                                       conditionals=False,
                                       name='wrong_batch')
    # wrong_cond_text, wrong_cond_text_embed = loader.get_batch(fps, args.train_batch_size, _WINDOW_LEN, args.data_first_window, wavs=False, conditionals=True, name='batch')

    # Make z vector
    z = tf.random_normal([args.train_batch_size, _D_Z])

    embed = hub.Module('https://tfhub.dev/google/elmo/2',
                       trainable=False,
                       name='embed')
    cond_text_embed = embed(cond_text)

    # Add conditioning input to the model
    args.wavegan_g_kwargs['context_embedding'] = cond_text_embed
    args.wavegan_d_kwargs['context_embedding'] = args.wavegan_g_kwargs[
        'context_embedding']

    lod = tf.placeholder(tf.float32, shape=[])

    with tf.variable_scope('G'):
        # Make generator
        G_z, c_kl_loss = WaveGANGenerator(z,
                                          lod,
                                          train=True,
                                          **args.wavegan_g_kwargs)
        if args.wavegan_genr_pp:
            with tf.variable_scope('pp_filt'):
                G_z = tf.layers.conv1d(G_z,
                                       1,
                                       args.wavegan_genr_pp_len,
                                       use_bias=False,
                                       padding='same')

    # Summarize
    G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1))
    x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1))
    x_rms_lod_4 = tf.sqrt(
        tf.reduce_mean(tf.square(avg_downsample(x)[:, :, 0]), axis=1))
    x_rms_lod_3 = tf.sqrt(
        tf.reduce_mean(tf.square(avg_downsample(avg_downsample(x))[:, :, 0]),
                       axis=1))
    x_rms_lod_2 = tf.sqrt(
        tf.reduce_mean(tf.square(
            avg_downsample(avg_downsample(avg_downsample(x)))[:, :, 0]),
                       axis=1))
    x_rms_lod_1 = tf.sqrt(
        tf.reduce_mean(tf.square(
            avg_downsample(avg_downsample(avg_downsample(
                avg_downsample(x))))[:, :, 0]),
                       axis=1))
    x_rms_lod_0 = tf.sqrt(
        tf.reduce_mean(tf.square(
            avg_downsample(
                avg_downsample(
                    avg_downsample(avg_downsample(avg_downsample(x)))))[:, :,
                                                                        0]),
                       axis=1))
    tf.summary.histogram('x_rms_batch', x_rms)
    tf.summary.histogram('G_z_rms_batch', G_z_rms)
    tf.summary.scalar('x_rms', tf.reduce_mean(x_rms))
    tf.summary.scalar('x_rms_lod_4', tf.reduce_mean(x_rms_lod_4))
    tf.summary.scalar('x_rms_lod_3', tf.reduce_mean(x_rms_lod_3))
    tf.summary.scalar('x_rms_lod_2', tf.reduce_mean(x_rms_lod_2))
    tf.summary.scalar('x_rms_lod_1', tf.reduce_mean(x_rms_lod_1))
    tf.summary.scalar('x_rms_lod_0', tf.reduce_mean(x_rms_lod_0))
    tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms))
    tf.summary.audio('x', x, _FS, max_outputs=10)
    tf.summary.audio('G_z', G_z, _FS, max_outputs=10)
    tf.summary.text('Conditioning Text', cond_text[:10])

    # with tf.variable_scope('G'):
    #   # Make history buffer
    #   history_buffer = HistoryBuffer(_WINDOW_LEN, args.train_batch_size * 100, args.train_batch_size)

    #   # Select half of batch from history buffer
    #   g_from_history, r_from_history, embeds_from_history = history_buffer.get_from_history_buffer()
    #   new_fake_batch = tf.concat([G_z[:tf.shape(G_z)[0] - tf.shape(g_from_history)[0]], g_from_history], 0) # Use tf.shape to handle case when g_from_history is empty
    #   new_cond_embeds = tf.concat([cond_text_embed[:tf.shape(cond_text_embed)[0] - tf.shape(embeds_from_history)[0]], embeds_from_history], 0)
    #   new_real_batch = tf.concat([x[:tf.shape(x)[0] - tf.shape(r_from_history)[0]], r_from_history], 0)
    #   with tf.control_dependencies([new_fake_batch, new_real_batch, new_cond_embeds]):
    #     with tf.control_dependencies([history_buffer.add_to_history_buffer(G_z, x, cond_text_embed)]):
    #       G_z = tf.identity(new_fake_batch)
    #       x = tf.identity(new_real_batch)
    #       args.wavegan_g_kwargs['context_embedding'] = tf.identity(new_cond_embeds)
    #       args.wavegan_d_kwargs['context_embedding'] = args.wavegan_g_kwargs['context_embedding']
    #   G_z.set_shape([args.train_batch_size, _WINDOW_LEN, 1])
    #   x.set_shape([args.train_batch_size, _WINDOW_LEN, 1])

    G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')

    # Print G summary
    print('-' * 80)
    print('Generator vars')
    nparams = 0
    for v in G_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))

    # Summarize
    # tf.summary.scalar('history_buffer_size', history_buffer.current_size)
    # tf.summary.scalar('g_from_history_size', tf.shape(g_from_history)[0])
    # tf.summary.scalar('r_from_history_size', tf.shape(r_from_history)[0])
    # tf.summary.scalar('embeds_from_history_size', tf.shape(embeds_from_history)[0])
    # tf.summary.audio('G_z_history', g_from_history, _FS, max_outputs=10)
    # tf.summary.audio('x_history', r_from_history, _FS, max_outputs=10)
    tf.summary.audio('wrong_audio', wrong_audio, _FS, max_outputs=10)
    tf.summary.scalar('Conditional Resample - KL-Loss', c_kl_loss)
    # tf.summary.scalar('embed_error_cosine', tf.reduce_sum(tf.multiply(cond_text_embed, expected_embed)) / (tf.norm(cond_text_embed) * tf.norm(expected_embed)))
    # tf.summary.scalar('embed_error_cosine_wrong', tf.reduce_sum(tf.multiply(wrong_cond_text_embed, expected_embed)) / (tf.norm(wrong_cond_text_embed) * tf.norm(expected_embed)))

    # Make real discriminator
    with tf.name_scope('D_x'), tf.variable_scope('D'):
        D_x = WaveGANDiscriminator(x, lod, **args.wavegan_d_kwargs)
    D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')

    # Print D summary
    print('-' * 80)
    print('Discriminator vars')
    nparams = 0
    for v in D_vars:
        v_shape = v.get_shape().as_list()
        v_n = reduce(lambda x, y: x * y, v_shape)
        nparams += v_n
        print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
    print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) /
                                                (1024 * 1024)))
    print('-' * 80)

    # Make fake / wrong discriminator
    with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
        D_G_z = WaveGANDiscriminator(G_z, lod, **args.wavegan_d_kwargs)
    with tf.name_scope('D_w'), tf.variable_scope('D', reuse=True):
        D_w = WaveGANDiscriminator(wrong_audio, lod, **args.wavegan_d_kwargs)

    # Create loss
    D_clip_weights = None
    if args.wavegan_loss == 'dcgan':
        fake = tf.zeros([args.train_batch_size, 1], dtype=tf.float32)
        real = tf.ones([args.train_batch_size, 1], dtype=tf.float32)

        # Conditional G Loss
        G_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z[0],
                                                    labels=real))
        G_loss += c_kl_loss

        # Unconditional G Loss
        if args.use_extra_uncond_loss:
            G_loss += tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z[1],
                                                        labels=real))
            G_loss /= 2

        # Conditional D Losses
        D_loss_fake = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z[0],
                                                    labels=fake))
        D_loss_wrong = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_w[0],
                                                    labels=fake))
        D_loss_real = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x[0],
                                                    labels=real))

        # Unconditional D Losses
        if args.use_extra_uncond_loss:
            D_loss_fake_uncond = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(logits=D_G_z[1],
                                                        labels=fake))
            D_loss_wrong_uncond = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(logits=D_w[1],
                                                        labels=real))
            D_loss_real_uncond = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(logits=D_x[1],
                                                        labels=real))

            D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake) \
                   + 0.5 * (D_loss_real_uncond + D_loss_wrong_uncond) + D_loss_fake_uncond
            D_loss /= 2
        else:
            D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake)

        # Warmup Conditional Loss
        # D_warmup_loss = D_loss_real + D_loss_wrong
    elif args.wavegan_loss == 'lsgan':
        # Conditional G Loss
        G_loss = tf.reduce_mean((D_G_z[0] - 1.)**2)
        G_loss += c_kl_loss

        # Unconditional G Loss
        if args.use_extra_uncond_loss:
            G_loss += tf.reduce_mean((D_G_z[1] - 1.)**2)
            G_loss /= 2

        # Conditional D Loss
        D_loss_real = tf.reduce_mean((D_x[0] - 1.)**2)
        D_loss_wrong = tf.reduce_mean(D_w[0]**2)
        D_loss_fake = tf.reduce_mean(D_G_z[0]**2)

        # Unconditional D Loss
        if args.use_extra_uncond_loss:
            D_loss_real_uncond = tf.reduce_mean((D_x[1] - 1.)**2)
            D_loss_wrong_uncond = tf.reduce_mean((D_w[1] - 1.)**2)
            D_loss_fake_uncond = tf.reduce_mean(D_G_z[1]**2)

            D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake) \
                   + 0.5 * (D_loss_real_uncond + D_loss_wrong_uncond) + D_loss_fake_uncond
            D_loss /= 2
        else:
            D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake)

        # Warmup Conditional Loss
        # D_warmup_loss = D_loss_real + D_loss_wrong
    elif args.wavegan_loss == 'wgan':
        # Conditional G Loss
        G_loss = -tf.reduce_mean(D_G_z[0])
        G_loss += c_kl_loss

        # Unconditional G Loss
        if args.use_extra_uncond_loss:
            G_loss += -tf.reduce_mean(D_G_z[1])
            G_loss /= 2

        # Conditional D Loss
        D_loss_real = -tf.reduce_mean(D_x[0])
        D_loss_wrong = tf.reduce_mean(D_w[0])
        D_loss_fake = tf.reduce_mean(D_G_z[0])

        # Unconditional D Loss
        if args.use_extra_uncond_loss:
            D_loss_real_uncond = -tf.reduce_mean(D_x[1])
            D_loss_wrong_uncond = -tf.reduce_mean(D_w[1])
            D_loss_fake_uncond = tf.reduce_mean(D_G_z[1])

            D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake) \
                   + 0.5 * (D_loss_real_uncond + D_loss_wrong_uncond) + D_loss_fake_uncond
            D_loss /= 2
        else:
            D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake)

        # Warmup Conditional Loss
        # D_warmup_loss = D_loss_real + D_loss_wrong

        with tf.name_scope('D_clip_weights'):
            clip_ops = []
            for var in D_vars:
                clip_bounds = [-.01, .01]
                clip_ops.append(
                    tf.assign(
                        var,
                        tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])))
            D_clip_weights = tf.group(*clip_ops)
    elif args.wavegan_loss == 'wgan-gp':
        # Conditional G Loss
        G_loss = -tf.reduce_mean(D_G_z[0])
        G_loss += c_kl_loss

        # Unconditional G Loss
        if args.use_extra_uncond_loss:
            G_loss += -tf.reduce_mean(D_G_z[1])
            G_loss /= 2

        # Conditional D Loss
        D_loss_real = -tf.reduce_mean(D_x[0])
        D_loss_wrong = tf.reduce_mean(D_w[0])
        D_loss_fake = tf.reduce_mean(D_G_z[0])

        # Unconditional D Loss
        if args.use_extra_uncond_loss:
            D_loss_real_uncond = -tf.reduce_mean(D_x[1])
            D_loss_wrong_uncond = -tf.reduce_mean(D_w[1])
            D_loss_fake_uncond = tf.reduce_mean(D_G_z[1])

            D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake) \
                   + 0.5 * (D_loss_real_uncond + D_loss_wrong_uncond) + D_loss_fake_uncond
            D_loss /= 2
        else:
            D_loss = D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake)

        # Warmup Conditional Loss
        # D_warmup_loss = D_loss_real + D_loss_wrong

        # Conditional Gradient Penalty
        alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1],
                                  minval=0.,
                                  maxval=1.)
        real = x
        fake = tf.concat([
            G_z[:args.train_batch_size // 2],
            wrong_audio[:args.train_batch_size // 2]
        ], 0)
        differences = fake - real
        interpolates = real + (alpha * differences)
        with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
            D_interp = WaveGANDiscriminator(
                interpolates, lod,
                **args.wavegan_d_kwargs)[0]  # Only want conditional output
        gradients = tf.gradients(D_interp, [interpolates])[0]
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
        cond_gradient_penalty = tf.reduce_mean((slopes - 1.)**2.)

        # Unconditional Gradient Penalty
        alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1],
                                  minval=0.,
                                  maxval=1.)
        real = tf.concat([
            x[:args.train_batch_size // 2],
            wrong_audio[:args.train_batch_size // 2]
        ], 0)
        fake = G_z
        differences = fake - real
        interpolates = real + (alpha * differences)
        with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
            D_interp = WaveGANDiscriminator(
                interpolates, lod,
                **args.wavegan_d_kwargs)[1]  # Only want unconditional output
        gradients = tf.gradients(D_interp, [interpolates])[0]
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
        uncond_gradient_penalty = tf.reduce_mean((slopes - 1.)**2.)

        # Warmup Gradient Penalty
        # alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1], minval=0., maxval=1.)
        # real = x
        # fake = wrong_audio
        # differences = fake - real
        # interpolates = real + (alpha * differences)
        # with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
        #   D_interp = WaveGANDiscriminator(interpolates, lod, **args.wavegan_d_kwargs)[0] # Only want conditional output
        # gradients = tf.gradients(D_interp, [interpolates])[0]
        # slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
        # warmup_gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2.)

        gradient_penalty = (cond_gradient_penalty +
                            uncond_gradient_penalty) / 2

        LAMBDA = 10
        D_loss += LAMBDA * gradient_penalty
        # D_warmup_loss += LAMBDA * warmup_gradient_penalty
    else:
        raise NotImplementedError()

    tf.summary.scalar('G_loss', G_loss)
    if (args.wavegan_loss == 'wgan-gp'):
        tf.summary.scalar('Gradient Penalty', LAMBDA * gradient_penalty)
    if (args.wavegan_loss == 'wgan' or args.wavegan_loss == 'wgan-gp'):
        if args.use_extra_uncond_loss:
            tf.summary.scalar('Critic Score - Real Data - Condition Match',
                              -D_loss_real)
            tf.summary.scalar('Critic Score - Fake Data - Condition Match',
                              D_loss_fake)
            tf.summary.scalar('Critic Score - Wrong Data - Condition Match',
                              D_loss_wrong)
            tf.summary.scalar('Critic Score - Real Data', -D_loss_real_uncond)
            tf.summary.scalar('Critic Score - Wrong Data',
                              -D_loss_wrong_uncond)
            tf.summary.scalar('Critic Score - Fake Data', D_loss_fake_uncond)
            tf.summary.scalar('Wasserstein Distance - No Regularization Term',
                              -((D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake) \
                               + 0.5 * (D_loss_real_uncond + D_loss_wrong_uncond) + D_loss_fake_uncond) / 2))
            tf.summary.scalar('Wasserstein Distance - Real-Wrong Only',
                              -(D_loss_real + D_loss_wrong))
            tf.summary.scalar('Wasserstein Distance - Real-Fake Only',
                              -((D_loss_real + D_loss_fake \
                               + D_loss_real_uncond + D_loss_fake_uncond) / 2))
        else:
            tf.summary.scalar('Critic Score - Real Data', -D_loss_real)
            tf.summary.scalar('Critic Score - Wrong Data', D_loss_wrong)
            tf.summary.scalar('Critic Score - Fake Data', D_loss_fake)
            tf.summary.scalar(
                'Wasserstein Distance - No Regularization Term',
                -(D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake)))
        tf.summary.scalar('Wasserstein Distance - With Regularization Term',
                          -D_loss)
    else:
        if args.use_extra_uncond_loss:
            tf.summary.scalar('D_acc_uncond', 0.5 * ((0.5 * (tf.reduce_mean(tf.sigmoid(D_x[1])) + tf.reduce_mean(tf.sigmoid(D_w[1])))) \
                                                   + tf.reduce_mean(1 - tf.sigmoid(D_G_z[1]))))
            tf.summary.scalar('D_acc', 0.5 * (tf.reduce_mean(tf.sigmoid(D_x[0])) \
                                            + 0.5 * (tf.reduce_mean(1 - tf.sigmoid(D_w[0])) + tf.reduce_mean(1 - tf.sigmoid(D_G_z[0])))))
            tf.summary.scalar('D_acc_real_wrong_only', 0.5 * (tf.reduce_mean(tf.sigmoid(D_x[0])) \
                                                            + tf.reduce_mean(1 - tf.sigmoid(D_w[0]))))
            tf.summary.scalar('D_loss_cond_real', D_loss_real)
            tf.summary.scalar('D_loss_uncond_real', D_loss_real_uncond)
            tf.summary.scalar('D_loss_cond_wrong', D_loss_wrong)
            tf.summary.scalar('D_loss_uncond_wrong', D_loss_wrong_uncond)
            tf.summary.scalar('D_loss_cond_fake', D_loss_fake)
            tf.summary.scalar('D_loss_uncond_fake', D_loss_fake_uncond)
            tf.summary.scalar('D_loss_unregularized',
                               (D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake) \
                              + 0.5 * (D_loss_real_uncond + D_loss_wrong_uncond) + D_loss_fake_uncond) / 2)
        else:
            tf.summary.scalar('D_acc', 0.5 * (tf.reduce_mean(tf.sigmoid(D_x[0])) \
                                            + 0.5 * (tf.reduce_mean(1 - tf.sigmoid(D_w[0])) + tf.reduce_mean(1 - tf.sigmoid(D_G_z[0])))))
            tf.summary.scalar('D_loss_real', D_loss_real)
            tf.summary.scalar('D_loss_wrong', D_loss_wrong)
            tf.summary.scalar('D_loss_fake', D_loss_fake)
            tf.summary.scalar('D_loss_unregularized',
                              D_loss_real + 0.5 * (D_loss_wrong + D_loss_fake))
        tf.summary.scalar('D_loss', D_loss)

    # Create (recommended) optimizer
    if args.wavegan_loss == 'dcgan':
        G_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
        D_opt = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5)
    elif args.wavegan_loss == 'lsgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=1e-4)
    elif args.wavegan_loss == 'wgan':
        G_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
        D_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
    elif args.wavegan_loss == 'wgan-gp':
        G_opt = tf.train.AdamOptimizer(learning_rate=4e-4,
                                       beta1=0.0,
                                       beta2=0.9)
        D_opt = tf.train.AdamOptimizer(learning_rate=4e-4,
                                       beta1=0.0,
                                       beta2=0.9)
    else:
        raise NotImplementedError()

    # Optimizer internal state reset ops
    reset_G_opt_op = tf.variables_initializer(G_opt.variables())
    reset_D_opt_op = tf.variables_initializer(D_opt.variables())

    # Create training ops
    G_train_op = G_opt.minimize(
        G_loss,
        var_list=G_vars,
        global_step=tf.train.get_or_create_global_step())
    D_train_op = D_opt.minimize(D_loss, var_list=D_vars)

    def smoothstep(x, mi, mx):
        return mi + (mx - mi) * (lambda t: np.where(
            t < 0, 0, np.where(t <= 1, 3 * t**2 - 2 * t**3, 1)))(x)

    def np_lerp_clip(t, a, b):
        return a + (b - a) * np.clip(t, 0.0, 1.0)

    def get_lod_at_step(step):
        return np.piecewise(float(step), [
            step < 10000, 10000 <= step < 20000, 20000 <= step < 30000,
            30000 <= step < 40000, 40000 <= step < 50000,
            50000 <= step < 60000, 60000 <= step < 70000,
            70000 <= step < 80000, 80000 <= step < 90000,
            90000 <= step < 100000
        ], [
            0, lambda x: np_lerp_clip((x - 10000) / 10000, 0, 1), 1,
            lambda x: np_lerp_clip(
                (x - 30000) / 10000, 1, 2), 2, lambda x: np_lerp_clip(
                    (x - 50000) / 10000, 2, 3), 3, lambda x: np_lerp_clip(
                        (x - 70000) / 10000, 3, 4), 4, lambda x: np_lerp_clip(
                            (x - 90000) / 10000, 4, 5), 5
        ])

    def my_filter_callable(datum, tensor):
        if (not isinstance(tensor, debug_data.InconvertibleTensorProto)) and (
                tensor.dtype == np.float32 or tensor.dtype == np.float64):
            return np.any([
                np.any(np.greater_equal(tensor, 50.0)),
                np.any(np.less_equal(tensor, -50.0))
            ])
        else:
            return False

    # Create a LocalCLIDebugHook and use it as a monitor
    # debug_hook = tf_debug.LocalCLIDebugHook(dump_root='C:/d/t/')
    # debug_hook.add_tensor_filter('large_values', my_filter_callable)
    # hooks = [debug_hook]

    # Run training
    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=args.train_dir,
            save_checkpoint_secs=args.train_save_secs,
            save_summaries_secs=args.train_summary_secs) as sess:
        # Get the summary writer for writing extra summary statistics
        summary_writer = SummaryWriterCache.get(args.train_dir)

        cur_lod = 0
        while True:
            # Calculate Maximum LOD to train
            step = sess.run(tf.train.get_or_create_global_step(),
                            feed_dict={lod: cur_lod})
            cur_lod = get_lod_at_step(step)
            prev_lod = get_lod_at_step(step - 1)

            # Reset optimizer internal state when new layers are introduced
            if np.floor(cur_lod) != np.floor(prev_lod) or np.ceil(
                    cur_lod) != np.ceil(prev_lod):
                print(
                    "Resetting optimizers' internal states at step {}".format(
                        step))
                sess.run([reset_G_opt_op, reset_D_opt_op],
                         feed_dict={lod: cur_lod})

            # Output current LOD and 'steps at currrent LOD' to tensorboard
            step = float(
                sess.run(tf.train.get_or_create_global_step(),
                         feed_dict={lod: cur_lod}))
            lod_summary = tf.Summary(value=[
                tf.Summary.Value(tag="current_lod",
                                 simple_value=float(cur_lod)),
            ])
            summary_writer.add_summary(lod_summary, step)

            # Train discriminator
            for i in xrange(args.wavegan_disc_nupdates):
                sess.run(D_train_op, feed_dict={lod: cur_lod})

                # Enforce Lipschitz constraint for WGAN
                if D_clip_weights is not None:
                    sess.run(D_clip_weights, feed_dict={lod: cur_lod})

            # Train generator
            sess.run(G_train_op, feed_dict={lod: cur_lod})
        os.makedirs(model_dir)

    LOGGER.info('Saving configurations...')
    config_path = os.path.join(model_dir, 'config.json')
    with open(config_path, 'w') as f:
        json.dump(args, f)

    # Try on some training data
    LOGGER.info('Loading audio data...')
    audio_filepaths = get_all_audio_filepaths(args['audio_dir'])
    train_gen, valid_data, test_data \
        = create_data_split(audio_filepaths, args['valid_ratio'], args['test_ratio'],
                            batch_size, batch_size, batch_size)

    LOGGER.info('Creating models...')
    model_gen = WaveGANGenerator(model_size=model_size, ngpus=ngpus, latent_dim=latent_dim,
                                 post_proc_filt_len=args['post_proc_filt_len'], upsample=True, num_class=10)
    model_dis = WaveGANDiscriminator(model_size=model_size, ngpus=ngpus,
                                     alpha=args['alpha'], shift_factor=args['shift_factor'],
                                     batch_shuffle=args['batch_shuffle'], num_class=10)

    LOGGER.info('Starting training...')
    model_gen, model_dis, history, final_discr_metrics, samples = train_wgan(
        model_gen=model_gen,
        model_dis=model_dis,
        train_gen=train_gen,
        valid_data=valid_data,
        test_data=test_data,
        num_epochs=args['num_epochs'],
        batches_per_epoch=args['batches_per_epoch'],
        batch_size=batch_size,
        output_dir=model_dir,
예제 #15
0
        os.makedirs(model_dir)

    LOGGER.info('Saving configurations...')
    config_path = os.path.join(model_dir, 'config.json')
    with open(config_path, 'w') as f:
        json.dump(args, f)

    # Try on some training data
    LOGGER.info('Loading audio data...')
    audio_filepaths = get_all_audio_filepaths(args['audio_dir'])
    train_gen, valid_data, test_data \
        = create_data_split(audio_filepaths, args['valid_ratio'], args['test_ratio'],
                            batch_size, batch_size, batch_size)

    LOGGER.info('Creating models...')
    model_gen = WaveGANGenerator(model_size=model_size, ngpus=ngpus, latent_dim=latent_dim,
                                 post_proc_filt_len=args['post_proc_filt_len'])
    model_dis = WaveGANDiscriminator(model_size=model_size, ngpus=ngpus,
                                     alpha=args['alpha'])

    LOGGER.info('Starting training...')
    model_gen, model_dis, history, final_discr_metrics, samples = train_wgan(
        model_gen=model_gen,
        model_dis=model_dis,
        train_gen=train_gen,
        valid_data=valid_data,
        test_data=test_data,
        num_epochs=args['num_epochs'],
        batches_per_epoch=args['batches_per_epoch'],
        batch_size=batch_size,
        output_dir=model_dir,
        lmbda=args['lmbda'],