def noise_resnet_block(inputs, kernel_size, stride, n_fmaps, scope_name): ''' Embedding network :param inputs : context input :return out : embedding vector represents the context information ''' # The transformation path path1 = conv2d(inputs, kernel_size, [1] + stride + [1], n_fmaps, FLAGS.w_std, FLAGS.b_init, False, 'SAME', scope_name + '_conv1') path1 = batch_norm(istrain, path1, scope_name + '_conv1') path1 = tf.nn.relu(path1) path1 = conv2d(path1, kernel_size, [1, 1, 1, 1], n_fmaps, FLAGS.w_std, FLAGS.b_init, True, 'SAME', scope_name + '_conv2') # The identity path n_input_channels = inputs.shape.as_list()[3] if n_input_channels == n_fmaps: path2 = inputs else: path2 = conv2d(inputs, [1, 1], [1] + stride + [1], n_fmaps, FLAGS.w_std, FLAGS.b_init, True, 'SAME', scope_name + '_transform') # Add and return assert path1.shape.as_list() == path2.shape.as_list() out = path1 + path2 out = batch_norm(istrain, out, scope_name + '_addition') out = tf.nn.relu(out) return out
def cont_embed(n, out_dim, scope_name): out = tf.constant(list(range(0, n)), dtype=tf.float32) # [n] out = tf.reshape(out, [n, 1]) # [n, 1] out = dense(out, 50, FLAGS.w_std, 0.0, False, scope_name + '_dense1') # [n, 50] out = batch_norm(istrain, out, scope_name + scope_name + '_dense1') out = tf.nn.relu(out) out = dense(out, 50, FLAGS.w_std, 0.0, False, scope_name + '_dense2') # [n, 50] out = batch_norm(istrain, out, scope_name + scope_name + '_dense2') out = tf.nn.relu(out) out = dense(out, out_dim, 0.0, 0.0, False, scope_name + '_dense3') # [n, out_dim] return out
def model(inputs, istrain): target, mixed, mixedph, targetph, pos, posph, neg, negph, noiseposcontext, noisenegcontext, location, cleanpath, noisepospath, noisenegpath, snr_pos, snr_neg = inputs nfeat = target.shape[2].value def noise_resnet_block(inputs, kernel_size, stride, n_fmaps, scope_name): ''' Embedding network :param inputs : context input :return out : embedding vector represents the context information ''' # The transformation path path1 = conv2d(inputs, kernel_size, [1] + stride + [1], n_fmaps, FLAGS.w_std, FLAGS.b_init, False, 'SAME', scope_name + '_conv1') path1 = batch_norm(istrain, path1, scope_name + '_conv1') path1 = tf.nn.relu(path1) path1 = conv2d(path1, kernel_size, [1, 1, 1, 1], n_fmaps, FLAGS.w_std, FLAGS.b_init, True, 'SAME', scope_name + '_conv2') # The identity path n_input_channels = inputs.shape.as_list()[3] if n_input_channels == n_fmaps: path2 = inputs else: path2 = conv2d(inputs, [1, 1], [1] + stride + [1], n_fmaps, FLAGS.w_std, FLAGS.b_init, True, 'SAME', scope_name + '_transform') # Add and return assert path1.shape.as_list() == path2.shape.as_list() out = path1 + path2 out = batch_norm(istrain, out, scope_name + '_addition') out = tf.nn.relu(out) return out def resnet_block(inputs, noiseposemb, noisenegemb, kernel_size, stride, n_fmaps, scope_name): ''' Residual block to process noisy signals, with injection of embedding vectors :param inputs : input feature maps :param noiseposemb : positive embedding vector :param noisenegemb : negative embedding vector :param n_fmaps : number of channels :return out : output feature maps ''' def cont_embed(n, out_dim, scope_name): out = tf.constant(list(range(0, n)), dtype=tf.float32) # [n] out = tf.reshape(out, [n, 1]) # [n, 1] out = dense(out, 50, FLAGS.w_std, 0.0, False, scope_name + '_dense1') # [n, 50] out = batch_norm(istrain, out, scope_name + scope_name + '_dense1') out = tf.nn.relu(out) out = dense(out, 50, FLAGS.w_std, 0.0, False, scope_name + '_dense2') # [n, 50] out = batch_norm(istrain, out, scope_name + scope_name + '_dense2') out = tf.nn.relu(out) out = dense(out, out_dim, 0.0, 0.0, False, scope_name + '_dense3') # [n, out_dim] return out def process_noise_t_f(match_to, scope_name): n_fmaps = match_to.shape[3].value # Project the noise to fit the conv noisepos_proj = dense(noiseposemb, n_fmaps, 0.0, 0.0, True, scope_name + '_noise_pos_emb') # [mb, n_fmaps] noisepos_proj = tf.expand_dims(noisepos_proj, 1) noisepos_proj = tf.expand_dims(noisepos_proj, 1) # [mb, 1, 1, n_fmaps] noiseneg_proj = dense(noisenegemb, n_fmaps, 0.0, 0.0, True, scope_name + '_noise_neg_emb') # [mb, n_fmaps] noiseneg_proj = tf.expand_dims(noiseneg_proj, 1) noiseneg_proj = tf.expand_dims(noiseneg_proj, 1) # [mb, 1, 1, n_fmaps] # Get the time and frequency embedding ts, fs = match_to.shape[1].value, match_to.shape[2].value tout = cont_embed(ts, n_fmaps, scope_name + '_temb') # [ts, n_fmaps] tout = tf.expand_dims(tout, 1) tout = tf.expand_dims(tout, 0) # [1, time, 1, n_fmaps] fout = cont_embed(fs, n_fmaps, scope_name + '_femb') # [fs, n_fmaps] fout = tf.expand_dims(fout, 0) fout = tf.expand_dims(fout, 0) # [1, 1, freq, n_fmaps] return noisepos_proj, noiseneg_proj, tout, fout # The transformation path path1 = conv2d(inputs, [kernel_size, kernel_size], [1, stride, stride, 1], n_fmaps, FLAGS.w_std, FLAGS.b_init, False, 'SAME', scope_name + '_conv1') # [mb, time, freq, n_fmaps] noisepos_proj1, noiseneg_proj1, tout1, fout1 = process_noise_t_f(path1, scope_name + '_conv1') path1 = path1 + noisepos_proj1 + noiseneg_proj1 + tout1 + fout1 path1 = batch_norm(istrain, path1, scope_name + '_conv1') path1 = tf.nn.relu(path1) path1 = conv2d(path1, [kernel_size, kernel_size], [1,1,1,1], n_fmaps, FLAGS.w_std, FLAGS.b_init, True, 'SAME', scope_name + '_conv2') noisepos_proj2, noiseneg_proj2, tout2, fout2 = process_noise_t_f(path1, scope_name + '_conv2') path1 = path1 + noisepos_proj2 + noiseneg_proj2 + tout2 + fout2 # The identity path n_input_channels = inputs.shape.as_list()[3] if n_input_channels == n_fmaps: path2 = inputs else: path2 = conv2d(inputs, [1, 1], [1, stride, stride, 1], n_fmaps, FLAGS.w_std, FLAGS.b_init, True, 'SAME', scope_name + '_transform') # Add and return assert path1.shape.as_list() == path2.shape.as_list() out = path1 + path2 out = batch_norm(istrain, out, scope_name + '_addition') out = tf.nn.relu(out) return out # The positive noise embedding with tf.variable_scope('embedding'): nout = None nout = noiseposcontext # [mb, noise frames, 201] nout = tf.expand_dims(nout, 3) nout = noise_resnet_block(nout, [8, 4], [3, 2], 64, 'noise_resblock1_1') # [mb, noise frames, 201, 64] nout = noise_resnet_block(nout, [8, 4], [3, 2], 128, 'noise_resblock2_1') # [mb, noise frames / 2, 201 / 2, 64] nout = noise_resnet_block(nout, [4, 4], [1, 1], 256, 'noise_resblock3_1') # [mb, noise frames / 4, 201 / 4, 64] nout = noise_resnet_block(nout, [4, 4], [1, 2], 512, 'noise_resblock4_1') # [mb, noise frames / 8, 201 / 8, 512] nout = tf.nn.avg_pool(nout, [1, nout.shape[1].value, nout.shape[2].value, 1], [1, 1, 1, 1], 'VALID') # [mb, 1, 1, 512] assert nout.shape.as_list()[1:3] == [1, 1] noiseposemb = nout[:, 0, 0, :] # [mb, 512] # The negative noise embedding with tf.variable_scope('embedding', reuse=True): nout = None nout = noisenegcontext # [mb, noise frames, 201] nout = tf.expand_dims(nout, 3) nout = noise_resnet_block(nout, [8, 4], [3, 2], 64, 'noise_resblock1_1') # [mb, noise frames, 201, 64] nout = noise_resnet_block(nout, [8, 4], [3, 2], 128, 'noise_resblock2_1') # [mb, noise frames / 2, 201 / 2, 64] nout = noise_resnet_block(nout, [4, 4], [1, 1], 256, 'noise_resblock3_1') # [mb, noise frames / 4, 201 / 4, 64] nout = noise_resnet_block(nout, [4, 4], [1, 2], 512, 'noise_resblock4_1') # [mb, noise frames / 8, 201 / 8, 512] nout = tf.nn.avg_pool(nout, [1, nout.shape[1].value, nout.shape[2].value, 1], [1, 1, 1, 1], 'VALID') # [mb, 1, 1, 512] assert nout.shape.as_list()[1:3] == [1, 1] noisenegemb = nout[:, 0, 0, :] # [mb, 512] # Processing the mixed signal out = mixed # [mb, context frames, 201] out = tf.expand_dims(out, 3) out = resnet_block(out, noiseposemb, noisenegemb, 4, 1, 64, 'resblock1_1') out = resnet_block(out, noiseposemb, noisenegemb, 4, 1, 64, 'resblock1_2') out = resnet_block(out, noiseposemb, noisenegemb, 4, 2, 128, 'resblock2_1') out = resnet_block(out, noiseposemb, noisenegemb, 4, 1, 128, 'resblock2_2') out = resnet_block(out, noiseposemb, noisenegemb, 3, 2, 256, 'resblock3_1') out = resnet_block(out, noiseposemb, noisenegemb, 3, 1, 256, 'resblock3_2') out = resnet_block(out, noiseposemb, noisenegemb, 3, 2, 512, 'resblock4_1') out = resnet_block(out, noiseposemb, noisenegemb, 3, 1, 512, 'resblock4_2') # [mb, context frames / 8, 201 / 8, 512] # final layers out = conv2d(out, [out.shape[1].value, 1], [1, 1, 1, 1], 512, FLAGS.w_std, FLAGS.b_init, False, 'VALID', 'last_conv') # [mb, 1, 201 / 8, 512] out = batch_norm(istrain, out, 'last_conv') out = tf.nn.relu(out) out = flatten(out) # [mb, (201 / 8) * 512] out = dense(out, nfeat, 0.0, 0.0, True, 'last_dense') # [mb, 201] mixed_central = mixed[:, FLAGS.window_frames // 2, :] # [mb, 201] pos_central = pos[:, FLAGS.window_frames // 2, :] # [mb, 201] neg_central = neg[:, FLAGS.window_frames // 2, :] # [mb, 201] denoised = mixed_central + out # [mb, 201] # Loss se = tf.square(denoised - target[:, 0, :]) # [mb, 201] imp_factor = np.linspace(2, 1, nfeat, dtype=np.float32).reshape((1, nfeat)) example_loss = tf.reduce_mean(se * tf.constant(imp_factor), axis=1) loss = tf.reduce_mean(example_loss) monitors = {'loss': loss} outputs = {'loss': example_loss, 'mixed': mixed_central, 'denoised': denoised, 'target': target[:, 0, :], 'mixedph': mixedph[:, 0, :], 'targetph': targetph[:, 0, :], 'pos': pos_central, 'neg': neg_central, 'posph': posph[:, 0, :], 'negph': negph[:, 0, :], 'location': location, 'cleanpath': cleanpath, 'noisepospath': noisepospath, 'noisenegpath': noisenegpath, 'snr_pos': snr_pos, 'snr_neg': snr_neg} return loss, monitors, outputs
def resnet_block(inputs, noiseposemb, noisenegemb, kernel_size, stride, n_fmaps, scope_name): ''' Residual block to process noisy signals, with injection of embedding vectors :param inputs : input feature maps :param noiseposemb : positive embedding vector :param noisenegemb : negative embedding vector :param n_fmaps : number of channels :return out : output feature maps ''' def cont_embed(n, out_dim, scope_name): out = tf.constant(list(range(0, n)), dtype=tf.float32) # [n] out = tf.reshape(out, [n, 1]) # [n, 1] out = dense(out, 50, FLAGS.w_std, 0.0, False, scope_name + '_dense1') # [n, 50] out = batch_norm(istrain, out, scope_name + scope_name + '_dense1') out = tf.nn.relu(out) out = dense(out, 50, FLAGS.w_std, 0.0, False, scope_name + '_dense2') # [n, 50] out = batch_norm(istrain, out, scope_name + scope_name + '_dense2') out = tf.nn.relu(out) out = dense(out, out_dim, 0.0, 0.0, False, scope_name + '_dense3') # [n, out_dim] return out def process_noise_t_f(match_to, scope_name): n_fmaps = match_to.shape[3].value # Project the noise to fit the conv noisepos_proj = dense(noiseposemb, n_fmaps, 0.0, 0.0, True, scope_name + '_noise_pos_emb') # [mb, n_fmaps] noisepos_proj = tf.expand_dims(noisepos_proj, 1) noisepos_proj = tf.expand_dims(noisepos_proj, 1) # [mb, 1, 1, n_fmaps] noiseneg_proj = dense(noisenegemb, n_fmaps, 0.0, 0.0, True, scope_name + '_noise_neg_emb') # [mb, n_fmaps] noiseneg_proj = tf.expand_dims(noiseneg_proj, 1) noiseneg_proj = tf.expand_dims(noiseneg_proj, 1) # [mb, 1, 1, n_fmaps] # Get the time and frequency embedding ts, fs = match_to.shape[1].value, match_to.shape[2].value tout = cont_embed(ts, n_fmaps, scope_name + '_temb') # [ts, n_fmaps] tout = tf.expand_dims(tout, 1) tout = tf.expand_dims(tout, 0) # [1, time, 1, n_fmaps] fout = cont_embed(fs, n_fmaps, scope_name + '_femb') # [fs, n_fmaps] fout = tf.expand_dims(fout, 0) fout = tf.expand_dims(fout, 0) # [1, 1, freq, n_fmaps] return noisepos_proj, noiseneg_proj, tout, fout # The transformation path path1 = conv2d(inputs, [kernel_size, kernel_size], [1, stride, stride, 1], n_fmaps, FLAGS.w_std, FLAGS.b_init, False, 'SAME', scope_name + '_conv1') # [mb, time, freq, n_fmaps] noisepos_proj1, noiseneg_proj1, tout1, fout1 = process_noise_t_f(path1, scope_name + '_conv1') path1 = path1 + noisepos_proj1 + noiseneg_proj1 + tout1 + fout1 path1 = batch_norm(istrain, path1, scope_name + '_conv1') path1 = tf.nn.relu(path1) path1 = conv2d(path1, [kernel_size, kernel_size], [1,1,1,1], n_fmaps, FLAGS.w_std, FLAGS.b_init, True, 'SAME', scope_name + '_conv2') noisepos_proj2, noiseneg_proj2, tout2, fout2 = process_noise_t_f(path1, scope_name + '_conv2') path1 = path1 + noisepos_proj2 + noiseneg_proj2 + tout2 + fout2 # The identity path n_input_channels = inputs.shape.as_list()[3] if n_input_channels == n_fmaps: path2 = inputs else: path2 = conv2d(inputs, [1, 1], [1, stride, stride, 1], n_fmaps, FLAGS.w_std, FLAGS.b_init, True, 'SAME', scope_name + '_transform') # Add and return assert path1.shape.as_list() == path2.shape.as_list() out = path1 + path2 out = batch_norm(istrain, out, scope_name + '_addition') out = tf.nn.relu(out) return out