def build_res_blocks(self, inputs, out_filters, block_res_num, first_blocks): """ Build multiple residual blocks Parameters ---------- inputs: Input tensor out_filters: Number of output filters block_res_num: Number of residual blocks first_blocks: Whether the first_blocks. The first_blocks is not implemented by residual blocks but plain blocks """ if block_res_num < 1: raise ValueError('block_conv_num must be >= 1') outputs = inputs if first_blocks: for _ in range(block_res_num): outputs = conv_bn_relu(outputs, out_filters, 3, 1, 1, use_bn=self.use_bn, training=self.training) else: shrink_factor = self.config.get('shrink_factor') for _ in range(block_res_num): outputs = self.build_one_res_block(outputs, out_filters, shrink_factor) return outputs
def build_dense_blocks(self, inputs, out_filters, block_dense_num, growth_rate, first_blocks): """ Build multiple residual blocks Parameters ---------- inputs: Input tensor out_filters: Number of output filters block_dense_num: Number of dense blocks growth_rate: number of channel growth for each dense block first_blocks: Whether is the first_blocks (before first pooling layer). The first_blocks is not implemented by dense blocks but plain conv layers """ if block_dense_num < 1: raise ValueError('block_dense_num must be >= 1') outputs = inputs if first_blocks: for _ in range(block_dense_num): outputs = conv_bn_relu(outputs, out_filters, 3, 1, 1, use_bn=self.use_bn, training=self.training) else: for _ in range(block_dense_num): outputs = self.build_one_dense_block(outputs, growth_rate) outputs = self.conv_block(outputs, out_filters, ksize=1) return outputs
def forward(self, inputs): """ Forward process Parameters ---------- inputs: Input tensor """ # encoder process outputs = inputs encoders = [] for i, filters in enumerate(self.block_filters): first_blocks = i == 0 outputs = self.build_res_blocks(outputs, filters, self.block_res_nums[i], first_blocks) if i != len(self.block_filters) - 1: encoders.append(outputs) outputs = maxpool(outputs, pool_size=2) # decoder process encoder_num = len(encoders) for k, encoder in enumerate(encoders[::-1]): i = encoder_num - k - 1 first_blocks = i == 0 filters = int(encoder.get_shape()[-1]) outputs = conv_bn_relu(outputs, filters, 3, 1, 1, use_bn=self.use_bn, training=self.training) outputs = upsample(outputs, 2, self.interpolation_type) outputs = concat([encoder, outputs]) outputs = self.build_res_blocks(outputs, filters, self.block_res_nums[i], first_blocks) outputs = conv(outputs, self.config.get('output_channel')) outputs = softmax(outputs) return outputs
def build_blocks(self, inputs, out_filters, block_conv_num): """ Build plain conv blocks Parameters ---------- inputs: Input tensor out_filters: Number of output filters block_conv_num: Number of conv layers in block """ use_bn = self.config.get('use_bn') if block_conv_num < 1: raise ValueError('block_conv_num must be >= 1') outputs = inputs for _ in range(block_conv_num): outputs = conv_bn_relu(outputs, out_filters, 3, 1, 1, use_bn=use_bn, training=self.training) return outputs
def build_one_res_block(self, inputs, out_filters, shrink_factor): """ Build one residual conv block Parameters ---------- inputs: Input tensor out_filters: Number of output filters shrink_factor: shrink factor with respect to out_filters for intermediate conv layers """ use_bn = self.config.get('use_bn') in_filters = int(inputs.get_shape()[-1]) shrinked_filters = int(out_filters / shrink_factor) # identity branch: branch_a if in_filters == out_filters: branch_a = inputs else: branch_a = conv_bn_relu(inputs, out_filters, 1, 1, 1, use_bn=use_bn, use_relu=False, training=self.training) # conv branch: branch_b if self.config.get('bottleneck'): branch_b = conv_bn_relu(inputs, shrinked_filters, 1, 1, 1, use_bn=use_bn, training=self.training) branch_b = conv_bn_relu(branch_b, shrinked_filters, 3, 1, 1, use_bn=use_bn, training=self.training) branch_b = conv_bn_relu(branch_b, out_filters, 1, 1, 1, use_bn=use_bn, use_relu=False, training=self.training) else: branch_b = conv_bn_relu(inputs, shrinked_filters, 3, 1, 1, use_bn=use_bn, training=self.training) branch_b = conv_bn_relu(branch_b, out_filters, 3, 1, 1, use_bn=use_bn, use_relu=False, training=self.training) branch = add([branch_a, branch_b]) branch = relu(branch) return branch