def model_fn(features, labels, mode, params): x = tf.expand_dims(features['combined'], -1) y = tf.expand_dims(features['y'], -1) partitioned_x = tf_featurization.pad_and_partition(x, partition_size) partitioned_y = tf_featurization.pad_and_partition(y, partition_size) model = unet.Model(partitioned_x, channels_interval = 36) l2_loss, snr = enhancement.loss.snr(model.logits, partitioned_y) sdr = enhancement.loss.sdr(model.logits, partitioned_y) mae = tf.losses.absolute_difference mae_loss = mae(labels = partitioned_y, predictions = model.logits) loss = mae_loss tf.identity(loss, 'total_loss') tf.summary.scalar('total_loss', loss) tf.summary.scalar('snr', snr) tf.summary.scalar('sdr', sdr) global_step = tf.train.get_or_create_global_step() learning_rate = tf.constant(value = init_lr, shape = [], dtype = tf.float32) learning_rate = tf.train.polynomial_decay( learning_rate, global_step, epochs, end_learning_rate = 1e-6, power = 1.0, cycle = False, ) tf.summary.scalar('learning_rate', learning_rate) if mode == tf.estimator.ModeKeys.TRAIN: optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate) train_op = optimizer.minimize(loss, global_step = global_step) estimator_spec = tf.estimator.EstimatorSpec( mode = mode, loss = loss, train_op = train_op ) elif mode == tf.estimator.ModeKeys.EVAL: estimator_spec = tf.estimator.EstimatorSpec( mode = tf.estimator.ModeKeys.EVAL, loss = loss ) return estimator_spec
def __init__( self, inputs, sources=4, audio_channels=1, channels=64, depth=6, rewrite=True, use_glu=True, rescale=0.1, kernel_size=8, stride=4, growth=2.0, lstm_layers=2, context=3, partition_length=44100 * 2, norm_after_partition=False, output_shape_same_as_input=False, logging=False, kernel_initializer=ConvScaling, ): self.audio_channels = audio_channels self.sources = sources self.kernel_size = kernel_size self.context = context self.stride = stride self.depth = depth self.channels = channels self.partition_length = partition_length if use_glu: activation = glu ch_scale = 2 else: activation = tf.nn.relu ch_scale = 1 in_channels = audio_channels self.encoder, self.decoder = [], [] for index in range(depth): encoder = tf.keras.Sequential() encoder.add( tf.keras.layers.Conv1D( channels, kernel_size, stride, activation=tf.nn.relu, kernel_initializer=kernel_initializer, )) if rewrite: encoder.add( tf.keras.layers.Conv1D( ch_scale * channels, 1, activation=activation, kernel_initializer=kernel_initializer, )) self.encoder.append(encoder) if index > 0: out_channels = in_channels else: out_channels = sources * audio_channels decoder = tf.keras.Sequential() if rewrite: decoder.add( tf.keras.layers.Conv1D( ch_scale * channels, context, activation=activation, kernel_initializer=kernel_initializer, )) if index > 0: a = tf.nn.relu else: a = None decoder.add( Conv1DTranspose( out_channels, kernel_size, stride, activation=a, kernel_initializer=kernel_initializer, )) self.decoder.insert(0, decoder) in_channels = channels channels = int(growth * channels) channels = in_channels if partition_length: partitioned = pad_and_partition(inputs, self.partition_length) if norm_after_partition: mean = tf.reduce_mean(partitioned, axis=0) std = tf.math.reduce_std(partitioned, axis=0) partitioned = (partitioned - mean) / std valid_length = self.valid_length(partitioned.shape.as_list()[1]) delta = valid_length - self.partition_length padded = tf.pad( partitioned, [[0, 0], [delta // 2, delta - delta // 2], [0, 0]], 'CONSTANT', ) if lstm_layers: self.lstm = BLSTM(channels, lstm_layers) else: self.lstm = None x = padded saved = [x] for encode in self.encoder: if logging: print(x) x = encode(x) saved.append(x) if logging: print('x', x) if self.lstm: x = self.lstm(x) for decode in self.decoder: if logging: print(x) skip = center_trim(saved.pop(-1), x) x = x + skip x = decode(x) if logging: print('x', x) self.logits = x self.logits = tf.reshape(self.logits, (-1, self.sources)) if output_shape_same_as_input: self.logits = self.logits[:tf.shape(inputs)[0]]
def __init__(self, X, Y, frame_length=4096, frame_step=1024): def get_stft(X): return tf.signal.stft( X, frame_length, frame_step, window_fn=lambda frame_length, dtype: (hann_window(frame_length, periodic=True, dtype=dtype)), pad_end=True, ) stft_X = get_stft(X) stft_Y = get_stft(Y) mag_X = tf.abs(stft_X) mag_Y = tf.abs(stft_Y) angle_X = tf.math.imag(stft_X) angle_Y = tf.math.imag(stft_Y) partitioned_mag_X = tf_featurization.pad_and_partition(mag_X, 512) partitioned_angle_X = tf_featurization.pad_and_partition(angle_X, 512) params = {'conv_n_filters': [32 * (2**i) for i in range(6)]} with tf.variable_scope('model_mag'): mix_mag = tf.expand_dims(partitioned_mag_X, 3)[:, :, :-1, :] mix_mag_logits = unet.Model( mix_mag, output_mask_logit=True, dropout=0.0, training=True, params=params, ).logits mix_mag_logits = tf.squeeze(mix_mag_logits, 3) mix_mag_logits = tf.pad(mix_mag_logits, [(0, 0), (0, 0), (0, 1)], mode='CONSTANT') mix_mag_logits = tf.nn.relu(mix_mag_logits) with tf.variable_scope('model_angle'): mix_angle = tf.expand_dims(partitioned_angle_X, 3)[:, :, :-1, :] mix_angle_logits = unet.Model( mix_angle, output_mask_logit=True, dropout=0.0, training=True, params=params, ).logits mix_angle_logits = tf.squeeze(mix_angle_logits, 3) mix_angle_logits = tf.pad(mix_angle_logits, [(0, 0), (0, 0), (0, 1)], mode='CONSTANT') partitioned_mag_Y = tf_featurization.pad_and_partition(mag_Y, 512) partitioned_angle_Y = tf_featurization.pad_and_partition(angle_Y, 512) self.mag_l1 = tf.reduce_mean(tf.abs(partitioned_mag_Y - mix_mag_logits)) self.angle_l1 = tf.reduce_mean( tf.abs(partitioned_angle_Y - mix_angle_logits)) self.cost = self.mag_l1 + self.angle_l1 def get_original_shape(D, stft): instrument_mask = D old_shape = tf.shape(instrument_mask) new_shape = tf.concat( [[old_shape[0] * old_shape[1]], old_shape[2:]], axis=0) instrument_mask = tf.reshape(instrument_mask, new_shape) instrument_mask = instrument_mask[:tf.shape(stft)[0]] return instrument_mask _mag = get_original_shape(tf.expand_dims(mix_mag_logits, -1), stft_X) _angle = get_original_shape(tf.expand_dims(mix_angle_logits, -1), stft_X) stft = tf.multiply(tf.complex(_mag, 0.0), tf.exp(tf.complex(0.0, _angle))) inverse_stft_X = inverse_stft( stft[:, :, 0], frame_length, frame_step, window_fn=lambda frame_length, dtype: (hann_window(frame_length, periodic=True, dtype=dtype)), )
def __init__( self, inputs, y=None, chin=1, chout=1, hidden=48, depth=5, use_glu=True, kernel_size=8, stride=4, causal=True, resample=4, growth=2, max_hidden=10000, normalize=True, rescale=0.1, floor=1e-3, lstm_layers=2, partition_length=44100 * 2, norm_after_partition=False, logging=False, kernel_initializer=ConvScaling, ): self.depth = depth self.kernel_size = kernel_size self.stride = stride self.causal = causal self.floor = floor self.resample = resample self.normalize = normalize self.chin = chin self.chout = chout self.hidden = hidden self.partition_length = partition_length if use_glu: activation = glu ch_scale = 2 else: activation = tf.nn.relu ch_scale = 1 self.encoder, self.decoder = [], [] for index in range(depth): encoder = tf.keras.Sequential() encoder.add( tf.keras.layers.Conv1D( hidden, kernel_size, stride, activation=tf.nn.relu, kernel_initializer=kernel_initializer, )) encoder.add( tf.keras.layers.Conv1D( ch_scale * hidden, 1, activation=activation, kernel_initializer=kernel_initializer, )) self.encoder.append(encoder) decoder = tf.keras.Sequential() decoder.add( tf.keras.layers.Conv1D( ch_scale * hidden, 1, activation=activation, kernel_initializer=kernel_initializer, )) if index > 0: a = tf.nn.relu else: a = None decoder.add( Conv1DTranspose( chout, kernel_size, stride, activation=a, kernel_initializer=kernel_initializer, )) self.decoder.insert(0, decoder) chout = hidden chin = hidden hidden = min(int(growth * hidden), max_hidden) self.lstm = BLSTM(chin, bi=not causal) if self.normalize: mono = tf.reduce_mean(inputs, axis=1, keepdims=True) self.std = tf.math.reduce_std(mono, axis=0, keepdims=True) inputs = inputs / (self.floor + self.std) else: self.std = 1.0 partitioned = pad_and_partition(inputs, self.partition_length) if norm_after_partition: mean = tf.reduce_mean(partitioned, axis=0) std = tf.math.reduce_std(partitioned, axis=0) partitioned = (partitioned - mean) / std valid_length = self.valid_length(self.partition_length) delta = valid_length - self.partition_length padded = tf.pad(partitioned, [[0, 0], [0, delta], [0, 0]], 'CONSTANT') x = padded if logging: print(x) if self.resample == 2: x = upsample2(x) elif self.resample == 4: x = upsample2(x) x = upsample2(x) if logging: print(x) skips = [] for encode in self.encoder: if logging: print(x) x = encode(x) skips.append(x) if logging: print('x', x) x = self.lstm(x) for decode in self.decoder: if logging: print(x) skip = skips.pop(-1) x = x + skip[:, :tf.shape(x)[1]] x = decode(x) if self.resample == 2: x = downsample2(x) elif self.resample == 4: x = downsample2(x) x = downsample2(x) if logging: print('x', x) self.logits = x self.logits = tf.reshape(self.logits, (-1, self.chout)) if y is not None: self.logits = self.logits[:tf.shape(y)[0]] else: self.logits = self.logits[:tf.shape(inputs)[0]] self.logits = self.std * self.logits