def process_noise_t_f(match_to, scope_name): n_fmaps = match_to.shape[3].value # Project the noise to fit the conv noise_proj = dense(noiseemb, n_fmaps, 0.0, 0.0, True, scope_name + '_noise_emb') # [mb, n_fmaps] noise_proj = tf.expand_dims(noise_proj, 1) noise_proj = tf.expand_dims(noise_proj, 1) # [mb, 1, 1, n_fmaps] clean_proj = dense(cleanemb, n_fmaps, 0.0, 0.0, True, scope_name + '_clean_emb') # [mb, n_fmaps] clean_proj = tf.expand_dims(clean_proj, 1) clean_proj = tf.expand_dims(clean_proj, 1) # [mb, 1, 1, n_fmaps] # Get the time and frequency embedding ts, fs = match_to.shape[1].value, match_to.shape[2].value tout = cont_embed(ts, n_fmaps, scope_name + '_temb') # [ts, n_fmaps] tout = tf.expand_dims(tout, 1) tout = tf.expand_dims(tout, 0) # [1, time, 1, n_fmaps] fout = cont_embed(fs, n_fmaps, scope_name + '_femb') # [fs, n_fmaps] fout = tf.expand_dims(fout, 0) fout = tf.expand_dims(fout, 0) # [1, 1, freq, n_fmaps] return noise_proj, clean_proj, tout, fout
def cont_embed(n, out_dim, scope_name): out = tf.constant(list(range(0, n)), dtype=tf.float32) # [n] out = tf.reshape(out, [n, 1]) # [n, 1] out = dense(out, 50, FLAGS.w_std, 0.0, False, scope_name + '_dense1') # [n, 50] out = batch_norm(istrain, out, scope_name + scope_name + '_dense1') out = tf.nn.relu(out) out = dense(out, 50, FLAGS.w_std, 0.0, False, scope_name + '_dense2') # [n, 50] out = batch_norm(istrain, out, scope_name + scope_name + '_dense2') out = tf.nn.relu(out) out = dense(out, out_dim, 0.0, 0.0, False, scope_name + '_dense3') # [n, out_dim] return out
def model(inputs, istrain): target, mixed, mixedph, targetph, pos, posph, neg, negph, noiseposcontext, noisenegcontext, location, cleanpath, noisepospath, noisenegpath, snr_pos, snr_neg = inputs nfeat = target.shape[2].value def noise_resnet_block(inputs, kernel_size, stride, n_fmaps, scope_name): ''' Embedding network :param inputs : context input :return out : embedding vector represents the context information ''' # The transformation path path1 = conv2d(inputs, kernel_size, [1] + stride + [1], n_fmaps, FLAGS.w_std, FLAGS.b_init, False, 'SAME', scope_name + '_conv1') path1 = batch_norm(istrain, path1, scope_name + '_conv1') path1 = tf.nn.relu(path1) path1 = conv2d(path1, kernel_size, [1, 1, 1, 1], n_fmaps, FLAGS.w_std, FLAGS.b_init, True, 'SAME', scope_name + '_conv2') # The identity path n_input_channels = inputs.shape.as_list()[3] if n_input_channels == n_fmaps: path2 = inputs else: path2 = conv2d(inputs, [1, 1], [1] + stride + [1], n_fmaps, FLAGS.w_std, FLAGS.b_init, True, 'SAME', scope_name + '_transform') # Add and return assert path1.shape.as_list() == path2.shape.as_list() out = path1 + path2 out = batch_norm(istrain, out, scope_name + '_addition') out = tf.nn.relu(out) return out def resnet_block(inputs, noiseposemb, noisenegemb, kernel_size, stride, n_fmaps, scope_name): ''' Residual block to process noisy signals, with injection of embedding vectors :param inputs : input feature maps :param noiseposemb : positive embedding vector :param noisenegemb : negative embedding vector :param n_fmaps : number of channels :return out : output feature maps ''' def cont_embed(n, out_dim, scope_name): out = tf.constant(list(range(0, n)), dtype=tf.float32) # [n] out = tf.reshape(out, [n, 1]) # [n, 1] out = dense(out, 50, FLAGS.w_std, 0.0, False, scope_name + '_dense1') # [n, 50] out = batch_norm(istrain, out, scope_name + scope_name + '_dense1') out = tf.nn.relu(out) out = dense(out, 50, FLAGS.w_std, 0.0, False, scope_name + '_dense2') # [n, 50] out = batch_norm(istrain, out, scope_name + scope_name + '_dense2') out = tf.nn.relu(out) out = dense(out, out_dim, 0.0, 0.0, False, scope_name + '_dense3') # [n, out_dim] return out def process_noise_t_f(match_to, scope_name): n_fmaps = match_to.shape[3].value # Project the noise to fit the conv noisepos_proj = dense(noiseposemb, n_fmaps, 0.0, 0.0, True, scope_name + '_noise_pos_emb') # [mb, n_fmaps] noisepos_proj = tf.expand_dims(noisepos_proj, 1) noisepos_proj = tf.expand_dims(noisepos_proj, 1) # [mb, 1, 1, n_fmaps] noiseneg_proj = dense(noisenegemb, n_fmaps, 0.0, 0.0, True, scope_name + '_noise_neg_emb') # [mb, n_fmaps] noiseneg_proj = tf.expand_dims(noiseneg_proj, 1) noiseneg_proj = tf.expand_dims(noiseneg_proj, 1) # [mb, 1, 1, n_fmaps] # Get the time and frequency embedding ts, fs = match_to.shape[1].value, match_to.shape[2].value tout = cont_embed(ts, n_fmaps, scope_name + '_temb') # [ts, n_fmaps] tout = tf.expand_dims(tout, 1) tout = tf.expand_dims(tout, 0) # [1, time, 1, n_fmaps] fout = cont_embed(fs, n_fmaps, scope_name + '_femb') # [fs, n_fmaps] fout = tf.expand_dims(fout, 0) fout = tf.expand_dims(fout, 0) # [1, 1, freq, n_fmaps] return noisepos_proj, noiseneg_proj, tout, fout # The transformation path path1 = conv2d(inputs, [kernel_size, kernel_size], [1, stride, stride, 1], n_fmaps, FLAGS.w_std, FLAGS.b_init, False, 'SAME', scope_name + '_conv1') # [mb, time, freq, n_fmaps] noisepos_proj1, noiseneg_proj1, tout1, fout1 = process_noise_t_f(path1, scope_name + '_conv1') path1 = path1 + noisepos_proj1 + noiseneg_proj1 + tout1 + fout1 path1 = batch_norm(istrain, path1, scope_name + '_conv1') path1 = tf.nn.relu(path1) path1 = conv2d(path1, [kernel_size, kernel_size], [1,1,1,1], n_fmaps, FLAGS.w_std, FLAGS.b_init, True, 'SAME', scope_name + '_conv2') noisepos_proj2, noiseneg_proj2, tout2, fout2 = process_noise_t_f(path1, scope_name + '_conv2') path1 = path1 + noisepos_proj2 + noiseneg_proj2 + tout2 + fout2 # The identity path n_input_channels = inputs.shape.as_list()[3] if n_input_channels == n_fmaps: path2 = inputs else: path2 = conv2d(inputs, [1, 1], [1, stride, stride, 1], n_fmaps, FLAGS.w_std, FLAGS.b_init, True, 'SAME', scope_name + '_transform') # Add and return assert path1.shape.as_list() == path2.shape.as_list() out = path1 + path2 out = batch_norm(istrain, out, scope_name + '_addition') out = tf.nn.relu(out) return out # The positive noise embedding with tf.variable_scope('embedding'): nout = None nout = noiseposcontext # [mb, noise frames, 201] nout = tf.expand_dims(nout, 3) nout = noise_resnet_block(nout, [8, 4], [3, 2], 64, 'noise_resblock1_1') # [mb, noise frames, 201, 64] nout = noise_resnet_block(nout, [8, 4], [3, 2], 128, 'noise_resblock2_1') # [mb, noise frames / 2, 201 / 2, 64] nout = noise_resnet_block(nout, [4, 4], [1, 1], 256, 'noise_resblock3_1') # [mb, noise frames / 4, 201 / 4, 64] nout = noise_resnet_block(nout, [4, 4], [1, 2], 512, 'noise_resblock4_1') # [mb, noise frames / 8, 201 / 8, 512] nout = tf.nn.avg_pool(nout, [1, nout.shape[1].value, nout.shape[2].value, 1], [1, 1, 1, 1], 'VALID') # [mb, 1, 1, 512] assert nout.shape.as_list()[1:3] == [1, 1] noiseposemb = nout[:, 0, 0, :] # [mb, 512] # The negative noise embedding with tf.variable_scope('embedding', reuse=True): nout = None nout = noisenegcontext # [mb, noise frames, 201] nout = tf.expand_dims(nout, 3) nout = noise_resnet_block(nout, [8, 4], [3, 2], 64, 'noise_resblock1_1') # [mb, noise frames, 201, 64] nout = noise_resnet_block(nout, [8, 4], [3, 2], 128, 'noise_resblock2_1') # [mb, noise frames / 2, 201 / 2, 64] nout = noise_resnet_block(nout, [4, 4], [1, 1], 256, 'noise_resblock3_1') # [mb, noise frames / 4, 201 / 4, 64] nout = noise_resnet_block(nout, [4, 4], [1, 2], 512, 'noise_resblock4_1') # [mb, noise frames / 8, 201 / 8, 512] nout = tf.nn.avg_pool(nout, [1, nout.shape[1].value, nout.shape[2].value, 1], [1, 1, 1, 1], 'VALID') # [mb, 1, 1, 512] assert nout.shape.as_list()[1:3] == [1, 1] noisenegemb = nout[:, 0, 0, :] # [mb, 512] # Processing the mixed signal out = mixed # [mb, context frames, 201] out = tf.expand_dims(out, 3) out = resnet_block(out, noiseposemb, noisenegemb, 4, 1, 64, 'resblock1_1') out = resnet_block(out, noiseposemb, noisenegemb, 4, 1, 64, 'resblock1_2') out = resnet_block(out, noiseposemb, noisenegemb, 4, 2, 128, 'resblock2_1') out = resnet_block(out, noiseposemb, noisenegemb, 4, 1, 128, 'resblock2_2') out = resnet_block(out, noiseposemb, noisenegemb, 3, 2, 256, 'resblock3_1') out = resnet_block(out, noiseposemb, noisenegemb, 3, 1, 256, 'resblock3_2') out = resnet_block(out, noiseposemb, noisenegemb, 3, 2, 512, 'resblock4_1') out = resnet_block(out, noiseposemb, noisenegemb, 3, 1, 512, 'resblock4_2') # [mb, context frames / 8, 201 / 8, 512] # final layers out = conv2d(out, [out.shape[1].value, 1], [1, 1, 1, 1], 512, FLAGS.w_std, FLAGS.b_init, False, 'VALID', 'last_conv') # [mb, 1, 201 / 8, 512] out = batch_norm(istrain, out, 'last_conv') out = tf.nn.relu(out) out = flatten(out) # [mb, (201 / 8) * 512] out = dense(out, nfeat, 0.0, 0.0, True, 'last_dense') # [mb, 201] mixed_central = mixed[:, FLAGS.window_frames // 2, :] # [mb, 201] pos_central = pos[:, FLAGS.window_frames // 2, :] # [mb, 201] neg_central = neg[:, FLAGS.window_frames // 2, :] # [mb, 201] denoised = mixed_central + out # [mb, 201] # Loss se = tf.square(denoised - target[:, 0, :]) # [mb, 201] imp_factor = np.linspace(2, 1, nfeat, dtype=np.float32).reshape((1, nfeat)) example_loss = tf.reduce_mean(se * tf.constant(imp_factor), axis=1) loss = tf.reduce_mean(example_loss) monitors = {'loss': loss} outputs = {'loss': example_loss, 'mixed': mixed_central, 'denoised': denoised, 'target': target[:, 0, :], 'mixedph': mixedph[:, 0, :], 'targetph': targetph[:, 0, :], 'pos': pos_central, 'neg': neg_central, 'posph': posph[:, 0, :], 'negph': negph[:, 0, :], 'location': location, 'cleanpath': cleanpath, 'noisepospath': noisepospath, 'noisenegpath': noisenegpath, 'snr_pos': snr_pos, 'snr_neg': snr_neg} return loss, monitors, outputs
def __init__(self, data, train=False, save=False, load=False): epochs = 2000 learning_rate = .001 weight_decay = .0005 batch_size = 100 early_stop = False num_train = 55000 f = [5, 5] k = [20, 50, 800, 500] X = tf.placeholder(tf.float32, [None, 28, 28, 1]) y = tf.placeholder(tf.float32, [None, 10]) lambdas = [10, .5, .1, .1, .1] # lambdas = [1 / num_train for l in lambdas] # lambdas = [1., 1., 1., 1., 1.] # conv1 conv1 = blocks.L0Conv2d('conv1', [f[0], f[0], 1, k[0]], weight_decay=weight_decay, lambd=lambdas[0]) # conv2 conv2 = blocks.L0Conv2d('conv2', [f[1], f[1], k[0], k[1]], weight_decay=weight_decay, lambd=lambdas[1]) # fc1, after 2 maxpools fc1 = blocks.L0Dense('fc1', [7 * 7 * k[1], k[2]], weight_decay=weight_decay, lambd=lambdas[2]) # fc2 fc2 = blocks.L0Dense('fc2', [k[2], k[3]], weight_decay=weight_decay, lambd=lambdas[3]) # output layer w_out = blocks.weight('w_out', [k[3], 10]) b_out = blocks.bias('b_out', [10]) layers = (conv1, conv2, fc1, fc2) global_step = tf.train.get_or_create_global_step() # Convolutional layers have feature map sparsity # FC layers have neuron sparsity # during training, the authors disable the bias as that kills any sparsitydd if train: # The goal here for convolutional layers is output feature map sparsity w1 = conv1.sample_weights() X_ = blocks.conv(X, w1, 1, None) X_ = blocks.relu(X_) X_ = blocks.pool(X_, 2, 2) w2 = conv2.sample_weights() X_ = blocks.conv(X_, w2, 1, None) X_ = blocks.relu(X_) X_ = blocks.pool(X_, 2, 2) # for fully connected layers we instead prune inputs in order to reduce # MAC operations at train time, thus the paper measures input neurons w3 = fc1.sample_weights() X_ = blocks.dense(X_, w3, None) w4 = fc2.sample_weights() X_ = blocks.dense(X_, w4, None) # count the number of neurons in the pruned architecture neurons = [] neurons.append( tf.count_nonzero(tf.reduce_sum(w1, axis=[0, 1, 2]), dtype=tf.float32)) neurons.append( tf.count_nonzero(tf.reduce_sum(w2, axis=[0, 1, 2]), dtype=tf.float32)) neurons.append( tf.count_nonzero(tf.reduce_sum(w3, axis=[1]), dtype=tf.float32)) neurons.append( tf.count_nonzero(tf.reduce_sum(w4, axis=[1]), dtype=tf.float32)) else: # at test time use deterministic weights X_ = blocks.conv(X, conv1.weights, 1, conv1.bias) z1 = conv1.sample_z(tf.shape(X_)[0]) X_ = X_ * z1 X_ = blocks.relu(X_) X_ = blocks.pool(X_, 2, 2) X_ = blocks.conv(X_, conv2.weights, 1, conv2.bias) z2 = conv2.sample_z(tf.shape(X_)[0]) X_ = X_ * z2 X_ = blocks.relu(X_) X_ = blocks.pool(X_, 2, 2) z3 = fc1.sample_z(10000) X_ = tf.layers.flatten(X_) * z3 X_ = blocks.dense(X_, fc1.weights, fc1.bias) z4 = fc2.sample_z(10000) X_ = X_ * z4 X_ = blocks.dense(X_, fc2.weights, fc2.bias) # count the number of neurons in the pruned architecture neurons = [] neurons.append( tf.count_nonzero(tf.reduce_sum(z1, axis=[0, 1, 2]), dtype=tf.float32)) neurons.append( tf.count_nonzero(tf.reduce_sum(z2, axis=[0, 1, 2]), dtype=tf.float32)) neurons.append( tf.count_nonzero(tf.reduce_sum(z3, axis=[0]), dtype=tf.float32)) neurons.append( tf.count_nonzero(tf.reduce_sum(z4, axis=[0]), dtype=tf.float32)) logits = blocks.dense(X_, w_out, b_out, activation=False) pred = tf.nn.softmax(logits) loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y)) expected_l0 = [l.count_l0() for l in layers] reg = tf.reduce_sum( [-(1 / num_train) * l.regularization() for l in layers]) loss = loss + reg correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) optim = tf.train.AdamOptimizer(learning_rate).minimize( loss, global_step=global_step) constrain = [l.constrain_parameters() for l in layers] saver = tf.train.Saver() checkpoint = 'checkpoints/model.ckpt' with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) if load: try: saver.restore(sess, 'checkpoints/model.ckpt.{}'.format(load)) except: pass if not train: a_test, l_test, n_test, g_test = sess.run( [accuracy, loss, neurons, global_step], feed_dict={ X: data.test.images, y: data.test.labels }) print( 'Deterministic pruned architecture after {} global steps: {}' .format(g_test, '-'.join([str(int(n)) for n in n_test]))) print('Test accuracy: {}'.format(a_test)) print('Test loss: {}'.format(l_test)) return best = 0 current_epoch = 0 step_in_epoch = 0 a_total, l_total = 0, 0 a_val, l_val = 0, 0 with tqdm(total=epochs * num_train // batch_size) as t: t.update(0) while True: # print(len([n.name for n in tf.get_default_graph().as_graph_def().node])) data_train, labels_train = data.train.next_batch( batch_size) a, l, o, s, n, expect, _ = sess.run([ accuracy, loss, optim, global_step, neurons, expected_l0, constrain ], feed_dict={ X: data_train, y: labels_train }) epochs_completed = data.train.epochs_completed total_epochs = s * batch_size // num_train # grab the next batch of data t.update(s - t.n) a_total += a l_total += l step_in_epoch += 1 t.set_postfix(epoch=total_epochs, neurons=n, t_acc=a_total / step_in_epoch, t_loss=l_total / step_in_epoch, v_acc=a_val) # check validation loss every complete epoch if epochs_completed > current_epoch: a_val, l_val = sess.run([accuracy, loss], feed_dict={ X: data.validation.images, y: data.validation.labels }) if save: if a >= best: saver.save(sess, 'checkpoints/model.ckpt.best') best = a_val if epochs_completed % 10 == 0: saver.save( sess, 'checkpoints/model.ckpt.{}'.format( epochs_completed)) saver.save(sess, checkpoint) # saver.save(sess, 'model.{}.ckpt'.format(current_epoch)) t.set_postfix(epoch=total_epochs, neurons=n, t_acc=a_total / step_in_epoch, t_loss=l_total / step_in_epoch, v_acc=a_val) a_total = 0 l_total = 0 step_in_epoch = 0 current_epoch = epochs_completed if total_epochs >= epochs: break