예제 #1
0
        def process_noise_t_f(match_to, scope_name):
            n_fmaps = match_to.shape[3].value
            # Project the noise to fit the conv
            noise_proj = dense(noiseemb, n_fmaps, 0.0, 0.0, True,
                               scope_name + '_noise_emb')  # [mb, n_fmaps]
            noise_proj = tf.expand_dims(noise_proj, 1)
            noise_proj = tf.expand_dims(noise_proj, 1)  # [mb, 1, 1, n_fmaps]

            clean_proj = dense(cleanemb, n_fmaps, 0.0, 0.0, True,
                               scope_name + '_clean_emb')  # [mb, n_fmaps]
            clean_proj = tf.expand_dims(clean_proj, 1)
            clean_proj = tf.expand_dims(clean_proj, 1)  # [mb, 1, 1, n_fmaps]

            # Get the time and frequency embedding
            ts, fs = match_to.shape[1].value, match_to.shape[2].value
            tout = cont_embed(ts, n_fmaps,
                              scope_name + '_temb')  # [ts, n_fmaps]
            tout = tf.expand_dims(tout, 1)
            tout = tf.expand_dims(tout, 0)  # [1, time, 1, n_fmaps]
            fout = cont_embed(fs, n_fmaps,
                              scope_name + '_femb')  # [fs, n_fmaps]
            fout = tf.expand_dims(fout, 0)
            fout = tf.expand_dims(fout, 0)  # [1, 1, freq, n_fmaps]

            return noise_proj, clean_proj, tout, fout
예제 #2
0
 def cont_embed(n, out_dim, scope_name):
   out = tf.constant(list(range(0, n)), dtype=tf.float32) # [n]
   out = tf.reshape(out, [n, 1])  # [n, 1]
   out = dense(out, 50, FLAGS.w_std, 0.0, False, scope_name + '_dense1') # [n, 50]
   out = batch_norm(istrain, out, scope_name + scope_name + '_dense1')
   out = tf.nn.relu(out)
   out = dense(out, 50, FLAGS.w_std, 0.0, False, scope_name + '_dense2') # [n, 50]
   out = batch_norm(istrain, out, scope_name + scope_name + '_dense2')
   out = tf.nn.relu(out)
   out = dense(out, out_dim, 0.0, 0.0, False, scope_name + '_dense3') # [n, out_dim]
   return out
예제 #3
0
def model(inputs, istrain):
  target, mixed, mixedph, targetph, pos, posph, neg, negph, noiseposcontext, noisenegcontext, location, cleanpath, noisepospath, noisenegpath, snr_pos, snr_neg = inputs
  nfeat = target.shape[2].value

  def noise_resnet_block(inputs, kernel_size, stride, n_fmaps, scope_name):
    '''
    Embedding network
    :param inputs : context input
    :return out   : embedding vector represents the context information
    '''
    # The transformation path
    path1 = conv2d(inputs, kernel_size, [1] + stride + [1], n_fmaps, FLAGS.w_std, FLAGS.b_init, False, 'SAME', scope_name + '_conv1')
    path1 = batch_norm(istrain, path1, scope_name + '_conv1')
    path1 = tf.nn.relu(path1)
    path1 = conv2d(path1, kernel_size, [1, 1, 1, 1], n_fmaps, FLAGS.w_std, FLAGS.b_init, True, 'SAME', scope_name + '_conv2')
    
    # The identity path
    n_input_channels = inputs.shape.as_list()[3]
    if n_input_channels == n_fmaps:
      path2 = inputs
    else:
      path2 = conv2d(inputs, [1, 1], [1] + stride + [1], n_fmaps, FLAGS.w_std, FLAGS.b_init, True, 'SAME', scope_name + '_transform')
    
    # Add and return 
    assert path1.shape.as_list() == path2.shape.as_list()
    out = path1 + path2
    out = batch_norm(istrain, out, scope_name + '_addition')
    out = tf.nn.relu(out)
    return out
    
  def resnet_block(inputs, noiseposemb, noisenegemb, kernel_size, stride, n_fmaps, scope_name):
    '''
    Residual block to process noisy signals, with injection of embedding vectors
    :param inputs      : input feature maps
    :param noiseposemb : positive embedding vector
    :param noisenegemb : negative embedding vector
    :param n_fmaps     : number of channels
    :return out        : output feature maps
    '''
    def cont_embed(n, out_dim, scope_name):
      out = tf.constant(list(range(0, n)), dtype=tf.float32) # [n]
      out = tf.reshape(out, [n, 1])  # [n, 1]
      out = dense(out, 50, FLAGS.w_std, 0.0, False, scope_name + '_dense1') # [n, 50]
      out = batch_norm(istrain, out, scope_name + scope_name + '_dense1')
      out = tf.nn.relu(out)
      out = dense(out, 50, FLAGS.w_std, 0.0, False, scope_name + '_dense2') # [n, 50]
      out = batch_norm(istrain, out, scope_name + scope_name + '_dense2')
      out = tf.nn.relu(out)
      out = dense(out, out_dim, 0.0, 0.0, False, scope_name + '_dense3') # [n, out_dim]
      return out
    
    def process_noise_t_f(match_to, scope_name):
      n_fmaps = match_to.shape[3].value
      # Project the noise to fit the conv
      noisepos_proj = dense(noiseposemb, n_fmaps, 0.0, 0.0, True, scope_name + '_noise_pos_emb') # [mb, n_fmaps]
      noisepos_proj = tf.expand_dims(noisepos_proj, 1)
      noisepos_proj = tf.expand_dims(noisepos_proj, 1)  # [mb, 1, 1, n_fmaps]

      noiseneg_proj = dense(noisenegemb, n_fmaps, 0.0, 0.0, True, scope_name + '_noise_neg_emb') # [mb, n_fmaps]
      noiseneg_proj = tf.expand_dims(noiseneg_proj, 1)
      noiseneg_proj = tf.expand_dims(noiseneg_proj, 1)  # [mb, 1, 1, n_fmaps]
      
      # Get the time and frequency embedding
      ts, fs = match_to.shape[1].value, match_to.shape[2].value
      tout = cont_embed(ts, n_fmaps, scope_name + '_temb')  # [ts, n_fmaps]
      tout = tf.expand_dims(tout, 1)
      tout = tf.expand_dims(tout, 0) # [1, time, 1, n_fmaps]
      fout = cont_embed(fs, n_fmaps, scope_name + '_femb')  # [fs, n_fmaps]
      fout = tf.expand_dims(fout, 0)
      fout = tf.expand_dims(fout, 0) # [1, 1, freq, n_fmaps]
      
      return noisepos_proj, noiseneg_proj, tout, fout
    
    # The transformation path
    path1 = conv2d(inputs, [kernel_size, kernel_size], [1, stride, stride, 1], 
                   n_fmaps, FLAGS.w_std, FLAGS.b_init, False, 
                   'SAME', scope_name + '_conv1')  # [mb, time, freq, n_fmaps]
    noisepos_proj1, noiseneg_proj1, tout1, fout1 = process_noise_t_f(path1, scope_name + '_conv1')
    path1 = path1 + noisepos_proj1 + noiseneg_proj1 + tout1 + fout1
    path1 = batch_norm(istrain, path1, scope_name + '_conv1')
    path1 = tf.nn.relu(path1)
    path1 = conv2d(path1, [kernel_size, kernel_size], [1,1,1,1], n_fmaps, FLAGS.w_std, FLAGS.b_init, True, 'SAME', scope_name + '_conv2')
    noisepos_proj2, noiseneg_proj2, tout2, fout2 = process_noise_t_f(path1, scope_name + '_conv2')
    path1 = path1 + noisepos_proj2 + noiseneg_proj2 + tout2 + fout2

    # The identity path
    n_input_channels = inputs.shape.as_list()[3]
    if n_input_channels == n_fmaps:
      path2 = inputs
    else:
      path2 = conv2d(inputs, [1, 1], [1, stride, stride, 1], n_fmaps, FLAGS.w_std, FLAGS.b_init, True, 'SAME', scope_name + '_transform')
    
    # Add and return 
    assert path1.shape.as_list() == path2.shape.as_list()
    out = path1 + path2
    out = batch_norm(istrain, out, scope_name + '_addition')
    out = tf.nn.relu(out)
    return out


  # The positive noise embedding
  with tf.variable_scope('embedding'):
    nout = None
    nout = noiseposcontext # [mb, noise frames, 201]
    nout = tf.expand_dims(nout, 3)
    nout = noise_resnet_block(nout, [8, 4], [3, 2], 64, 'noise_resblock1_1')  # [mb, noise frames, 201, 64]
    nout = noise_resnet_block(nout, [8, 4], [3, 2], 128, 'noise_resblock2_1') # [mb, noise frames / 2, 201 / 2, 64]
    nout = noise_resnet_block(nout, [4, 4], [1, 1], 256, 'noise_resblock3_1') # [mb, noise frames / 4, 201 / 4, 64]
    nout = noise_resnet_block(nout, [4, 4], [1, 2], 512, 'noise_resblock4_1') # [mb, noise frames / 8, 201 / 8, 512]
    nout = tf.nn.avg_pool(nout, [1, nout.shape[1].value, nout.shape[2].value, 1], [1, 1, 1, 1], 'VALID') # [mb, 1, 1, 512]
    assert nout.shape.as_list()[1:3] == [1, 1]
    noiseposemb = nout[:, 0, 0, :]    # [mb, 512]

  # The negative noise embedding
  with tf.variable_scope('embedding', reuse=True):
    nout = None
    nout = noisenegcontext # [mb, noise frames, 201]
    nout = tf.expand_dims(nout, 3)
    nout = noise_resnet_block(nout, [8, 4], [3, 2], 64, 'noise_resblock1_1') # [mb, noise frames, 201, 64]
    nout = noise_resnet_block(nout, [8, 4], [3, 2], 128, 'noise_resblock2_1') # [mb, noise frames / 2, 201 / 2, 64]
    nout = noise_resnet_block(nout, [4, 4], [1, 1], 256, 'noise_resblock3_1') # [mb, noise frames / 4, 201 / 4, 64]
    nout = noise_resnet_block(nout, [4, 4], [1, 2], 512, 'noise_resblock4_1') # [mb, noise frames / 8, 201 / 8, 512]
    nout = tf.nn.avg_pool(nout, [1, nout.shape[1].value, nout.shape[2].value, 1], [1, 1, 1, 1], 'VALID') # [mb, 1, 1, 512]
    assert nout.shape.as_list()[1:3] == [1, 1]
    noisenegemb = nout[:, 0, 0, :]    # [mb, 512]


  # Processing the mixed signal
  out = mixed # [mb, context frames, 201]
  out = tf.expand_dims(out, 3)
  out = resnet_block(out, noiseposemb, noisenegemb, 4, 1, 64, 'resblock1_1')
  out = resnet_block(out, noiseposemb, noisenegemb, 4, 1, 64, 'resblock1_2')
  out = resnet_block(out, noiseposemb, noisenegemb, 4, 2, 128, 'resblock2_1')
  out = resnet_block(out, noiseposemb, noisenegemb, 4, 1, 128, 'resblock2_2')
  out = resnet_block(out, noiseposemb, noisenegemb, 3, 2, 256, 'resblock3_1')
  out = resnet_block(out, noiseposemb, noisenegemb, 3, 1, 256, 'resblock3_2')
  out = resnet_block(out, noiseposemb, noisenegemb, 3, 2, 512, 'resblock4_1')
  out = resnet_block(out, noiseposemb, noisenegemb, 3, 1, 512, 'resblock4_2') # [mb, context frames / 8, 201 / 8, 512]

  # final layers
  out = conv2d(out, [out.shape[1].value, 1], [1, 1, 1, 1],
              512, FLAGS.w_std, FLAGS.b_init, False,
              'VALID', 'last_conv')                       # [mb, 1, 201 / 8, 512]
  out = batch_norm(istrain, out, 'last_conv')
  out = tf.nn.relu(out)
  out = flatten(out)                                      # [mb,  (201 / 8) * 512]
  out = dense(out, nfeat, 0.0, 0.0, True, 'last_dense')   # [mb, 201]
  mixed_central = mixed[:, FLAGS.window_frames // 2, :]   # [mb, 201]
  pos_central = pos[:, FLAGS.window_frames // 2, :]       # [mb, 201]
  neg_central = neg[:, FLAGS.window_frames // 2, :]       # [mb, 201]
  denoised = mixed_central + out                          # [mb, 201]
  
  # Loss
  se = tf.square(denoised - target[:, 0, :])              # [mb, 201]
  imp_factor = np.linspace(2, 1, nfeat, dtype=np.float32).reshape((1, nfeat))
  example_loss = tf.reduce_mean(se * tf.constant(imp_factor), axis=1)
  loss = tf.reduce_mean(example_loss)
  
  monitors = {'loss': loss}
  outputs = {'loss': example_loss, 'mixed': mixed_central, 'denoised': denoised, 'target': target[:, 0, :],
             'mixedph': mixedph[:, 0, :], 'targetph': targetph[:, 0, :], 'pos': pos_central, 'neg': neg_central, 'posph': posph[:, 0, :], 'negph': negph[:, 0, :], 'location': location, 'cleanpath': cleanpath,
             'noisepospath': noisepospath, 'noisenegpath': noisenegpath, 'snr_pos': snr_pos, 'snr_neg': snr_neg}
  return loss, monitors, outputs
예제 #4
0
    def __init__(self, data, train=False, save=False, load=False):
        epochs = 2000
        learning_rate = .001
        weight_decay = .0005
        batch_size = 100
        early_stop = False
        num_train = 55000

        f = [5, 5]
        k = [20, 50, 800, 500]

        X = tf.placeholder(tf.float32, [None, 28, 28, 1])
        y = tf.placeholder(tf.float32, [None, 10])

        lambdas = [10, .5, .1, .1, .1]
        # lambdas = [1 / num_train for l in lambdas]
        # lambdas = [1., 1., 1., 1., 1.]

        # conv1
        conv1 = blocks.L0Conv2d('conv1', [f[0], f[0], 1, k[0]],
                                weight_decay=weight_decay,
                                lambd=lambdas[0])

        # conv2
        conv2 = blocks.L0Conv2d('conv2', [f[1], f[1], k[0], k[1]],
                                weight_decay=weight_decay,
                                lambd=lambdas[1])

        # fc1, after 2 maxpools
        fc1 = blocks.L0Dense('fc1', [7 * 7 * k[1], k[2]],
                             weight_decay=weight_decay,
                             lambd=lambdas[2])

        # fc2
        fc2 = blocks.L0Dense('fc2', [k[2], k[3]],
                             weight_decay=weight_decay,
                             lambd=lambdas[3])

        # output layer
        w_out = blocks.weight('w_out', [k[3], 10])
        b_out = blocks.bias('b_out', [10])

        layers = (conv1, conv2, fc1, fc2)

        global_step = tf.train.get_or_create_global_step()

        # Convolutional layers have feature map sparsity
        # FC layers have neuron sparsity

        # during training, the authors disable the bias as that kills any sparsitydd
        if train:
            # The goal here for convolutional layers is output feature map sparsity
            w1 = conv1.sample_weights()
            X_ = blocks.conv(X, w1, 1, None)
            X_ = blocks.relu(X_)
            X_ = blocks.pool(X_, 2, 2)

            w2 = conv2.sample_weights()
            X_ = blocks.conv(X_, w2, 1, None)
            X_ = blocks.relu(X_)
            X_ = blocks.pool(X_, 2, 2)

            # for fully connected layers we instead prune inputs in order to reduce
            # MAC operations at train time, thus the paper measures input neurons
            w3 = fc1.sample_weights()
            X_ = blocks.dense(X_, w3, None)

            w4 = fc2.sample_weights()
            X_ = blocks.dense(X_, w4, None)

            # count the number of neurons in the pruned architecture
            neurons = []
            neurons.append(
                tf.count_nonzero(tf.reduce_sum(w1, axis=[0, 1, 2]),
                                 dtype=tf.float32))
            neurons.append(
                tf.count_nonzero(tf.reduce_sum(w2, axis=[0, 1, 2]),
                                 dtype=tf.float32))
            neurons.append(
                tf.count_nonzero(tf.reduce_sum(w3, axis=[1]),
                                 dtype=tf.float32))
            neurons.append(
                tf.count_nonzero(tf.reduce_sum(w4, axis=[1]),
                                 dtype=tf.float32))

        else:
            # at test time use deterministic weights
            X_ = blocks.conv(X, conv1.weights, 1, conv1.bias)
            z1 = conv1.sample_z(tf.shape(X_)[0])
            X_ = X_ * z1
            X_ = blocks.relu(X_)
            X_ = blocks.pool(X_, 2, 2)

            X_ = blocks.conv(X_, conv2.weights, 1, conv2.bias)
            z2 = conv2.sample_z(tf.shape(X_)[0])
            X_ = X_ * z2
            X_ = blocks.relu(X_)
            X_ = blocks.pool(X_, 2, 2)

            z3 = fc1.sample_z(10000)
            X_ = tf.layers.flatten(X_) * z3
            X_ = blocks.dense(X_, fc1.weights, fc1.bias)

            z4 = fc2.sample_z(10000)
            X_ = X_ * z4
            X_ = blocks.dense(X_, fc2.weights, fc2.bias)

            # count the number of neurons in the pruned architecture
            neurons = []
            neurons.append(
                tf.count_nonzero(tf.reduce_sum(z1, axis=[0, 1, 2]),
                                 dtype=tf.float32))
            neurons.append(
                tf.count_nonzero(tf.reduce_sum(z2, axis=[0, 1, 2]),
                                 dtype=tf.float32))
            neurons.append(
                tf.count_nonzero(tf.reduce_sum(z3, axis=[0]),
                                 dtype=tf.float32))
            neurons.append(
                tf.count_nonzero(tf.reduce_sum(z4, axis=[0]),
                                 dtype=tf.float32))

        logits = blocks.dense(X_, w_out, b_out, activation=False)

        pred = tf.nn.softmax(logits)
        loss = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y))
        expected_l0 = [l.count_l0() for l in layers]
        reg = tf.reduce_sum(
            [-(1 / num_train) * l.regularization() for l in layers])
        loss = loss + reg

        correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

        optim = tf.train.AdamOptimizer(learning_rate).minimize(
            loss, global_step=global_step)

        constrain = [l.constrain_parameters() for l in layers]

        saver = tf.train.Saver()
        checkpoint = 'checkpoints/model.ckpt'
        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())
            if load:
                try:
                    saver.restore(sess,
                                  'checkpoints/model.ckpt.{}'.format(load))
                except:
                    pass

            if not train:
                a_test, l_test, n_test, g_test = sess.run(
                    [accuracy, loss, neurons, global_step],
                    feed_dict={
                        X: data.test.images,
                        y: data.test.labels
                    })
                print(
                    'Deterministic pruned architecture after {} global steps: {}'
                    .format(g_test, '-'.join([str(int(n)) for n in n_test])))
                print('Test accuracy: {}'.format(a_test))
                print('Test loss: {}'.format(l_test))
                return

            best = 0
            current_epoch = 0
            step_in_epoch = 0
            a_total, l_total = 0, 0
            a_val, l_val = 0, 0
            with tqdm(total=epochs * num_train // batch_size) as t:
                t.update(0)
                while True:
                    # print(len([n.name for n in tf.get_default_graph().as_graph_def().node]))
                    data_train, labels_train = data.train.next_batch(
                        batch_size)
                    a, l, o, s, n, expect, _ = sess.run([
                        accuracy, loss, optim, global_step, neurons,
                        expected_l0, constrain
                    ],
                                                        feed_dict={
                                                            X: data_train,
                                                            y: labels_train
                                                        })
                    epochs_completed = data.train.epochs_completed
                    total_epochs = s * batch_size // num_train
                    # grab the next batch of data
                    t.update(s - t.n)
                    a_total += a
                    l_total += l
                    step_in_epoch += 1

                    t.set_postfix(epoch=total_epochs,
                                  neurons=n,
                                  t_acc=a_total / step_in_epoch,
                                  t_loss=l_total / step_in_epoch,
                                  v_acc=a_val)

                    # check validation loss every complete epoch
                    if epochs_completed > current_epoch:
                        a_val, l_val = sess.run([accuracy, loss],
                                                feed_dict={
                                                    X: data.validation.images,
                                                    y: data.validation.labels
                                                })
                        if save:
                            if a >= best:
                                saver.save(sess, 'checkpoints/model.ckpt.best')
                                best = a_val
                            if epochs_completed % 10 == 0:
                                saver.save(
                                    sess, 'checkpoints/model.ckpt.{}'.format(
                                        epochs_completed))
                            saver.save(sess, checkpoint)
                        #  saver.save(sess, 'model.{}.ckpt'.format(current_epoch))
                        t.set_postfix(epoch=total_epochs,
                                      neurons=n,
                                      t_acc=a_total / step_in_epoch,
                                      t_loss=l_total / step_in_epoch,
                                      v_acc=a_val)

                        a_total = 0
                        l_total = 0
                        step_in_epoch = 0
                        current_epoch = epochs_completed

                        if total_epochs >= epochs:
                            break