Пример #1
0
    def _build_graph(self, inputs):
        inp, label = inputs
        is_training = get_current_tower_context().is_training

        fw, fa = get_dorefa(self.bitw, self.bita)

        def binarize_weight(v):
            name = v.op.name
            if not (name.endswith('W') or name.endswith('b')):
                logger.info("Not quantizing {}".format(name))
                return v
            elif not self.quant_ends and 'conv0' in name:
                logger.info("Not quantizing {}".format(name))
                return v
            elif not self.quant_ends and 'last_linear' in name:
                logger.info("Not quantizing {}".format(name))
                return v
            elif not self.quant_ends and (self.net_fn == fcn1_net or self.net_fn == fcn2_net) and 'linear0' in name:
                logger.info("Not quantizing {}".format(name))
                return v
            else:
                logger.info("Quantizing weight {}".format(name))
                return fw(v)

        def nonlin(x, name="activate"):
            if self.bita == 32:
                return fa(tf.nn.relu(BNWithTrackedMults(x)))
            else:
                return fa(tf.clip_by_value(BNWithTrackedMults(x), 0.0, 1.0))

        with remap_variables(binarize_weight), \
                argscope([FullyConnectedWithTrackedMults], network_complexity=self.network_complexity), \
                argscope([Conv2DWithTrackedMults], network_complexity=self.network_complexity), \
                argscope([BNReLUWithTrackedMults], network_complexity=self.network_complexity), \
                argscope([BNWithTrackedMults], network_complexity=self.network_complexity), \
                argscope(BatchNorm, decay=0.9, epsilon=1e-4):
            l = self.net_fn(inp, nonlin, self.n_context)
            logits = FullyConnectedWithTrackedMults('last_linear', l, out_dim=self.n_spks, nl=tf.identity)

        prob = tf.nn.softmax(logits, name='output')

        # used for validation accuracy of utterance
        identity_guesses = flatten(tf.argmax(prob, axis=1))
        uniq_identities, _, count = tf.unique_with_counts(identity_guesses)
        idx_to_identity_with_most_votes = tf.argmax(count)
        chosen_identity = tf.gather(uniq_identities, idx_to_identity_with_most_votes)
        wrong = tf.expand_dims(tf.not_equal(chosen_identity, tf.cast(label[0], tf.int64)), axis=0, name='utt-wrong')

        cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
        cost = tf.reduce_mean(cost, name='cross_entropy_loss')
        add_moving_summary(cost)

        wrong = prediction_incorrect(logits, label, 1, name='wrong-top1')
        add_moving_summary(tf.reduce_mean(wrong, name='train-error-top1'))

        with tf.name_scope('original-weight-summaries'):
            add_param_summary(('.*/W', ['rms', 'histogram']))
            add_param_summary(('.*/b', ['rms', 'histogram']))

        with tf.name_scope('activation-summaries'):
            def fn(name):
                return (name.endswith('output') or name.endswith('output:0')) and "Inference" not in name and 'quantized' not in name
            tensors = get_tensors_from_graph(tf.get_default_graph(), fn) 
            logger.info("Adding activation tensors to summary: {}".format(tensors))
            for tensor in tensors:
                add_tensor_summary(tensor, ['rms', 'histogram'])

        wd_w = tf.train.exponential_decay(0.0002, get_global_step_var(), 480000, 0.2, True)
        wd_cost = tf.multiply(wd_w, regularize_cost('.*/W', tf.nn.l2_loss), name='wd_cost')
        add_moving_summary(wd_cost)
        self.cost = tf.add_n([cost, wd_cost], name='cost')

        tf.constant([self.network_complexity['mults']], name='TotalMults')
        tf.constant([self.network_complexity['weights']], name='TotalWeights')
        logger.info("Parameter count: {}".format(self.network_complexity))
Пример #2
0
    def _build_graph(self, inputs):
        inp, label = inputs

        with argscope([Conv2DWithTrackedMults, BatchNorm], data_format='NHWC'), \
                argscope([Conv2DWithTrackedMults], nl=tf.identity, \
                         W_init=variance_scaling_initializer(mode='FAN_OUT')), \
                argscope([DepthwiseSeparableConvWithTrackedMults, \
                            Conv2DWithTrackedMults, \
                            FullyConnectedWithTrackedMults], \
                            network_complexity=self.network_complexity):
            l = self.net_fn(inp, self.batchnorm, self.n_context)

        logits = FullyConnectedWithTrackedMults(
            'last_linear',
            l,
            out_dim=self.n_spks,
            nl=tf.identity,
            network_complexity=self.network_complexity)

        prob = tf.nn.softmax(logits, name='output')

        # used for validation accuracy of utterance
        identity_guesses = flatten(tf.argmax(prob, axis=1))
        uniq_identities, _, count = tf.unique_with_counts(identity_guesses)
        idx_to_identity_with_most_votes = tf.argmax(count)
        chosen_identity = tf.gather(uniq_identities,
                                    idx_to_identity_with_most_votes)
        wrong = tf.expand_dims(tf.not_equal(chosen_identity,
                                            tf.cast(label[0], tf.int64)),
                               axis=0,
                               name='utt-wrong')

        cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                              labels=label)
        cost = tf.reduce_mean(cost, name='cross_entropy_loss')
        add_moving_summary(cost)

        wrong = prediction_incorrect(logits, label, 1, name='wrong-top1')
        add_moving_summary(tf.reduce_mean(wrong, name='train-error-top1'))

        with tf.name_scope('original-weight-summaries'):
            add_param_summary(('.*/W', ['rms', 'histogram']))
            add_param_summary(('.*/b', ['rms', 'histogram']))

        with tf.name_scope('activation-summaries'):

            def fn(name):
                return (name.endswith('output') or name.endswith('output:0'))

            tensors = get_tensors_from_graph(tf.get_default_graph(), fn)
            print("Adding activation tensors to summary:", tensors)
            for tensor in tensors:
                add_tensor_summary(tensor, ['rms', 'histogram'])

        if self.regularize:
            # decreasing regularization on all W of fc layers
            wd_w = tf.train.exponential_decay(0.0002, get_global_step_var(),
                                              480000, 0.2, True)
            wd_cost = tf.multiply(wd_w,
                                  regularize_cost('.*/W', tf.nn.l2_loss),
                                  name='wd_cost')
            add_moving_summary(wd_cost)
            self.cost = tf.add_n([cost, wd_cost], name='cost')
        else:
            self.cost = tf.identity(cost, name='cost')

        tf.constant([self.network_complexity['mults']], name='TotalMults')
        tf.constant([self.network_complexity['weights']], name='TotalWeights')
        logger.info("Parameter count: {}".format(self.network_complexity))
Пример #3
0
    def _build_graph(self, inputs):
        inp, label = inputs

        if self.twn:
            old_get_variable = tf.get_variable

            def new_get_variable(name, shape=None, **kwargs):
                v = old_get_variable(name, shape, **kwargs)
                if name is 'W':
                    logger.info("Ternarizing weight {}".format(v.op.name))
                    return tw_ternarize(v, 0.05)
                else:
                    logger.info("NOT ternarizing weight {}".format(v.op.name))
                    return v

            tf.get_variable = new_get_variable

        with argscope([Conv2DWithTrackedMults, BatchNorm], data_format='NHWC'), \
                argscope([Conv2DWithTrackedMults], nl=tf.identity, use_bias=False, kernel_shape=3,
                         W_init=variance_scaling_initializer(mode='FAN_OUT')), \
                argscope(Conv2DWithTrackedMults, network_complexity=self.network_complexity), \
                argscope(FullyConnectedWithTrackedMults, network_complexity=self.network_complexity), \
                argscope(DepthwiseSeparableConvWithTrackedMults, network_complexity=self.network_complexity):
            l = self.net_fn(inp)

        logits = FullyConnectedWithTrackedMults(
            'last_linear',
            l,
            out_dim=self.n_spks,
            nl=tf.identity,
            network_complexity=self.network_complexity)

        if self.twn:
            tf.get_variable = old_get_variable

        prob = tf.nn.softmax(logits, name='output')

        # used for validation accuracy of utterance
        identity_guesses = flatten(tf.argmax(prob, axis=1))
        uniq_identities, _, count = tf.unique_with_counts(identity_guesses)
        idx_to_identity_with_most_votes = tf.argmax(count)
        chosen_identity = tf.gather(uniq_identities,
                                    idx_to_identity_with_most_votes)
        wrong = tf.expand_dims(tf.not_equal(chosen_identity,
                                            tf.cast(label[0], tf.int64)),
                               axis=0,
                               name='utt-wrong')

        cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                              labels=label)
        cost = tf.reduce_mean(cost, name='cross_entropy_loss')

        wrong = prediction_incorrect(logits, label, 1, name='wrong-top1')
        add_moving_summary(tf.reduce_mean(wrong, name='train-error-top1'))

        # weight decay on all W of fc layers
        wd_w = tf.train.exponential_decay(0.0002, get_global_step_var(),
                                          480000, 0.2, True)
        wd_cost = tf.multiply(wd_w,
                              regularize_cost('.*/W', tf.nn.l2_loss),
                              name='wd_cost')
        add_moving_summary(cost, wd_cost)

        with tf.name_scope('original-weight-summaries'):
            add_param_summary(('.*/W', ['rms', 'histogram']))
            add_param_summary(('.*/b', ['rms', 'histogram']))

        if self.twn:

            def fn(name):
                return ('ternarized_W' in name and 'InferenceTower' in name)

            tensors = get_tensors_from_graph(tf.get_default_graph(), fn)
            self.ternary_weight_tensors = tensors
            print("yolo", self.ternary_weight_tensors)
            with tf.name_scope('scalar-factor-summaries'):
                add_param_summary(('.*/Wp', ['scalar']))
                add_param_summary(('.*/Wn', ['scalar']))

        with tf.name_scope('activation-summaries'):

            def fn(name):
                return (name.endswith('output') or
                        name.endswith('output:0')) and 'InferenceTower' in name

            tensors = get_tensors_from_graph(tf.get_default_graph(), fn)
            for tensor in tensors:
                add_tensor_summary(tensor, ['rms', 'histogram'])

        self.cost = tf.add_n([cost, wd_cost], name='cost')
        tf.constant([self.network_complexity['mults']], name='TotalMults')
        tf.constant([self.network_complexity['weights']], name='TotalWeights')
        logger.info("Parameter count: {}".format(self.network_complexity))