示例#1
0
文件: vocoder.py 项目: dyelax/wavenet
    def __init__(self, hp, max_to_keep=5):
        self.hp = hp

        dilations_factor = hp.layers // hp.stacks
        dilations = [
            2**i for j in range(hp.stacks) for i in range(dilations_factor)
        ]

        self.upsample_factor = hp.upsample_factor
        self.gc_enable = hp.gc_enable
        global_condition_channels = None
        global_condition_cardinality = None
        if hp.gc_enable:
            global_condition_channels = hp.global_channel
            global_condition_cardinality = hp.global_cardinality

        scalar_input = hp.input_type == "raw"
        quantization_channels = hp.quantize_channels[hp.input_type]
        if scalar_input:
            quantization_channels = None

        with tf.variable_scope('vocoder'):
            self.net = WaveNetModel(
                batch_size=hp.batch_size,
                dilations=dilations,
                filter_width=hp.filter_width,
                scalar_input=scalar_input,
                initial_filter_width=hp.initial_filter_width,
                residual_channels=hp.residual_channels,
                dilation_channels=hp.dilation_channels,
                quantization_channels=quantization_channels,
                out_channels=hp.out_channels,
                skip_channels=hp.skip_channels,
                global_condition_channels=global_condition_channels,
                global_condition_cardinality=global_condition_cardinality,
                use_biases=True,
                local_condition_channels=hp.n_mel_bins)

            if hp.upsample_conditional_features:
                with tf.variable_scope('upsample_layer') as upsample_scope:
                    layer = dict()
                    for i in range(len(hp.upsample_factor)):
                        shape = [hp.upsample_factor[i], hp.filter_width, 1, 1]
                        weights = np.ones(shape) * 1 / float(
                            hp.upsample_factor[i])
                        init = tf.constant_initializer(value=weights,
                                                       dtype=tf.float32)
                        variable = tf.get_variable(name='upsample{}'.format(i),
                                                   initializer=init,
                                                   shape=weights.shape)
                        layer['upsample{}_filter'.format(i)] = variable
                        layer['upsample{}_bias'.format(
                            i)] = create_bias_variable(
                                'upsample{}_bias'.format(i), [1])

                    self.upsample_var = layer
                    self.upsample_scope = upsample_scope

        self.saver = tf.train.Saver(var_list=tf.trainable_variables(),
                                    max_to_keep=max_to_keep)
示例#2
0
    def _create_audio_reader(self):
        # TODO Calculate receptive_field:
        receptive_field = WaveNetModel.calculate_receptive_field(
            self.wavenet_params["filter_width"],
            self.wavenet_params["dilations"],
            self.wavenet_params["scalar_input"],
            self.wavenet_params["initial_filter_width"])
        # receptive_field = 1

        return AudioReader(self.args.audio_dir,
                           self.coord,
                           self.args.sample_rate,
                           self.args.gc_enabled,
                           receptive_field,
                           sample_size=self.args.sample_size,
                           silence_threshold=self.args.silence_threshold,
                           queue_size=32)
示例#3
0
def create_wavenet(args, wavenet_params):
    # Create network.
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"],
    )

    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None

    return net
示例#4
0
class Vocoder(object):
    def __init__(self, max_to_keep=5):
        dilations_factor = hparams.layers // hparams.stacks
        dilations = [2 ** i for j in range(hparams.stacks) for i in range(dilations_factor)]

        self.upsample_factor = hparams.upsample_factor
        global_condition_channels = None
        global_condition_cardinality = None
        if hparams.gc_enable:
            global_condition_channels = hparams.global_channel
            global_condition_cardinality = hparams.global_cardinality

        scalar_input = hparams.input_type == "raw"
        quantization_channels = hparams.quantize_channels[hparams.input_type]
        if scalar_input:
            quantization_channels = None

        with tf.variable_scope('vocoder'):
            self.net = WaveNetModel(batch_size=hparams.batch_size,
                                    dilations=dilations,
                                    filter_width=hparams.filter_width,
                                    scalar_input=scalar_input,
                                    initial_filter_width=hparams.initial_filter_width,
                                    residual_channels=hparams.residual_channels,
                                    dilation_channels=hparams.dilation_channels,
                                    quantization_channels=quantization_channels,
                                    out_channels=hparams.out_channels,
                                    skip_channels=hparams.skip_channels,
                                    global_condition_channels=global_condition_channels,
                                    global_condition_cardinality=global_condition_cardinality,
                                    use_biases=True,
                                    local_condition_channels=hparams.num_mels)

            if hparams.upsample_conditional_features:
                with tf.variable_scope('upsample_layer') as upsample_scope:
                    layer = dict()
                    for i in range(len(hparams.upsample_factor)):
                        shape = [hparams.upsample_factor[i], hparams.filter_width, 1, 1]
                        weights = np.ones(shape) * 1 / float(hparams.upsample_factor[i])
                        init = tf.constant_initializer(value=weights, dtype=tf.float32)
                        variable = tf.get_variable(name='upsample{}'.format(i), initializer=init, shape=weights.shape)
                        layer['upsample{}_filter'.format(i)] = variable
                        layer['upsample{}_bias'.format(i)] = create_bias_variable('upsample{}_bias'.format(i), [1])

                    self.upsample_var = layer
                    self.upsample_scope = upsample_scope

        self.saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=max_to_keep)

    def create_upsample(self, l):
        layer_filter = self.upsample_var
        local_condition_batch = tf.expand_dims(l, [3])

        # local condition batch N H W C
        batch_size = tf.shape(local_condition_batch)[0]
        upsample_dim = tf.shape(local_condition_batch)[1]

        for i in range(len(self.upsample_factor)):
            upsample_dim = upsample_dim * self.upsample_factor[i]
            output_shape = tf.stack([batch_size, upsample_dim, tf.shape(local_condition_batch)[2], 1])
            local_condition_batch = tf.nn.conv2d_transpose(
                local_condition_batch,
                layer_filter['upsample{}_filter'.format(i)],
                strides=[1, self.upsample_factor[i], 1, 1],
                output_shape=output_shape
            )
            local_condition_batch += layer_filter['upsample{}_bias'.format(i)]
            local_condition_batch = tf.nn.relu(local_condition_batch)

        local_condition_batch = tf.squeeze(local_condition_batch, [3])
        return local_condition_batch

    def loss(self, x, l, g):
        self.upsampled_lc = self.create_upsample(l)
        loss = self.net.loss(x, self.upsampled_lc, g, l2_regularization_strength=hparams.l2_regularization_strength)

        return loss

    def save(self, sess, logdir, step):
        model_name = 'model.ckpt'
        checkpoint_path = os.path.join(logdir, model_name)
        print('Storing checkpoint to {} ...'.format(logdir), end="")
        sys.stdout.flush()

        if not os.path.exists(logdir):
            os.makedirs(logdir)

        self.saver.save(sess, checkpoint_path, global_step=step)
        print(' Done.')

    def load(self, sess, logdir):
        print("Trying to restore saved checkpoints from {} ...".format(logdir),
              end="")

        ckpt = tf.train.get_checkpoint_state(logdir)
        if ckpt:
            print("  Checkpoint found: {}".format(ckpt.model_checkpoint_path))
            global_step = int(ckpt.model_checkpoint_path
                              .split('/')[-1]
                              .split('-')[-1])
            print("  Global step was: {}".format(global_step))
            print("  Restoring...", end="")
            self.saver.restore(sess, ckpt.model_checkpoint_path)
            print(" Done.")
            return global_step, sess
        else:
            print(" No checkpoint found.")
            return None, sess

    def init_synthesizer(self, batch_size, gc_enable=True):
        self.batch_size = batch_size
        if self.net.scalar_input:
            self.sample_placeholder = tf.placeholder(tf.float32)
        else:
            self.sample_placeholder = tf.placeholder(tf.int32)

        self.lc_placeholder = tf.placeholder(tf.float32)
        self.gc_placeholder = tf.placeholder(tf.int32) if gc_enable else None

        self.gen_num = tf.placeholder(tf.int32)

        self.next_sample_prob, self.layers_out, self.qs = \
            self.net.predict_proba_incremental(self.sample_placeholder,
                                                  self.gen_num,
                                                  batch_size=batch_size,
                                                  local_condition=self.lc_placeholder,
                                                  global_condition=self.gc_placeholder
                                                  )
        self.initial = tf.placeholder(tf.float32)
        self.others = tf.placeholder(tf.float32)
        self.update_q_ops = \
            self.net.create_update_q_ops(self.qs,
                                            self.initial,
                                            self.others,
                                            self.gen_num,
                                            batch_size=batch_size)

        self.var_q = self.net.get_vars_q()

    def synthesize(self, sess, n_samples, lc, gc):
        sess.run(tf.variables_initializer(self.var_q))

        if self.net.scalar_input:
            seeds = [0]
        else:
            seeds = [128]

        seeds = [seeds]
        seeds = np.repeat(seeds, self.batch_size, axis=0)
        generated = [seeds]


        if type(n_samples) == list:
            n_sample = max(n_samples)
        else:
            n_sample = n_samples

        for j in tqdm(range(n_sample)):
            sample = generated[-1]
            current_lc = lc[:, j, :]

            # Generation phase
            feed_dict = {
                self.sample_placeholder: sample,
                self.lc_placeholder: current_lc,
                self.gen_num: j}

            if self.gc_placeholder is not None:
                feed_dict.update({self.gc_placeholder: gc})

            prob, _layers = sess.run([self.next_sample_prob, self.layers_out], feed_dict=feed_dict)

            # Update phase
            feed_dict = {
                self.initial: _layers[0],
                self.others: np.array(_layers[1:]),
                self.gen_num: j}

            sess.run(self.update_q_ops, feed_dict=feed_dict)

            if self.net.scalar_input:
                generated_sample = prob
            else:
                # TODO: random choice
                generated_sample = np.argmax(prob, axis=-1)

            generated.append(generated_sample)

        result = np.hstack(generated)
        if not self.net.scalar_input:
            result = P.inv_mulaw_quantize(result.astype(np.int16), self.net.quantization_channels)

        if type(n_samples) == list:
            result = [x[:n_samples[i]] for i, x in enumerate(result)]

        return result
示例#5
0
class Vocoder(object):
    def __init__(self, max_to_keep=5):
        dilations_factor = hparams.layers // hparams.stacks
        dilations = [2 ** i for j in range(hparams.stacks) for i in range(dilations_factor)]

        self.upsample_factor = hparams.upsample_factor
        global_condition_channels = None
        global_condition_cardinality = None
        if hparams.gc_enable:
            global_condition_channels = hparams.global_channel
            global_condition_cardinality = hparams.global_cardinality

        scalar_input = hparams.input_type == "raw"
        quantization_channels = hparams.quantize_channels[hparams.input_type]
        if scalar_input:
            quantization_channels = None

        with tf.variable_scope('vocoder'):
            self.net = WaveNetModel(batch_size=hparams.batch_size,
                                    dilations=dilations,
                                    filter_width=hparams.filter_width,
                                    scalar_input=scalar_input,
                                    initial_filter_width=hparams.initial_filter_width,
                                    residual_channels=hparams.residual_channels,
                                    dilation_channels=hparams.dilation_channels,
                                    quantization_channels=quantization_channels,
                                    out_channels=hparams.out_channels,
                                    skip_channels=hparams.skip_channels,
                                    global_condition_channels=global_condition_channels,
                                    global_condition_cardinality=global_condition_cardinality,
                                    use_biases=True,
                                    local_condition_channels=hparams.num_mels)

            if hparams.upsample_conditional_features:
                with tf.variable_scope('upsample_layer') as upsample_scope:
                    layer = dict()
                    for i in range(len(hparams.upsample_factor)):
                        shape = [hparams.upsample_factor[i], hparams.filter_width, 1, 1]
                        weights = np.ones(shape) * 1 / float(hparams.upsample_factor[i])
                        init = tf.constant_initializer(value=weights, dtype=tf.float32)
                        variable = tf.get_variable(name='upsample{}'.format(i), initializer=init, shape=weights.shape)
                        layer['upsample{}_filter'.format(i)] = variable
                        layer['upsample{}_bias'.format(i)] = create_bias_variable('upsample{}_bias'.format(i), [1])

                    self.upsample_var = layer
                    self.upsample_scope = upsample_scope

        self.saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=max_to_keep)

    def create_upsample(self, l):
        layer_filter = self.upsample_var
        local_condition_batch = tf.expand_dims(l, [3])

        # local condition batch N H W C
        batch_size = tf.shape(local_condition_batch)[0]
        upsample_dim = tf.shape(local_condition_batch)[1]

        for i in range(len(self.upsample_factor)):
            upsample_dim = upsample_dim * self.upsample_factor[i]
            output_shape = tf.stack([batch_size, upsample_dim, tf.shape(local_condition_batch)[2], 1])
            local_condition_batch = tf.nn.conv2d_transpose(
                local_condition_batch,
                layer_filter['upsample{}_filter'.format(i)],
                strides=[1, self.upsample_factor[i], 1, 1],
                output_shape=output_shape
            )
            local_condition_batch += layer_filter['upsample{}_bias'.format(i)]
            local_condition_batch = tf.nn.relu(local_condition_batch)

        local_condition_batch = tf.squeeze(local_condition_batch, [3])
        return local_condition_batch

    def loss(self, x, l, g):
        self.upsampled_lc = self.create_upsample(l)
        loss = self.net.loss(x, self.upsampled_lc, g, l2_regularization_strength=hparams.l2_regularization_strength)

        return loss

    def save(self, sess, logdir, step):
        model_name = 'model.ckpt'
        checkpoint_path = os.path.join(logdir, model_name)
        print('Storing checkpoint to {} ...'.format(logdir), end="")
        sys.stdout.flush()

        if not os.path.exists(logdir):
            os.makedirs(logdir)

        self.saver.save(sess, checkpoint_path, global_step=step)
        print(' Done.')

    def load(self, sess, logdir):
        print("Trying to restore saved checkpoints from {} ...".format(logdir),
              end="")

        ckpt = tf.train.get_checkpoint_state(logdir)
        if ckpt:
            print("  Checkpoint found: {}".format(ckpt.model_checkpoint_path))
            global_step = int(ckpt.model_checkpoint_path
                              .split('/')[-1]
                              .split('-')[-1])
            print("  Global step was: {}".format(global_step))
            print("  Restoring...", end="")
            self.saver.restore(sess, ckpt.model_checkpoint_path)
            print(" Done.")
            return global_step, sess
        else:
            print(" No checkpoint found.")
            return None, sess
示例#6
0
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = None

        #AUDIO_FILE_PATH = '/home/sriramso/data/VCTK-Corpus'
        AUDIO_FILE_PATH = '/home/andrewszot/VCTK-Corpus'

        gc_enabled = False
        reader = AudioReader(
            AUDIO_FILE_PATH,
            coord,
            sample_rate=wavenet_params['sample_rate'],
            gc_enabled=gc_enabled,
            receptive_field=WaveNetModel.calculate_receptive_field(
                wavenet_params["filter_width"], wavenet_params["dilations"],
                wavenet_params["scalar_input"],
                wavenet_params["initial_filter_width"]),
            sample_size=39939,
            silence_threshold=silence_threshold)

        audio_batch = reader.dequeue(1)
        if gc_enabled:
            gc_id_batch = reader.dequeue_gc(1)
        else:
            gc_id_batch = None

    global_step = tf.Variable(0, trainable=False)

    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)
示例#7
0
文件: model.py 项目: twidddj/vqvae
    def __init__(self,
                 batch_size=None,
                 sample_size=None,
                 q_factor=1,
                 n_stack=2,
                 max_dilation=10,
                 K=512,
                 D=128,
                 lr=0.001,
                 use_gc=False,
                 gc_cardinality=None,
                 is_training=True,
                 global_step=None,
                 scope='params',
                 residual_channels=256,
                 dilation_channels=512,
                 skip_channels=256,
                 use_biases=False,
                 upsampling_method='deconv',
                 encoding_channels=[2, 4, 8, 16, 32, 1]):

        assert sample_size is not None
        assert q_factor == 1 or (q_factor % 2) == 0

        self.filter_width = 2
        self.dilations = [
            2**i for j in range(n_stack) for i in range(max_dilation)
        ]
        self.receptive_field = (self.filter_width - 1) * sum(
            self.dilations) + 1
        self.receptive_field += self.filter_width - 1

        self.q_factor = q_factor
        self.quantization_channels = 256 * q_factor

        self.K = K
        self.D = D
        self.use_gc = use_gc
        self.gc_cardinality = gc_cardinality
        self.use_biases = use_biases

        # encoding spec
        self.encode_level = 6
        self.encoding_channels = encoding_channels

        # model spec
        self.upsampling_method = upsampling_method
        self.is_training = is_training
        self.train_op = None
        self.batch_size = batch_size
        self.sample_size = sample_size
        self.reduced_timestep = None
        self.initialized = False
        if batch_size is not None and sample_size is not None:
            self.reduced_timestep = int(
                np.ceil(self.sample_size / 2**self.encode_level))
            self.initialized = True

        # etc
        self.drop_rate = 0.5
        self.global_step = global_step
        self.lr = lr

        with tf.variable_scope(scope) as params:
            self.enc_var, self.enc_scope = self.create_encoder_variables()
            with tf.variable_scope('decoder') as dec_param_scope:

                self.deconv_var = self.create_deconv_variables()
                self.wavenet = WaveNetModel(
                    batch_size=batch_size,
                    dilations=self.dilations,
                    filter_width=self.filter_width,
                    residual_channels=residual_channels,
                    dilation_channels=dilation_channels,
                    quantization_channels=self.quantization_channels,
                    skip_channels=skip_channels,
                    global_condition_channels=gc_cardinality,
                    global_condition_cardinality=gc_cardinality,
                    use_biases=use_biases)

                self.dec_scope = dec_param_scope

            with tf.variable_scope('embed'):
                init = tf.truncated_normal_initializer(stddev=0.01)
                #                 init = tf.constant_initializer(value=np.random.random((self.K, self.D)), dtype=tf.float32)
                self.embeds = tf.get_variable('embedding', [self.K, self.D],
                                              dtype=tf.float32,
                                              initializer=init)

        self.param_scope = params
        self.saver = None
        self.set_saver()
示例#8
0
文件: model.py 项目: twidddj/vqvae
class VQVAE:
    def __init__(self,
                 batch_size=None,
                 sample_size=None,
                 q_factor=1,
                 n_stack=2,
                 max_dilation=10,
                 K=512,
                 D=128,
                 lr=0.001,
                 use_gc=False,
                 gc_cardinality=None,
                 is_training=True,
                 global_step=None,
                 scope='params',
                 residual_channels=256,
                 dilation_channels=512,
                 skip_channels=256,
                 use_biases=False,
                 upsampling_method='deconv',
                 encoding_channels=[2, 4, 8, 16, 32, 1]):

        assert sample_size is not None
        assert q_factor == 1 or (q_factor % 2) == 0

        self.filter_width = 2
        self.dilations = [
            2**i for j in range(n_stack) for i in range(max_dilation)
        ]
        self.receptive_field = (self.filter_width - 1) * sum(
            self.dilations) + 1
        self.receptive_field += self.filter_width - 1

        self.q_factor = q_factor
        self.quantization_channels = 256 * q_factor

        self.K = K
        self.D = D
        self.use_gc = use_gc
        self.gc_cardinality = gc_cardinality
        self.use_biases = use_biases

        # encoding spec
        self.encode_level = 6
        self.encoding_channels = encoding_channels

        # model spec
        self.upsampling_method = upsampling_method
        self.is_training = is_training
        self.train_op = None
        self.batch_size = batch_size
        self.sample_size = sample_size
        self.reduced_timestep = None
        self.initialized = False
        if batch_size is not None and sample_size is not None:
            self.reduced_timestep = int(
                np.ceil(self.sample_size / 2**self.encode_level))
            self.initialized = True

        # etc
        self.drop_rate = 0.5
        self.global_step = global_step
        self.lr = lr

        with tf.variable_scope(scope) as params:
            self.enc_var, self.enc_scope = self.create_encoder_variables()
            with tf.variable_scope('decoder') as dec_param_scope:

                self.deconv_var = self.create_deconv_variables()
                self.wavenet = WaveNetModel(
                    batch_size=batch_size,
                    dilations=self.dilations,
                    filter_width=self.filter_width,
                    residual_channels=residual_channels,
                    dilation_channels=dilation_channels,
                    quantization_channels=self.quantization_channels,
                    skip_channels=skip_channels,
                    global_condition_channels=gc_cardinality,
                    global_condition_cardinality=gc_cardinality,
                    use_biases=use_biases)

                self.dec_scope = dec_param_scope

            with tf.variable_scope('embed'):
                init = tf.truncated_normal_initializer(stddev=0.01)
                #                 init = tf.constant_initializer(value=np.random.random((self.K, self.D)), dtype=tf.float32)
                self.embeds = tf.get_variable('embedding', [self.K, self.D],
                                              dtype=tf.float32,
                                              initializer=init)

        self.param_scope = params
        self.saver = None
        self.set_saver()

    def create_deconv_variables(self):
        var = None
        if self.upsampling_method.startswith('deconv'):
            var = list()

            tokens = self.upsampling_method.split('-')
            n_step = tokens[0].split('deconv')[1]

            out_channel = int(tokens[1]) if len(tokens) > 1 else 1

            if not n_step:
                n_step = 1
            else:
                n_step = int(n_step)

            assert n_step < 4

            height, width = self.reduced_timestep, self.D
            upscale_factor = 2**self.encode_level

            if n_step == 1:
                upscale_per_step = upscale_factor
            elif n_step == 2:
                upscale_per_step = int(np.sqrt(upscale_factor))
            elif n_step == 3:
                upscale_per_step = int(np.cbrt(upscale_factor))

            h = height
            in_channel = 1
            for step in range(n_step):
                with tf.variable_scope('deconv_layer_{}'.format(step)):
                    layer = dict()

                    h *= upscale_per_step

                    kernel_size = 2 * upscale_per_step - upscale_per_step % 2
                    #                     layer['filter'] = create_variable('deconv_layer_filter', [kernel_size, 1, out_channel, in_channel])
                    layer['filter'] = get_bilinear_filter(
                        [kernel_size, 1, out_channel, in_channel],
                        upscale_per_step,
                        name='deconv_layer_filter')
                    layer['strides'] = [1, upscale_per_step, 1, 1]
                    layer['shape'] = [self.batch_size, h, width, out_channel]
                    if self.use_biases:
                        layer['bias'] = create_bias_variable(
                            'deconv_bias', [out_channel])
                    var.append(layer)

                    in_channel = out_channel
                    out_channel = out_channel * 2
        return var

    def initialize(self, input_batch, sample_size=40960):
        # TODO
        self.batch_size = tf.shape(input_batch)[0]
        self.sample_size = sample_size
        self.reduced_timestep = int(
            np.ceil(self.sample_size / 2**self.encode_level))
        self.initialized = True

    def set_saver(self):
        if self.saver is None:
            save_vars = {
                ('train/' + '/'.join(var.name.split('/')[1:])).split(':')[0]:
                var
                for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                             self.param_scope.name)
            }
            #             for name,var in save_vars.items():
            #                 print(name)
            self.saver = tf.train.Saver(var_list=save_vars, max_to_keep=10)

    def _gc_embedding(self):
        return create_embedding_table(
            'gc_embedding', [self.gc_cardinality, self.gc_cardinality])

    def create_encoder_variables(self):
        with tf.variable_scope('enc') as enc_param_scope:
            var = dict()

            input_channel = 1
            output_channel = self.encoding_channels

            var['enc_conv_stack'] = list()
            for i in range(self.encode_level):
                with tf.variable_scope('encoder_conv_{}'.format(i)):
                    current = dict()
                    if i < self.q_factor:
                        current['filter'] = create_variable(
                            'filter', [4, 4, input_channel, output_channel[i]])
                    else:
                        current['filter'] = create_variable(
                            'filter', [4, 1, input_channel, output_channel[i]])
                    if self.use_biases:
                        current['bias'] = create_bias_variable(
                            'bias', [output_channel[i]])
                    input_channel = output_channel[i]
                    var['enc_conv_stack'].append(current)
        return var, enc_param_scope

    def encode(self, encoded_input_batch):
        encoded_input_batch = tf.expand_dims(encoded_input_batch, -1)

        out = encoded_input_batch

        for i, layer in enumerate(self.enc_var['enc_conv_stack']):
            kernel = layer['filter']
            if i < self.q_factor:
                out = tf.nn.conv2d(out, kernel, [1, 2, 2, 1], padding='SAME')
            else:
                out = tf.nn.conv2d(out, kernel, [1, 2, 1, 1], padding='SAME')

            if self.use_biases:
                out = tf.nn.bias_add(out, layer['bias'])

            if i < (self.encode_level - 1):
                out = tf.nn.elu(out)
#                 out = tf.layers.dropout(out, rate=self.drop_rate, training=self.is_training ,name='enc_dropout_%d' % (i))

        if self.encoding_channels[-1] > 1:
            z_e = tf.reduce_sum(out, -1)
        else:
            z_e = tf.squeeze(out, axis=-1, name='encode_squeeze')

        z_e = tf.nn.tanh(z_e)

        return z_e

    def upsampling(self, z_q):
        dec_input = tf.expand_dims(z_q, -1)
        initial = tf.image.resize_nearest_neighbor(dec_input,
                                                   [self.sample_size, self.D])
        initial = tf.squeeze(initial, axis=-1, name='dec_input_squeeze')

        if self.deconv_var is not None:
            for i, layer in enumerate(self.deconv_var):
                dec_input = tf.nn.conv2d_transpose(dec_input,
                                                   layer['filter'],
                                                   layer['shape'],
                                                   layer['strides'],
                                                   padding='SAME',
                                                   data_format='NHWC',
                                                   name=None)

                if self.use_biases:
                    dec_input = tf.nn.bias_add(dec_input, layer['bias'])

                if i < len(self.deconv_var) - 1:
                    dec_input = tf.layers.batch_normalization(
                        dec_input, training=self.is_training)
                    dec_input = tf.nn.tanh(dec_input)


#                     dec_input = tf.nn.elu(dec_input)

            dec_input = tf.reduce_sum(dec_input, -1)
            dec_input = tf.add(dec_input, initial)
        else:
            dec_input = initial

        return dec_input

    def vq(self, z_e):
        _e = tf.reshape(self.embeds, [1, self.K, self.D])
        _e = tf.tile(_e, [self.batch_size, self.reduced_timestep, 1])

        _t = tf.tile(z_e, [1, 1, self.K])
        _t = tf.reshape(
            _t, [self.batch_size, self.reduced_timestep * self.K, self.D])

        dist = tf.norm(_t - _e, axis=-1)
        dist = tf.reshape(dist, [self.batch_size, -1, self.K])
        k = tf.argmin(dist, axis=-1)
        z_q = tf.gather(self.embeds, k)

        return z_q

    def get_condition(self, input_batch, gc=None):
        with tf.variable_scope('forward'):
            encoded_input_batch, gc = self.preprocess(input_batch, gc=gc)
            self.encoded_input_batch = encoded_input_batch
            self.gc = gc

            # encoding
            z_e = self.encode(encoded_input_batch)

            # VQ-embedding
            z_q = self.vq(z_e)

            # decoding
            lc = self.upsampling(z_q)
        return lc, gc

    def create_model(self, padded_input, gc=None):
        with tf.variable_scope('forward'):

            padded_encoded_input, gc = self.preprocess(padded_input, gc=gc)
            self.gc = gc

            # Cut off the last sample of network input to preserve causality.
            wavenet_input_width = tf.shape(padded_encoded_input)[1] - 1
            wavenet_input = tf.slice(padded_encoded_input, [0, 0, 0],
                                     [-1, wavenet_input_width, -1])

            encoded_input = tf.slice(padded_encoded_input,
                                     [0, self.receptive_field, 0],
                                     [-1, -1, -1],
                                     name="remove_pad")

            self.encoded_input = encoded_input

            # encoding
            self.z_e = self.encode(encoded_input)

            # VQ-embedding
            self.z_q = self.vq(self.z_e)

            # decoding
            lc = self.upsampling(self.z_q)
            self.lc = lc

            paddings = tf.constant([[0, 0], [self.receptive_field - 1, 0],
                                    [0, 0]])
            lc = tf.pad(lc, paddings, "CONSTANT")

            output = self.wavenet._create_network(wavenet_input, lc, gc)

        return output

    def generate_waveform(self,
                          sess,
                          n_samples,
                          lc,
                          gc,
                          seed=None,
                          use_randomness=True):
        sample_placeholder = tf.placeholder(tf.int32)
        lc_placeholder = tf.placeholder(tf.float32)
        gc_placeholder = tf.placeholder(tf.float32)
        next_sample_probs = self.wavenet.predict_proba_incremental(
            sample_placeholder, lc_placeholder, gc_placeholder)
        sess.run(self.wavenet.init_ops)

        operations = [next_sample_probs]
        operations.extend(self.wavenet.push_ops)

        waveform = [128] * (self.receptive_field - 2)
        waveform = np.tile(waveform, (self.batch_size, 1))
        if seed is None:
            seed = []
            for i in range(self.batch_size):
                _seed = np.random.randint(
                    self.quantization_channels) if use_randomness else 128
                seed.append([_seed])

        waveform = np.hstack([waveform, seed])

        for i in range(waveform.shape[1] - 1):
            sample = waveform[:, i]
            lc_sample = np.zeros((self.batch_size, 128))
            sess.run(operations,
                     feed_dict={
                         sample_placeholder: sample,
                         lc_placeholder: lc_sample,
                         gc_placeholder: gc
                     })

        softmax_result = []
        for i in range(n_samples):
            if i > 0 and i % 10000 == 0:
                print("Generating {} of {}.".format(i, n_samples))
                sys.stdout.flush()

            sample = waveform[:, -1]
            lc_sample = lc[:, i, :].reshape(self.batch_size, -1)
            results = sess.run(operations,
                               feed_dict={
                                   sample_placeholder: sample,
                                   lc_placeholder: lc_sample,
                                   gc_placeholder: gc
                               })

            softmax_result.append(np.expand_dims(results[0], 1))
            if use_randomness:
                sample = []
                for k in range(self.batch_size):
                    _sample = np.random.choice(np.arange(
                        self.quantization_channels),
                                               p=results[0][k, :])
                    sample.append([_sample])
            else:
                sample = np.argmax(results[0], axis=1).reshape(-1, 1)

            waveform = np.hstack([waveform, sample])

        waveform = waveform[:, self.receptive_field:]
        softmax_result = np.hstack(softmax_result)
        return waveform, softmax_result

    def _one_hot_encode(self, input_batch):
        with tf.name_scope('one_hot_encode'):
            encoded = tf.one_hot(input_batch, depth=self.quantization_channels)
            encoded = tf.reshape(
                encoded, [self.batch_size, -1, self.quantization_channels])

        return encoded

    def preprocess(self, input_batch, gc=None):
        if not self.initialized:
            self.initialize(input_batch)

        encoded = mu_law(input_batch,
                         quantization_channels=self.quantization_channels)
        encoded = self._one_hot_encode(encoded)

        # gc-embedding
        if self.use_gc and gc is not None:
            gc_embedding_table = self._gc_embedding()
            gc = tf.nn.embedding_lookup(gc_embedding_table, gc)
            gc = tf.reshape(gc, [self.batch_size, 1, self.gc_cardinality],
                            name="gc_embbedding_resize")

        return encoded, gc

    def loss_recon(self, mu_law_output, encoded_target, beta=0.25):
        encoded_output = self._one_hot_encode(mu_law_output)

        output = encoded_output
        target = encoded_target

        target = tf.slice(target, [0, 1, 0], [-1, -1, -1],
                          name="loss_recon_slice_target")
        recon = tf.nn.softmax_cross_entropy_with_logits(logits=output,
                                                        labels=target)
        recon = tf.reduce_mean(recon)

        return recon

    def loss(self, output, beta=0.25):
        recon = tf.nn.softmax_cross_entropy_with_logits(
            logits=output, labels=self.encoded_input)
        recon = tf.reduce_mean(recon)

        z_q = self.z_q
        z_e = self.z_e

        vq = tf.reduce_mean(tf.norm(tf.stop_gradient(z_e) - z_q, axis=-1)**2)
        commit = tf.reduce_mean(
            tf.norm(z_e - tf.stop_gradient(z_q), axis=-1)**2)

        loss = (recon + vq + beta * commit)

        if self.is_training:
            with tf.variable_scope('backward'):
                # Decoder Grads
                decoder_vars = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES, self.dec_scope.name)
                decoder_grads = list(
                    zip(tf.gradients(loss, decoder_vars), decoder_vars))

                # Encoder Grads
                encoder_vars = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES, self.enc_scope.name)
                grad_z = tf.gradients(recon, z_q)
                encoder_grads = [(tf.gradients(z_e, _var, grad_z)[0] +
                                  beta * tf.gradients(commit, _var)[0], _var)
                                 for _var in encoder_vars]

                # Embedding Grads
                embed_grads = list(
                    zip(tf.gradients(vq, self.embeds), [self.embeds]))

                optimizer = tf.train.AdamOptimizer(self.lr)
                self.train_op = optimizer.apply_gradients(
                    decoder_grads + encoder_grads + embed_grads,
                    global_step=self.global_step)

        return loss, recon

    def load(self, sess, model):
        self.saver.restore(sess, model)

    def save(self, sess, logdir, step):
        model_name = 'model.ckpt'
        checkpoint_path = os.path.join(logdir, model_name)
        print('Storing checkpoint to {} ...'.format(logdir), end="")
        sys.stdout.flush()

        if not os.path.exists(logdir):
            os.makedirs(logdir)

        self.saver.save(sess, checkpoint_path, global_step=step)
        print(' Done.')