Exemplo n.º 1
0
    def setUp(self):
        super(AugmentTest, self).setUp()

        ap = augment.AugmentParameters(0.1, 0.05, 0.1)

        self.input_augment_lt, self.target_augment_lt = augment.augment(
            ap, self.input_lt, self.target_lt)
Exemplo n.º 2
0
    def setUp(self):
        super(Base, self).setUp()

        rtp = data_provider.ReadTableParameters(
            shard_paths=[self.recordio_path()],
            is_recordio=True,
            bp=util.BatchParameters(2, num_threads=1, capacity=2),
            is_deterministic=True,
            pad_size=51,
            crop_size=110 + 24)
        self.dp = data_provider.DataParameters(
            rtp,
            input_z_values=self.input_z_values,
            input_channel_values=self.input_channel_values,
            target_z_values=self.target_z_values,
            target_channel_values=self.target_channel_values)

        self.ap = augment.AugmentParameters(offset_standard_deviation=0.1,
                                            multiplier_standard_deviation=0.05,
                                            noise_standard_deviation=0.1)

        self.extract_patch_size = 80
        self.stride = 8
        self.stitch_patch_size = 8

        self.bp = util.BatchParameters(4, 1, 4)

        self.core_model = functools.partial(fovea.core, 50)
        self.add_head = functools.partial(model_util.add_head,
                                          is_residual_conv=True)

        self.num_classes = 16
        self.shuffle = False
Exemplo n.º 3
0
def parameters() -> controller.GetInputTargetAndPredictedParameters:
    """Creates the network parameters for the given inputs and flags.

    Returns:
      A GetInputTargetAndPredictedParameters containing network parameters for the
      given mode, metric, and other flags.
    """
    if FLAGS.metric == METRIC_LOSS:
        stride = FLAGS.loss_patch_stride
        shuffle = True
    else:
        if FLAGS.model == MODEL_CONCORDANCE:
            stride = CONCORDANCE_STITCH_STRIDE
        else:
            raise NotImplementedError('Unsupported model: %s' % FLAGS.model)
        # Shuffling breaks stitching.
        shuffle = False

    if FLAGS.mode == MODE_TRAIN:
        is_train = True
    else:
        is_train = False

    if FLAGS.model == MODEL_CONCORDANCE:
        core_model = functools.partial(concordance.core, base_depth=FLAGS.base_depth, num_gpus=FLAGS.num_gpus)
        add_head = functools.partial(model_util.add_head, is_residual_conv=True)
        extract_patch_size = CONCORDANCE_EXTRACT_PATCH_SIZE
        stitch_patch_size = CONCORDANCE_STITCH_PATCH_SIZE
    else:
        raise NotImplementedError('Unsupported model: %s' % FLAGS.model)

    dp = data_parameters()

    if shuffle:
        preprocess_num_threads = FLAGS.preprocess_shuffle_batch_num_threads
    else:
        # Thread racing is an additional source of shuffling, so we can only
        # use 1 thread per queue.
        preprocess_num_threads = 1
    if is_train or FLAGS.metric == METRIC_JITTER_STITCH:
        ap = augment.AugmentParameters(FLAGS.augment_offset_std,
                                       FLAGS.augment_multiplier_std,
                                       FLAGS.augment_noise_std)
    else:
        ap = None

    if FLAGS.metric == METRIC_INFER_FULL:
        bp = None
    else:
        bp = util.BatchParameters(FLAGS.preprocess_batch_size,
                                  preprocess_num_threads,
                                  FLAGS.preprocess_batch_capacity)

    if FLAGS.loss == LOSS_CROSS_ENTROPY:
        loss = util.softmax_cross_entropy
    elif FLAGS.loss == LOSS_RANKED_PROBABILITY_SCORE:
        loss = util.ranked_probability_score
    else:
        logging.fatal('Invalid loss: %s', FLAGS.loss)

    return controller.GetInputTargetAndPredictedParameters(
        dp, ap, extract_patch_size, stride, stitch_patch_size, bp, core_model,
        add_head, shuffle, NUM_CLASSES, loss, is_train)