def pred(self, task_instance, inference_model_dir=None): if self._for_train: raise Exception( 'This controller is a trainer. Please build a new controller with for_train=False for predicting.' ) assert isinstance(task_instance, str) if isinstance(inference_model_dir, str): assert os.path.exists( inference_model_dir), inference_model_dir + " not found." # if not self.has_init_pred and inference_model_dir is None: # raise ValueError('infer_model_path is required for prediction.') if inference_model_dir is None: assert 'save_path' in self.mtl_conf, "one of the `inference_model_dir` and 'save_path' should be set to load inference model." inference_model_dir = os.path.join(self.mtl_conf['save_path'], task_instance, 'infer_model') instance = None for inst in self.instances: if inst.name == task_instance: instance = inst break if instance is None: raise ValueError(task_instance + ' is not a valid task_instance.') pred_prog = self._init_pred(instance, inference_model_dir) inst = instance print(inst.name + ": loading data...") inst.reader['pred'].load_data() fetch_names, fetch_vars = inst.pred_fetch_list print('predicting...') feed_batch_process_fn = create_feed_batch_process_fn(inst.pred_input) distribute_feeder = data_feeder(inst.reader['pred'].iterator, feed_batch_process_fn, prefetch_steps=1, phase='pred') buf = [] for feed, mask, id in distribute_feeder: rt_outputs = self.exe.run(pred_prog, feed, fetch_vars) nums_fake = decode_fake(len(rt_outputs[0]), mask, self.batch_size) while nums_fake: for item in rt_outputs: item.pop() nums_fake = nums_fake - 1 rt_outputs = {k: v for k, v in zip(fetch_names, rt_outputs)} inst.postprocess(rt_outputs, phase='pred') if inst.task_layer['pred'].epoch_inputs_attrs: reader_outputs = inst.reader['pred'].get_epoch_outputs() else: reader_outputs = None inst.epoch_postprocess({'reader': reader_outputs}, phase='pred')
def train_one_step(self, batch): # exe = self._exe if executor is None else executor # distribute_train_prog = self._distribute_train_prog if distribute_train_prog is None else distribute_train_prog # fetch_list = self._fetch_list if fetch_list is None else fetch_list exe = self._exe distribute_train_prog = self._distribute_train_prog fetch_list = self._fetch_list if gpu_dev_count > 1: feed, mask = batch rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=fetch_list) num_fakes = decode_fake(len(rt_outputs[0]), mask, self._train_batch_size) if num_fakes: rt_outputs = [i[:-num_fakes] for i in rt_outputs] else: feed = self._feed_batch_process_fn(batch) rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=fetch_list) rt_outputs = {k: v for k, v in zip(self._fetch_names, rt_outputs)} self._cur_train_step += 1 self._check_save() self._cur_train_epoch = (self._cur_train_step - 1) // self._steps_pur_epoch return rt_outputs
def train_one_step(self, batch): if not self._dist_train_init: self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name) self._dist_train_init = True exe = self._exe distribute_train_prog = self._distribute_train_prog fetch_list = self._fetch_list if gpu_dev_count > 1: feed, mask = batch rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=fetch_list) num_fakes = decode_fake(len(rt_outputs[0]), mask, self._train_batch_size) if num_fakes: rt_outputs = [i[:-num_fakes] for i in rt_outputs] else: feed = self._feed_batch_process_fn(batch) rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=fetch_list) rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)} self._cur_train_step += 1 self._check_save() self._cur_train_epoch = (self._cur_train_step-1) // self._steps_pur_epoch return rt_outputs
def predict_one_batch(self, batch): if gpu_dev_count > 1: feed, mask = batch rt_outputs = self._exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._pred_fetch_list, use_prune=True) num_fakes = decode_fake(len(rt_outputs[0]), mask, self._predict_batch_size) if num_fakes: rt_outputs = [i[:-num_fakes] for i in rt_outputs] else: feed = self._pred_feed_batch_process_fn(batch) rt_outputs = self._exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._pred_fetch_list, use_prune=True) rt_outputs = {k:v for k,v in zip(self._pred_fetch_name_list, rt_outputs)} return rt_outputs