def fit_readers_with_mixratio(self, readers, sampling_reference, num_epochs, phase='train'): """ Bind readers and loaded train/predict data to trainers. The `num_epochs` argument only works on `sampling_reference` task(trainer), and num_epochs of other tasks are infered from their `mix_ratio`. Args: readers: a dict or list of Reader objects. For dict case, each key is a trainer's name, and the mapped value is the reader object to bind to the trainer. For list case, each sampling_reference: a trainer name. The task(trainer) selected as baseline for task sampling. num_epochs: training epochs of the sampling_reference task (trainer). """ self._check_phase(phase) if isinstance(readers, list): reader_dict = {k.name: v for k, v in zip(self._trainers, readers)} elif isinstance(readers, dict): reader_dict = readers else: raise ValueError() num_heads = len(self._trainers) assert len( reader_dict ) == num_heads, "received number of readers is not consistent with trainers." trainer_dict = {t.name: t for t in self._trainers} assert sampling_reference in trainer_dict trainer_dict[sampling_reference]._set_task_id(self._task_id_var) trainer_dict[sampling_reference].fit_reader( reader_dict[sampling_reference]) base_steps_pur_epoch = trainer_dict[ sampling_reference]._steps_pur_epoch self._finish_steps = {} self._finish = {} input_names = [] name_to_pos = [] joint_shape_and_dtypes = [] iterators = [] prefixes = [] mrs = [] net_inputs = [] global_steps = 0 for t in self._trainers: assert t.name in reader_dict assert reader_dict[ t.name].num_epochs is None, "{}: num_epochs is not None. \ To run with multi-head mode, num_epochs of each Trainer should be set as None.".format( t.name) # print(num_epochs, t.mix_ratio, base_steps_pur_epoch) max_train_steps = int(num_epochs * t.mix_ratio * base_steps_pur_epoch) if not t._as_auxilary: print('{}: expected train steps {}.'.format( t.name, max_train_steps)) sys.stdout.flush() self._finish_steps[t.name] = max_train_steps self._finish[t.name] = False else: self._finish_steps[t.name] = 9999999999 self._finish[t.name] = True global_steps += max_train_steps if t.name != sampling_reference: t._set_task_id(self._task_id_var) t.fit_reader(reader_dict[t.name]) net_inputs.append(t._net_inputs) prefixes.append(t.name) mrs.append(t.mix_ratio) iterators.append(t._raw_iterator_fn()) input_names.append(t._input_names) name_to_pos.append(t._name_to_position) joint_shape_and_dtypes.append(t._shape_and_dtypes) print('Estimated overall train steps {}.'.format(global_steps)) sys.stdout.flush() self._overall_train_steps = global_steps iterator_fn = reader_helper.create_multihead_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, \ mrs, input_names, name_to_pos, dev_count=dev_count) feed_batch_process_fn = reader_helper.create_feed_batch_process_fn( net_inputs) if gpu_dev_count > 1: distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn, phase=phase, is_multi=True) else: distribute_feeder_fn = iterator_fn() if phase == 'train': self._train_reader = distribute_feeder_fn self._feed_batch_process_fn = feed_batch_process_fn elif phase == 'predict': self._predict_reader = distribute_feeder_fn self._pred_feed_batch_process_fn = feed_batch_process_fn
def fit_reader(self, reader, phase='train'): """ Bind a reader and loaded train/predict data to trainer. Args: reader: a Reader object. The running phase of the reader should be consistent with `phase` argument of this method. phase: running phase. Currently support: train, predict. """ self._check_phase(phase) if phase=='train': assert self._shape_and_dtypes is not None, "You need to build_forward or build_predict_head first to prepare input features." else: assert self._pred_shape_and_dtypes is not None, "You need to build_forward or build_predict_head first to prepare input features." batch_size = reader._batch_size self._num_epochs = reader.num_epochs if phase == 'train': self._train_reader = reader self._steps_pur_epoch = reader.num_examples // batch_size shape_and_dtypes = self._shape_and_dtypes name_to_position = self._name_to_position if self._task_id is not None: self._net_inputs['__task_id'] = self._task_id net_inputs = self._net_inputs self._train_batch_size = batch_size self._num_examples = reader.num_examples reader_helper.check_io(self._backbone.inputs_attr, reader.outputs_attr, in_name='backbone', out_name='reader(train)') reader_helper.check_io(self._task_head.inputs_attrs['reader'], reader.outputs_attr, in_name='task_head(reader)', out_name='reader(train)') reader_helper.check_io(self._task_head.inputs_attrs['backbone'], self._backbone.outputs_attr, in_name='task_head(backbone, train)', out_name='backbone') elif phase == 'predict': self._predict_reader = reader self._pred_steps_pur_epoch = reader.num_examples // batch_size shape_and_dtypes = self._pred_shape_and_dtypes name_to_position = self._pred_name_to_position net_inputs = self._pred_net_inputs self._predict_batch_size = batch_size self._pred_num_examples = reader.num_examples reader_helper.check_io(self._pred_backbone.inputs_attr, reader.outputs_attr, in_name='backbone', out_name='reader(predict)') reader_helper.check_io(self._pred_head.inputs_attrs['reader'], reader.outputs_attr, in_name='task_head(reader)', out_name='reader(predict)') reader_helper.check_io(self._pred_head.inputs_attrs['backbone'], self._pred_backbone.outputs_attr, in_name='task_head(backbone, predict)', out_name='backbone') else: raise NotImplementedError() print('ok!') # merge dataset iterators and create net input vars iterator = reader._iterator() prefix = self.name # merge dataset iterators and create net input vars iterator = reader._iterator() prefix = self.name # 对yield出的数据进行runtime检查和适配 iterator_fn = reader_helper.create_iterator_fn(iterator, prefix, shape_and_dtypes, name_to_position, return_type='dict') self._raw_iterator_fn = iterator_fn feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(net_inputs) if gpu_dev_count > 1: distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn, phase=phase) else: distribute_feeder_fn = iterator_fn() if phase == 'train': self._train_iterator = distribute_feeder_fn self._feed_batch_process_fn = feed_batch_process_fn elif phase == 'predict': self._predict_iterator = distribute_feeder_fn self._pred_feed_batch_process_fn = feed_batch_process_fn
def merge_inference_readers(self, readers): for r in readers: assert r._phase == 'predict' if isinstance(readers, list): reader_dict = {k.name: v for k, v in zip(self._trainers, readers)} elif isinstance(readers, dict): reader_dict = readers else: raise ValueError() num_heads = len(self._trainers) assert len( reader_dict ) == num_heads, "received number of readers is not consistent with trainers." trainer_dict = {t.name: t for t in self._trainers} task_name2id = {t.name: idx for idx, t in enumerate(self._trainers)} self._task_name2id = task_name2id self._finish_steps = {} self._finish = {} input_names = [] name_to_pos = [] joint_shape_and_dtypes = [] iterators = [] prefixes = [] mrs = [] net_inputs = [] global_steps = 0 for t in self._trainers: assert t.name in reader_dict assert reader_dict[ t.name].num_epochs is None, "{}: num_epochs is not None. \ To run with multi-head mode, num_epochs of each Trainer should be set as None.".format( t.name) # print(num_epochs, t.mix_ratio, base_steps_pur_epoch) self._finish_steps[t.name] = 9999999999 self._finish[t.name] = True # t._set_task_id(self._task_id_var) t.fit_reader(reader_dict[t.name], phase='predict') net_inputs.append(t._pred_net_inputs) prefixes.append(t.name) iterators.append(t._raw_iterator_fn()) input_names.append(t._pred_input_names) name_to_pos.append(t._pred_name_to_position) joint_shape_and_dtypes.append(t._pred_shape_and_dtypes) iterator_fn = reader_helper.create_multihead_inference_fn(iterators, prefixes, joint_shape_and_dtypes, \ input_names, name_to_pos, task_name2id, dev_count=dev_count) feed_batch_process_fn = reader_helper.create_feed_batch_process_fn( net_inputs) if gpu_dev_count > 1: raise NotImplementedError( 'currently only single-gpu mode has been supported running with multi-task mode.' ) # distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn, phase=phase, is_multi=True, with_arg=True) else: distribute_feeder_fn = iterator_fn self._predict_iterator_fn = distribute_feeder_fn self._pred_feed_batch_process_fn = feed_batch_process_fn return distribute_feeder_fn