def conv_block(self, inputs, out_filters, ksize): """ Pre-activated conv block (BN-ReLU-Conv) Parameters ---------- inputs: Input tensor out_filters: Number of output filters ksize: Kernel size. One integer of tuple of two integers """ use_bias = not self.use_bn outputs = inputs if self.use_bn: outputs = batch_normalization(outputs, training=self.training) outputs = relu(outputs) outputs = conv(outputs, out_filters, ksize=ksize, use_bias=use_bias) return outputs
def forward(self, inputs): """ Forward process Parameters ---------- inputs: Input tensor """ # encoder process outputs = inputs encoders = [] for i, filters in enumerate(self.block_filters): first_blocks = i == 0 outputs = self.build_res_blocks(outputs, filters, self.block_res_nums[i], first_blocks) if i != len(self.block_filters) - 1: encoders.append(outputs) outputs = maxpool(outputs, pool_size=2) # decoder process encoder_num = len(encoders) for k, encoder in enumerate(encoders[::-1]): i = encoder_num - k - 1 first_blocks = i == 0 filters = int(encoder.get_shape()[-1]) outputs = conv_bn_relu(outputs, filters, 3, 1, 1, use_bn=self.use_bn, training=self.training) outputs = upsample(outputs, 2, self.interpolation_type) outputs = concat([encoder, outputs]) outputs = self.build_res_blocks(outputs, filters, self.block_res_nums[i], first_blocks) outputs = conv(outputs, self.config.get('output_channel')) outputs = softmax(outputs) return outputs