コード例 #1
0
    def get_output_shape_for(self, input_shape):
        batch_size = input_shape[0]

        output_rows = conv_output_length(input_shape[2],
                                         self.filter_size[0],
                                         self.stride[0],
                                         'pad', self.pad[0])

        output_columns = conv_output_length(input_shape[3],
                                            self.filter_size[1],
                                            self.stride[1],
                                            'pad', self.pad[1])

        return (batch_size, self.num_filters, output_rows, output_columns)
コード例 #2
0
 def get_output_shape_for(self, input_shape):
     pad = self.pad if isinstance(self.pad,
                                  tuple) else (self.pad, ) * self.n
     batchsize = input_shape[0]
     return ((batchsize, self.num_filters) + tuple(
         conv_output_length(input, filter, stride, p)
         for input, filter, stride, p in zip(
             input_shape[2:], self.filter_size, self.stride, pad)))
コード例 #3
0
ファイル: conv_recurrent.py プロジェクト: Seb-Leb/Tars
 def get_output_shape_for(self, input_shapes):
     pad = self.pad if isinstance(self.pad, tuple) else (self.pad,) * self.n
     input_shape_h = self.input_shapes[2]
     batchsize = input_shape_h[0]
     return ((batchsize, self.num_filters) +
             tuple(conv_output_length(input, filter, stride, p)
                   for input, filter, stride, p
                   in zip(input_shape_h[2:], self.filter_size,
                          self.stride, pad)))
コード例 #4
0
    def get_output_shape_for(self, input_shape):
        pad = (0, 0)
        batchsize = input_shape[0]
        shape_raw = ((batchsize, self.n_fft / 2 + 1) + tuple(
            conv_output_length(input, filter, stride, p)
            for input, filter, stride, p in zip(
                input_shape[2:], self.filter_size, self.stride, pad)))

        if self.n_ch == 2:
            return (shape_raw[0], 2, shape_raw[1], shape_raw[2])
        elif self.n_ch == 1:
            return (shape_raw[0], 1, shape_raw[1], shape_raw[2])
コード例 #5
0
ファイル: test_conv.py プロジェクト: colinfang/Lasagne
def test_conv_output_length():
    from lasagne.layers.conv import conv_output_length

    assert conv_output_length(13, 5, 3, 'valid') == 3
    assert conv_output_length(13, 5, 3, 0) == 3
    assert conv_output_length(13, 5, 3, 'full') == 6
    assert conv_output_length(13, 5, 3, 'same') == 5
    assert conv_output_length(13, 5, 3, 2) == 5

    with pytest.raises(ValueError) as exc:
        conv_output_length(13, 5, 3, '_nonexistent_mode')
    assert "Invalid pad: " in exc.value.args[0]
コード例 #6
0
ファイル: test_conv.py プロジェクト: sveitser/Lasagne
def test_conv_output_length():
    from lasagne.layers.conv import conv_output_length

    assert conv_output_length(13, 5, 3, 'valid') == 3
    assert conv_output_length(13, 5, 3, 0) == 3
    assert conv_output_length(13, 5, 3, 'full') == 6
    assert conv_output_length(13, 5, 3, 'same') == 5
    assert conv_output_length(13, 5, 3, 2) == 5

    with pytest.raises(ValueError) as exc:
        conv_output_length(13, 5, 3, '_nonexistent_mode')
    assert "Invalid pad: " in exc.value.args[0]
コード例 #7
0
ファイル: test_conv.py プロジェクト: JackKelly/nntools
 def test_invalid_border_mode(self):
     from lasagne.layers.conv import conv_output_length
     with pytest.raises(RuntimeError) as exc:
         conv_output_length(5, 3, 1, border_mode='_nonexistent_mode')
     assert "Invalid border mode" in exc.value.args[0]
コード例 #8
0
 def test_invalid_border_mode(self):
     from lasagne.layers.conv import conv_output_length
     with pytest.raises(RuntimeError) as exc:
         conv_output_length(5, 3, 1, border_mode='_nonexistent_mode')
     assert "Invalid border mode" in exc.value.args[0]
コード例 #9
0
    def get_output_shape_for(self, input_shape):
        output_length = conv_output_length(input_shape[2], self.filter_size,
                                           self.stride, self.border_mode)

        return (input_shape[0], self.num_output_channels, output_length)