Exemplo n.º 1
0
def get_metric_function(metric, output_shape=None, loss_fn=None):
    """Returns the metric function corresponding to the given metric input.

  Arguments:
      metric: Metric function name or reference.
      output_shape: The shape of the output that this metric
          will be calculated for.
      loss_fn: The loss function used.

  Returns:
      The metric function.
  """
    if metric in ['accuracy', 'acc']:
        if output_shape[-1] == 1 or loss_fn == losses.binary_crossentropy:
            return metrics_module.binary_accuracy  # case: binary accuracy
        elif loss_fn == losses.sparse_categorical_crossentropy:
            # case: categorical accuracy with sparse targets
            return metrics_module.sparse_categorical_accuracy
        return metrics_module.categorical_accuracy  # case: categorical accuracy
    elif metric in ['crossentropy', 'ce']:
        if output_shape[-1] == 1 or loss_fn == losses.binary_crossentropy:
            return metrics_module.binary_crossentropy  # case: binary cross-entropy
        elif loss_fn == losses.sparse_categorical_crossentropy:
            # case: categorical cross-entropy with sparse targets
            return metrics_module.sparse_categorical_crossentropy
        # case: categorical cross-entropy
        return metrics_module.categorical_crossentropy
    return metrics_module.get(metric)
Exemplo n.º 2
0
def get_metric_function(metric, output_shape=None, loss_fn=None):
  """Returns the metric function corresponding to the given metric input.

  Arguments:
      metric: Metric function name or reference.
      output_shape: The shape of the output that this metric
          will be calculated for.
      loss_fn: The loss function used.

  Returns:
      The metric function.
  """
  if metric in ['accuracy', 'acc']:
    if output_shape[-1] == 1 or loss_fn == losses.binary_crossentropy:
      return metrics_module.binary_accuracy  # case: binary accuracy
    elif loss_fn == losses.sparse_categorical_crossentropy:
      # case: categorical accuracy with sparse targets
      return metrics_module.sparse_categorical_accuracy
    return metrics_module.categorical_accuracy  # case: categorical accuracy
  elif metric in ['crossentropy', 'ce']:
    if output_shape[-1] == 1 or loss_fn == losses.binary_crossentropy:
      return metrics_module.binary_crossentropy  # case: binary cross-entropy
    elif loss_fn == losses.sparse_categorical_crossentropy:
      # case: categorical cross-entropy with sparse targets
      return metrics_module.sparse_categorical_crossentropy
    # case: categorical cross-entropy
    return metrics_module.categorical_crossentropy
  return metrics_module.get(metric)
Exemplo n.º 3
0
def get_metric_name(metric, weighted=False):
    """Returns the name corresponding to the given metric input.

  Arguments:
    metric: Metric function name or reference.
    weighted: Boolean indicating if the given metric is weighted.

  Returns:
      The metric name.
  """
    metric_name_prefix = 'weighted_' if weighted else ''
    if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
        if metric in ('accuracy', 'acc'):
            suffix = 'acc'
        elif metric in ('crossentropy', 'ce'):
            suffix = 'ce'
    else:
        metric_fn = metrics_module.get(metric)
        # Get metric name as string
        if hasattr(metric_fn, 'name'):
            suffix = metric_fn.name
        else:
            suffix = metric_fn.__name__
    metric_name = metric_name_prefix + suffix
    return metric_name
Exemplo n.º 4
0
    def _get_metric_object(self, metric, y_t, y_p):
        """Converts user-supplied metric to a `Metric` object.

    Args:
      metric: A string, function, or `Metric` object.
      y_t: Sample of label.
      y_p: Sample of output.

    Returns:
      A `Metric` object.
    """
        if metric is None:
            return None  # Ok to have no metric for an output.

        # Convenience feature for selecting b/t binary, categorical,
        # and sparse categorical.
        if metric not in ['accuracy', 'acc', 'crossentropy', 'ce']:
            metric_obj = metrics_mod.get(metric)
        else:
            y_t_rank = len(y_t.shape.as_list())
            y_p_rank = len(y_p.shape.as_list())
            y_t_last_dim = y_t.shape.as_list()[-1]
            y_p_last_dim = y_p.shape.as_list()[-1]

            is_binary = y_p_last_dim == 1
            is_sparse_categorical = (y_t_rank < y_p_rank
                                     or y_t_last_dim == 1 and y_p_last_dim > 1)

            if metric in ['accuracy', 'acc']:
                if is_binary:
                    metric_obj = metrics_mod.binary_accuracy
                elif is_sparse_categorical:
                    metric_obj = metrics_mod.sparse_categorical_accuracy
                else:
                    metric_obj = metrics_mod.categorical_accuracy
            else:
                if is_binary:
                    metric_obj = metrics_mod.binary_crossentropy
                elif is_sparse_categorical:
                    metric_obj = metrics_mod.sparse_categorical_crossentropy
                else:
                    metric_obj = metrics_mod.categorical_crossentropy

        if isinstance(metric_obj, losses_mod.Loss):
            metric_obj._allow_sum_over_batch_size = True  # pylint: disable=protected-access

        if not isinstance(metric_obj, metrics_mod.Metric):
            if isinstance(metric, six.string_types):
                metric_name = metric
            else:
                metric_name = get_custom_object_name(metric)
                if metric_name is None:
                    raise ValueError(
                        'Metric should be a callable, found: {}'.format(
                            metric))

            metric_obj = metrics_mod.MeanMetricWrapper(metric_obj,
                                                       name=metric_name)

        return metric_obj
Exemplo n.º 5
0
def get_metric_name(metric, weighted=False):
  """Returns the name corresponding to the given metric input.

  Arguments:
    metric: Metric function name or reference.
    weighted: Boolean indicating if the given metric is weighted.

  Returns:
      The metric name.
  """
  metric_name_prefix = 'weighted_' if weighted else ''
  if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
    if metric in ('accuracy', 'acc'):
      suffix = 'acc'
    elif metric in ('crossentropy', 'ce'):
      suffix = 'ce'
  else:
    metric_fn = metrics_module.get(metric)
    # Get metric name as string
    if hasattr(metric_fn, 'name'):
      suffix = metric_fn.name
    else:
      suffix = metric_fn.__name__
  metric_name = metric_name_prefix + suffix
  return metric_name
Exemplo n.º 6
0
def get_base_metric_name(metric, weighted=False):
  """Returns the metric name given the metric function.

  Arguments:
      metric: Metric function name or reference.
      weighted: Boolean indicating if the metric for which we are adding
          names is weighted.

  Returns:
      a metric name.
  """
  metric_name_prefix = 'weighted_' if weighted else ''
  if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
    if metric in ('accuracy', 'acc'):
      suffix = 'acc'
    elif metric in ('crossentropy', 'ce'):
      suffix = 'ce'
    metric_name = metric_name_prefix + suffix
  else:
    metric_fn = metrics_module.get(metric)
    # Get metric name as string
    if hasattr(metric_fn, 'name'):
      metric_name = metric_fn.name
    else:
      metric_name = metric_fn.__name__
    metric_name = metric_name_prefix + metric_name

  return metric_name
Exemplo n.º 7
0
def get_base_metric_name(metric, weighted=False):
    """Returns the metric name given the metric function.

  Arguments:
      metric: Metric function name or reference.
      weighted: Boolean indicating if the metric for which we are adding
          names is weighted.

  Returns:
      a metric name.
  """
    metric_name_prefix = 'weighted_' if weighted else ''
    if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
        if metric in ('accuracy', 'acc'):
            suffix = 'acc'
        elif metric in ('crossentropy', 'ce'):
            suffix = 'ce'
        metric_name = metric_name_prefix + suffix
    else:
        metric_fn = metrics_module.get(metric)
        # Get metric name as string
        if hasattr(metric_fn, 'name'):
            metric_name = metric_fn.name
        else:
            metric_name = metric_fn.__name__
        metric_name = metric_name_prefix + metric_name

    return metric_name
Exemplo n.º 8
0
def get_model(depth, optim, mets):
    """
    Build a U-Net model
    Parameters:
        depth (int): number of training features (i.e. bands)
        optim (tf.keras.optimizer): keras optimizer
        mets (list<tf.keras.metrics): list of keras metrics
    Returns:
        tf.keras.model: compiled U-Net model
    """
	inputs = layers.Input(shape=[None, None, len(BANDS)]) # 256
	encoder0_pool, encoder0 = encoder_block(inputs, 32) # 128
	encoder1_pool, encoder1 = encoder_block(encoder0_pool, 64) # 64
	encoder2_pool, encoder2 = encoder_block(encoder1_pool, 128) # 32
	encoder3_pool, encoder3 = encoder_block(encoder2_pool, 256) # 16
	encoder4_pool, encoder4 = encoder_block(encoder3_pool, 512) # 8
	center = conv_block(encoder4_pool, 1024) # center
	decoder4 = decoder_block(center, encoder4, 512) # 16
	decoder3 = decoder_block(decoder4, encoder3, 256) # 32
	decoder2 = decoder_block(decoder3, encoder2, 128) # 64
	decoder1 = decoder_block(decoder2, encoder1, 64) # 128
	decoder0 = decoder_block(decoder1, encoder0, 32) # 256
	outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(decoder0)

	model = models.Model(inputs=[inputs], outputs=[outputs])

	model.compile(
            optimizer=optim, 
            loss = weighted_bce,
            #loss=losses.get(LOSS),
            metrics=[metrics.get(metric) for metric in mets])

	return model
Exemplo n.º 9
0
    def _get_metric_object(self, metric, y_t, y_p):
        """Converts user-supplied metric to a `Metric` object.

    Arguments:
      metric: A string, function, or `Metric` object.
      y_t: Sample of label.
      y_p: Sample of output.

    Returns:
      A `Metric` object.
    """
        if metric is None:
            return None  # Ok to have no metric for an output.

        # Convenience feature for selecting b/t binary, categorical,
        # and sparse categorical.
        if metric not in ['accuracy', 'acc', 'crossentropy', 'ce']:
            metric_obj = metrics_mod.get(metric)
        else:
            y_t_rank = len(y_t.shape.as_list())
            y_p_rank = len(y_p.shape.as_list())
            y_t_last_dim = y_t.shape.as_list()[-1]
            y_p_last_dim = y_p.shape.as_list()[-1]

            is_binary = y_p_last_dim == 1
            is_sparse_categorical = (y_t_rank < y_p_rank
                                     or y_t_last_dim == 1 and y_p_last_dim > 1)

            if metric in ['accuracy', 'acc']:
                if is_binary:
                    metric_obj = metrics_mod.binary_accuracy
                elif is_sparse_categorical:
                    metric_obj = metrics_mod.sparse_categorical_accuracy
                else:
                    metric_obj = metrics_mod.categorical_accuracy
            else:
                if is_binary:
                    metric_obj = metrics_mod.binary_crossentropy
                elif is_sparse_categorical:
                    metric_obj = metrics_mod.sparse_categorical_crossentropy
                else:
                    metric_obj = metrics_mod.categorical_crossentropy

        if isinstance(metric_obj, losses_mod.Loss):
            metric_obj._allow_sum_over_batch_size = True  # pylint: disable=protected-access

        if not isinstance(metric_obj, metrics_mod.Metric):
            if isinstance(metric, six.string_types):
                metric_name = metric
            elif hasattr(metric, 'name'):
                metric_name = metric.name  # TODO(omalleyt): Is this needed?
            else:
                # function was passed.
                metric_name = metric.__name__

            metric_obj = metrics_mod.MeanMetricWrapper(metric_obj,
                                                       name=metric_name)

        return metric_obj
Exemplo n.º 10
0
    def get_cost(*inputs):
        ctx = get_current_tower_context()
        input_tensors = list(inputs[:nr_inputs])
        target_tensors = list(inputs[nr_inputs:])
        # TODO mapping between target tensors & output tensors

        outputs = model_caller(input_tensors)

        if isinstance(outputs, tf.Tensor):
            outputs = [outputs]
        assert len(outputs) == len(target_tensors), \
            "len({}) != len({})".format(str(outputs), str(target_tensors))
        assert len(outputs) == len(loss), \
            "len({}) != len({})".format(str(outputs), str(loss))

        loss_tensors = []
        for idx, loss_name in enumerate(loss):
            with cached_name_scope('keras_loss', top_level=False):
                loss_fn = keras.losses.get(loss_name)
                curr_loss = loss_fn(target_tensors[idx], outputs[idx])
            curr_loss = tf.reduce_mean(curr_loss, name=loss_name)
            _check_name(curr_loss, loss_name)
            loss_tensors.append(curr_loss)

        loss_reg = regularize_cost_from_collection()
        if loss_reg is not None:
            total_loss = tf.add_n(loss_tensors + [loss_reg],
                                  name=TOTAL_LOSS_NAME)
            add_moving_summary(loss_reg, total_loss, *loss_tensors)
        else:
            total_loss = tf.add_n(loss_tensors, name=TOTAL_LOSS_NAME)
            add_moving_summary(total_loss, *loss_tensors)

        if metrics and (ctx.is_main_training_tower or not ctx.is_training):
            # for list: one metric for each output
            metric_tensors = []
            for oid, metric_name in enumerate(metrics):
                output_tensor = outputs[oid]
                target_tensor = target_tensors[
                    oid]  # TODO may not have the same mapping?
                with cached_name_scope('keras_metric', top_level=False):
                    metric_fn = metrics_module.get(metric_name)
                    metric_tensor = metric_fn(target_tensor, output_tensor)
                metric_tensor = tf.reduce_mean(metric_tensor, name=metric_name)
                _check_name(metric_tensor, metric_name)
                # check name conflict here
                metric_tensors.append(metric_tensor)
            add_moving_summary(*metric_tensors)

        return total_loss
Exemplo n.º 11
0
    def get_cost(*inputs):
        ctx = get_current_tower_context()
        input_tensors = list(inputs[:nr_inputs])
        target_tensors = list(inputs[nr_inputs:])
        # TODO mapping between target tensors & output tensors

        outputs = model_caller(input_tensors)

        if isinstance(outputs, tf.Tensor):
            outputs = [outputs]
        assert len(outputs) == len(target_tensors), \
            "len({}) != len({})".format(str(outputs), str(target_tensors))
        assert len(outputs) == len(loss), \
            "len({}) != len({})".format(str(outputs), str(loss))

        loss_tensors = []
        for idx, loss_name in enumerate(loss):
            with cached_name_scope('keras_loss', top_level=False):
                loss_fn = keras.losses.get(loss_name)
                curr_loss = loss_fn(target_tensors[idx], outputs[idx])
            curr_loss = tf.reduce_mean(curr_loss, name=loss_name)
            _check_name(curr_loss, loss_name)
            loss_tensors.append(curr_loss)

        loss_reg = regularize_cost_from_collection()
        if loss_reg is not None:
            total_loss = tf.add_n(loss_tensors + [loss_reg], name=TOTAL_LOSS_NAME)
            add_moving_summary(loss_reg, total_loss, *loss_tensors)
        else:
            add_moving_summary(*loss_tensors)
            total_loss = tf.add_n(loss_tensors, name=TOTAL_LOSS_NAME)

        if metrics and (ctx.is_main_training_tower or not ctx.is_training):
            # for list: one metric for each output
            metric_tensors = []
            for oid, metric_name in enumerate(metrics):
                output_tensor = outputs[oid]
                target_tensor = target_tensors[oid]  # TODO may not have the same mapping?
                with cached_name_scope('keras_metric', top_level=False):
                    metric_fn = metrics_module.get(metric_name)
                    metric_tensor = metric_fn(target_tensor, output_tensor)
                metric_tensor = tf.reduce_mean(metric_tensor, name=metric_name)
                _check_name(metric_tensor, metric_name)
                # check name conflict here
                metric_tensors.append(metric_tensor)
            add_moving_summary(*metric_tensors)

        return total_loss
Exemplo n.º 12
0
def _get_metrics_info(metric, internal_output_shapes=None, loss_func=None):
    if metric == 'accuracy' or metric == 'acc':
        # custom handling of accuracy
        # (because of class mode duality)
        output_shape = internal_output_shapes
        if output_shape[-1] == 1 or loss_func == losses.binary_crossentropy:
            # case: binary accuracy
            acc_fn = metrics_module.binary_accuracy
        elif loss_func == losses.sparse_categorical_crossentropy:
            # case: categorical accuracy with sparse targets
            acc_fn = metrics_module.sparse_categorical_accuracy
        else:
            acc_fn = metrics_module.categorical_accuracy

        metric_name = 'acc'
        return metric_name, acc_fn
    else:
        metric_fn = metrics_module.get(metric)
        metric_name = metric_fn.__name__
        return metric_name, metric_fn
Exemplo n.º 13
0
def _get_metrics_info(metric, internal_output_shapes=None, loss_func=None):
  if metric == 'accuracy' or metric == 'acc':
    # custom handling of accuracy
    # (because of class mode duality)
    output_shape = internal_output_shapes
    if output_shape[-1] == 1 or loss_func == losses.binary_crossentropy:
      # case: binary accuracy
      acc_fn = metrics_module.binary_accuracy
    elif loss_func == losses.sparse_categorical_crossentropy:
      # case: categorical accuracy with sparse targets
      acc_fn = metrics_module.sparse_categorical_accuracy
    else:
      acc_fn = metrics_module.categorical_accuracy

    metric_name = 'acc'
    return metric_name, acc_fn
  else:
    metric_fn = metrics_module.get(metric)
    metric_name = metric_fn.__name__
    return metric_name, metric_fn