Esempio n. 1
0
    def fit_readers_with_mixratio(self,
                                  readers,
                                  sampling_reference,
                                  num_epochs,
                                  phase='train'):
        """
        Bind readers and loaded train/predict data to trainers. The `num_epochs` argument only 
            works on `sampling_reference` task(trainer), and num_epochs of other tasks are infered from 
            their `mix_ratio`.

        Args:
            readers: a dict or list of Reader objects. For dict case, each key is a trainer's name, and the mapped value is the reader object to bind to the trainer. For list case, each 
            sampling_reference: a trainer name. The task(trainer) selected as baseline for task sampling. 
            num_epochs: training epochs of the sampling_reference task (trainer). 
        """
        self._check_phase(phase)

        if isinstance(readers, list):
            reader_dict = {k.name: v for k, v in zip(self._trainers, readers)}
        elif isinstance(readers, dict):
            reader_dict = readers
        else:
            raise ValueError()

        num_heads = len(self._trainers)
        assert len(
            reader_dict
        ) == num_heads, "received number of readers is not consistent with trainers."

        trainer_dict = {t.name: t for t in self._trainers}
        assert sampling_reference in trainer_dict

        trainer_dict[sampling_reference]._set_task_id(self._task_id_var)
        trainer_dict[sampling_reference].fit_reader(
            reader_dict[sampling_reference])
        base_steps_pur_epoch = trainer_dict[
            sampling_reference]._steps_pur_epoch

        self._finish_steps = {}
        self._finish = {}
        input_names = []
        name_to_pos = []
        joint_shape_and_dtypes = []
        iterators = []
        prefixes = []
        mrs = []
        net_inputs = []
        global_steps = 0
        for t in self._trainers:
            assert t.name in reader_dict
            assert reader_dict[
                t.name].num_epochs is None, "{}: num_epochs is not None. \
                To run with multi-head mode, num_epochs of each Trainer should be set as None.".format(
                    t.name)
            # print(num_epochs, t.mix_ratio, base_steps_pur_epoch)
            max_train_steps = int(num_epochs * t.mix_ratio *
                                  base_steps_pur_epoch)
            if not t._as_auxilary:
                print('{}: expected train steps {}.'.format(
                    t.name, max_train_steps))
                sys.stdout.flush()
                self._finish_steps[t.name] = max_train_steps
                self._finish[t.name] = False
            else:
                self._finish_steps[t.name] = 9999999999
                self._finish[t.name] = True

            global_steps += max_train_steps
            if t.name != sampling_reference:
                t._set_task_id(self._task_id_var)
                t.fit_reader(reader_dict[t.name])
            net_inputs.append(t._net_inputs)
            prefixes.append(t.name)
            mrs.append(t.mix_ratio)
            iterators.append(t._raw_iterator_fn())
            input_names.append(t._input_names)
            name_to_pos.append(t._name_to_position)
            joint_shape_and_dtypes.append(t._shape_and_dtypes)

        print('Estimated overall train steps {}.'.format(global_steps))
        sys.stdout.flush()
        self._overall_train_steps = global_steps

        iterator_fn = reader_helper.create_multihead_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, \
            mrs, input_names, name_to_pos, dev_count=dev_count)
        feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(
            net_inputs)

        if gpu_dev_count > 1:
            distribute_feeder_fn = data_feeder(iterator_fn,
                                               feed_batch_process_fn,
                                               phase=phase,
                                               is_multi=True)
        else:
            distribute_feeder_fn = iterator_fn()

        if phase == 'train':
            self._train_reader = distribute_feeder_fn
            self._feed_batch_process_fn = feed_batch_process_fn
        elif phase == 'predict':
            self._predict_reader = distribute_feeder_fn
            self._pred_feed_batch_process_fn = feed_batch_process_fn
Esempio n. 2
0
    def fit_reader(self, reader, phase='train'):
        """
        Bind a reader and loaded train/predict data to trainer. 
        
        Args:
            reader: a Reader object. The running phase of the reader should be consistent with `phase` argument of this method.
            phase: running phase. Currently support: train, predict.

        """

        self._check_phase(phase)
        if phase=='train':
            assert self._shape_and_dtypes is not None, "You need to build_forward or build_predict_head first to prepare input features."
        else:
            assert self._pred_shape_and_dtypes is not None, "You need to build_forward     or build_predict_head first to prepare input features."

        batch_size = reader._batch_size

        self._num_epochs = reader.num_epochs
        if phase == 'train':
            self._train_reader = reader
            self._steps_pur_epoch = reader.num_examples // batch_size
            shape_and_dtypes = self._shape_and_dtypes
            name_to_position = self._name_to_position
            if self._task_id is not None:
                self._net_inputs['__task_id'] = self._task_id
            net_inputs = self._net_inputs
            self._train_batch_size = batch_size
            self._num_examples = reader.num_examples
            reader_helper.check_io(self._backbone.inputs_attr, reader.outputs_attr, in_name='backbone', out_name='reader(train)')
            reader_helper.check_io(self._task_head.inputs_attrs['reader'], reader.outputs_attr, in_name='task_head(reader)', out_name='reader(train)')
            reader_helper.check_io(self._task_head.inputs_attrs['backbone'], self._backbone.outputs_attr, in_name='task_head(backbone, train)', out_name='backbone')
        elif phase == 'predict':
            self._predict_reader = reader
            self._pred_steps_pur_epoch = reader.num_examples // batch_size 
            shape_and_dtypes = self._pred_shape_and_dtypes
            name_to_position = self._pred_name_to_position
            net_inputs = self._pred_net_inputs
            self._predict_batch_size = batch_size
            self._pred_num_examples = reader.num_examples
            reader_helper.check_io(self._pred_backbone.inputs_attr, reader.outputs_attr, in_name='backbone', out_name='reader(predict)')
            reader_helper.check_io(self._pred_head.inputs_attrs['reader'], reader.outputs_attr, in_name='task_head(reader)', out_name='reader(predict)')
            reader_helper.check_io(self._pred_head.inputs_attrs['backbone'], self._pred_backbone.outputs_attr, in_name='task_head(backbone, predict)', out_name='backbone')
        else:
            raise NotImplementedError()
            
        print('ok!')

        # merge dataset iterators and create net input vars
        iterator = reader._iterator()
        prefix = self.name

        # merge dataset iterators and create net input vars
        iterator = reader._iterator()
        prefix = self.name

        # 对yield出的数据进行runtime检查和适配
        iterator_fn = reader_helper.create_iterator_fn(iterator, prefix, shape_and_dtypes, name_to_position, return_type='dict')
        self._raw_iterator_fn = iterator_fn
        feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(net_inputs)
        if gpu_dev_count > 1:
            distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn, phase=phase)
        else:
            distribute_feeder_fn = iterator_fn()

        if phase == 'train':
            self._train_iterator = distribute_feeder_fn
            self._feed_batch_process_fn = feed_batch_process_fn
        elif phase == 'predict':
            self._predict_iterator = distribute_feeder_fn
            self._pred_feed_batch_process_fn = feed_batch_process_fn
Esempio n. 3
0
    def merge_inference_readers(self, readers):

        for r in readers:
            assert r._phase == 'predict'

        if isinstance(readers, list):
            reader_dict = {k.name: v for k, v in zip(self._trainers, readers)}
        elif isinstance(readers, dict):
            reader_dict = readers
        else:
            raise ValueError()

        num_heads = len(self._trainers)
        assert len(
            reader_dict
        ) == num_heads, "received number of readers is not consistent with trainers."

        trainer_dict = {t.name: t for t in self._trainers}
        task_name2id = {t.name: idx for idx, t in enumerate(self._trainers)}
        self._task_name2id = task_name2id

        self._finish_steps = {}
        self._finish = {}
        input_names = []
        name_to_pos = []
        joint_shape_and_dtypes = []
        iterators = []
        prefixes = []
        mrs = []
        net_inputs = []
        global_steps = 0
        for t in self._trainers:
            assert t.name in reader_dict
            assert reader_dict[
                t.name].num_epochs is None, "{}: num_epochs is not None. \
                To run with multi-head mode, num_epochs of each Trainer should be set as None.".format(
                    t.name)
            # print(num_epochs, t.mix_ratio, base_steps_pur_epoch)
            self._finish_steps[t.name] = 9999999999
            self._finish[t.name] = True

            # t._set_task_id(self._task_id_var)
            t.fit_reader(reader_dict[t.name], phase='predict')
            net_inputs.append(t._pred_net_inputs)
            prefixes.append(t.name)
            iterators.append(t._raw_iterator_fn())
            input_names.append(t._pred_input_names)
            name_to_pos.append(t._pred_name_to_position)
            joint_shape_and_dtypes.append(t._pred_shape_and_dtypes)

        iterator_fn = reader_helper.create_multihead_inference_fn(iterators, prefixes, joint_shape_and_dtypes, \
            input_names, name_to_pos, task_name2id, dev_count=dev_count)
        feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(
            net_inputs)

        if gpu_dev_count > 1:
            raise NotImplementedError(
                'currently only single-gpu mode has been supported running with multi-task mode.'
            )
            # distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn, phase=phase, is_multi=True, with_arg=True)
        else:
            distribute_feeder_fn = iterator_fn

        self._predict_iterator_fn = distribute_feeder_fn
        self._pred_feed_batch_process_fn = feed_batch_process_fn
        return distribute_feeder_fn