def bottleneck_ops(self, x_enc, use_audio=True): if len(x_enc) == 0: return None bottleneck = [] audio_sz = x_enc[AUDIO][-1].get_shape().as_list() for k in [AUDIO, VIDEO, FLOW]: if k == AUDIO and not use_audio: continue if k in x_enc: x = x_enc[k][-1] if k == AUDIO else x_enc[k] print(' * {:15s} | {:20s} | {:10s}'.format( k + '-feats', str(x.get_shape()), str(x.dtype))) if k != AUDIO: name = k + '-fc-red' x = tfw.fully_connected(x, 128, activation_fn=tf.nn.relu, name=name) print(' * {:15s} | {:20s} | {:10s}'.format( name, str(x.get_shape()), str(x.dtype))) sz = x.get_shape().as_list() out_shape = (sz[0], sz[1], sz[2] * sz[3]) if k == AUDIO else (sz[0], 1, sz[1] * sz[2] * sz[3]) x = tf.reshape(x, out_shape) print(' * {:15s} | {:20s} | {:10s}'.format( k + '-reshape', str(x.get_shape()), str(x.dtype))) name = k + '-fc' n_units = 1024 if k == AUDIO else 512 x = tfw.fully_connected(x, n_units, activation_fn=tf.nn.relu, name=name) print(' * {:15s} | {:20s} | {:10s}'.format( name, str(x.get_shape()), str(x.dtype))) if k in [VIDEO, FLOW]: x = tf.tile(x, (1, audio_sz[1], 1)) print(' * {:15s} | {:20s} | {:10s}'.format( k + ' tile', str(x.get_shape()), str(x.dtype))) bottleneck.append(x) bottleneck = tf.concat(bottleneck, 2) print(' * {:15s} | {:20s} | {:10s}'.format('Concat', str(bottleneck.get_shape()), str(bottleneck.dtype))) return bottleneck
def localization_ops(self, x): num_out = (self.ambi_order + 1)**2 - self.ambi_order**2 num_in = self.ambi_order**2 # Localization for i, u in enumerate(self.params.loc_fc_units): name = 'fc{}'.format(i + 1) x = tfw.fully_connected(x, u, activation_fn=tf.nn.relu, name=name) print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()), str(x.dtype))) # Compute localization weights name = 'fc{}'.format(len(self.params.loc_fc_units) + 1) x = tfw.fully_connected( x, num_out * num_in * (self.params.sep_num_tracks + 1), activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=0.001), weight_decay=0, name=name) # BS x NF x NIN x NOUT sz = x.get_shape().as_list() x = tf.reshape( x, (sz[0], sz[1], num_out, num_in, self.params.sep_num_tracks + 1)) print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()), str(x.dtype))) sz = x.get_shape().as_list() x = tf.tile(tf.expand_dims(x, 2), (1, 1, self.snd_dur / sz[1], 1, 1, 1)) x = tf.reshape(x, (sz[0], self.snd_dur, sz[2], sz[3], sz[4])) print(' * {:15s} | {:20s} | {:10s}'.format('Tile', str(x.get_shape()), str(x.dtype))) weights = x[:, :, :, :, :-1] print(' * {:15s} | {:20s} | {:10s}'.format('weights', str(weights.get_shape()), str(weights.dtype))) biases = x[:, :, :, :, -1] print(' * {:15s} | {:20s} | {:10s}'.format('biases', str(biases.get_shape()), str(biases.dtype))) return weights, biases
def inference_ops(self, x, is_training=True, for_imagenet=True, spatial_squeeze=True, truncate_at=None, reuse=False): inp_dims = x.get_shape() assert inp_dims.ndims in (3, 4) if inp_dims.ndims == 3: x = tf.expand_dims(x, 0) ends = OrderedDict() # Scale 1 name = 'conv1' ends[name] = x = tfw.conv_2d(x, 64, 7, 2, use_bias=False, use_batch_norm=True, activation_fn=tf.nn.relu, is_training=is_training, reuse=reuse, name=name) if truncate_at == name: return x, ends name = 'pool1' ends[name] = x = tfw.max_pool2d(x, 3, 2, padding='SAME', name=name) if truncate_at == name: return x, ends # Scale 2 for i, c in enumerate(string.ascii_lowercase[:3]): name = 'res2' + c ends[name] = x = self.block(x, is_training, 64, 64, 256, None if i > 0 else 256, downsample=False, reuse=reuse, name=name) if truncate_at == name: return x, ends # Scale 3 name = 'res3a' ends[name] = x = self.block(x, is_training, 128, 128, 512, 512, downsample=True, reuse=reuse, name=name) if truncate_at == name: return x, ends for i in range(7): name = 'res3b%d' % (i + 1) ends[name] = x = self.block(x, is_training, 128, 128, 512, None, downsample=False, reuse=reuse, name=name) if truncate_at == name: return x, ends # Scale 4 name = 'res4a' ends[name] = x = self.block(x, is_training, 256, 256, 1024, 1024, downsample=True, reuse=reuse, name=name) if truncate_at == name: return x, ends for i in range(35): name = 'res4b%d' % (i + 1) ends[name] = x = self.block(x, is_training, 256, 256, 1024, None, downsample=False, reuse=reuse, name=name) if truncate_at == name: return x, ends # Scale 5 for i, c in enumerate(string.ascii_lowercase[:3]): name = 'res5' + c ends[name] = x = self.block(x, is_training, 512, 512, 2048, 2048 if i == 0 else None, downsample=True if i == 0 else False, reuse=reuse, name=name) if truncate_at == name: return x, ends name = 'pool5' x = tfw.avg_pool2d(x, 7, 1, 'VALID', name=name) if spatial_squeeze and x.get_shape().as_list()[1] == x.get_shape( ).as_list()[2] == 1: x = tf.squeeze(x, squeeze_dims=[1, 2]) ends[name] = x if truncate_at == name: return x, ends # Logits if for_imagenet: name = 'fc1000' ends['logits'] = x = tfw.fully_connected(x, 1000, activation_fn=None, is_training=is_training, reuse=reuse, name=name) return x, ends
def inference_ops(self, x, is_training=True, spatial_squeeze=True, truncate_at=None, reuse=False): # filters = [128, 128, 256, 512, 1024] filters = [64, 64, 128, 256, 512] kernels = [7, 3, 3, 3, 3] strides = [2, 0, 2, 2, 2] # conv1 ends = OrderedDict() with tf.variable_scope('conv1'): name = 'conv' ends[name] = x = tfw.conv_2d(x, filters[0], kernels[0], strides[0], use_bias=False, use_batch_norm=True, activation_fn=tf.nn.relu, is_training=is_training, trainable=is_training, reuse=reuse, name=name) x = tf.nn.max_pool(x, [1, 3, 3, 1], [1, 2, 2, 1], 'SAME') print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()), str(x.dtype))) if truncate_at is not None and truncate_at == 'conv1': return x, ends # conv2_x name = 'conv2_1' ends[name] = x = self._residual_block(x, is_training, reuse, name=name) print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()), str(x.dtype))) if truncate_at is not None and truncate_at == name: return x, ends name = 'conv2_2' ends[name] = x = self._residual_block(x, is_training, reuse, name=name) print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()), str(x.dtype))) if truncate_at is not None and truncate_at == name: return x, ends # conv3_x name = 'conv3_1' ends[name] = x = self._residual_block_first(x, filters[2], strides[2], is_training, reuse, name=name) print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()), str(x.dtype))) if truncate_at is not None and truncate_at == name: return x, ends name = 'conv3_2' ends[name] = x = self._residual_block(x, is_training, reuse, name=name) print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()), str(x.dtype))) if truncate_at is not None and truncate_at == name: return x, ends # conv4_x name = 'conv4_1' ends[name] = x = self._residual_block_first(x, filters[3], strides[3], is_training, reuse, name=name) print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()), str(x.dtype))) if truncate_at is not None and truncate_at == name: return x, ends name = 'conv4_2' ends[name] = x = self._residual_block(x, is_training, reuse, name=name) print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()), str(x.dtype))) if truncate_at is not None and truncate_at == name: return x, ends # conv5_x name = 'conv5_1' ends[name] = x = self._residual_block_first(x, filters[4], strides[4], is_training, reuse, name=name) print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()), str(x.dtype))) if truncate_at is not None and truncate_at == name: return x, ends name = 'conv5_2' ends[name] = x = self._residual_block(x, is_training, reuse, name=name) print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()), str(x.dtype))) if truncate_at is not None and truncate_at == name: return x, ends # Logit with tf.variable_scope('logits') as scope: x = tf.reduce_mean(x, [1, 2]) name = 'fc' ends[name] = x = tfw.fully_connected(x, 1000, name=name) print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()), str(x.dtype))) return x, ends
def separation_ops(self, mono, stft, audio_enc, feats, scope='separation'): if self.separation == NO_SEPARATION: ss = self.snd_contx / 2 x_sep = mono[:, :, ss:ss + self.snd_dur] # BS x 1 x NF x_sep = tf.expand_dims(x_sep, axis=1) self.ends[scope + '/' + 'all_channels'] = x_sep print(' * {:15s} | {:20s} | {:10s}'.format('Crop Audio', str(x_sep.get_shape()), str(x_sep.dtype))) return x_sep elif self.separation == FREQ_MASK: n_filters = [32, 64, 128, 256, 512] filter_size = [(7, 16), (3, 7), (3, 5), (3, 5), (3, 5)] stride = [(4, 8), (2, 4), (2, 2), (1, 1), (1, 1)] name = 'fc-feats' feats = tfw.fully_connected(feats, n_filters[-1], activation_fn=tf.nn.relu, name=name) print(' * {:15s} | {:20s} | {:10s}'.format(name, str(feats.get_shape()), str(feats.dtype))) sz = feats.get_shape().as_list() enc_sz = audio_enc[-1].get_shape().as_list() feats = tf.tile(tf.expand_dims(feats, 2), (1, 1, enc_sz[2], 1)) feats = tf.reshape(feats, (sz[0], sz[1], enc_sz[2], sz[2])) print(' * {:15s} | {:20s} | {:10s}'.format('Tile', str(feats.get_shape()), str(feats.dtype))) x = tf.concat([audio_enc[-1], feats], axis=3) print(' * {:15s} | {:20s} | {:10s}'.format('Concat', str(x.get_shape()), str(x.dtype))) # Up-convolution n_chann_in = mono.get_shape().as_list()[1] for l, nf, fs, st, l_in in reversed( zip(range(len(n_filters)), [ self.params.sep_num_tracks * n_chann_in, ] + n_filters[:-1], filter_size, stride, audio_enc[:-1])): name = 'deconv{}'.format(l + 1) x = tfw.deconv_2d(x, nf, fs, stride=st, padding='VALID', activation_fn=None, name=name) print(' * {:15s} | {:20s} | {:10s}'.format( name, str(x.get_shape()), str(x.dtype))) if l == 0: break x = tf.concat((tf.nn.relu(x), l_in), 3) print(' * {:15s} | {:20s} | {:10s}'.format( 'Concat', str(x.get_shape()), str(x.dtype))) # Crop ss = np.floor( (self.snd_contx / 2. - self.wind_size) * (4. / self.wind_size)) tt = np.ceil( (self.snd_contx / 2. + self.snd_dur + self.wind_size) * (4. / self.wind_size)) inp_dim = 95. # Encoder Dim=1 skip = (self.snd_contx / 2.) * (4. / self.wind_size) skip = int(skip - (inp_dim - 1) / 2.) stft = stft[:, :, int(ss):int(tt)] print(' * {:15s} | {:20s} | {:10s}'.format('Crop STFT', str(stft.get_shape()), str(stft.dtype))) x = x[:, int(ss - skip):int(tt - skip), :] print(' * {:15s} | {:20s} | {:10s}'.format('Crop deconv1', str(x.get_shape()), str(x.dtype))) x = tf.transpose(x, (0, 3, 1, 2)) print(' * {:15s} | {:20s} | {:10s}'.format('Permute', str(x.get_shape()), str(x.dtype))) x_sz = x.get_shape().as_list() x = tf.reshape(x, (x_sz[0], n_chann_in, -1, x_sz[2], x_sz[3])) print(' * {:15s} | {:20s} | {:10s}'.format('Reshape', str(x.get_shape()), str(x.dtype))) # Apply Mask f_mask = tf.cast(tf.sigmoid(x), dtype=tf.complex64) print(' * {:15s} | {:20s} | {:10s}'.format('Sigmoid', str(f_mask.get_shape()), str(f_mask.dtype))) stft_sep = tf.expand_dims(stft, 2) * f_mask print(' * {:15s} | {:20s} | {:10s}'.format( 'Prod', str(stft_sep.get_shape()), str(stft_sep.dtype))) # IFFT x_sep = myutils.istft(stft_sep, 4) print(' * {:15s} | {:20s} | {:10s}'.format('ISTFT', str(x_sep.get_shape()), str(x_sep.dtype))) ss = self.snd_contx / 2. skip = np.floor((self.snd_contx / 2. - self.wind_size) * (4. / self.wind_size)) * (self.wind_size / 4.) skip += 3. * self.wind_size / 4. # ISTFT ignores 3/4 of a window x_sep = x_sep[:, :, :, int(ss - skip):int(ss - skip) + self.snd_dur] print(' * {:15s} | {:20s} | {:10s}'.format('Crop', str(x_sep.get_shape()), str(x_sep.dtype))) else: raise ValueError('Unknown separation mode.') self.ends[scope + '/' + 'all_channels'] = x_sep return x_sep