Esempio n. 1
0
    def _build_graph(self, inputs):
        image, label = inputs
        image = image_preprocess(image, bgr=True)
        image = tf.transpose(image, [0, 3, 1, 2])

        def bottleneck_se(l, ch_out, stride, preact):
            l, shortcut = apply_preactivation(l, preact)
            l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU)
            l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU)
            l = Conv2D('conv3', l, ch_out * 4, 1)

            squeeze = GlobalAvgPooling('gap', l)
            squeeze = FullyConnected('fc1',
                                     squeeze,
                                     ch_out // 4,
                                     nl=tf.identity)
            squeeze = FullyConnected('fc2',
                                     squeeze,
                                     ch_out * 4,
                                     nl=tf.nn.sigmoid)
            l = l * tf.reshape(squeeze, [-1, ch_out * 4, 1, 1])
            return l + resnet_shortcut(shortcut, ch_out * 4, stride)

        defs = RESNET_CONFIG[DEPTH]

        with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm],
                      data_format='NCHW'):
            logits = resnet_backbone(image, defs, bottleneck_se)

        loss = compute_loss_and_error(logits, label)
        wd_loss = regularize_cost('.*/W',
                                  l2_regularizer(1e-4),
                                  name='l2_regularize_loss')
        add_moving_summary(loss, wd_loss)
        self.cost = tf.add_n([loss, wd_loss], name='cost')
Esempio n. 2
0
    def _build_graph(self, inputs):
        image, label = inputs
        image = image_preprocess(image, bgr=True)
        image = tf.transpose(image, [0, 3, 1, 2])

        cfg = {
            18: ([2, 2, 2, 2], preresnet_basicblock),
            34: ([3, 4, 6, 3], preresnet_basicblock),
        }
        defs, block_func = cfg[DEPTH]

        with argscope(Conv2D, nl=tf.identity, use_bias=False,
                      W_init=variance_scaling_initializer(mode='FAN_OUT')), \
                argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format='NCHW'):
            convmaps = (LinearWrap(image)
                        .Conv2D('conv0', 64, 7, stride=2, nl=BNReLU)
                        .MaxPooling('pool0', shape=3, stride=2, padding='SAME')
                        .apply(preresnet_group, 'group0', block_func, 64, defs[0], 1)
                        .apply(preresnet_group, 'group1', block_func, 128, defs[1], 2)
                        .apply(preresnet_group, 'group2', block_func, 256, defs[2], 2)
                        .apply(preresnet_group, 'group3new', block_func, 512, defs[3], 1)())
            print(convmaps)
            logits = (LinearWrap(convmaps)
                      .GlobalAvgPooling('gap')
                      .FullyConnected('linearnew', 1000, nl=tf.identity)())

        loss = compute_loss_and_error(logits, label)
        wd_cost = regularize_cost('.*/W', l2_regularizer(1e-4), name='l2_regularize_loss')
        add_moving_summary(loss, wd_cost)
        self.cost = tf.add_n([loss, wd_cost], name='cost')
Esempio n. 3
0
    def _build_graph(self, inputs):
        image, label = inputs
        image = image_preprocess(image, bgr=True)

        if self.data_format == 'NCHW':
            image = tf.transpose(image, [0, 3, 1, 2])
        defs, block_func = RESNET_CONFIG[DEPTH]

        with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format=self.data_format):
            logits = resnet_backbone(image, defs, block_func)

        loss = compute_loss_and_error(logits, label)

        wd_loss = regularize_cost('.*/W', l2_regularizer(1e-4), name='l2_regularize_loss')
        add_moving_summary(loss, wd_loss)
        self.cost = tf.add_n([loss, wd_loss], name='cost')
Esempio n. 4
0
    def _build_graph(self, inputs):
        image, label = inputs
        # It should actually use bgr=True here, but for compatibility with
        # pretrained models, we keep the wrong version.
        image = image_preprocess(image, bgr=False)

        if self.data_format == 'NCHW':
            image = tf.transpose(image, [0, 3, 1, 2])
        defs, block_func = RESNET_CONFIG[DEPTH]

        with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format=self.data_format):
            logits = resnet_backbone(image, defs, block_func)

        loss = compute_loss_and_error(logits, label)

        wd_loss = regularize_cost('.*/W', l2_regularizer(1e-4), name='l2_regularize_loss')
        add_moving_summary(loss, wd_loss)
        self.cost = tf.add_n([loss, wd_loss], name='cost')