def pred(self, task_instance, inference_model_dir=None): if self._for_train: raise Exception( 'This controller is a trainer. Please build a new controller with for_train=False for predicting.' ) assert isinstance(task_instance, str) if isinstance(inference_model_dir, str): assert os.path.exists( inference_model_dir), inference_model_dir + " not found." # if not self.has_init_pred and inference_model_dir is None: # raise ValueError('infer_model_path is required for prediction.') if inference_model_dir is None: assert 'save_path' in self.mtl_conf, "one of the `inference_model_dir` and 'save_path' should be set to load inference model." inference_model_dir = os.path.join(self.mtl_conf['save_path'], task_instance, 'infer_model') instance = None for inst in self.instances: if inst.name == task_instance: instance = inst break if instance is None: raise ValueError(task_instance + ' is not a valid task_instance.') pred_prog = self._init_pred(instance, inference_model_dir) inst = instance print(inst.name + ": loading data...") inst.reader['pred'].load_data() fetch_names, fetch_vars = inst.pred_fetch_list print('predicting...') feed_batch_process_fn = create_feed_batch_process_fn(inst.pred_input) distribute_feeder = data_feeder(inst.reader['pred'].iterator, feed_batch_process_fn, prefetch_steps=1, phase='pred') buf = [] for feed, mask, id in distribute_feeder: rt_outputs = self.exe.run(pred_prog, feed, fetch_vars) nums_fake = decode_fake(len(rt_outputs[0]), mask, self.batch_size) while nums_fake: for item in rt_outputs: item.pop() nums_fake = nums_fake - 1 rt_outputs = {k: v for k, v in zip(fetch_names, rt_outputs)} inst.postprocess(rt_outputs, phase='pred') if inst.task_layer['pred'].epoch_inputs_attrs: reader_outputs = inst.reader['pred'].get_epoch_outputs() else: reader_outputs = None inst.epoch_postprocess({'reader': reader_outputs}, phase='pred')
def fit_reader(self, reader, phase='train'): """ Bind a reader and loaded train/predict data to trainer. Args: reader: a Reader object. The running phase of the reader should be consistent with `phase` argument of this method. phase: running phase. Currently support: train, predict. """ self._check_phase(phase) if phase=='train': assert self._shape_and_dtypes is not None, "You need to build_forward or build_predict_head first to prepare input features." else: assert self._pred_shape_and_dtypes is not None, "You need to build_forward or build_predict_head first to prepare input features." batch_size = reader._batch_size self._num_epochs = reader.num_epochs if phase == 'train': self._train_reader = reader self._steps_pur_epoch = reader.num_examples // batch_size shape_and_dtypes = self._shape_and_dtypes name_to_position = self._name_to_position if self._task_id is not None: self._net_inputs['__task_id'] = self._task_id net_inputs = self._net_inputs self._train_batch_size = batch_size self._num_examples = reader.num_examples reader_helper.check_io(self._backbone.inputs_attr, reader.outputs_attr, in_name='backbone', out_name='reader(train)') reader_helper.check_io(self._task_head.inputs_attrs['reader'], reader.outputs_attr, in_name='task_head(reader)', out_name='reader(train)') reader_helper.check_io(self._task_head.inputs_attrs['backbone'], self._backbone.outputs_attr, in_name='task_head(backbone, train)', out_name='backbone') elif phase == 'predict': self._predict_reader = reader self._pred_steps_pur_epoch = reader.num_examples // batch_size shape_and_dtypes = self._pred_shape_and_dtypes name_to_position = self._pred_name_to_position net_inputs = self._pred_net_inputs self._predict_batch_size = batch_size self._pred_num_examples = reader.num_examples reader_helper.check_io(self._pred_backbone.inputs_attr, reader.outputs_attr, in_name='backbone', out_name='reader(predict)') reader_helper.check_io(self._pred_head.inputs_attrs['reader'], reader.outputs_attr, in_name='task_head(reader)', out_name='reader(predict)') reader_helper.check_io(self._pred_head.inputs_attrs['backbone'], self._pred_backbone.outputs_attr, in_name='task_head(backbone, predict)', out_name='backbone') else: raise NotImplementedError() print('ok!') # merge dataset iterators and create net input vars iterator = reader._iterator() prefix = self.name # merge dataset iterators and create net input vars iterator = reader._iterator() prefix = self.name # 对yield出的数据进行runtime检查和适配 iterator_fn = reader_helper.create_iterator_fn(iterator, prefix, shape_and_dtypes, name_to_position, return_type='dict') self._raw_iterator_fn = iterator_fn feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(net_inputs) if gpu_dev_count > 1: distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn, phase=phase) else: distribute_feeder_fn = iterator_fn() if phase == 'train': self._train_iterator = distribute_feeder_fn self._feed_batch_process_fn = feed_batch_process_fn elif phase == 'predict': self._predict_iterator = distribute_feeder_fn self._pred_feed_batch_process_fn = feed_batch_process_fn
def fit_readers_with_mixratio(self, readers, sampling_reference, num_epochs, phase='train'): """ Bind readers and loaded train/predict data to trainers. The `num_epochs` argument only works on `sampling_reference` task(trainer), and num_epochs of other tasks are infered from their `mix_ratio`. Args: readers: a dict or list of Reader objects. For dict case, each key is a trainer's name, and the mapped value is the reader object to bind to the trainer. For list case, each sampling_reference: a trainer name. The task(trainer) selected as baseline for task sampling. num_epochs: training epochs of the sampling_reference task (trainer). """ self._check_phase(phase) if isinstance(readers, list): reader_dict = {k.name: v for k, v in zip(self._trainers, readers)} elif isinstance(readers, dict): reader_dict = readers else: raise ValueError() num_heads = len(self._trainers) assert len( reader_dict ) == num_heads, "received number of readers is not consistent with trainers." trainer_dict = {t.name: t for t in self._trainers} assert sampling_reference in trainer_dict trainer_dict[sampling_reference]._set_task_id(self._task_id_var) trainer_dict[sampling_reference].fit_reader( reader_dict[sampling_reference]) base_steps_pur_epoch = trainer_dict[ sampling_reference]._steps_pur_epoch self._finish_steps = {} self._finish = {} input_names = [] name_to_pos = [] joint_shape_and_dtypes = [] iterators = [] prefixes = [] mrs = [] net_inputs = [] global_steps = 0 for t in self._trainers: assert t.name in reader_dict assert reader_dict[ t.name].num_epochs is None, "{}: num_epochs is not None. \ To run with multi-head mode, num_epochs of each Trainer should be set as None.".format( t.name) # print(num_epochs, t.mix_ratio, base_steps_pur_epoch) max_train_steps = int(num_epochs * t.mix_ratio * base_steps_pur_epoch) if not t._as_auxilary: print('{}: expected train steps {}.'.format( t.name, max_train_steps)) sys.stdout.flush() self._finish_steps[t.name] = max_train_steps self._finish[t.name] = False else: self._finish_steps[t.name] = 9999999999 self._finish[t.name] = True global_steps += max_train_steps if t.name != sampling_reference: t._set_task_id(self._task_id_var) t.fit_reader(reader_dict[t.name]) net_inputs.append(t._net_inputs) prefixes.append(t.name) mrs.append(t.mix_ratio) iterators.append(t._raw_iterator_fn()) input_names.append(t._input_names) name_to_pos.append(t._name_to_position) joint_shape_and_dtypes.append(t._shape_and_dtypes) print('Estimated overall train steps {}.'.format(global_steps)) sys.stdout.flush() self._overall_train_steps = global_steps iterator_fn = reader_helper.create_multihead_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, \ mrs, input_names, name_to_pos, dev_count=dev_count) feed_batch_process_fn = reader_helper.create_feed_batch_process_fn( net_inputs) if gpu_dev_count > 1: distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn, phase=phase, is_multi=True) else: distribute_feeder_fn = iterator_fn() if phase == 'train': self._train_reader = distribute_feeder_fn self._feed_batch_process_fn = feed_batch_process_fn elif phase == 'predict': self._predict_reader = distribute_feeder_fn self._pred_feed_batch_process_fn = feed_batch_process_fn
def train(self): if not self.has_init_train: self._init_train() self.has_init_train = True instances = self.instances num_instances = self.num_instances main_inst = self.main_inst main_conf = main_inst.config backbone = self.train_backbone train_program = self.train_program saver_program = self.saver_program finish = [] for inst in instances: if inst.is_target: if inst.expected_train_steps > 0: finish.append(False) else: finish.append(True) print(inst.name + ': train finished!') inst.save() def train_finish(): for inst in instances: if inst.is_target: if not inst.train_finish: return False return True # do training fetch_names = {} fetch_list = [] main_step = 0 # only count for main task global_step = 0 # count for all tasks epoch = 0 time_begin = time.time() backbone_buffer = [] feed_batch_process_fn = create_feed_batch_process_fn(self._net_inputs) distribute_feeder = data_feeder(self._joint_iterator_fn, feed_batch_process_fn) while not train_finish(): feed, mask, id = next(distribute_feeder) for i in range(self.dev_count): feed[i].update({'branch': np.array([id], dtype='int64')}) fetch_list.append(self._switched_loss) rt_outputs = self.exe.run(train_program, feed=feed, fetch_list=fetch_list) rt_loss = rt_outputs.pop() rt_outputs = {k: v for k, v in zip(fetch_names, rt_outputs)} cur_task = instances[id] # backbone_rt_outputs = {k:v for k,v in rt_outputs.items() if '/' not in k} # backbone_buffer.append(backbone.postprocess(backbone_rt_outputs)) # task_rt_outputs = {k[len(cur_task.name+'/'):]: v for k,v in rt_outputs.items() if k.startswith(cur_task.name+'/')} # instances[rt_task_id].task_layer['train'].postprocess(task_rt_outputs) global_step += 1 cur_task.cur_train_step += 1 cur_task_global_step = cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch if cur_task.is_target and cur_task.save_infermodel_every_n_steps > 0 and cur_task_global_step % cur_task.save_infermodel_every_n_steps == 0: cur_task.save(suffix='.step' + str(cur_task_global_step), prog=self._pred_prog) if global_step % main_conf.get('print_every_n_steps', 5) == 0: loss = rt_loss loss = np.mean(np.squeeze(loss)).tolist() time_end = time.time() time_cost = time_end - time_begin print( "Global step: {}. Task: {}, step {}/{} (epoch {}), loss: {:.3f}, speed: {:.2f} steps/s" .format( global_step, cur_task.name, cur_task.cur_train_step, cur_task.steps_pur_epoch, cur_task.cur_train_epoch, loss, main_conf.get('print_every_n_steps', 5) / time_cost)) time_begin = time.time() if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps: print(cur_task.name + ': train finished!') cur_task.save(prog=self._pred_prog) if 'save_ckpt_every_n_steps' in main_conf and global_step % main_conf[ 'save_ckpt_every_n_steps'] == 0: save_path = os.path.join(main_conf['save_path'], 'ckpt', "step_" + str(global_step)) fluid.io.save_persistables(self.exe, save_path, saver_program) print('checkpoint has been saved at ' + save_path) save_path = os.path.join(main_conf['save_path'], 'ckpt', "step_" + str(global_step)) fluid.io.save_persistables(self.exe, save_path, saver_program) print('checkpoint has been saved at ' + save_path) print("ALL tasks train finished, exiting...")