Esempio n. 1
0
    def _build_graph(self, inputs):
        image, label = inputs
        image = image / 255.0

        fw, fa, fg = get_dorefa(BITW, BITA, BITG)

        old_get_variable = tf.get_variable

        # monkey-patch tf.get_variable to apply fw
        def new_get_variable(name, shape=None, **kwargs):
            v = old_get_variable(name, shape, **kwargs)
            # don't binarize first and last layer
            if name != 'W' or 'conv0' in v.op.name or 'fct' in v.op.name:
                return v
            else:
                logger.info("Binarizing weight {}".format(v.op.name))
                return fw(v)

        def nonlin(x):
            if BITA == 32:
                return tf.nn.relu(x)  # still use relu for 32bit cases
            return tf.clip_by_value(x, 0.0, 1.0)

        def activate(x):
            return fa(nonlin(x))

        with replace_get_variable(new_get_variable), \
                argscope(BatchNorm, decay=0.9, epsilon=1e-4), \
                argscope([Conv2D, FullyConnected], use_bias=False, nl=tf.identity):
            logits = (LinearWrap(image).Conv2D(
                'conv0', 96, 12, stride=4,
                padding='VALID').apply(activate).Conv2D(
                    'conv1', 256, 5, padding='SAME',
                    split=2).apply(fg).BatchNorm('bn1').MaxPooling(
                        'pool1', 3, 2, padding='SAME').apply(activate).Conv2D(
                            'conv2', 384,
                            3).apply(fg).BatchNorm('bn2').MaxPooling(
                                'pool2', 3, 2,
                                padding='SAME').apply(activate).Conv2D(
                                    'conv3', 384, 3, split=2).apply(fg).
                      BatchNorm('bn3').apply(activate).Conv2D(
                          'conv4', 256, 3,
                          split=2).apply(fg).BatchNorm('bn4').MaxPooling(
                              'pool4', 3, 2,
                              padding='VALID').apply(activate).FullyConnected(
                                  'fc0', 4096).apply(fg).
                      BatchNorm('bnfc0').apply(activate).FullyConnected(
                          'fc1', 4096).apply(fg).BatchNorm('bnfc1').apply(
                              nonlin).FullyConnected('fct',
                                                     1000,
                                                     use_bias=True)())

        prob = tf.nn.softmax(logits, name='output')

        cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                              labels=label)
        cost = tf.reduce_mean(cost, name='cross_entropy_loss')

        wrong = prediction_incorrect(logits, label, 1, name='wrong-top1')
        add_moving_summary(tf.reduce_mean(wrong, name='train-error-top1'))
        wrong = prediction_incorrect(logits, label, 5, name='wrong-top5')
        add_moving_summary(tf.reduce_mean(wrong, name='train-error-top5'))

        # weight decay on all W of fc layers
        wd_cost = regularize_cost('fc.*/W',
                                  l2_regularizer(5e-6),
                                  name='regularize_cost')

        add_param_summary(('.*/W', ['histogram', 'rms']))
        self.cost = tf.add_n([cost, wd_cost], name='cost')
        add_moving_summary(cost, wd_cost, self.cost)
Esempio n. 2
0
    def _build_graph(self, inputs):
        image, label = inputs
        is_training = get_current_tower_context().is_training

        fw, fa, fg = get_dorefa(BITW, BITA, BITG)

        old_get_variable = tf.get_variable

        # monkey-patch tf.get_variable to apply fw
        def new_get_variable(name, shape=None, **kwargs):
            v = old_get_variable(name, shape, **kwargs)
            # don't binarize first and last layer
            if name != 'W' or 'conv0' in v.op.name or 'fc' in v.op.name:
                return v
            else:
                logger.info("Binarizing weight {}".format(v.op.name))
                return fw(v)

        def cabs(x):
            return tf.minimum(1.0, tf.abs(x), name='cabs')

        def activate(x):
            return fa(cabs(x))

        image = image / 256.0

        with replace_get_variable(new_get_variable), \
                argscope(BatchNorm, decay=0.9, epsilon=1e-4), \
                argscope(Conv2D, use_bias=False, nl=tf.identity):
            logits = (
                LinearWrap(image).Conv2D('conv0',
                                         48,
                                         5,
                                         padding='VALID',
                                         use_bias=True).MaxPooling(
                                             'pool0', 2,
                                             padding='SAME').apply(activate)
                # 18
                .Conv2D('conv1', 64, 3, padding='SAME').apply(fg).BatchNorm(
                    'bn1').apply(activate).Conv2D(
                        'conv2', 64, 3,
                        padding='SAME').apply(fg).BatchNorm('bn2').MaxPooling(
                            'pool1', 2, padding='SAME').apply(activate)
                # 9
                .Conv2D(
                    'conv3', 128, 3,
                    padding='VALID').apply(fg).BatchNorm('bn3').apply(activate)
                # 7
                .Conv2D('conv4', 128, 3, padding='SAME').apply(fg).
                BatchNorm('bn4').apply(activate).Conv2D(
                    'conv5', 128, 3,
                    padding='VALID').apply(fg).BatchNorm('bn5').apply(activate)
                # 5
                .tf.nn.dropout(0.5 if is_training else 1.0).Conv2D(
                    'conv6', 512, 5, padding='VALID').apply(fg).BatchNorm(
                        'bn6').apply(cabs).FullyConnected('fc1',
                                                          10,
                                                          nl=tf.identity)())
        prob = tf.nn.softmax(logits, name='output')

        # compute the number of failed samples
        wrong = prediction_incorrect(logits, label)
        # monitor training error
        add_moving_summary(tf.reduce_mean(wrong, name='train_error'))

        cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                              labels=label)
        cost = tf.reduce_mean(cost, name='cross_entropy_loss')
        # weight decay on all W of fc layers
        wd_cost = regularize_cost('fc.*/W', l2_regularizer(1e-7))

        add_param_summary(('.*/W', ['histogram', 'rms']))
        self.cost = tf.add_n([cost, wd_cost], name='cost')
        add_moving_summary(cost, wd_cost, self.cost)
Esempio n. 3
0
    def _build_graph(self, input_vars):
        image, label = input_vars
        image = image / 256.0

        fw, fa, fg = get_dorefa(BITW, BITA, BITG)
        old_get_variable = tf.get_variable

        def new_get_variable(name, shape=None, **kwargs):
            v = old_get_variable(name, shape, **kwargs)
            # don't binarize first and last layer
            if name != 'W' or 'conv1' in v.op.name or 'fct' in v.op.name:
                return v
            else:
                logger.info("Binarizing weight {}".format(v.op.name))
                return fw(v)

        def nonlin(x):
            return tf.clip_by_value(x, 0.0, 1.0)

        def activate(x):
            return fa(nonlin(x))

        def resblock(x, channel, stride):
            def get_stem_full(x):
                return (LinearWrap(x).Conv2D(
                    'c3x3a', channel,
                    3).BatchNorm('stembn').apply(activate).Conv2D(
                        'c3x3b', channel, 3)())

            channel_mismatch = channel != x.get_shape().as_list()[3]
            if stride != 1 or channel_mismatch or 'pool1' in x.name:
                # handling pool1 is to work around an architecture bug in our model
                if stride != 1 or 'pool1' in x.name:
                    x = AvgPooling('pool', x, stride, stride)
                x = BatchNorm('bn', x)
                x = activate(x)
                shortcut = Conv2D('shortcut', x, channel, 1)
                stem = get_stem_full(x)
            else:
                shortcut = x
                x = BatchNorm('bn', x)
                x = activate(x)
                stem = get_stem_full(x)
            return shortcut + stem

        def group(x, name, channel, nr_block, stride):
            with tf.variable_scope(name + 'blk1'):
                x = resblock(x, channel, stride)
            for i in range(2, nr_block + 1):
                with tf.variable_scope(name + 'blk{}'.format(i)):
                    x = resblock(x, channel, 1)
            return x

        with replace_get_variable(new_get_variable), \
                argscope(BatchNorm, decay=0.9, epsilon=1e-4), \
                argscope(Conv2D, use_bias=False, nl=tf.identity):
            logits = (
                LinearWrap(image)
                # use explicit padding here, because our training framework has
                # different padding mechanisms from TensorFlow
                .tf.pad([[0, 0], [3, 2], [3, 2], [0, 0]]).Conv2D(
                    'conv1', 64, 7, stride=2, padding='VALID',
                    use_bias=True).tf.pad(
                        [[0, 0], [1, 1], [1, 1], [0, 0]],
                        'SYMMETRIC').MaxPooling(
                            'pool1',
                            3, 2, padding='VALID').apply(
                                group, 'conv2', 64,
                                2, 1).apply(group, 'conv3', 128, 2, 2).apply(
                                    group, 'conv4', 256, 2,
                                    2).apply(group, 'conv5', 512, 2,
                                             2).BatchNorm('lastbn').
                apply(nonlin).GlobalAvgPooling('gap').tf.multiply(
                    49)  # this is due to a bug in our model design
                .FullyConnected('fct', 1000)())
        prob = tf.nn.softmax(logits, name='output')
        wrong = prediction_incorrect(logits, label, 1, name='wrong-top1')
        wrong = prediction_incorrect(logits, label, 5, name='wrong-top5')