Пример #1
0
 def __preprocess_forward_pass(self, x):
     self.x_quantized = mu_law(x)
     x_scaled = tf.cast(self.x_quantized, tf.float32) / 128.0
     # Why are we not expanding dim here? b/c we are defaulting to batch size of 1?
     #self.x_scaled = x_scaled
     self.input_mean = tf.reduce_mean(x_scaled)
     return x_scaled
Пример #2
0
 def __getitem__(self):
     """returns the tensor of the segmented data from the specified part of a track"""
     data_and_aug = None
     while data_and_aug is None:
         data = self.get_random_slice()
         if self.augmentation:
             data_and_aug = [data, self.augmentation(data)]
         else:
             data_and_aug = [data, data]
         data_and_aug = [mu_law(x / 2 ** 15) for x in data_and_aug]
     
     return torch.tensor(data_and_aug[0]), torch.tensor(data_and_aug[1])
Пример #3
0
    def __getitem__(self, _):
        ret = None
        while ret is None:
            try:
                ret = self.try_random_slice()
                if self.augmentation:
                    ret = [ret, self.augmentation(ret)]
                else:
                    ret = [ret, ret]

                if self.dataset_name == 'wav':
                    ret = [mu_law(x / 2**15) for x in ret]
            except Exception as e:
                logger.info('Exception %s in dataset __getitem__, path %s', e,
                            self.path)
                logger.debug('Exception in H5Dataset', exc_info=True)

        return torch.tensor(ret[0]), torch.tensor(ret[1])
Пример #4
0
    def __getitem__(self, _):
        # pdb.set_trace()
        ret = None
        while ret is None:
            try:
                ret, midi = self.try_random_slice()
                if self.augmentation:
                    ret = [ret, self.augmentation(ret)]
                else:
                    ret = [ret, ret]

                if self.dataset_name == 'wav':
                    ret = [mu_law(x / 2 ** 15) for x in ret]
            except Exception as e:
                logger.info('Exception %s in dataset __getitem__, path %s', e, self.path)
                logger.debug('Exception in H5Dataset', exc_info=True)

        # print("getitem shapes: ", ret[0].shape, ret[1].shape, midi.shape)
        # if(midi is None):
        #     return torch.tensor(ret[0]), torch.tensor(ret[1]), None
        # return torch.tensor(ret[0]), torch.tensor(ret[1]), torch.LongTensor(midi)
        return ret[0], ret[1], midi
Пример #5
0
def main(args):
    print('Starting')
    matplotlib.use('agg')
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    checkpoints = args.checkpoint.parent.glob(args.checkpoint.name + '_*.pth')
    checkpoints = [c for c in checkpoints if extract_id(c) in args.decoders]
    assert len(checkpoints) >= 1, "No checkpoints found."

    model_args = torch.load(args.model.parent / 'args.pth')[0]
    encoder = wavenet_models.Encoder(model_args)
    encoder.load_state_dict(torch.load(checkpoints[0])['encoder_state'])
    encoder.eval()
    encoder = encoder.cuda()

    decoders = []
    decoder_ids = []
    for checkpoint in checkpoints:
        decoder = WaveNet(model_args)
        decoder.load_state_dict(torch.load(checkpoint)['decoder_state'])
        decoder.eval()
        decoder = decoder.cuda()
        if args.py:
            decoder = WavenetGenerator(decoder,
                                       args.batch_size,
                                       wav_freq=args.rate)
        else:
            decoder = NVWavenetGenerator(decoder,
                                         args.rate * (args.split_size // 20),
                                         args.batch_size, 3)

        decoders += [decoder]
        decoder_ids += [extract_id(checkpoint)]

    xs = []
    assert args.output_next_to_orig ^ (args.output is not None)

    if len(args.files) == 1 and args.files[0].is_dir():
        top = args.files[0]
        file_paths = list(top.glob('**/*.wav')) + list(top.glob('**/*.h5'))
    else:
        file_paths = args.files

    if not args.skip_filter:
        file_paths = [f for f in file_paths if not '_' in str(f.name)]

    for file_path in file_paths:
        if file_path.suffix == '.wav':
            data, rate = librosa.load(file_path, sr=16000)
            assert rate == 16000
            data = utils.mu_law(data)
        elif file_path.suffix == '.h5':
            data = utils.mu_law(h5py.File(file_path, 'r')['wav'][:] / (2**15))
            if data.shape[-1] % args.rate != 0:
                data = data[:-(data.shape[-1] % args.rate)]
            assert data.shape[-1] % args.rate == 0
        else:
            raise Exception(f'Unsupported filetype {file_path}')

        if args.sample_len:
            data = data[:args.sample_len]
        else:
            args.sample_len = len(data)
        xs.append(torch.tensor(data).unsqueeze(0).float().cuda())

    xs = torch.stack(xs).contiguous()
    print(f'xs size: {xs.size()}')

    def save(x, decoder_ix, filepath):
        wav = utils.inv_mu_law(x.cpu().numpy())
        print(f'X size: {x.shape}')
        print(f'X min: {x.min()}, max: {x.max()}')

        if args.output_next_to_orig:
            save_audio(wav.squeeze(),
                       filepath.parent / f'{filepath.stem}_{decoder_ix}.wav',
                       rate=args.rate)
        else:
            save_audio(wav.squeeze(),
                       args.output / str(extract_id(args.model)) /
                       str(args.update) / filepath.with_suffix('.wav').name,
                       rate=args.rate)

    yy = {}
    with torch.no_grad():
        zz = []
        for xs_batch in torch.split(xs, args.batch_size):
            zz += [encoder(xs_batch)]
        zz = torch.cat(zz, dim=0)

        with utils.timeit("Generation timer"):
            for i, decoder_id in enumerate(decoder_ids):
                yy[decoder_id] = []
                decoder = decoders[i]
                for zz_batch in torch.split(zz, args.batch_size):
                    print(zz_batch.shape)
                    splits = torch.split(zz_batch, args.split_size, -1)
                    audio_data = []
                    decoder.reset()
                    for cond in tqdm.tqdm(splits):
                        audio_data += [decoder.generate(cond).cpu()]
                    audio_data = torch.cat(audio_data, -1)
                    yy[decoder_id] += [audio_data]
                yy[decoder_id] = torch.cat(yy[decoder_id], dim=0)
                del decoder

    for decoder_ix, decoder_result in yy.items():
        for sample_result, filepath in zip(decoder_result, file_paths):
            save(sample_result, decoder_ix, filepath)
Пример #6
0
    def build(self, inputs):
        """Build the graph for this configuration.

    Args:
      inputs: A dict of inputs. For training, should contain 'wav'.

    Returns:
      A dict of outputs that includes the 'predictions',
      'init_ops', the 'push_ops', and the 'quantized_input'.
    """
        num_stages = 10
        num_layers = 30
        filter_length = 3
        width = 512
        skip_width = 256
        num_z = 16

        # Encode the source with 8-bit Mu-Law.
        x = inputs['wav']
        batch_size = self.batch_size
        x_quantized = utils.mu_law(x)
        x_scaled = tf.cast(x_quantized, tf.float32) / 128.0
        x_scaled = tf.expand_dims(x_scaled, 2)

        encoding = tf.placeholder(name='encoding',
                                  shape=[batch_size, num_z],
                                  dtype=tf.float32)
        en = tf.expand_dims(encoding, 1)

        init_ops, push_ops = [], []

        ###
        # The WaveNet Decoder.
        ###
        l = x_scaled
        l, inits, pushs = utils.causal_linear(x=l,
                                              n_inputs=1,
                                              n_outputs=width,
                                              name='startconv',
                                              rate=1,
                                              batch_size=batch_size,
                                              filter_length=filter_length)

        for init in inits:
            init_ops.append(init)
        for push in pushs:
            push_ops.append(push)

        # Set up skip connections.
        s = utils.linear(l, width, skip_width, name='skip_start')

        # Residual blocks with skip connections.
        for i in range(num_layers):
            dilation = 2**(i % num_stages)

            # dilated masked cnn
            d, inits, pushs = utils.causal_linear(x=l,
                                                  n_inputs=width,
                                                  n_outputs=width * 2,
                                                  name='dilatedconv_%d' %
                                                  (i + 1),
                                                  rate=dilation,
                                                  batch_size=batch_size,
                                                  filter_length=filter_length)

            for init in inits:
                init_ops.append(init)
            for push in pushs:
                push_ops.append(push)

            # local conditioning
            d += utils.linear(en,
                              num_z,
                              width * 2,
                              name='cond_map_%d' % (i + 1))

            # gated cnn
            assert d.get_shape().as_list()[2] % 2 == 0
            m = d.get_shape().as_list()[2] // 2
            d = tf.sigmoid(d[:, :, :m]) * tf.tanh(d[:, :, m:])

            # residuals
            l += utils.linear(d, width, width, name='res_%d' % (i + 1))

            # skips
            s += utils.linear(d, width, skip_width, name='skip_%d' % (i + 1))

        s = tf.nn.relu(s)
        s = (utils.linear(s, skip_width, skip_width, name='out1') +
             utils.linear(en, num_z, skip_width, name='cond_map_out1'))
        s = tf.nn.relu(s)

        ###
        # Compute the logits and get the loss.
        ###
        logits = utils.linear(s, skip_width, 256, name='logits')
        logits = tf.reshape(logits, [-1, 256])
        probs = tf.nn.softmax(logits, name='softmax')

        return {
            'init_ops': init_ops,
            'push_ops': push_ops,
            'predictions': probs,
            'encoding': encoding,
            'quantized_input': x_quantized,
        }
Пример #7
0
    def build(self, inputs, is_training):
        """Build the graph for this configuration.

    Args:
      inputs: A dict of inputs. For training, should contain 'wav'.
      is_training: Whether we are training or not. Not used in this config.

    Returns:
      A dict of outputs that includes the 'predictions', 'loss', the 'encoding',
      the 'quantized_input', and whatever metrics we want to track for eval.
    """
        del is_training
        num_stages = 10
        num_layers = 30
        filter_length = 3
        width = 512
        skip_width = 256
        ae_num_stages = 10
        ae_num_layers = 30
        ae_filter_length = 3
        ae_width = 128

        # Encode the source with 8-bit Mu-Law.
        x = inputs['wav']
        x_quantized = utils.mu_law(x)
        x_scaled = tf.cast(x_quantized, tf.float32) / 128.0
        x_scaled = tf.expand_dims(x_scaled, 2)

        ###
        # The Non-Causal Temporal Encoder.
        ###
        en = masked.conv1d(x_scaled,
                           causal=False,
                           num_filters=ae_width,
                           filter_length=ae_filter_length,
                           name='ae_startconv')

        for num_layer in range(ae_num_layers):
            dilation = 2**(num_layer % ae_num_stages)
            d = tf.nn.relu(en)
            d = masked.conv1d(d,
                              causal=False,
                              num_filters=ae_width,
                              filter_length=ae_filter_length,
                              dilation=dilation,
                              name='ae_dilatedconv_%d' % (num_layer + 1))
            d = tf.nn.relu(d)
            en += masked.conv1d(d,
                                num_filters=ae_width,
                                filter_length=1,
                                name='ae_res_%d' % (num_layer + 1))

        en = masked.conv1d(en,
                           num_filters=self.ae_bottleneck_width,
                           filter_length=1,
                           name='ae_bottleneck')
        en = masked.pool1d(en, self.ae_hop_length, name='ae_pool', mode='avg')
        encoding = en

        ###
        # The WaveNet Decoder.
        ###
        l = masked.shift_right(x_scaled)
        l = masked.conv1d(l,
                          num_filters=width,
                          filter_length=filter_length,
                          name='startconv')

        # Set up skip connections.
        s = masked.conv1d(l,
                          num_filters=skip_width,
                          filter_length=1,
                          name='skip_start')

        # Residual blocks with skip connections.
        for i in range(num_layers):
            dilation = 2**(i % num_stages)
            d = masked.conv1d(l,
                              num_filters=2 * width,
                              filter_length=filter_length,
                              dilation=dilation,
                              name='dilatedconv_%d' % (i + 1))
            d = self._condition(
                d,
                masked.conv1d(en,
                              num_filters=2 * width,
                              filter_length=1,
                              name='cond_map_%d' % (i + 1)))

            assert d.get_shape().as_list()[2] % 2 == 0
            m = d.get_shape().as_list()[2] // 2
            d_sigmoid = tf.sigmoid(d[:, :, :m])
            d_tanh = tf.tanh(d[:, :, m:])
            d = d_sigmoid * d_tanh

            l += masked.conv1d(d,
                               num_filters=width,
                               filter_length=1,
                               name='res_%d' % (i + 1))
            s += masked.conv1d(d,
                               num_filters=skip_width,
                               filter_length=1,
                               name='skip_%d' % (i + 1))

        s = tf.nn.relu(s)
        s = masked.conv1d(s,
                          num_filters=skip_width,
                          filter_length=1,
                          name='out1')
        s = self._condition(
            s,
            masked.conv1d(en,
                          num_filters=skip_width,
                          filter_length=1,
                          name='cond_map_out1'))
        s = tf.nn.relu(s)

        ###
        # Compute the logits and get the loss.
        ###
        logits = masked.conv1d(s,
                               num_filters=256,
                               filter_length=1,
                               name='logits')
        logits = tf.reshape(logits, [-1, 256])
        probs = tf.nn.softmax(logits, name='softmax')
        x_indices = tf.cast(tf.reshape(x_quantized, [-1]), tf.int32) + 128
        loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=x_indices, name='nll'),
                              0,
                              name='loss')

        return {
            'predictions': probs,
            'loss': loss,
            'eval': {
                'nll': loss
            },
            'quantized_input': x_quantized,
            'encoding': encoding,
        }
Пример #8
0
  def build(self, inputs, K, D, beta, global_step):
    """Build the graph for this configuration.

    Args:
      inputs: A dict of inputs. For training, should contain 'wav'.
      is_training: Whether we are training or not. Not used in this config.

    Returns:
      A dict of outputs that includes the 'predictions', 'loss', the 'encoding',
      the 'quantized_input', and whatever metrics we want to track for eval.
    """
    # Decoder
    num_stages = 10
    num_layers = 10
    filter_length = 3
    width = 512
    skip_width = 256
    # Encoder
    ae_num_stages = 10
    ae_num_layers = 10
    ae_filter_length = 3
    ae_width = 128
    self.ae_bottleneck_width = D
    self.beta = beta
    lr = 2e-4

    with tf.variable_scope('forward'):
        # Encode the source with 8-bit Mu-Law.
        #x = inputs['wav']
        x = inputs
        x_quantized = utils.mu_law(x)
        x_scaled = tf.cast(x_quantized, tf.float32) / 128.0

        with tf.variable_scope('embed') :
            embeds = tf.get_variable('embed', [K, D],
                    initializer=tf.truncated_normal_initializer(stddev=0.02))

        ###
        # The Non-Causal Temporal Encoder.
        ###

        with tf.variable_scope('enc') as enc_param_scope:
            en = masked.conv1d(
                x_scaled,
                causal=False,
                num_filters=ae_width,
                filter_length=ae_filter_length,
                name='ae_startconv')

            for num_layer in range(ae_num_layers):
              #dilation = 2**(num_layer % ae_num_stages)
              dilation = 1
              d = tf.nn.relu(en)
              d = masked.conv1d(
                  d,
                  causal=False,
                  num_filters=ae_width,
                  filter_length=ae_filter_length,
                  dilation=dilation,
                  name='ae_dilatedconv_%d' % (num_layer + 1))
              d = tf.nn.relu(d)
              en += masked.conv1d(
                  d,
                  num_filters=ae_width,
                  filter_length=1,
                  name='ae_res_%d' % (num_layer + 1))

            en = masked.conv1d(
                en,
                num_filters=self.ae_bottleneck_width,
                filter_length=1,
                name='ae_bottleneck')
            en = masked.pool1d(en, self.ae_hop_length, name='ae_pool', mode='avg')

        self.enc_param_scope = enc_param_scope

        encoding = en

        z_e = encoding

        _t = tf.tile(tf.expand_dims(z_e, -2), [1, 1, K, 1]) #[batch,latent_h,latent_w,K,D]
        _e = tf.reshape(embeds, [1, 1, K, D])
        _t = tf.norm(_t - _e, axis = -1)
        k = tf.argmin(_t, axis = -1) # -> [latent_h,latent_w]
        self.k = k
        z_q = tf.gather(embeds, k)

        self.z_e = z_e
        self.k = k
        self.z_q = z_q

        en = self.z_q

        ###
        # The WaveNet Decoder.
        ###
        with tf.variable_scope('dec') as dec_param_scope:
            l = masked.shift_right(x_scaled)
            l = masked.conv1d(
                l, num_filters=width, filter_length=filter_length, name='startconv')

            # Set up skip connections.
            s = masked.conv1d(
                l, num_filters=skip_width, filter_length=1, name='skip_start')

            # Residual blocks with skip connections.
            for i in range(num_layers):
              dilation = 2**(i % num_stages)
              d = masked.conv1d(
                  l,
                  num_filters=2 * width,
                  filter_length=filter_length,
                  dilation=dilation,
                  name='dilatedconv_%d' % (i + 1))
              d = self._condition(d,
                                  masked.conv1d(
                                      en,
                                      num_filters=2 * width,
                                      filter_length=1,
                                      name='cond_map_%d' % (i + 1)))

              assert d.get_shape().as_list()[2] % 2 == 0
              m = d.get_shape().as_list()[2] // 2
              d_sigmoid = tf.sigmoid(d[:, :, :m])
              d_tanh = tf.tanh(d[:, :, m:])
              d = d_sigmoid * d_tanh

              l += masked.conv1d(
                  d, num_filters=width, filter_length=1, name='res_%d' % (i + 1))
              s += masked.conv1d(
                  d, num_filters=skip_width, filter_length=1, name='skip_%d' % (i + 1))

            s = tf.nn.relu(s)
            s = masked.conv1d(s, num_filters=skip_width, filter_length=1, name='out1')
            s = self._condition(s,
                                masked.conv1d(
                                    en,
                                    num_filters=skip_width,
                                    filter_length=1,
                                    name='cond_map_out1'))
            s = tf.nn.relu(s)

        self.dec_param_scope = dec_param_scope

        ###
        # Compute the logits and get the loss.
        ###
        logits = masked.conv1d(s, num_filters=256, filter_length=1, name='logits')
        logits = tf.reshape(logits, [-1, 256])
        probs = tf.nn.softmax(logits, name='softmax')
        x_indices = tf.cast(tf.reshape(x_quantized, [-1]), tf.int32) + 128

        self.recon = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=logits, labels=x_indices, name='nll'),
            0,
            name='recon_loss')

        self.vq = tf.reduce_mean(
                    tf.norm(tf.stop_gradient(self.z_e) - z_q,axis=-1)**2,
                    axis=[0,1])
        self.commit = tf.reduce_mean(
            tf.norm(self.z_e - tf.stop_gradient(z_q),axis=-1)**2,
            axis=[0,1])
        self.loss = self.recon + self.vq + beta * self.commit
        #self.loss = self.recon

    with tf.variable_scope('backward'):
      # Decoder grads
      decoder_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.dec_param_scope.name)
      decoder_grads = list(zip(
          tf.clip_by_global_norm(tf.gradients(self.loss,decoder_vars), 5.0)[0],
          decoder_vars))
      #decoder_grads = [(tf.clip_by_value(grad, -1.0, 1.0), var) for grad, var in
      #    decoder_grads if grad is not None]

      # Encoder Grads
      encoder_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.enc_param_scope.name)

      grad_z = tf.gradients(self.recon,self.z_q)

      grads = [tf.gradients(self.z_e,var,grad_z)[0] + self.beta*tf.gradients(self.commit,var)[0] for var in encoder_vars]
      grads, _ = tf.clip_by_global_norm(grads, 5.0)
      encoder_grads = list(zip(grads, encoder_vars))

      # Embedding Grads
      embed_grads = list(zip(
          tf.clip_by_global_norm(tf.gradients(self.vq,embeds), 5.0)[0],
          [embeds]))

      #ema = tf.train.ExponentialMovingAverage(decay=0.9999,
      #        num_updates=global_step)

      #optimizer = tf.train.SyncReplicasOptimizer(
      #        tf.train.AdamOptimizer(lr, epsilon=1e-8),
      #        1,
      #        total_num_replicas=1,
      #        variable_averages=ema,
      #        variables_to_average=tf.trainable_variables())
      optimizer = tf.train.AdamOptimizer(lr, epsilon=1e-8)

      self.train_op= optimizer.apply_gradients(
              decoder_grads + encoder_grads + embed_grads,
              global_step=global_step)

      #self.train_op = optimizer.minimize(self.loss, global_step=global_step,
      #        name="train", colocate_gradients_with_ops=True)

    return {
        'predictions': probs,
        'loss': self.loss,
        'train_op': self.train_op,
        'eval': {
            'nll': self.loss
        },
        'quantized_input': x_quantized,
        'encoding': encoding,
    }
Пример #9
0
import sys

import numpy as np
import librosa
import chainer

from models import VAE
from utils import mu_law
from utils import Preprocess
import opt

model = VAE(opt.d, opt.k, opt.n_loop, opt.n_layer, opt.n_filter, opt.mu,
            opt.n_channel1, opt.n_channel2, opt.n_channel3, opt.beta, True)
chainer.serializers.load_npz(sys.argv[1], model, 'updater/model:main/')

n = 1
x = np.expand_dims(
    Preprocess(opt.data_format, opt.sr, opt.mu, opt.top_db, opt.sr * 3,
               False)(sys.argv[2])[0], 0)
output = model.generate(x)
wave = mu_law(opt.mu).itransform(output)
np.save('result.npy', wave)
librosa.output.write_wav('result.wav', wave, opt.sr)
Пример #10
0
    def __init__(self,
                 lr,
                 global_step,
                 beta,
                 x,
                 K,
                 D,
                 arch_fn,
                 sess,
                 param_scope,
                 is_training=False):
        with tf.variable_scope(param_scope):
            enc_spec, enc_param_scope, dec_spec, dec_param_scope = arch_fn(D)
            with tf.variable_scope('embed'):
                embeds = tf.get_variable(
                    'embed', [K, D],
                    initializer=tf.truncated_normal_initializer(stddev=0.02))

        with tf.variable_scope('forward') as forward_scope:
            # Encoder Pass
            x_quantized = mu_law(x)
            x_scaled = tf.cast(x_quantized, tf.float32) / 128.0
            # Why are we not expanding dim here? b/c we are defaulting to batch size of 1?
            #self.x_scaled = x_scaled

            self.input_mean = tf.reduce_mean(x_scaled)

            _t = x_scaled
            for block in enc_spec:
                _t = block(_t)
            z_e = _t

            self.z_e_mean = tf.reduce_mean(z_e)

            # Middle Area (Compression or Discretize)
            # TODO: Gross.. use brodcast instead!

            _t = tf.tile(tf.expand_dims(z_e, -2),
                         [1, 1, K, 1])  #[batch,latent_h,latent_w,K,D]
            _e = tf.reshape(embeds, [1, 1, K, D])
            _t = tf.norm(_t - _e, axis=-1)
            k = tf.argmin(_t, axis=-1)  # -> [latent_h,latent_w]
            self.k = k
            z_q = tf.gather(embeds, k)

            self.z_q_mean = tf.reduce_mean(z_q)

            self.z_e = z_e  # -> [batch,latent_h,latent_w,D]
            self.k = k
            self.z_q = z_q  # -> [batch,latent_h,latent_w,D]

            # End early
            #return

            # Decoder Pass
            # Copy just to be safe...?
            _t = tf.identity(z_q)

            # THINGS TO DO
            # 1. check if x is right dimension, no need to expand dim?
            # 2. check if s is same dim as x
            # 3. add conditional on speaker id (can do after reconstruction)

            num_stages = 10  # Has to do with dilation stages
            num_layers = 1  # Could lower the amount of layers
            filter_length = 3
            width = 512
            skip_width = 256

            with tf.variable_scope('dec') as dec_param_scope:
                # May need to have x be an expanded dim
                l = masked.shift_right(x_scaled)

                self.l0 = tf.reduce_mean(tf.identity(l))

                l, W_mean, b_mean = masked.conv1d_log(
                    l,
                    num_filters=width,
                    filter_length=filter_length,
                    name='startconv_dec')
                self.W_mean = W_mean
                self.b_mean = b_mean

                self.l1 = tf.reduce_mean(tf.identity(l))

                # Skip connection
                s = masked.conv1d(l,
                                  num_filters=skip_width,
                                  filter_length=1,
                                  name='skip_start_dec')

                self.s0 = tf.reduce_mean(tf.identity(s))

                # Residual blocks with skip connection
                for i in xrange(num_layers):
                    dilation = 2**(i % num_stages)
                    d = masked.conv1d(l,
                                      num_filters=2 * width,
                                      filter_length=filter_length,
                                      dilation=dilation,
                                      name='dilatedconv_%d' % (i + 1))

                    self.d0 = tf.reduce_mean(tf.identity(d))

                    # Condition on z_q
                    d = self._condition(
                        d,
                        masked.conv1d(_t,
                                      num_filters=2 * width,
                                      filter_length=1,
                                      name='cond_map_%d' % (i + 1)))
                    self.d1 = tf.reduce_mean(tf.identity(d))

                    assert d.get_shape().as_list()[2] % 2 == 0

                    m = d.get_shape().as_list()[2] // 2
                    d_sigmoid = tf.sigmoid(d[:, :, :m])
                    d_tanh = tf.tanh(d[:, :, m:])
                    d = d_sigmoid * d_tanh

                    self.d2 = tf.reduce_mean(tf.identity(d))

                    l += masked.conv1d(d,
                                       num_filters=width,
                                       filter_length=1,
                                       name='res_%d' % (i + 1))
                    self.l2 = tf.reduce_mean(tf.identity(l))

                    s += masked.conv1d(d,
                                       num_filters=skip_width,
                                       filter_length=1,
                                       name='skip_%d' % (i + 1))
                    self.s1 = tf.reduce_mean(tf.identity(s))

                s = tf.nn.relu(s)
                s = masked.conv1d(s,
                                  num_filters=skip_width,
                                  filter_length=1,
                                  name='out1')
                self.s2 = tf.reduce_mean(tf.identity(s))

                # Condition on z_q again.
                s = self._condition(
                    s,
                    masked.conv1d(_t,
                                  num_filters=skip_width,
                                  filter_length=1,
                                  name='cond_map_out1'))
                s = tf.nn.relu(s)
                self.s3 = tf.reduce_mean(tf.identity(s))

                # Should this parameter be trainable...?
                logits = masked.conv1d(s,
                                       num_filters=256,
                                       filter_length=1,
                                       name='logits')
                self.logits_mean = tf.reduce_mean(tf.identity(logits))

            # Losses
            # CHECK AXES FOR REDUCE MEAN ON RECON LOSS
            logits = tf.reshape(logits, [-1, 256])

            #probs = tf.nn.softmax(logits, name='softmax')
            x_indices = tf.cast(tf.reshape(x_quantized, [-1]), tf.int32) + 128
            self.indices_mean = tf.reduce_mean(x_indices)

            self.recon = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=logits, labels=x_indices, name='nll'),
                0,
                name='recon_loss')

            # Reconstruction loss for images.
            #self.recon = tf.reduce_mean((self.p_x_z - x) ** 2, axis=[0,1,2,3])

            sg_e = tf.stop_gradient(self.z_e)
            sg_norm = tf.norm(sg_e - z_q, axis=-1)**2

            self.vq = tf.reduce_mean(sg_norm, axis=[0, 1])
            self.commit = tf.reduce_mean(tf.norm(self.z_e -
                                                 tf.stop_gradient(z_q),
                                                 axis=-1)**2,
                                         axis=[0, 1])
            self.loss = self.recon + self.vq + beta * self.commit

            # NLL
            # TODO: is it correct impl?
            # it seems tf.reduce_prod(tf.shape(self.z_q)[1:2]) should be multipled
            # in front of log(1/K) if we assume uniform prior on z.
            #self.nll = -1.*(tf.reduce_mean(tf.log(self.p_x_z),axis=[1,2]) + tf.log(1/tf.cast(K,tf.float32)))/tf.log(2.)

        if (is_training):
            with tf.variable_scope('backward'):
                # Decoder Grads
                decoder_vars = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES, dec_param_scope.name)
                decoder_grads = list(
                    zip(tf.gradients(self.loss, decoder_vars), decoder_vars))

                # Encoder Grads
                encoder_vars = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES, enc_param_scope.name)
                grad_z = tf.gradients(self.recon, z_q)
                encoder_grads = [
                    (tf.gradients(z_e, var, grad_z)[0] +
                     beta * tf.gradients(self.commit, var)[0], var)
                    for var in encoder_vars
                ]
                # Embedding Grads
                embed_grads = list(zip(tf.gradients(self.vq, embeds),
                                       [embeds]))

                optimizer = tf.train.AdamOptimizer(lr)
                self.train_op = optimizer.apply_gradients(
                    decoder_grads + encoder_grads + embed_grads,
                    global_step=global_step)
        else:
            # Another decoder pass that we can play with!
            self.latent = tf.placeholder(tf.int64, [None, 3, 3])
            _t = tf.gather(embeds, self.latent)
            for block in dec_spec:
                _t = block(_t)
            self.gen = _t

        all_save_vars = tf.get_collection(
            tf.GraphKeys.GLOBAL_VARIABLES,
            param_scope.name) + tf.get_collection(
                tf.GraphKeys.GLOBAL_VARIABLES, dec_param_scope.name)

        save_vars = {
            ('train/' + '/'.join(var.name.split('/')[1:])).split(':')[0]: var
            for var in all_save_vars
        }
        #for name,var in save_vars.items():
        #    print(name,var)

        self.saver = tf.train.Saver(var_list=save_vars)