Beispiel #1
0
    def bottleneck_ops(self, x_enc, use_audio=True):
        if len(x_enc) == 0:
            return None
        bottleneck = []
        audio_sz = x_enc[AUDIO][-1].get_shape().as_list()
        for k in [AUDIO, VIDEO, FLOW]:
            if k == AUDIO and not use_audio:
                continue
            if k in x_enc:
                x = x_enc[k][-1] if k == AUDIO else x_enc[k]
                print(' * {:15s} | {:20s} | {:10s}'.format(
                    k + '-feats', str(x.get_shape()), str(x.dtype)))

                if k != AUDIO:
                    name = k + '-fc-red'
                    x = tfw.fully_connected(x,
                                            128,
                                            activation_fn=tf.nn.relu,
                                            name=name)
                    print(' * {:15s} | {:20s} | {:10s}'.format(
                        name, str(x.get_shape()), str(x.dtype)))

                sz = x.get_shape().as_list()
                out_shape = (sz[0], sz[1], sz[2] *
                             sz[3]) if k == AUDIO else (sz[0], 1,
                                                        sz[1] * sz[2] * sz[3])
                x = tf.reshape(x, out_shape)
                print(' * {:15s} | {:20s} | {:10s}'.format(
                    k + '-reshape', str(x.get_shape()), str(x.dtype)))

                name = k + '-fc'
                n_units = 1024 if k == AUDIO else 512
                x = tfw.fully_connected(x,
                                        n_units,
                                        activation_fn=tf.nn.relu,
                                        name=name)
                print(' * {:15s} | {:20s} | {:10s}'.format(
                    name, str(x.get_shape()), str(x.dtype)))

                if k in [VIDEO, FLOW]:
                    x = tf.tile(x, (1, audio_sz[1], 1))
                    print(' * {:15s} | {:20s} | {:10s}'.format(
                        k + ' tile', str(x.get_shape()), str(x.dtype)))

                bottleneck.append(x)

        bottleneck = tf.concat(bottleneck, 2)
        print(' * {:15s} | {:20s} | {:10s}'.format('Concat',
                                                   str(bottleneck.get_shape()),
                                                   str(bottleneck.dtype)))

        return bottleneck
Beispiel #2
0
    def localization_ops(self, x):
        num_out = (self.ambi_order + 1)**2 - self.ambi_order**2
        num_in = self.ambi_order**2

        # Localization
        for i, u in enumerate(self.params.loc_fc_units):
            name = 'fc{}'.format(i + 1)
            x = tfw.fully_connected(x, u, activation_fn=tf.nn.relu, name=name)
            print(' * {:15s} | {:20s} | {:10s}'.format(name,
                                                       str(x.get_shape()),
                                                       str(x.dtype)))

        # Compute localization weights
        name = 'fc{}'.format(len(self.params.loc_fc_units) + 1)
        x = tfw.fully_connected(
            x,
            num_out * num_in * (self.params.sep_num_tracks + 1),
            activation_fn=None,
            weights_initializer=tf.truncated_normal_initializer(stddev=0.001),
            weight_decay=0,
            name=name)  # BS x NF x NIN x NOUT

        sz = x.get_shape().as_list()
        x = tf.reshape(
            x, (sz[0], sz[1], num_out, num_in, self.params.sep_num_tracks + 1))
        print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()),
                                                   str(x.dtype)))

        sz = x.get_shape().as_list()
        x = tf.tile(tf.expand_dims(x, 2),
                    (1, 1, self.snd_dur / sz[1], 1, 1, 1))
        x = tf.reshape(x, (sz[0], self.snd_dur, sz[2], sz[3], sz[4]))
        print(' * {:15s} | {:20s} | {:10s}'.format('Tile', str(x.get_shape()),
                                                   str(x.dtype)))

        weights = x[:, :, :, :, :-1]
        print(' * {:15s} | {:20s} | {:10s}'.format('weights',
                                                   str(weights.get_shape()),
                                                   str(weights.dtype)))
        biases = x[:, :, :, :, -1]
        print(' * {:15s} | {:20s} | {:10s}'.format('biases',
                                                   str(biases.get_shape()),
                                                   str(biases.dtype)))
        return weights, biases
Beispiel #3
0
    def inference_ops(self,
                      x,
                      is_training=True,
                      for_imagenet=True,
                      spatial_squeeze=True,
                      truncate_at=None,
                      reuse=False):
        inp_dims = x.get_shape()
        assert inp_dims.ndims in (3, 4)
        if inp_dims.ndims == 3: x = tf.expand_dims(x, 0)

        ends = OrderedDict()

        # Scale 1
        name = 'conv1'
        ends[name] = x = tfw.conv_2d(x,
                                     64,
                                     7,
                                     2,
                                     use_bias=False,
                                     use_batch_norm=True,
                                     activation_fn=tf.nn.relu,
                                     is_training=is_training,
                                     reuse=reuse,
                                     name=name)
        if truncate_at == name: return x, ends

        name = 'pool1'
        ends[name] = x = tfw.max_pool2d(x, 3, 2, padding='SAME', name=name)
        if truncate_at == name: return x, ends

        # Scale 2
        for i, c in enumerate(string.ascii_lowercase[:3]):
            name = 'res2' + c
            ends[name] = x = self.block(x,
                                        is_training,
                                        64,
                                        64,
                                        256,
                                        None if i > 0 else 256,
                                        downsample=False,
                                        reuse=reuse,
                                        name=name)
            if truncate_at == name: return x, ends

        # Scale 3
        name = 'res3a'
        ends[name] = x = self.block(x,
                                    is_training,
                                    128,
                                    128,
                                    512,
                                    512,
                                    downsample=True,
                                    reuse=reuse,
                                    name=name)
        if truncate_at == name: return x, ends

        for i in range(7):
            name = 'res3b%d' % (i + 1)
            ends[name] = x = self.block(x,
                                        is_training,
                                        128,
                                        128,
                                        512,
                                        None,
                                        downsample=False,
                                        reuse=reuse,
                                        name=name)
            if truncate_at == name: return x, ends

        # Scale 4
        name = 'res4a'
        ends[name] = x = self.block(x,
                                    is_training,
                                    256,
                                    256,
                                    1024,
                                    1024,
                                    downsample=True,
                                    reuse=reuse,
                                    name=name)
        if truncate_at == name: return x, ends

        for i in range(35):
            name = 'res4b%d' % (i + 1)
            ends[name] = x = self.block(x,
                                        is_training,
                                        256,
                                        256,
                                        1024,
                                        None,
                                        downsample=False,
                                        reuse=reuse,
                                        name=name)
            if truncate_at == name: return x, ends

        # Scale 5
        for i, c in enumerate(string.ascii_lowercase[:3]):
            name = 'res5' + c
            ends[name] = x = self.block(x,
                                        is_training,
                                        512,
                                        512,
                                        2048,
                                        2048 if i == 0 else None,
                                        downsample=True if i == 0 else False,
                                        reuse=reuse,
                                        name=name)
            if truncate_at == name: return x, ends

        name = 'pool5'
        x = tfw.avg_pool2d(x, 7, 1, 'VALID', name=name)
        if spatial_squeeze and x.get_shape().as_list()[1] == x.get_shape(
        ).as_list()[2] == 1:
            x = tf.squeeze(x, squeeze_dims=[1, 2])
        ends[name] = x
        if truncate_at == name: return x, ends

        # Logits
        if for_imagenet:
            name = 'fc1000'
            ends['logits'] = x = tfw.fully_connected(x,
                                                     1000,
                                                     activation_fn=None,
                                                     is_training=is_training,
                                                     reuse=reuse,
                                                     name=name)

        return x, ends
Beispiel #4
0
    def inference_ops(self,
                      x,
                      is_training=True,
                      spatial_squeeze=True,
                      truncate_at=None,
                      reuse=False):
        # filters = [128, 128, 256, 512, 1024]
        filters = [64, 64, 128, 256, 512]
        kernels = [7, 3, 3, 3, 3]
        strides = [2, 0, 2, 2, 2]

        # conv1
        ends = OrderedDict()
        with tf.variable_scope('conv1'):
            name = 'conv'
            ends[name] = x = tfw.conv_2d(x,
                                         filters[0],
                                         kernels[0],
                                         strides[0],
                                         use_bias=False,
                                         use_batch_norm=True,
                                         activation_fn=tf.nn.relu,
                                         is_training=is_training,
                                         trainable=is_training,
                                         reuse=reuse,
                                         name=name)
            x = tf.nn.max_pool(x, [1, 3, 3, 1], [1, 2, 2, 1], 'SAME')
            print(' * {:15s} | {:20s} | {:10s}'.format(name,
                                                       str(x.get_shape()),
                                                       str(x.dtype)))
            if truncate_at is not None and truncate_at == 'conv1':
                return x, ends

        # conv2_x
        name = 'conv2_1'
        ends[name] = x = self._residual_block(x, is_training, reuse, name=name)
        print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()),
                                                   str(x.dtype)))
        if truncate_at is not None and truncate_at == name:
            return x, ends

        name = 'conv2_2'
        ends[name] = x = self._residual_block(x, is_training, reuse, name=name)
        print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()),
                                                   str(x.dtype)))
        if truncate_at is not None and truncate_at == name:
            return x, ends

        # conv3_x
        name = 'conv3_1'
        ends[name] = x = self._residual_block_first(x,
                                                    filters[2],
                                                    strides[2],
                                                    is_training,
                                                    reuse,
                                                    name=name)
        print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()),
                                                   str(x.dtype)))
        if truncate_at is not None and truncate_at == name:
            return x, ends

        name = 'conv3_2'
        ends[name] = x = self._residual_block(x, is_training, reuse, name=name)
        print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()),
                                                   str(x.dtype)))
        if truncate_at is not None and truncate_at == name:
            return x, ends

        # conv4_x
        name = 'conv4_1'
        ends[name] = x = self._residual_block_first(x,
                                                    filters[3],
                                                    strides[3],
                                                    is_training,
                                                    reuse,
                                                    name=name)
        print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()),
                                                   str(x.dtype)))
        if truncate_at is not None and truncate_at == name:
            return x, ends

        name = 'conv4_2'
        ends[name] = x = self._residual_block(x, is_training, reuse, name=name)
        print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()),
                                                   str(x.dtype)))
        if truncate_at is not None and truncate_at == name:
            return x, ends

        # conv5_x
        name = 'conv5_1'
        ends[name] = x = self._residual_block_first(x,
                                                    filters[4],
                                                    strides[4],
                                                    is_training,
                                                    reuse,
                                                    name=name)
        print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()),
                                                   str(x.dtype)))
        if truncate_at is not None and truncate_at == name:
            return x, ends

        name = 'conv5_2'
        ends[name] = x = self._residual_block(x, is_training, reuse, name=name)
        print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()),
                                                   str(x.dtype)))
        if truncate_at is not None and truncate_at == name:
            return x, ends

        # Logit
        with tf.variable_scope('logits') as scope:
            x = tf.reduce_mean(x, [1, 2])
            name = 'fc'
            ends[name] = x = tfw.fully_connected(x, 1000, name=name)
            print(' * {:15s} | {:20s} | {:10s}'.format(name,
                                                       str(x.get_shape()),
                                                       str(x.dtype)))
        return x, ends
Beispiel #5
0
    def separation_ops(self, mono, stft, audio_enc, feats, scope='separation'):
        if self.separation == NO_SEPARATION:
            ss = self.snd_contx / 2
            x_sep = mono[:, :, ss:ss + self.snd_dur]  # BS x 1 x NF
            x_sep = tf.expand_dims(x_sep, axis=1)
            self.ends[scope + '/' + 'all_channels'] = x_sep
            print(' * {:15s} | {:20s} | {:10s}'.format('Crop Audio',
                                                       str(x_sep.get_shape()),
                                                       str(x_sep.dtype)))
            return x_sep

        elif self.separation == FREQ_MASK:
            n_filters = [32, 64, 128, 256, 512]
            filter_size = [(7, 16), (3, 7), (3, 5), (3, 5), (3, 5)]
            stride = [(4, 8), (2, 4), (2, 2), (1, 1), (1, 1)]

            name = 'fc-feats'
            feats = tfw.fully_connected(feats,
                                        n_filters[-1],
                                        activation_fn=tf.nn.relu,
                                        name=name)
            print(' * {:15s} | {:20s} | {:10s}'.format(name,
                                                       str(feats.get_shape()),
                                                       str(feats.dtype)))

            sz = feats.get_shape().as_list()
            enc_sz = audio_enc[-1].get_shape().as_list()
            feats = tf.tile(tf.expand_dims(feats, 2), (1, 1, enc_sz[2], 1))
            feats = tf.reshape(feats, (sz[0], sz[1], enc_sz[2], sz[2]))
            print(' * {:15s} | {:20s} | {:10s}'.format('Tile',
                                                       str(feats.get_shape()),
                                                       str(feats.dtype)))

            x = tf.concat([audio_enc[-1], feats], axis=3)
            print(' * {:15s} | {:20s} | {:10s}'.format('Concat',
                                                       str(x.get_shape()),
                                                       str(x.dtype)))

            # Up-convolution
            n_chann_in = mono.get_shape().as_list()[1]
            for l, nf, fs, st, l_in in reversed(
                    zip(range(len(n_filters)), [
                        self.params.sep_num_tracks * n_chann_in,
                    ] + n_filters[:-1], filter_size, stride, audio_enc[:-1])):
                name = 'deconv{}'.format(l + 1)
                x = tfw.deconv_2d(x,
                                  nf,
                                  fs,
                                  stride=st,
                                  padding='VALID',
                                  activation_fn=None,
                                  name=name)
                print(' * {:15s} | {:20s} | {:10s}'.format(
                    name, str(x.get_shape()), str(x.dtype)))

                if l == 0:
                    break

                x = tf.concat((tf.nn.relu(x), l_in), 3)
                print(' * {:15s} | {:20s} | {:10s}'.format(
                    'Concat', str(x.get_shape()), str(x.dtype)))

            # Crop
            ss = np.floor(
                (self.snd_contx / 2. - self.wind_size) * (4. / self.wind_size))
            tt = np.ceil(
                (self.snd_contx / 2. + self.snd_dur + self.wind_size) *
                (4. / self.wind_size))
            inp_dim = 95.  # Encoder Dim=1
            skip = (self.snd_contx / 2.) * (4. / self.wind_size)
            skip = int(skip - (inp_dim - 1) / 2.)

            stft = stft[:, :, int(ss):int(tt)]
            print(' * {:15s} | {:20s} | {:10s}'.format('Crop STFT',
                                                       str(stft.get_shape()),
                                                       str(stft.dtype)))

            x = x[:, int(ss - skip):int(tt - skip), :]
            print(' * {:15s} | {:20s} | {:10s}'.format('Crop deconv1',
                                                       str(x.get_shape()),
                                                       str(x.dtype)))

            x = tf.transpose(x, (0, 3, 1, 2))
            print(' * {:15s} | {:20s} | {:10s}'.format('Permute',
                                                       str(x.get_shape()),
                                                       str(x.dtype)))

            x_sz = x.get_shape().as_list()
            x = tf.reshape(x, (x_sz[0], n_chann_in, -1, x_sz[2], x_sz[3]))
            print(' * {:15s} | {:20s} | {:10s}'.format('Reshape',
                                                       str(x.get_shape()),
                                                       str(x.dtype)))

            # Apply Mask
            f_mask = tf.cast(tf.sigmoid(x), dtype=tf.complex64)
            print(' * {:15s} | {:20s} | {:10s}'.format('Sigmoid',
                                                       str(f_mask.get_shape()),
                                                       str(f_mask.dtype)))

            stft_sep = tf.expand_dims(stft, 2) * f_mask
            print(' * {:15s} | {:20s} | {:10s}'.format(
                'Prod', str(stft_sep.get_shape()), str(stft_sep.dtype)))

            # IFFT
            x_sep = myutils.istft(stft_sep, 4)
            print(' * {:15s} | {:20s} | {:10s}'.format('ISTFT',
                                                       str(x_sep.get_shape()),
                                                       str(x_sep.dtype)))

            ss = self.snd_contx / 2.
            skip = np.floor((self.snd_contx / 2. - self.wind_size) *
                            (4. / self.wind_size)) * (self.wind_size / 4.)
            skip += 3. * self.wind_size / 4.  # ISTFT ignores 3/4 of a window
            x_sep = x_sep[:, :, :,
                          int(ss - skip):int(ss - skip) + self.snd_dur]
            print(' * {:15s} | {:20s} | {:10s}'.format('Crop',
                                                       str(x_sep.get_shape()),
                                                       str(x_sep.dtype)))

        else:
            raise ValueError('Unknown separation mode.')

        self.ends[scope + '/' + 'all_channels'] = x_sep
        return x_sep