Ejemplo n.º 1
0
    def pred(self, task_instance, inference_model_dir=None):
        if self._for_train:
            raise Exception(
                'This controller is a trainer. Please build a new controller with for_train=False for predicting.'
            )

        assert isinstance(task_instance, str)
        if isinstance(inference_model_dir, str):
            assert os.path.exists(
                inference_model_dir), inference_model_dir + " not found."
        # if not self.has_init_pred and inference_model_dir is None:
        #     raise ValueError('infer_model_path is required for prediction.')
        if inference_model_dir is None:
            assert 'save_path' in self.mtl_conf, "one of the `inference_model_dir` and 'save_path' should be set to load inference model."
            inference_model_dir = os.path.join(self.mtl_conf['save_path'],
                                               task_instance, 'infer_model')

        instance = None
        for inst in self.instances:
            if inst.name == task_instance:
                instance = inst
                break

        if instance is None:
            raise ValueError(task_instance + ' is not a valid task_instance.')

        pred_prog = self._init_pred(instance, inference_model_dir)

        inst = instance
        print(inst.name + ": loading data...")
        inst.reader['pred'].load_data()
        fetch_names, fetch_vars = inst.pred_fetch_list

        print('predicting...')
        feed_batch_process_fn = create_feed_batch_process_fn(inst.pred_input)
        distribute_feeder = data_feeder(inst.reader['pred'].iterator,
                                        feed_batch_process_fn,
                                        prefetch_steps=1,
                                        phase='pred')

        buf = []
        for feed, mask, id in distribute_feeder:

            rt_outputs = self.exe.run(pred_prog, feed, fetch_vars)

            nums_fake = decode_fake(len(rt_outputs[0]), mask, self.batch_size)
            while nums_fake:
                for item in rt_outputs:
                    item.pop()
                nums_fake = nums_fake - 1

            rt_outputs = {k: v for k, v in zip(fetch_names, rt_outputs)}
            inst.postprocess(rt_outputs, phase='pred')

        if inst.task_layer['pred'].epoch_inputs_attrs:
            reader_outputs = inst.reader['pred'].get_epoch_outputs()
        else:
            reader_outputs = None

        inst.epoch_postprocess({'reader': reader_outputs}, phase='pred')
Ejemplo n.º 2
0
    def fit_reader(self, reader, phase='train'):
        """
        Bind a reader and loaded train/predict data to trainer. 
        
        Args:
            reader: a Reader object. The running phase of the reader should be consistent with `phase` argument of this method.
            phase: running phase. Currently support: train, predict.

        """

        self._check_phase(phase)
        if phase=='train':
            assert self._shape_and_dtypes is not None, "You need to build_forward or build_predict_head first to prepare input features."
        else:
            assert self._pred_shape_and_dtypes is not None, "You need to build_forward     or build_predict_head first to prepare input features."

        batch_size = reader._batch_size

        self._num_epochs = reader.num_epochs
        if phase == 'train':
            self._train_reader = reader
            self._steps_pur_epoch = reader.num_examples // batch_size
            shape_and_dtypes = self._shape_and_dtypes
            name_to_position = self._name_to_position
            if self._task_id is not None:
                self._net_inputs['__task_id'] = self._task_id
            net_inputs = self._net_inputs
            self._train_batch_size = batch_size
            self._num_examples = reader.num_examples
            reader_helper.check_io(self._backbone.inputs_attr, reader.outputs_attr, in_name='backbone', out_name='reader(train)')
            reader_helper.check_io(self._task_head.inputs_attrs['reader'], reader.outputs_attr, in_name='task_head(reader)', out_name='reader(train)')
            reader_helper.check_io(self._task_head.inputs_attrs['backbone'], self._backbone.outputs_attr, in_name='task_head(backbone, train)', out_name='backbone')
        elif phase == 'predict':
            self._predict_reader = reader
            self._pred_steps_pur_epoch = reader.num_examples // batch_size 
            shape_and_dtypes = self._pred_shape_and_dtypes
            name_to_position = self._pred_name_to_position
            net_inputs = self._pred_net_inputs
            self._predict_batch_size = batch_size
            self._pred_num_examples = reader.num_examples
            reader_helper.check_io(self._pred_backbone.inputs_attr, reader.outputs_attr, in_name='backbone', out_name='reader(predict)')
            reader_helper.check_io(self._pred_head.inputs_attrs['reader'], reader.outputs_attr, in_name='task_head(reader)', out_name='reader(predict)')
            reader_helper.check_io(self._pred_head.inputs_attrs['backbone'], self._pred_backbone.outputs_attr, in_name='task_head(backbone, predict)', out_name='backbone')
        else:
            raise NotImplementedError()
            
        print('ok!')

        # merge dataset iterators and create net input vars
        iterator = reader._iterator()
        prefix = self.name

        # merge dataset iterators and create net input vars
        iterator = reader._iterator()
        prefix = self.name

        # 对yield出的数据进行runtime检查和适配
        iterator_fn = reader_helper.create_iterator_fn(iterator, prefix, shape_and_dtypes, name_to_position, return_type='dict')
        self._raw_iterator_fn = iterator_fn
        feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(net_inputs)
        if gpu_dev_count > 1:
            distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn, phase=phase)
        else:
            distribute_feeder_fn = iterator_fn()

        if phase == 'train':
            self._train_iterator = distribute_feeder_fn
            self._feed_batch_process_fn = feed_batch_process_fn
        elif phase == 'predict':
            self._predict_iterator = distribute_feeder_fn
            self._pred_feed_batch_process_fn = feed_batch_process_fn
Ejemplo n.º 3
0
    def fit_readers_with_mixratio(self,
                                  readers,
                                  sampling_reference,
                                  num_epochs,
                                  phase='train'):
        """
        Bind readers and loaded train/predict data to trainers. The `num_epochs` argument only 
            works on `sampling_reference` task(trainer), and num_epochs of other tasks are infered from 
            their `mix_ratio`.

        Args:
            readers: a dict or list of Reader objects. For dict case, each key is a trainer's name, and the mapped value is the reader object to bind to the trainer. For list case, each 
            sampling_reference: a trainer name. The task(trainer) selected as baseline for task sampling. 
            num_epochs: training epochs of the sampling_reference task (trainer). 
        """
        self._check_phase(phase)

        if isinstance(readers, list):
            reader_dict = {k.name: v for k, v in zip(self._trainers, readers)}
        elif isinstance(readers, dict):
            reader_dict = readers
        else:
            raise ValueError()

        num_heads = len(self._trainers)
        assert len(
            reader_dict
        ) == num_heads, "received number of readers is not consistent with trainers."

        trainer_dict = {t.name: t for t in self._trainers}
        assert sampling_reference in trainer_dict

        trainer_dict[sampling_reference]._set_task_id(self._task_id_var)
        trainer_dict[sampling_reference].fit_reader(
            reader_dict[sampling_reference])
        base_steps_pur_epoch = trainer_dict[
            sampling_reference]._steps_pur_epoch

        self._finish_steps = {}
        self._finish = {}
        input_names = []
        name_to_pos = []
        joint_shape_and_dtypes = []
        iterators = []
        prefixes = []
        mrs = []
        net_inputs = []
        global_steps = 0
        for t in self._trainers:
            assert t.name in reader_dict
            assert reader_dict[
                t.name].num_epochs is None, "{}: num_epochs is not None. \
                To run with multi-head mode, num_epochs of each Trainer should be set as None.".format(
                    t.name)
            # print(num_epochs, t.mix_ratio, base_steps_pur_epoch)
            max_train_steps = int(num_epochs * t.mix_ratio *
                                  base_steps_pur_epoch)
            if not t._as_auxilary:
                print('{}: expected train steps {}.'.format(
                    t.name, max_train_steps))
                sys.stdout.flush()
                self._finish_steps[t.name] = max_train_steps
                self._finish[t.name] = False
            else:
                self._finish_steps[t.name] = 9999999999
                self._finish[t.name] = True

            global_steps += max_train_steps
            if t.name != sampling_reference:
                t._set_task_id(self._task_id_var)
                t.fit_reader(reader_dict[t.name])
            net_inputs.append(t._net_inputs)
            prefixes.append(t.name)
            mrs.append(t.mix_ratio)
            iterators.append(t._raw_iterator_fn())
            input_names.append(t._input_names)
            name_to_pos.append(t._name_to_position)
            joint_shape_and_dtypes.append(t._shape_and_dtypes)

        print('Estimated overall train steps {}.'.format(global_steps))
        sys.stdout.flush()
        self._overall_train_steps = global_steps

        iterator_fn = reader_helper.create_multihead_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, \
            mrs, input_names, name_to_pos, dev_count=dev_count)
        feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(
            net_inputs)

        if gpu_dev_count > 1:
            distribute_feeder_fn = data_feeder(iterator_fn,
                                               feed_batch_process_fn,
                                               phase=phase,
                                               is_multi=True)
        else:
            distribute_feeder_fn = iterator_fn()

        if phase == 'train':
            self._train_reader = distribute_feeder_fn
            self._feed_batch_process_fn = feed_batch_process_fn
        elif phase == 'predict':
            self._predict_reader = distribute_feeder_fn
            self._pred_feed_batch_process_fn = feed_batch_process_fn
Ejemplo n.º 4
0
    def train(self):

        if not self.has_init_train:
            self._init_train()
            self.has_init_train = True

        instances = self.instances
        num_instances = self.num_instances
        main_inst = self.main_inst
        main_conf = main_inst.config

        backbone = self.train_backbone
        train_program = self.train_program
        saver_program = self.saver_program
        finish = []
        for inst in instances:
            if inst.is_target:
                if inst.expected_train_steps > 0:
                    finish.append(False)
                else:
                    finish.append(True)
                    print(inst.name + ': train finished!')
                    inst.save()

        def train_finish():
            for inst in instances:
                if inst.is_target:
                    if not inst.train_finish:
                        return False
            return True

        # do training
        fetch_names = {}
        fetch_list = []
        main_step = 0  # only count for main task
        global_step = 0  # count for all tasks
        epoch = 0
        time_begin = time.time()
        backbone_buffer = []

        feed_batch_process_fn = create_feed_batch_process_fn(self._net_inputs)
        distribute_feeder = data_feeder(self._joint_iterator_fn,
                                        feed_batch_process_fn)

        while not train_finish():
            feed, mask, id = next(distribute_feeder)
            for i in range(self.dev_count):
                feed[i].update({'branch': np.array([id], dtype='int64')})
            fetch_list.append(self._switched_loss)
            rt_outputs = self.exe.run(train_program,
                                      feed=feed,
                                      fetch_list=fetch_list)
            rt_loss = rt_outputs.pop()

            rt_outputs = {k: v for k, v in zip(fetch_names, rt_outputs)}
            cur_task = instances[id]

            # backbone_rt_outputs = {k:v for k,v in rt_outputs.items() if '/' not in k}
            # backbone_buffer.append(backbone.postprocess(backbone_rt_outputs))

            # task_rt_outputs = {k[len(cur_task.name+'/'):]: v for k,v in rt_outputs.items() if k.startswith(cur_task.name+'/')}
            # instances[rt_task_id].task_layer['train'].postprocess(task_rt_outputs)

            global_step += 1
            cur_task.cur_train_step += 1

            cur_task_global_step = cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch
            if cur_task.is_target and cur_task.save_infermodel_every_n_steps > 0 and cur_task_global_step % cur_task.save_infermodel_every_n_steps == 0:
                cur_task.save(suffix='.step' + str(cur_task_global_step),
                              prog=self._pred_prog)

            if global_step % main_conf.get('print_every_n_steps', 5) == 0:
                loss = rt_loss
                loss = np.mean(np.squeeze(loss)).tolist()

                time_end = time.time()
                time_cost = time_end - time_begin

                print(
                    "Global step: {}. Task: {}, step {}/{} (epoch {}), loss: {:.3f}, speed: {:.2f} steps/s"
                    .format(
                        global_step, cur_task.name, cur_task.cur_train_step,
                        cur_task.steps_pur_epoch, cur_task.cur_train_epoch,
                        loss,
                        main_conf.get('print_every_n_steps', 5) / time_cost))
                time_begin = time.time()

            if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps:
                print(cur_task.name + ': train finished!')
                cur_task.save(prog=self._pred_prog)

            if 'save_ckpt_every_n_steps' in main_conf and global_step % main_conf[
                    'save_ckpt_every_n_steps'] == 0:
                save_path = os.path.join(main_conf['save_path'], 'ckpt',
                                         "step_" + str(global_step))
                fluid.io.save_persistables(self.exe, save_path, saver_program)
                print('checkpoint has been saved at ' + save_path)

        save_path = os.path.join(main_conf['save_path'], 'ckpt',
                                 "step_" + str(global_step))
        fluid.io.save_persistables(self.exe, save_path, saver_program)
        print('checkpoint has been saved at ' + save_path)

        print("ALL tasks train finished, exiting...")