Exemplo n.º 1
0
def augment(ap: AugmentParameters,
            input_lt: lt.LabeledTensor,
            target_lt: lt.LabeledTensor,
            name: str = None) -> Tuple[lt.LabeledTensor, lt.LabeledTensor]:
  """Apply data augmentation to the given input and target tensors.

  Args:
    ap:  AugmentParameters.
    input_lt: An input tensor with canonical axes.
    target_lt: A target tensor with canonical axes.
    name: Optional op name.

  Returns:
    The augmented input and target tensors.
    Both tensors are rotated and flipped, and the input tensor additionally
    has added noise.
  """
  with tf.name_scope(name, 'augment', [input_lt, target_lt]) as scope:
    input_lt = lt.transpose(input_lt, util.CANONICAL_AXIS_ORDER)
    target_lt = lt.transpose(target_lt, util.CANONICAL_AXIS_ORDER)

    input_rc = lt.ReshapeCoder(['z', 'channel', 'mask'], ['channel'])
    input_reshape_lt = input_rc.encode(input_lt)
    target_rc = lt.ReshapeCoder(['z', 'channel', 'mask'], ['channel'])
    target_reshape_lt = target_rc.encode(target_lt)

    merge_lt = lt.concat([input_reshape_lt, target_reshape_lt], 'channel')
    flip_lt = _random_flip_and_rotation(merge_lt)

    num_reshaped_input_channels = len(input_reshape_lt.axes['channel'])
    input_lt = input_rc.decode(flip_lt[:, :, :, :num_reshaped_input_channels])
    target_lt = target_rc.decode(flip_lt[:, :, :, num_reshaped_input_channels:])

    # Select out the input signal channel and add noise to it.
    input_pixels_lt = lt.select(input_lt, {'mask': util.slice_1(False)})
    rc = lt.ReshapeCoder(['z', 'channel', 'mask'], ['channel'])
    input_pixels_lt = rc.decode(
        corrupt(ap.offset_standard_deviation, ap.multiplier_standard_deviation,
                ap.noise_standard_deviation, rc.encode(input_pixels_lt)))
    input_lt = lt.concat(
        [input_pixels_lt,
         lt.select(input_lt, {
             'mask': util.slice_1(True)
         })],
        'mask',
        name=scope + 'input')

    target_lt = lt.identity(target_lt, name=scope + 'target')

    return input_lt, target_lt
Exemplo n.º 2
0
def itemize_losses(
    loss: Callable,
    target_lt: lt.LabeledTensor,
    predict_lt: lt.LabeledTensor,
    name: str = None,
) -> Dict[str, lt.LabeledTensor]:
    """Create itemized losses for each prediction task.

  Creates named losses for each prediction task.

  Args:
    loss: Loss function to use.
      Arguments should be (target, mask, prediction, name).
    target_lt: Tensor with ground truth values, in canonical format.
    predict_lt: Tensor with predicted logits, in canonical prediction format.
    name: Optional op name.

  Returns:
    A dictionary mapping loss names to loss tensors.
  """
    with tf.name_scope(name, 'itemize_losses',
                       [target_lt, predict_lt]) as scope:
        loss_lts = {}
        axes = target_lt.axes
        for z in axes['z'].labels:
            for channel in axes['channel'].labels:
                target_selection_lt = lt.select(
                    target_lt, {
                        'z': util.slice_1(z),
                        'channel': util.slice_1(channel)
                    })
                predict_selection_lt = lt.select(
                    predict_lt, {
                        'z': util.slice_1(z),
                        'channel': util.slice_1(channel)
                    })
                tag = '%s/%s' % (z, channel)
                loss_lt = add_loss(loss,
                                   target_selection_lt,
                                   predict_selection_lt,
                                   name=scope + tag)
                tf.summary.scalar(name=os.path.join('loss', tag),
                                  tensor=loss_lt.tensor)
                tf.summary.histogram(name=os.path.join('loss', tag, 'target'),
                                     values=target_selection_lt.tensor)
                tf.summary.histogram(name=os.path.join('loss', tag, 'predict'),
                                     values=predict_selection_lt.tensor)
                loss_lts[tag] = loss_lt

        return loss_lts
Exemplo n.º 3
0
    def setUp(self):
        super(CorruptTest, self).setUp()

        self.signal_lt = lt.select(self.input_lt,
                                   {'mask': util.slice_1(False)})
        rc = lt.ReshapeCoder(['z', 'channel', 'mask'], ['channel'])
        self.corrupt_coded_lt = augment.corrupt(0.1, 0.05, 0.1,
                                                rc.encode(self.signal_lt))
        self.corrupt_lt = rc.decode(self.corrupt_coded_lt)
Exemplo n.º 4
0
def add_neurite_confocal(
    target_lt: lt.LabeledTensor,
    name: str = None,
) -> lt.LabeledTensor:
    """Add the synthetic target neurite channel.

  Args:
    target_lt: Input target tensor.
    name: Optional op name.

  Returns:
    The target tensor with a new 'NEURITE_CONFOCAL' channel, which is the
    average of the NFH_CONFOCAL and MAP2_CONFOCAL channels.
  """
    with tf.name_scope(name, 'add_neurite_confocal', [target_lt]) as scope:
        target_lt = lt.transpose(target_lt, util.CANONICAL_AXIS_ORDER)
        neurite_lts = []
        for m in [False, True]:
            nfh_lt = lt.select(target_lt, {
                'channel': util.slice_1('NFH_CONFOCAL'),
                'mask': util.slice_1(m)
            })
            nfh_lt = lt.reshape(nfh_lt, ['channel'],
                                [('channel', ['NEURITE_CONFOCAL'])])
            map2_lt = lt.select(target_lt, {
                'channel': util.slice_1('MAP2_CONFOCAL'),
                'mask': util.slice_1(m)
            })
            map2_lt = lt.reshape(map2_lt, ['channel'],
                                 [('channel', ['NEURITE_CONFOCAL'])])

            if not m:
                # This corresponds to a logical OR.
                neurite_lts.append((nfh_lt + map2_lt) / 2.0)
            else:
                # The combined mask is the geometric mean of the original masks.
                # This corresponds to a logical AND.
                neurite_lts.append(lt.pow(nfh_lt * map2_lt, 0.5))

        neurite_lt = lt.concat(neurite_lts, 'mask')
        return lt.concat([target_lt, neurite_lt], 'channel', name=scope)