예제 #1
0
    def conv_block(self, inputs, out_filters, ksize):
        """
        Pre-activated conv block (BN-ReLU-Conv)

        Parameters
        ----------
        inputs: Input tensor
        out_filters: Number of output filters
        ksize: Kernel size. One integer of tuple of two integers
        """
        use_bias = not self.use_bn
        outputs = inputs

        if self.use_bn:
            outputs = batch_normalization(outputs, training=self.training)
        outputs = relu(outputs)
        outputs = conv(outputs, out_filters, ksize=ksize, use_bias=use_bias)
        return outputs
예제 #2
0
    def forward(self, inputs):
        """
        Forward process

        Parameters
        ----------
        inputs: Input tensor
        """
        # encoder process
        outputs = inputs
        encoders = []
        for i, filters in enumerate(self.block_filters):
            first_blocks = i == 0
            outputs = self.build_res_blocks(outputs, filters,
                                            self.block_res_nums[i],
                                            first_blocks)

            if i != len(self.block_filters) - 1:
                encoders.append(outputs)
                outputs = maxpool(outputs, pool_size=2)

        # decoder process
        encoder_num = len(encoders)
        for k, encoder in enumerate(encoders[::-1]):
            i = encoder_num - k - 1
            first_blocks = i == 0
            filters = int(encoder.get_shape()[-1])
            outputs = conv_bn_relu(outputs,
                                   filters,
                                   3,
                                   1,
                                   1,
                                   use_bn=self.use_bn,
                                   training=self.training)
            outputs = upsample(outputs, 2, self.interpolation_type)
            outputs = concat([encoder, outputs])
            outputs = self.build_res_blocks(outputs, filters,
                                            self.block_res_nums[i],
                                            first_blocks)

        outputs = conv(outputs, self.config.get('output_channel'))
        outputs = softmax(outputs)
        return outputs