예제 #1
0
 def __call__(self, **kwargs):
     name = kwargs.get('name')
     kwargs.pop('name')
     with tf.variable_scope(name):
         loss = self.call(**kwargs)
     summary.scalar(name, loss)
     return loss
예제 #2
0
 def l2_loss(self, tvars=None):
     _l2_loss = 0.0
     weight_decay = self.config['solver']['optimizer'].get(
         'weight_decay', None)
     if weight_decay:
         logging.info(f"add L2 Loss with decay: {weight_decay}")
         with tf.name_scope('l2_loss'):
             tvars = tvars if tvars else tf.trainable_variables()
             tvars = [v for v in tvars if 'bias' not in v.name]
             _l2_loss = weight_decay * tf.add_n(
                 [tf.nn.l2_loss(v) for v in tvars])
             summary_lib.scalar('l2_loss', _l2_loss)
     return _l2_loss
예제 #3
0
    def get_train_hooks(self, labels, logits, alpha=None):
        nclass = self.config['data']['task']['classes']['num']
        metric_tensor = {
            "batch_accuracy": metrics_lib.accuracy(logits, labels),
            'global_step': tf.train.get_or_create_global_step(),
        }
        if nclass > 100:
            logging.info(
                'Too many classes, disable confusion matrix in train: %d' %
                (nclass))
        else:
            metric_tensor['batch_confusion'] = \
                metrics_lib.confusion_matrix(logits, labels, nclass)
        summary_lib.scalar('batch_accuracy', metric_tensor['batch_accuracy'])
        if alpha:
            metric_tensor.update({"alignment": alpha})

        # plot PR curve
        true_label_bool = tf.cast(labels, tf.bool)
        softmax = tf.nn.softmax(logits)
        pr_summary.op(name='pr_curve_train_batch',
                      labels=true_label_bool,
                      predictions=softmax[:, -1],
                      num_thresholds=16,
                      weights=None)

        train_hooks = [
            tf.train.StepCounterHook(every_n_steps=100,
                                     every_n_secs=None,
                                     output_dir=None,
                                     summary_writer=None),
            tf.train.FinalOpsHook(
                final_ops=[tf.train.get_or_create_global_step()],
                final_ops_feed_dict=None),
            tf.train.LoggingTensorHook(tensors=metric_tensor,
                                       every_n_iter=100,
                                       every_n_secs=None,
                                       at_end=False,
                                       formatter=None),
        ]
        return train_hooks