def build_inputs(self, params, input_context=None): input_size = self.task_config.model.input_size if params.tfds_name: decoder = simclr_input.TFDSDecoder(params.decoder.decode_label) else: decoder = simclr_input.Decoder(params.decoder.decode_label) parser = simclr_input.Parser( output_size=input_size[:2], aug_rand_crop=params.parser.aug_rand_crop, aug_rand_hflip=params.parser.aug_rand_hflip, aug_color_distort=params.parser.aug_color_distort, aug_color_jitter_strength=params.parser.aug_color_jitter_strength, aug_color_jitter_impl=params.parser.aug_color_jitter_impl, aug_rand_blur=params.parser.aug_rand_blur, parse_label=params.parser.parse_label, test_crop=params.parser.test_crop, mode=params.parser.mode, dtype=params.dtype) reader = input_reader.InputReader(params, dataset_fn=tf.data.TFRecordDataset, decoder_fn=decoder.decode, parser_fn=parser.parse_fn( params.is_training)) dataset = reader.read(input_context=input_context) return dataset
def build_inputs(self, params, input_context=None): input_size = self.task_config.model.input_size if params.tfds_name: decoder = simclr_input.TFDSDecoder(params.decoder.decode_label) else: decoder = simclr_input.Decoder(params.decoder.decode_label) parser = simclr_input.Parser(output_size=input_size[:2], parse_label=params.parser.parse_label, test_crop=params.parser.test_crop, mode=params.parser.mode, dtype=params.dtype) reader = input_reader.InputReader(params, dataset_fn=tf.data.TFRecordDataset, decoder_fn=decoder.decode, parser_fn=parser.parse_fn( params.is_training)) dataset = reader.read(input_context=input_context) return dataset