Exemplo n.º 1
0
    def _residual_block(self, x, is_training=False, reuse=False, name="unit"):
        num_channel = x.get_shape().as_list()[-1]
        with tf.variable_scope(name) as scope:
            # Shortcut connection
            shortcut = x
            # Residual
            x = tfw.conv_2d(x,
                            num_channel,
                            3,
                            1,
                            use_bias=False,
                            use_batch_norm=True,
                            activation_fn=tf.nn.relu,
                            is_training=is_training,
                            trainable=is_training,
                            reuse=reuse,
                            name='conv_1')
            x = tfw.conv_2d(x,
                            num_channel,
                            3,
                            1,
                            use_bias=False,
                            use_batch_norm=True,
                            activation_fn=None,
                            is_training=is_training,
                            trainable=is_training,
                            reuse=reuse,
                            name='conv_2')

            x = tf.nn.relu(x + shortcut)
        return x
Exemplo n.º 2
0
    def block(x,
              is_training,
              b2a,
              b2b,
              b2c,
              b1=None,
              downsample=False,
              name=None,
              reuse=False):
        with tf.variable_scope(name, values=[x, is_training]):
            s = 2 if downsample else 1
            y1 = tfw.conv_2d(x,
                             b1,
                             1,
                             s,
                             use_bias=False,
                             use_batch_norm=True,
                             activation_fn=None,
                             is_training=is_training,
                             trainable=is_training,
                             reuse=reuse,
                             name='branch1') if b1 is not None else x
            y2 = tfw.conv_2d(x,
                             b2a,
                             1,
                             s,
                             use_bias=False,
                             use_batch_norm=True,
                             activation_fn=tf.nn.relu,
                             is_training=is_training,
                             trainable=is_training,
                             reuse=reuse,
                             name='branch2a')
            y2 = tfw.conv_2d(y2,
                             b2b,
                             3,
                             1,
                             use_bias=False,
                             use_batch_norm=True,
                             activation_fn=tf.nn.relu,
                             is_training=is_training,
                             trainable=is_training,
                             reuse=reuse,
                             name='branch2b')
            y2 = tfw.conv_2d(y2,
                             b2c,
                             1,
                             1,
                             use_bias=False,
                             use_batch_norm=True,
                             activation_fn=None,
                             is_training=is_training,
                             trainable=is_training,
                             reuse=reuse,
                             name='branch2c')

            return tf.nn.relu(y1 + y2)
Exemplo n.º 3
0
    def _residual_block_first(self,
                              x,
                              out_channel,
                              strides,
                              is_training=False,
                              reuse=False,
                              name="unit"):
        in_channel = x.get_shape().as_list()[-1]
        with tf.variable_scope(name) as scope:

            # Shortcut connection
            if in_channel == out_channel:
                if strides == 1:
                    shortcut = tf.identity(x)
                else:
                    shortcut = tf.nn.max_pool(x, [1, strides, strides, 1],
                                              [1, strides, strides, 1],
                                              'VALID')
            else:
                shortcut = tfw.conv_2d(x,
                                       out_channel,
                                       1,
                                       strides,
                                       use_bias=False,
                                       activation_fn=None,
                                       is_training=is_training,
                                       trainable=is_training,
                                       reuse=reuse,
                                       name='shortcut')

            # Residual
            x = tfw.conv_2d(x,
                            out_channel,
                            3,
                            strides,
                            use_bias=False,
                            use_batch_norm=True,
                            activation_fn=tf.nn.relu,
                            is_training=is_training,
                            trainable=is_training,
                            reuse=reuse,
                            name='conv_1')
            x = tfw.conv_2d(x,
                            out_channel,
                            3,
                            1,
                            use_bias=False,
                            use_batch_norm=True,
                            activation_fn=None,
                            is_training=is_training,
                            trainable=is_training,
                            reuse=reuse,
                            name='conv_2')

            # Merge
            x = tf.nn.relu(x + shortcut)
        return x
Exemplo n.º 4
0
    def audio_encoder_ops(self, stft):
        n_filters = [32, 64, 128, 256, 512]
        filter_size = [(7, 16), (3, 7), (3, 5), (3, 5), (3, 5)]
        stride = [(4, 8), (2, 4), (2, 2), (1, 1), (1, 1)]

        inp_dim = 95.  # Encoder Dim=1
        ss = (self.snd_contx / 2.) * (4. / self.wind_size)
        ss = int(ss - (inp_dim - 1) / 2.)

        tt = (self.snd_contx / 2. + self.snd_dur) * (4. / self.wind_size)
        tt = int(tt + (inp_dim - 1) / 2.)
        tt = int((np.ceil((tt - ss - inp_dim) / 16.)) * 16 + inp_dim + ss)

        sz = stft.get_shape().as_list()
        stft = tf.transpose(stft[:, :, ss:tt, :], (0, 2, 3, 1))
        print(' * {:15s} | {:20s} | {:10s}'.format('Crop',
                                                   str(stft.get_shape()),
                                                   str(stft.dtype)))

        x = tf.abs(stft)
        print(' * {:15s} | {:20s} | {:10s}'.format('Magnitude',
                                                   str(x.get_shape()),
                                                   str(x.dtype)))

        downsampling_l = [x]
        for l, nf, fs, st in zip(range(len(n_filters)), n_filters, filter_size,
                                 stride):
            name = 'conv{}'.format(l + 1)
            x = tfw.conv_2d(x,
                            nf,
                            fs,
                            padding='VALID',
                            activation_fn=tf.nn.relu,
                            stride=st,
                            name=name)
            downsampling_l.append(x)
            print(' * {:15s} | {:20s} | {:10s}'.format(name,
                                                       str(x.get_shape()),
                                                       str(x.dtype)))
        return downsampling_l
Exemplo n.º 5
0
    def inference_ops(self,
                      x,
                      is_training=True,
                      for_imagenet=True,
                      spatial_squeeze=True,
                      truncate_at=None,
                      reuse=False):
        inp_dims = x.get_shape()
        assert inp_dims.ndims in (3, 4)
        if inp_dims.ndims == 3: x = tf.expand_dims(x, 0)

        ends = OrderedDict()

        # Scale 1
        name = 'conv1'
        ends[name] = x = tfw.conv_2d(x,
                                     64,
                                     7,
                                     2,
                                     use_bias=False,
                                     use_batch_norm=True,
                                     activation_fn=tf.nn.relu,
                                     is_training=is_training,
                                     reuse=reuse,
                                     name=name)
        if truncate_at == name: return x, ends

        name = 'pool1'
        ends[name] = x = tfw.max_pool2d(x, 3, 2, padding='SAME', name=name)
        if truncate_at == name: return x, ends

        # Scale 2
        for i, c in enumerate(string.ascii_lowercase[:3]):
            name = 'res2' + c
            ends[name] = x = self.block(x,
                                        is_training,
                                        64,
                                        64,
                                        256,
                                        None if i > 0 else 256,
                                        downsample=False,
                                        reuse=reuse,
                                        name=name)
            if truncate_at == name: return x, ends

        # Scale 3
        name = 'res3a'
        ends[name] = x = self.block(x,
                                    is_training,
                                    128,
                                    128,
                                    512,
                                    512,
                                    downsample=True,
                                    reuse=reuse,
                                    name=name)
        if truncate_at == name: return x, ends

        for i in range(7):
            name = 'res3b%d' % (i + 1)
            ends[name] = x = self.block(x,
                                        is_training,
                                        128,
                                        128,
                                        512,
                                        None,
                                        downsample=False,
                                        reuse=reuse,
                                        name=name)
            if truncate_at == name: return x, ends

        # Scale 4
        name = 'res4a'
        ends[name] = x = self.block(x,
                                    is_training,
                                    256,
                                    256,
                                    1024,
                                    1024,
                                    downsample=True,
                                    reuse=reuse,
                                    name=name)
        if truncate_at == name: return x, ends

        for i in range(35):
            name = 'res4b%d' % (i + 1)
            ends[name] = x = self.block(x,
                                        is_training,
                                        256,
                                        256,
                                        1024,
                                        None,
                                        downsample=False,
                                        reuse=reuse,
                                        name=name)
            if truncate_at == name: return x, ends

        # Scale 5
        for i, c in enumerate(string.ascii_lowercase[:3]):
            name = 'res5' + c
            ends[name] = x = self.block(x,
                                        is_training,
                                        512,
                                        512,
                                        2048,
                                        2048 if i == 0 else None,
                                        downsample=True if i == 0 else False,
                                        reuse=reuse,
                                        name=name)
            if truncate_at == name: return x, ends

        name = 'pool5'
        x = tfw.avg_pool2d(x, 7, 1, 'VALID', name=name)
        if spatial_squeeze and x.get_shape().as_list()[1] == x.get_shape(
        ).as_list()[2] == 1:
            x = tf.squeeze(x, squeeze_dims=[1, 2])
        ends[name] = x
        if truncate_at == name: return x, ends

        # Logits
        if for_imagenet:
            name = 'fc1000'
            ends['logits'] = x = tfw.fully_connected(x,
                                                     1000,
                                                     activation_fn=None,
                                                     is_training=is_training,
                                                     reuse=reuse,
                                                     name=name)

        return x, ends
Exemplo n.º 6
0
    def inference_ops(self,
                      x,
                      is_training=True,
                      spatial_squeeze=True,
                      truncate_at=None,
                      reuse=False):
        # filters = [128, 128, 256, 512, 1024]
        filters = [64, 64, 128, 256, 512]
        kernels = [7, 3, 3, 3, 3]
        strides = [2, 0, 2, 2, 2]

        # conv1
        ends = OrderedDict()
        with tf.variable_scope('conv1'):
            name = 'conv'
            ends[name] = x = tfw.conv_2d(x,
                                         filters[0],
                                         kernels[0],
                                         strides[0],
                                         use_bias=False,
                                         use_batch_norm=True,
                                         activation_fn=tf.nn.relu,
                                         is_training=is_training,
                                         trainable=is_training,
                                         reuse=reuse,
                                         name=name)
            x = tf.nn.max_pool(x, [1, 3, 3, 1], [1, 2, 2, 1], 'SAME')
            print(' * {:15s} | {:20s} | {:10s}'.format(name,
                                                       str(x.get_shape()),
                                                       str(x.dtype)))
            if truncate_at is not None and truncate_at == 'conv1':
                return x, ends

        # conv2_x
        name = 'conv2_1'
        ends[name] = x = self._residual_block(x, is_training, reuse, name=name)
        print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()),
                                                   str(x.dtype)))
        if truncate_at is not None and truncate_at == name:
            return x, ends

        name = 'conv2_2'
        ends[name] = x = self._residual_block(x, is_training, reuse, name=name)
        print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()),
                                                   str(x.dtype)))
        if truncate_at is not None and truncate_at == name:
            return x, ends

        # conv3_x
        name = 'conv3_1'
        ends[name] = x = self._residual_block_first(x,
                                                    filters[2],
                                                    strides[2],
                                                    is_training,
                                                    reuse,
                                                    name=name)
        print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()),
                                                   str(x.dtype)))
        if truncate_at is not None and truncate_at == name:
            return x, ends

        name = 'conv3_2'
        ends[name] = x = self._residual_block(x, is_training, reuse, name=name)
        print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()),
                                                   str(x.dtype)))
        if truncate_at is not None and truncate_at == name:
            return x, ends

        # conv4_x
        name = 'conv4_1'
        ends[name] = x = self._residual_block_first(x,
                                                    filters[3],
                                                    strides[3],
                                                    is_training,
                                                    reuse,
                                                    name=name)
        print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()),
                                                   str(x.dtype)))
        if truncate_at is not None and truncate_at == name:
            return x, ends

        name = 'conv4_2'
        ends[name] = x = self._residual_block(x, is_training, reuse, name=name)
        print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()),
                                                   str(x.dtype)))
        if truncate_at is not None and truncate_at == name:
            return x, ends

        # conv5_x
        name = 'conv5_1'
        ends[name] = x = self._residual_block_first(x,
                                                    filters[4],
                                                    strides[4],
                                                    is_training,
                                                    reuse,
                                                    name=name)
        print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()),
                                                   str(x.dtype)))
        if truncate_at is not None and truncate_at == name:
            return x, ends

        name = 'conv5_2'
        ends[name] = x = self._residual_block(x, is_training, reuse, name=name)
        print(' * {:15s} | {:20s} | {:10s}'.format(name, str(x.get_shape()),
                                                   str(x.dtype)))
        if truncate_at is not None and truncate_at == name:
            return x, ends

        # Logit
        with tf.variable_scope('logits') as scope:
            x = tf.reduce_mean(x, [1, 2])
            name = 'fc'
            ends[name] = x = tfw.fully_connected(x, 1000, name=name)
            print(' * {:15s} | {:20s} | {:10s}'.format(name,
                                                       str(x.get_shape()),
                                                       str(x.dtype)))
        return x, ends