def __init__(self, hp, max_to_keep=5): self.hp = hp dilations_factor = hp.layers // hp.stacks dilations = [ 2**i for j in range(hp.stacks) for i in range(dilations_factor) ] self.upsample_factor = hp.upsample_factor self.gc_enable = hp.gc_enable global_condition_channels = None global_condition_cardinality = None if hp.gc_enable: global_condition_channels = hp.global_channel global_condition_cardinality = hp.global_cardinality scalar_input = hp.input_type == "raw" quantization_channels = hp.quantize_channels[hp.input_type] if scalar_input: quantization_channels = None with tf.variable_scope('vocoder'): self.net = WaveNetModel( batch_size=hp.batch_size, dilations=dilations, filter_width=hp.filter_width, scalar_input=scalar_input, initial_filter_width=hp.initial_filter_width, residual_channels=hp.residual_channels, dilation_channels=hp.dilation_channels, quantization_channels=quantization_channels, out_channels=hp.out_channels, skip_channels=hp.skip_channels, global_condition_channels=global_condition_channels, global_condition_cardinality=global_condition_cardinality, use_biases=True, local_condition_channels=hp.n_mel_bins) if hp.upsample_conditional_features: with tf.variable_scope('upsample_layer') as upsample_scope: layer = dict() for i in range(len(hp.upsample_factor)): shape = [hp.upsample_factor[i], hp.filter_width, 1, 1] weights = np.ones(shape) * 1 / float( hp.upsample_factor[i]) init = tf.constant_initializer(value=weights, dtype=tf.float32) variable = tf.get_variable(name='upsample{}'.format(i), initializer=init, shape=weights.shape) layer['upsample{}_filter'.format(i)] = variable layer['upsample{}_bias'.format( i)] = create_bias_variable( 'upsample{}_bias'.format(i), [1]) self.upsample_var = layer self.upsample_scope = upsample_scope self.saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=max_to_keep)
def _create_audio_reader(self): # TODO Calculate receptive_field: receptive_field = WaveNetModel.calculate_receptive_field( self.wavenet_params["filter_width"], self.wavenet_params["dilations"], self.wavenet_params["scalar_input"], self.wavenet_params["initial_filter_width"]) # receptive_field = 1 return AudioReader(self.args.audio_dir, self.coord, self.args.sample_rate, self.args.gc_enabled, receptive_field, sample_size=self.args.sample_size, silence_threshold=self.args.silence_threshold, queue_size=32)
def create_wavenet(args, wavenet_params): # Create network. net = WaveNetModel( batch_size=args.batch_size, dilations=wavenet_params["dilations"], filter_width=wavenet_params["filter_width"], residual_channels=wavenet_params["residual_channels"], dilation_channels=wavenet_params["dilation_channels"], skip_channels=wavenet_params["skip_channels"], quantization_channels=wavenet_params["quantization_channels"], use_biases=wavenet_params["use_biases"], scalar_input=wavenet_params["scalar_input"], initial_filter_width=wavenet_params["initial_filter_width"], ) if args.l2_regularization_strength == 0: args.l2_regularization_strength = None return net
class Vocoder(object): def __init__(self, max_to_keep=5): dilations_factor = hparams.layers // hparams.stacks dilations = [2 ** i for j in range(hparams.stacks) for i in range(dilations_factor)] self.upsample_factor = hparams.upsample_factor global_condition_channels = None global_condition_cardinality = None if hparams.gc_enable: global_condition_channels = hparams.global_channel global_condition_cardinality = hparams.global_cardinality scalar_input = hparams.input_type == "raw" quantization_channels = hparams.quantize_channels[hparams.input_type] if scalar_input: quantization_channels = None with tf.variable_scope('vocoder'): self.net = WaveNetModel(batch_size=hparams.batch_size, dilations=dilations, filter_width=hparams.filter_width, scalar_input=scalar_input, initial_filter_width=hparams.initial_filter_width, residual_channels=hparams.residual_channels, dilation_channels=hparams.dilation_channels, quantization_channels=quantization_channels, out_channels=hparams.out_channels, skip_channels=hparams.skip_channels, global_condition_channels=global_condition_channels, global_condition_cardinality=global_condition_cardinality, use_biases=True, local_condition_channels=hparams.num_mels) if hparams.upsample_conditional_features: with tf.variable_scope('upsample_layer') as upsample_scope: layer = dict() for i in range(len(hparams.upsample_factor)): shape = [hparams.upsample_factor[i], hparams.filter_width, 1, 1] weights = np.ones(shape) * 1 / float(hparams.upsample_factor[i]) init = tf.constant_initializer(value=weights, dtype=tf.float32) variable = tf.get_variable(name='upsample{}'.format(i), initializer=init, shape=weights.shape) layer['upsample{}_filter'.format(i)] = variable layer['upsample{}_bias'.format(i)] = create_bias_variable('upsample{}_bias'.format(i), [1]) self.upsample_var = layer self.upsample_scope = upsample_scope self.saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=max_to_keep) def create_upsample(self, l): layer_filter = self.upsample_var local_condition_batch = tf.expand_dims(l, [3]) # local condition batch N H W C batch_size = tf.shape(local_condition_batch)[0] upsample_dim = tf.shape(local_condition_batch)[1] for i in range(len(self.upsample_factor)): upsample_dim = upsample_dim * self.upsample_factor[i] output_shape = tf.stack([batch_size, upsample_dim, tf.shape(local_condition_batch)[2], 1]) local_condition_batch = tf.nn.conv2d_transpose( local_condition_batch, layer_filter['upsample{}_filter'.format(i)], strides=[1, self.upsample_factor[i], 1, 1], output_shape=output_shape ) local_condition_batch += layer_filter['upsample{}_bias'.format(i)] local_condition_batch = tf.nn.relu(local_condition_batch) local_condition_batch = tf.squeeze(local_condition_batch, [3]) return local_condition_batch def loss(self, x, l, g): self.upsampled_lc = self.create_upsample(l) loss = self.net.loss(x, self.upsampled_lc, g, l2_regularization_strength=hparams.l2_regularization_strength) return loss def save(self, sess, logdir, step): model_name = 'model.ckpt' checkpoint_path = os.path.join(logdir, model_name) print('Storing checkpoint to {} ...'.format(logdir), end="") sys.stdout.flush() if not os.path.exists(logdir): os.makedirs(logdir) self.saver.save(sess, checkpoint_path, global_step=step) print(' Done.') def load(self, sess, logdir): print("Trying to restore saved checkpoints from {} ...".format(logdir), end="") ckpt = tf.train.get_checkpoint_state(logdir) if ckpt: print(" Checkpoint found: {}".format(ckpt.model_checkpoint_path)) global_step = int(ckpt.model_checkpoint_path .split('/')[-1] .split('-')[-1]) print(" Global step was: {}".format(global_step)) print(" Restoring...", end="") self.saver.restore(sess, ckpt.model_checkpoint_path) print(" Done.") return global_step, sess else: print(" No checkpoint found.") return None, sess def init_synthesizer(self, batch_size, gc_enable=True): self.batch_size = batch_size if self.net.scalar_input: self.sample_placeholder = tf.placeholder(tf.float32) else: self.sample_placeholder = tf.placeholder(tf.int32) self.lc_placeholder = tf.placeholder(tf.float32) self.gc_placeholder = tf.placeholder(tf.int32) if gc_enable else None self.gen_num = tf.placeholder(tf.int32) self.next_sample_prob, self.layers_out, self.qs = \ self.net.predict_proba_incremental(self.sample_placeholder, self.gen_num, batch_size=batch_size, local_condition=self.lc_placeholder, global_condition=self.gc_placeholder ) self.initial = tf.placeholder(tf.float32) self.others = tf.placeholder(tf.float32) self.update_q_ops = \ self.net.create_update_q_ops(self.qs, self.initial, self.others, self.gen_num, batch_size=batch_size) self.var_q = self.net.get_vars_q() def synthesize(self, sess, n_samples, lc, gc): sess.run(tf.variables_initializer(self.var_q)) if self.net.scalar_input: seeds = [0] else: seeds = [128] seeds = [seeds] seeds = np.repeat(seeds, self.batch_size, axis=0) generated = [seeds] if type(n_samples) == list: n_sample = max(n_samples) else: n_sample = n_samples for j in tqdm(range(n_sample)): sample = generated[-1] current_lc = lc[:, j, :] # Generation phase feed_dict = { self.sample_placeholder: sample, self.lc_placeholder: current_lc, self.gen_num: j} if self.gc_placeholder is not None: feed_dict.update({self.gc_placeholder: gc}) prob, _layers = sess.run([self.next_sample_prob, self.layers_out], feed_dict=feed_dict) # Update phase feed_dict = { self.initial: _layers[0], self.others: np.array(_layers[1:]), self.gen_num: j} sess.run(self.update_q_ops, feed_dict=feed_dict) if self.net.scalar_input: generated_sample = prob else: # TODO: random choice generated_sample = np.argmax(prob, axis=-1) generated.append(generated_sample) result = np.hstack(generated) if not self.net.scalar_input: result = P.inv_mulaw_quantize(result.astype(np.int16), self.net.quantization_channels) if type(n_samples) == list: result = [x[:n_samples[i]] for i, x in enumerate(result)] return result
class Vocoder(object): def __init__(self, max_to_keep=5): dilations_factor = hparams.layers // hparams.stacks dilations = [2 ** i for j in range(hparams.stacks) for i in range(dilations_factor)] self.upsample_factor = hparams.upsample_factor global_condition_channels = None global_condition_cardinality = None if hparams.gc_enable: global_condition_channels = hparams.global_channel global_condition_cardinality = hparams.global_cardinality scalar_input = hparams.input_type == "raw" quantization_channels = hparams.quantize_channels[hparams.input_type] if scalar_input: quantization_channels = None with tf.variable_scope('vocoder'): self.net = WaveNetModel(batch_size=hparams.batch_size, dilations=dilations, filter_width=hparams.filter_width, scalar_input=scalar_input, initial_filter_width=hparams.initial_filter_width, residual_channels=hparams.residual_channels, dilation_channels=hparams.dilation_channels, quantization_channels=quantization_channels, out_channels=hparams.out_channels, skip_channels=hparams.skip_channels, global_condition_channels=global_condition_channels, global_condition_cardinality=global_condition_cardinality, use_biases=True, local_condition_channels=hparams.num_mels) if hparams.upsample_conditional_features: with tf.variable_scope('upsample_layer') as upsample_scope: layer = dict() for i in range(len(hparams.upsample_factor)): shape = [hparams.upsample_factor[i], hparams.filter_width, 1, 1] weights = np.ones(shape) * 1 / float(hparams.upsample_factor[i]) init = tf.constant_initializer(value=weights, dtype=tf.float32) variable = tf.get_variable(name='upsample{}'.format(i), initializer=init, shape=weights.shape) layer['upsample{}_filter'.format(i)] = variable layer['upsample{}_bias'.format(i)] = create_bias_variable('upsample{}_bias'.format(i), [1]) self.upsample_var = layer self.upsample_scope = upsample_scope self.saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=max_to_keep) def create_upsample(self, l): layer_filter = self.upsample_var local_condition_batch = tf.expand_dims(l, [3]) # local condition batch N H W C batch_size = tf.shape(local_condition_batch)[0] upsample_dim = tf.shape(local_condition_batch)[1] for i in range(len(self.upsample_factor)): upsample_dim = upsample_dim * self.upsample_factor[i] output_shape = tf.stack([batch_size, upsample_dim, tf.shape(local_condition_batch)[2], 1]) local_condition_batch = tf.nn.conv2d_transpose( local_condition_batch, layer_filter['upsample{}_filter'.format(i)], strides=[1, self.upsample_factor[i], 1, 1], output_shape=output_shape ) local_condition_batch += layer_filter['upsample{}_bias'.format(i)] local_condition_batch = tf.nn.relu(local_condition_batch) local_condition_batch = tf.squeeze(local_condition_batch, [3]) return local_condition_batch def loss(self, x, l, g): self.upsampled_lc = self.create_upsample(l) loss = self.net.loss(x, self.upsampled_lc, g, l2_regularization_strength=hparams.l2_regularization_strength) return loss def save(self, sess, logdir, step): model_name = 'model.ckpt' checkpoint_path = os.path.join(logdir, model_name) print('Storing checkpoint to {} ...'.format(logdir), end="") sys.stdout.flush() if not os.path.exists(logdir): os.makedirs(logdir) self.saver.save(sess, checkpoint_path, global_step=step) print(' Done.') def load(self, sess, logdir): print("Trying to restore saved checkpoints from {} ...".format(logdir), end="") ckpt = tf.train.get_checkpoint_state(logdir) if ckpt: print(" Checkpoint found: {}".format(ckpt.model_checkpoint_path)) global_step = int(ckpt.model_checkpoint_path .split('/')[-1] .split('-')[-1]) print(" Global step was: {}".format(global_step)) print(" Restoring...", end="") self.saver.restore(sess, ckpt.model_checkpoint_path) print(" Done.") return global_step, sess else: print(" No checkpoint found.") return None, sess
with tf.name_scope('create_inputs'): # Allow silence trimming to be skipped by specifying a threshold near # zero. silence_threshold = None #AUDIO_FILE_PATH = '/home/sriramso/data/VCTK-Corpus' AUDIO_FILE_PATH = '/home/andrewszot/VCTK-Corpus' gc_enabled = False reader = AudioReader( AUDIO_FILE_PATH, coord, sample_rate=wavenet_params['sample_rate'], gc_enabled=gc_enabled, receptive_field=WaveNetModel.calculate_receptive_field( wavenet_params["filter_width"], wavenet_params["dilations"], wavenet_params["scalar_input"], wavenet_params["initial_filter_width"]), sample_size=39939, silence_threshold=silence_threshold) audio_batch = reader.dequeue(1) if gc_enabled: gc_id_batch = reader.dequeue_gc(1) else: gc_id_batch = None global_step = tf.Variable(0, trainable=False) sess = tf.Session(config=tf.ConfigProto(log_device_placement=False)) threads = tf.train.start_queue_runners(sess=sess, coord=coord) reader.start_threads(sess)
def __init__(self, batch_size=None, sample_size=None, q_factor=1, n_stack=2, max_dilation=10, K=512, D=128, lr=0.001, use_gc=False, gc_cardinality=None, is_training=True, global_step=None, scope='params', residual_channels=256, dilation_channels=512, skip_channels=256, use_biases=False, upsampling_method='deconv', encoding_channels=[2, 4, 8, 16, 32, 1]): assert sample_size is not None assert q_factor == 1 or (q_factor % 2) == 0 self.filter_width = 2 self.dilations = [ 2**i for j in range(n_stack) for i in range(max_dilation) ] self.receptive_field = (self.filter_width - 1) * sum( self.dilations) + 1 self.receptive_field += self.filter_width - 1 self.q_factor = q_factor self.quantization_channels = 256 * q_factor self.K = K self.D = D self.use_gc = use_gc self.gc_cardinality = gc_cardinality self.use_biases = use_biases # encoding spec self.encode_level = 6 self.encoding_channels = encoding_channels # model spec self.upsampling_method = upsampling_method self.is_training = is_training self.train_op = None self.batch_size = batch_size self.sample_size = sample_size self.reduced_timestep = None self.initialized = False if batch_size is not None and sample_size is not None: self.reduced_timestep = int( np.ceil(self.sample_size / 2**self.encode_level)) self.initialized = True # etc self.drop_rate = 0.5 self.global_step = global_step self.lr = lr with tf.variable_scope(scope) as params: self.enc_var, self.enc_scope = self.create_encoder_variables() with tf.variable_scope('decoder') as dec_param_scope: self.deconv_var = self.create_deconv_variables() self.wavenet = WaveNetModel( batch_size=batch_size, dilations=self.dilations, filter_width=self.filter_width, residual_channels=residual_channels, dilation_channels=dilation_channels, quantization_channels=self.quantization_channels, skip_channels=skip_channels, global_condition_channels=gc_cardinality, global_condition_cardinality=gc_cardinality, use_biases=use_biases) self.dec_scope = dec_param_scope with tf.variable_scope('embed'): init = tf.truncated_normal_initializer(stddev=0.01) # init = tf.constant_initializer(value=np.random.random((self.K, self.D)), dtype=tf.float32) self.embeds = tf.get_variable('embedding', [self.K, self.D], dtype=tf.float32, initializer=init) self.param_scope = params self.saver = None self.set_saver()
class VQVAE: def __init__(self, batch_size=None, sample_size=None, q_factor=1, n_stack=2, max_dilation=10, K=512, D=128, lr=0.001, use_gc=False, gc_cardinality=None, is_training=True, global_step=None, scope='params', residual_channels=256, dilation_channels=512, skip_channels=256, use_biases=False, upsampling_method='deconv', encoding_channels=[2, 4, 8, 16, 32, 1]): assert sample_size is not None assert q_factor == 1 or (q_factor % 2) == 0 self.filter_width = 2 self.dilations = [ 2**i for j in range(n_stack) for i in range(max_dilation) ] self.receptive_field = (self.filter_width - 1) * sum( self.dilations) + 1 self.receptive_field += self.filter_width - 1 self.q_factor = q_factor self.quantization_channels = 256 * q_factor self.K = K self.D = D self.use_gc = use_gc self.gc_cardinality = gc_cardinality self.use_biases = use_biases # encoding spec self.encode_level = 6 self.encoding_channels = encoding_channels # model spec self.upsampling_method = upsampling_method self.is_training = is_training self.train_op = None self.batch_size = batch_size self.sample_size = sample_size self.reduced_timestep = None self.initialized = False if batch_size is not None and sample_size is not None: self.reduced_timestep = int( np.ceil(self.sample_size / 2**self.encode_level)) self.initialized = True # etc self.drop_rate = 0.5 self.global_step = global_step self.lr = lr with tf.variable_scope(scope) as params: self.enc_var, self.enc_scope = self.create_encoder_variables() with tf.variable_scope('decoder') as dec_param_scope: self.deconv_var = self.create_deconv_variables() self.wavenet = WaveNetModel( batch_size=batch_size, dilations=self.dilations, filter_width=self.filter_width, residual_channels=residual_channels, dilation_channels=dilation_channels, quantization_channels=self.quantization_channels, skip_channels=skip_channels, global_condition_channels=gc_cardinality, global_condition_cardinality=gc_cardinality, use_biases=use_biases) self.dec_scope = dec_param_scope with tf.variable_scope('embed'): init = tf.truncated_normal_initializer(stddev=0.01) # init = tf.constant_initializer(value=np.random.random((self.K, self.D)), dtype=tf.float32) self.embeds = tf.get_variable('embedding', [self.K, self.D], dtype=tf.float32, initializer=init) self.param_scope = params self.saver = None self.set_saver() def create_deconv_variables(self): var = None if self.upsampling_method.startswith('deconv'): var = list() tokens = self.upsampling_method.split('-') n_step = tokens[0].split('deconv')[1] out_channel = int(tokens[1]) if len(tokens) > 1 else 1 if not n_step: n_step = 1 else: n_step = int(n_step) assert n_step < 4 height, width = self.reduced_timestep, self.D upscale_factor = 2**self.encode_level if n_step == 1: upscale_per_step = upscale_factor elif n_step == 2: upscale_per_step = int(np.sqrt(upscale_factor)) elif n_step == 3: upscale_per_step = int(np.cbrt(upscale_factor)) h = height in_channel = 1 for step in range(n_step): with tf.variable_scope('deconv_layer_{}'.format(step)): layer = dict() h *= upscale_per_step kernel_size = 2 * upscale_per_step - upscale_per_step % 2 # layer['filter'] = create_variable('deconv_layer_filter', [kernel_size, 1, out_channel, in_channel]) layer['filter'] = get_bilinear_filter( [kernel_size, 1, out_channel, in_channel], upscale_per_step, name='deconv_layer_filter') layer['strides'] = [1, upscale_per_step, 1, 1] layer['shape'] = [self.batch_size, h, width, out_channel] if self.use_biases: layer['bias'] = create_bias_variable( 'deconv_bias', [out_channel]) var.append(layer) in_channel = out_channel out_channel = out_channel * 2 return var def initialize(self, input_batch, sample_size=40960): # TODO self.batch_size = tf.shape(input_batch)[0] self.sample_size = sample_size self.reduced_timestep = int( np.ceil(self.sample_size / 2**self.encode_level)) self.initialized = True def set_saver(self): if self.saver is None: save_vars = { ('train/' + '/'.join(var.name.split('/')[1:])).split(':')[0]: var for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.param_scope.name) } # for name,var in save_vars.items(): # print(name) self.saver = tf.train.Saver(var_list=save_vars, max_to_keep=10) def _gc_embedding(self): return create_embedding_table( 'gc_embedding', [self.gc_cardinality, self.gc_cardinality]) def create_encoder_variables(self): with tf.variable_scope('enc') as enc_param_scope: var = dict() input_channel = 1 output_channel = self.encoding_channels var['enc_conv_stack'] = list() for i in range(self.encode_level): with tf.variable_scope('encoder_conv_{}'.format(i)): current = dict() if i < self.q_factor: current['filter'] = create_variable( 'filter', [4, 4, input_channel, output_channel[i]]) else: current['filter'] = create_variable( 'filter', [4, 1, input_channel, output_channel[i]]) if self.use_biases: current['bias'] = create_bias_variable( 'bias', [output_channel[i]]) input_channel = output_channel[i] var['enc_conv_stack'].append(current) return var, enc_param_scope def encode(self, encoded_input_batch): encoded_input_batch = tf.expand_dims(encoded_input_batch, -1) out = encoded_input_batch for i, layer in enumerate(self.enc_var['enc_conv_stack']): kernel = layer['filter'] if i < self.q_factor: out = tf.nn.conv2d(out, kernel, [1, 2, 2, 1], padding='SAME') else: out = tf.nn.conv2d(out, kernel, [1, 2, 1, 1], padding='SAME') if self.use_biases: out = tf.nn.bias_add(out, layer['bias']) if i < (self.encode_level - 1): out = tf.nn.elu(out) # out = tf.layers.dropout(out, rate=self.drop_rate, training=self.is_training ,name='enc_dropout_%d' % (i)) if self.encoding_channels[-1] > 1: z_e = tf.reduce_sum(out, -1) else: z_e = tf.squeeze(out, axis=-1, name='encode_squeeze') z_e = tf.nn.tanh(z_e) return z_e def upsampling(self, z_q): dec_input = tf.expand_dims(z_q, -1) initial = tf.image.resize_nearest_neighbor(dec_input, [self.sample_size, self.D]) initial = tf.squeeze(initial, axis=-1, name='dec_input_squeeze') if self.deconv_var is not None: for i, layer in enumerate(self.deconv_var): dec_input = tf.nn.conv2d_transpose(dec_input, layer['filter'], layer['shape'], layer['strides'], padding='SAME', data_format='NHWC', name=None) if self.use_biases: dec_input = tf.nn.bias_add(dec_input, layer['bias']) if i < len(self.deconv_var) - 1: dec_input = tf.layers.batch_normalization( dec_input, training=self.is_training) dec_input = tf.nn.tanh(dec_input) # dec_input = tf.nn.elu(dec_input) dec_input = tf.reduce_sum(dec_input, -1) dec_input = tf.add(dec_input, initial) else: dec_input = initial return dec_input def vq(self, z_e): _e = tf.reshape(self.embeds, [1, self.K, self.D]) _e = tf.tile(_e, [self.batch_size, self.reduced_timestep, 1]) _t = tf.tile(z_e, [1, 1, self.K]) _t = tf.reshape( _t, [self.batch_size, self.reduced_timestep * self.K, self.D]) dist = tf.norm(_t - _e, axis=-1) dist = tf.reshape(dist, [self.batch_size, -1, self.K]) k = tf.argmin(dist, axis=-1) z_q = tf.gather(self.embeds, k) return z_q def get_condition(self, input_batch, gc=None): with tf.variable_scope('forward'): encoded_input_batch, gc = self.preprocess(input_batch, gc=gc) self.encoded_input_batch = encoded_input_batch self.gc = gc # encoding z_e = self.encode(encoded_input_batch) # VQ-embedding z_q = self.vq(z_e) # decoding lc = self.upsampling(z_q) return lc, gc def create_model(self, padded_input, gc=None): with tf.variable_scope('forward'): padded_encoded_input, gc = self.preprocess(padded_input, gc=gc) self.gc = gc # Cut off the last sample of network input to preserve causality. wavenet_input_width = tf.shape(padded_encoded_input)[1] - 1 wavenet_input = tf.slice(padded_encoded_input, [0, 0, 0], [-1, wavenet_input_width, -1]) encoded_input = tf.slice(padded_encoded_input, [0, self.receptive_field, 0], [-1, -1, -1], name="remove_pad") self.encoded_input = encoded_input # encoding self.z_e = self.encode(encoded_input) # VQ-embedding self.z_q = self.vq(self.z_e) # decoding lc = self.upsampling(self.z_q) self.lc = lc paddings = tf.constant([[0, 0], [self.receptive_field - 1, 0], [0, 0]]) lc = tf.pad(lc, paddings, "CONSTANT") output = self.wavenet._create_network(wavenet_input, lc, gc) return output def generate_waveform(self, sess, n_samples, lc, gc, seed=None, use_randomness=True): sample_placeholder = tf.placeholder(tf.int32) lc_placeholder = tf.placeholder(tf.float32) gc_placeholder = tf.placeholder(tf.float32) next_sample_probs = self.wavenet.predict_proba_incremental( sample_placeholder, lc_placeholder, gc_placeholder) sess.run(self.wavenet.init_ops) operations = [next_sample_probs] operations.extend(self.wavenet.push_ops) waveform = [128] * (self.receptive_field - 2) waveform = np.tile(waveform, (self.batch_size, 1)) if seed is None: seed = [] for i in range(self.batch_size): _seed = np.random.randint( self.quantization_channels) if use_randomness else 128 seed.append([_seed]) waveform = np.hstack([waveform, seed]) for i in range(waveform.shape[1] - 1): sample = waveform[:, i] lc_sample = np.zeros((self.batch_size, 128)) sess.run(operations, feed_dict={ sample_placeholder: sample, lc_placeholder: lc_sample, gc_placeholder: gc }) softmax_result = [] for i in range(n_samples): if i > 0 and i % 10000 == 0: print("Generating {} of {}.".format(i, n_samples)) sys.stdout.flush() sample = waveform[:, -1] lc_sample = lc[:, i, :].reshape(self.batch_size, -1) results = sess.run(operations, feed_dict={ sample_placeholder: sample, lc_placeholder: lc_sample, gc_placeholder: gc }) softmax_result.append(np.expand_dims(results[0], 1)) if use_randomness: sample = [] for k in range(self.batch_size): _sample = np.random.choice(np.arange( self.quantization_channels), p=results[0][k, :]) sample.append([_sample]) else: sample = np.argmax(results[0], axis=1).reshape(-1, 1) waveform = np.hstack([waveform, sample]) waveform = waveform[:, self.receptive_field:] softmax_result = np.hstack(softmax_result) return waveform, softmax_result def _one_hot_encode(self, input_batch): with tf.name_scope('one_hot_encode'): encoded = tf.one_hot(input_batch, depth=self.quantization_channels) encoded = tf.reshape( encoded, [self.batch_size, -1, self.quantization_channels]) return encoded def preprocess(self, input_batch, gc=None): if not self.initialized: self.initialize(input_batch) encoded = mu_law(input_batch, quantization_channels=self.quantization_channels) encoded = self._one_hot_encode(encoded) # gc-embedding if self.use_gc and gc is not None: gc_embedding_table = self._gc_embedding() gc = tf.nn.embedding_lookup(gc_embedding_table, gc) gc = tf.reshape(gc, [self.batch_size, 1, self.gc_cardinality], name="gc_embbedding_resize") return encoded, gc def loss_recon(self, mu_law_output, encoded_target, beta=0.25): encoded_output = self._one_hot_encode(mu_law_output) output = encoded_output target = encoded_target target = tf.slice(target, [0, 1, 0], [-1, -1, -1], name="loss_recon_slice_target") recon = tf.nn.softmax_cross_entropy_with_logits(logits=output, labels=target) recon = tf.reduce_mean(recon) return recon def loss(self, output, beta=0.25): recon = tf.nn.softmax_cross_entropy_with_logits( logits=output, labels=self.encoded_input) recon = tf.reduce_mean(recon) z_q = self.z_q z_e = self.z_e vq = tf.reduce_mean(tf.norm(tf.stop_gradient(z_e) - z_q, axis=-1)**2) commit = tf.reduce_mean( tf.norm(z_e - tf.stop_gradient(z_q), axis=-1)**2) loss = (recon + vq + beta * commit) if self.is_training: with tf.variable_scope('backward'): # Decoder Grads decoder_vars = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, self.dec_scope.name) decoder_grads = list( zip(tf.gradients(loss, decoder_vars), decoder_vars)) # Encoder Grads encoder_vars = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, self.enc_scope.name) grad_z = tf.gradients(recon, z_q) encoder_grads = [(tf.gradients(z_e, _var, grad_z)[0] + beta * tf.gradients(commit, _var)[0], _var) for _var in encoder_vars] # Embedding Grads embed_grads = list( zip(tf.gradients(vq, self.embeds), [self.embeds])) optimizer = tf.train.AdamOptimizer(self.lr) self.train_op = optimizer.apply_gradients( decoder_grads + encoder_grads + embed_grads, global_step=self.global_step) return loss, recon def load(self, sess, model): self.saver.restore(sess, model) def save(self, sess, logdir, step): model_name = 'model.ckpt' checkpoint_path = os.path.join(logdir, model_name) print('Storing checkpoint to {} ...'.format(logdir), end="") sys.stdout.flush() if not os.path.exists(logdir): os.makedirs(logdir) self.saver.save(sess, checkpoint_path, global_step=step) print(' Done.')