コード例 #1
0
    def build_inputs(self, params, input_context=None):
        """Builds classification input."""

        ignore_label = self.task_config.losses.ignore_label

        decoder = segmentation_input.Decoder()
        parser = segmentation_input.Parser(
            output_size=params.output_size,
            train_on_crops=params.train_on_crops,
            ignore_label=ignore_label,
            resize_eval_groundtruth=params.resize_eval_groundtruth,
            groundtruth_padded_size=params.groundtruth_padded_size,
            aug_scale_min=params.aug_scale_min,
            aug_scale_max=params.aug_scale_max,
            aug_rand_hflip=params.aug_rand_hflip,
            dtype=params.dtype)

        reader = input_reader.InputReader(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))

        dataset = reader.read(input_context=input_context)

        return dataset
コード例 #2
0
ファイル: retinanet.py プロジェクト: yghlc/models
    def build_inputs(self, params, input_context=None):
        """Build input dataset."""
        decoder_cfg = params.decoder.get()
        if params.decoder.type == 'simple_decoder':
            decoder = tf_example_decoder.TfExampleDecoder(
                regenerate_source_id=decoder_cfg.regenerate_source_id)
        elif params.decoder.type == 'label_map_decoder':
            decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
                label_map=decoder_cfg.label_map,
                regenerate_source_id=decoder_cfg.regenerate_source_id)
        else:
            raise ValueError('Unknown decoder type: {}!'.format(
                params.decoder.type))
        decoder_cfg = params.decoder.get()
        if params.decoder.type == 'simple_decoder':
            decoder = tf_example_decoder.TfExampleDecoder(
                regenerate_source_id=decoder_cfg.regenerate_source_id)
        elif params.decoder.type == 'label_map_decoder':
            decoder = tf_example_decoder.TfExampleDecoderLabelMap(
                label_map=decoder_cfg.label_map,
                regenerate_source_id=decoder_cfg.regenerate_source_id)
        else:
            raise ValueError('Unknown decoder type: {}!'.format(
                params.decoder.type))
        parser = retinanet_input.Parser(
            output_size=self.task_config.model.input_size[:2],
            min_level=self.task_config.model.min_level,
            max_level=self.task_config.model.max_level,
            num_scales=self.task_config.model.anchor.num_scales,
            aspect_ratios=self.task_config.model.anchor.aspect_ratios,
            anchor_size=self.task_config.model.anchor.anchor_size,
            dtype=params.dtype,
            match_threshold=params.parser.match_threshold,
            unmatched_threshold=params.parser.unmatched_threshold,
            aug_rand_hflip=params.parser.aug_rand_hflip,
            aug_scale_min=params.parser.aug_scale_min,
            aug_scale_max=params.parser.aug_scale_max,
            skip_crowd_during_training=params.parser.
            skip_crowd_during_training,
            max_num_instances=params.parser.max_num_instances)

        reader = input_reader.InputReader(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))
        dataset = reader.read(input_context=input_context)

        return dataset
コード例 #3
0
    def build_inputs(self, params, input_context=None):
        """Builds classification input."""

        num_classes = self.task_config.model.num_classes
        input_size = self.task_config.model.input_size

        decoder = classification_input.Decoder()
        parser = classification_input.Parser(output_size=input_size[:2],
                                             num_classes=num_classes,
                                             aug_policy=params.aug_policy,
                                             dtype=params.dtype)

        reader = input_reader.InputReader(
            params,
            dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
            decoder_fn=decoder.decode,
            parser_fn=parser.parse_fn(params.is_training))

        dataset = reader.read(input_context=input_context)

        return dataset