Beispiel #1
0
class TrainingRunner:
    def __init__(self, conf: TRConf, model: ZModel, train_stream: Streamer,
                 train_batch_f: Callable, train_discard_batch_f: Callable,
                 dev_runners: List[TestingRunner], **kwargs):
        self.conf = conf
        self.kwargs = kwargs
        # --
        self.model = model
        self.train_stream = train_stream
        self.train_batch_f = train_batch_f if train_batch_f is not None else lambda x: 1  # inst -> int
        self.train_discard_batch_f = train_discard_batch_f if train_discard_batch_f is not None else lambda x: False  # List[inst] -> bool
        self.dev_runners = dev_runners
        # some records
        self.tp = TrainingProgressRecord()
        self.train_recorder = StatRecorder(timing=True)
        # -- special!!
        # store all insts for future usage (do not use if input is large)
        self.stored_all_insts = None
        if conf.store_all_insts:
            _all_insts = []
            for _one_insts in self.train_stream:
                _all_insts.extend(_one_insts)
            self.stored_all_insts = _all_insts
        # --
        # scheduled values
        self.lrate = ScheduledValue("lrate", conf.lrate)
        self.lrate_warmup_steps = self._determine_warmup_steps(
            conf.lrate_warmup_eidx, conf.lrate_warmup_uidx, train_stream)
        if self.lrate_warmup_steps > 0:
            self.lrate_warmup_factor = 1. / (self.lrate_warmup_steps**
                                             conf.lrate_decrease_alpha)
            zlog(
                f"For lrate-warmup, first {self.lrate_warmup_steps} steps up to {self.lrate.value}, "
                f"then decrease with lrate*{self.lrate_warmup_factor}*step^{conf.lrate_decrease_alpha}",
                func="plain")
        else:
            self.lrate_warmup_factor = 1.
        self.scheduled_values = OrderedDict([("_lrate", self.lrate)])
        DictHelper.update_dict(self.scheduled_values,
                               model.get_scheduled_values()
                               )  # values to be scheduled as training goes

    def _determine_warmup_steps(self, eidx: int, uidx: int, stream: Streamer):
        # set warmup steps if using eidx
        if eidx > 0:
            _epoch_step_size = len(list(stream))
            _warmup_steps = _epoch_step_size * eidx
            zlog(
                f"Determine warmup steps: {eidx} x {_epoch_step_size} = {_warmup_steps}",
                func="plain")
        else:
            _warmup_steps = 0
        _warmup_steps = max(_warmup_steps, uidx)
        return _warmup_steps

    def current_name(self):
        return self.tp.current_suffix()

    def add_scheduled_value(self, v: ScheduledValue, key_prefix=''):
        key = key_prefix + v.name
        assert key not in self.scheduled_values
        self.scheduled_values[key] = v

    def adjust_scheduled_values(self):
        # adjust schedule values
        ss = self.current_name()
        for one_name, one_sv in self.scheduled_values.items():
            if one_sv.changeable:
                one_sv.adjust_at_ckp(ss, self.tp, extra_info=one_name)

    # -----
    # saving and loading related
    def save_progress(self, file: str):
        default_json_serializer.to_file(self.tp, file)
        zlog(f"Save training progress to {file}", func="io")

    def load_progress(self, file: str, forward_stream=False):
        old_uidx = self.tp.uidx
        d = default_json_serializer.from_file(file)
        self.tp.from_json(d)
        if forward_stream:
            if old_uidx > self.tp.uidx:
                zwarn(
                    f"Cannot go to the past: {old_uidx} -> {self.tp.uidx}, skip this!"
                )
            else:
                _s = self.train_stream
                for _ in range(self.tp.uidx - old_uidx):
                    _, _eos = _s.next_and_check()
                    if _eos:  # restart and get one
                        _s.restart()
                        _s.next()
                zlog(f"Forward to the future: {old_uidx} -> {self.tp.uidx}!",
                     func="io")
        zlog(f"Load training progress from {file}", func="io")
        self.adjust_scheduled_values()  # also adjust values!

    def save(self, prefix: str):
        self.save_progress(prefix + ".tp.json")
        self.model.save(prefix + ".m")

    def load(self,
             prefix: str,
             load_progress=False,
             forward_stream=False,
             load_strict=None):
        if prefix.endswith(".m"):
            prefix = prefix[:-2]
        if load_progress:
            self.load_progress(prefix + ".tp.json", forward_stream)
        self.model.load(prefix + ".m", strict=load_strict)

    def get_train_stream(self):
        return self.train_stream

    # =====
    # run until the end of training
    def run(self):
        conf = self.conf
        last_report_uidx, last_dev_uidx = 0, 0
        # --
        if conf.valid_first:  # valid before training
            self.validate()
        # --
        _lrate_warmup_factor, _lrate_warmup_steps = self.lrate_warmup_factor, self.lrate_warmup_steps
        _skip_batch = conf.skip_batch
        _gen0 = Random.get_generator("train")
        _gen = Random.stream(_gen0.random_sample)
        # --
        _accu_checker = 0
        _accu_batch = conf.accu_batch
        # --
        # start before loop
        self.adjust_scheduled_values()
        # loop
        act_lrate = None
        while True:  # loop over and over
            _train_stream = self.get_train_stream(
            )  # we may change train_stream!!
            # --
            if _train_stream.is_inactive(
            ):  # check to avoid restart after load_progress
                _train_stream.restart()
            insts, _eos = _train_stream.next_and_check()
            if _eos:  # end of epoch
                zlog(
                    f"End of epoch at {self.tp.current_suffix(False)}: Current act_lrate is {act_lrate}.",
                    func="plain",
                    timed=True)
                if conf.valid_epoch:
                    last_dev_uidx = self.tp.uidx
                    self.validate()
                    # todo(+N): do we need to adjust sv at a finer grained?
                    self.adjust_scheduled_values()  # adjust after validation
                if self._finished():
                    break
                self.tp.update_eidx(1)
                continue
            # skip batch?
            if _skip_batch > 0 and next(_gen) < _skip_batch:
                continue
            if self.train_discard_batch_f(insts):
                continue  # discard this batch due to some specific reasons (like noevt)
            # run fb (possibly split batch)
            self.fb_batch(insts, 1. / _accu_batch)
            self.tp.update_iidx(len(insts))
            # ==
            # only update for certain accu fb runs
            _accu_checker += 1
            if _accu_checker % _accu_batch == 0:
                self.tp.update_uidx(1)
                cur_uidx = self.tp.uidx
                # get the effective lrate and update
                act_lrate = float(
                    self.lrate.value)  # start with the lrate.value
                if cur_uidx < _lrate_warmup_steps:  # linear increase
                    act_lrate *= (cur_uidx / _lrate_warmup_steps)
                else:  # decrease
                    act_lrate *= _lrate_warmup_factor * (
                        cur_uidx**conf.lrate_decrease_alpha)
                self._run_update(act_lrate, 1.)
                # --
                # report on training process
                if conf.flag_verbose and (
                        cur_uidx - last_report_uidx) >= conf.report_ufreq:
                    zlog(
                        f"Report at {self.tp.current_suffix(False)}: Current act_lrate is {act_lrate}.",
                        func="plain",
                        timed=True)
                    self._run_train_report()
                    last_report_uidx = cur_uidx
                # valid?
                if (cur_uidx - last_dev_uidx) >= conf.valid_ufreq:
                    last_dev_uidx = self.tp.uidx
                    self.validate()
                    # todo(+N): do we need to adjust sv at a finer grained?
                    self.adjust_scheduled_values()  # adjust after validation
                    if self._finished():
                        break
            # =====
        zlog(f"Finish training because of: {self._reach_ends()}", func="plain")
        zlog(
            f"zzzzzfinal: After training, the best point is: {self.tp.info_best()}.",
            func="report")

    # only finish when reaching any of the endings
    def _reach_ends(self):
        conf = self.conf
        cur_eidx, cur_uidx, cur_aidx = self.tp.eidx, self.tp.uidx, self.tp.aidx
        return cur_eidx >= conf.max_eidx, cur_uidx >= conf.max_uidx, cur_aidx >= conf.anneal_times

    def _finished(self):
        conf = self.conf
        cur_eidx, cur_uidx, cur_aidx = self.tp.eidx, self.tp.uidx, self.tp.aidx
        return (cur_eidx >= conf.min_eidx) and (
            cur_uidx >= conf.min_uidx) and any(self._reach_ends())

    # training for one batch
    def fb_batch(self, insts: List, loss_factor: float):
        with self.train_recorder.go():
            res = self._run_fb(insts, loss_factor)
            self.train_recorder.record(res)
        # --

    # do validation and record checkpoints
    def validate(self):
        conf = self.conf
        # report & reset training stat
        if self.tp.uidx > 0:
            train_result = self._run_train_report(
            )  # first report training stat
            self.train_recorder.reset()  # reset training stat
        else:  # for validate_first
            train_result = None
        # dev
        ss, cur_cidx = self.current_name(), self.tp.cidx
        zlog("", func="plain")  # empty line
        with Timer(info=f"Valid {ss}",
                   print_date=True), self.model.ema_wrap_dev():
            # no validation if specified
            if (self.tp.eidx < conf.valid_start_eidx) or (
                    self.tp.uidx < conf.valid_start_uidx):
                zlog("No validation since not the time yet!\n", func="plain")
                return
            # validate
            if len(self.dev_runners
                   ) == 0:  # simply use train if there are no dev
                zlog(
                    "Use training results for dev since there are no dev set provided!",
                    func="warn")
                dev_result = train_result
            else:
                dev_result = self._run_validate(self.dev_runners)
            # record
            cur_no_bad = (self.tp.eidx < conf.bad_start_eidx) or (
                self.tp.uidx < conf.bad_start_uidx)
            cur_record_best = (self.tp.cidx >= conf.record_best_cidx)
            if_overall_best, if_best, if_anneal = self.tp.update_checkpoint(
                train_result, dev_result, cur_no_bad, cur_record_best,
                conf.anneal_patience)
            # save curr & best
            self.save(conf.model_prefix + conf.model_suffix_curr)
            if if_overall_best:
                zlog("Curr is overall best " +
                     str(self.tp.info_overall_best()),
                     func="result")
            else:
                zlog("Curr not overall best, the overall best is " +
                     str(self.tp.info_overall_best()),
                     func="result")
            if if_best:
                self.save(conf.model_prefix + conf.model_suffix_best)
                zlog("Curr is best: " + str(self.tp.info_best()),
                     func="result")
            else:
                zlog("Curr not best, the best is " + str(self.tp.info_best()),
                     func="result")
            if cur_cidx >= conf.save_start_cidx and cur_cidx % conf.save_cfreq == 0:
                self.save(conf.model_prefix + ss)  # speical save
            if if_anneal and conf.anneal_restore:
                zlog("Restore from previous best model!!", func="plain")
                self.load(conf.model_prefix + conf.model_suffix_best, False)
        zlog("", func="plain")  # empty line

    # =====
    # template methods which can be overridden

    # return one train result
    def _run_fb(self, insts: List, loss_factor: float):
        res = self.model.loss_on_batch(insts, loss_factor=loss_factor)
        return res

    # return None
    def _run_update(self, lrate: float, grad_factor: float):
        self.model.update(lrate, grad_factor)

    # print and return train summary
    def _run_train_report(self) -> ResultRecord:
        x = self.train_recorder.summary()
        zlog(f"Train-Info: {OtherHelper.printd_str(x, sep=' ')}")
        return ResultRecord(results=x, description=None)

    # run and return dev results
    def _run_validate(self, dev_runners: List[TestingRunner]) -> ResultRecord:
        if len(dev_runners) == 0: return ResultRecord.get_nil()
        all_records: List[ResultRecord] = [r.run() for r in dev_runners]
        # note: use devs[0] as the criterion, assuming that is the dev itself!!
        r = ResultRecord(results={
            f"res{ii}": v.results
            for ii, v in enumerate(all_records)
        },
                         description={
                             f"res{ii}": str(v)
                             for ii, v in enumerate(all_records)
                         },
                         score=all_records[0].score)
        return r
Beispiel #2
0
class RunCenter:
    def __init__(self, conf: RunCenterConf, model, t_center: TaskCenter,
                 d_center: DataCenter):
        self.conf = conf
        self.model = model
        self.t_center = t_center
        self.d_center = d_center
        # ==
        # for train
        self.tp = TrainingProgressRecord()
        self.train_recorder = StatRecorder(timing=True)
        # --
        self.lrate = ScheduledValue("lrate", conf.lrate)
        self.scheduled_values = OrderedDict([("_lrate", self.lrate)
                                             ])  # add all scheduled values
        DictHelper.update_dict(self.scheduled_values,
                               model.get_scheduled_values())
        DictHelper.update_dict(self.scheduled_values,
                               d_center.get_scheduled_values())
        # --
        self.lrate_warmup_steps = conf.lrate_warmup_uidx
        if self.lrate_warmup_steps > 0:
            self.lrate_warmup_factor = 1. / (self.lrate_warmup_steps**
                                             conf.lrate_decrease_alpha)
            zlog(
                f"For lrate-warmup, first {self.lrate_warmup_steps} steps up to {self.lrate.value}, "
                f"then decrease with lrate*{self.lrate_warmup_factor}*step^{conf.lrate_decrease_alpha}",
                func="plain")
        else:
            self.lrate_warmup_factor = 1.
        # ==
        # for ddp
        self.ddp_world_size = BK.ddp_world_size()
        self.ddp_rank = BK.ddp_rank()
        self.is_main_process = BK.is_main_process()  # handle the savings!
        if self.ddp_world_size > 1:
            assert conf.accu_batch == 1, "accu_batch and ddp may conflict!!"
        # --

    # --
    # helpers

    def current_name(self):
        return self.tp.current_suffix()

    def adjust_scheduled_values(self):
        # adjust schedule values
        ss = self.current_name()
        for one_name, one_sv in self.scheduled_values.items():
            if one_sv.changeable:
                one_sv.adjust_at_ckp(ss, self.tp, extra_info=one_name)

    # saving and loading related
    def save_progress(self, file: str):
        default_json_serializer.to_file(self.tp, file)
        zlog(f"Save training progress to {file}", func="io")

    def load_progress(self, file: str, forward_stream=False):
        d = default_json_serializer.from_file(file)
        self.tp.from_json(d)
        assert not forward_stream, "Error: 'forward_stream' not supported in this mode!!"
        zlog(f"Load training progress from {file}", func="io")
        self.adjust_scheduled_values()  # also adjust values!

    def save(self, prefix: str):
        if self.is_main_process:  # note: do saving only with main process
            self.save_progress(prefix + ".tp.json")
            self.model.save(prefix + ".m")

    def load(self, prefix: str, load_progress=False, forward_stream=False):
        if prefix.endswith(".m"):
            prefix = prefix[:-2]
        if load_progress:
            self.load_progress(prefix + ".tp.json", forward_stream)
        self.model.load(prefix + ".m")

    # go
    # training for one batch
    def fb_batch(self, ibatch, loss_factor: float):
        with self.train_recorder.go():
            res = self.model.loss_on_batch(ibatch, loss_factor)
            self.train_recorder.record(res)
        # --

    def train_finished(self):
        conf = self.conf
        return self.tp.uidx >= conf.max_uidx

    # print and return train summary
    def run_train_report(self) -> ResultRecord:
        x = self.train_recorder.summary()
        zlog(f"Train-Info: {OtherHelper.printd_str(x, sep=' ')}")
        # also report uidx_counter/iidx_counter
        zlog(f"UidxCounter: {self.tp.uidx_counter}")
        zlog(f"IidxCounter: {self.tp.iidx_counter}")
        return ResultRecord(results=x, description=None)

    # run test on one dataset
    def run_test_dataset(self, dataset: ZDataset):
        test_recorder = StatRecorder(timing=True)
        for ibatch in dataset.yield_batches(loop=False):
            with test_recorder.go():
                one_res = self.model.predict_on_batch(ibatch)
                test_recorder.record(one_res)
        # --
        # write output
        if self.is_main_process:  # note: do saving only with main process
            dataset.write_insts(None)  # use dataset's conf
        # --
        # eval
        x = test_recorder.summary()
        zlog(f"Test-Info: {OtherHelper.printd_str(x, sep=' ')}")
        aggr = ResultAggregator()
        for task in self.t_center.tasks.values():
            if task.name not in dataset.tasks:
                continue
            tn_res: ResultRecord = task.eval_insts(dataset.gold_insts,
                                                   dataset.insts,
                                                   quite=False)
            if tn_res is None:
                continue
            aggr.add(task.name, tn_res, task.conf.eval_weight)
        ret = aggr.get_res()
        return ret

    # --
    # main runs

    def do_train(self):
        model, t_center, d_center = self.model, self.t_center, self.d_center
        conf = self.conf
        last_dev_uidx = 0
        # --
        if conf.valid_first:  # valid before training
            self.do_dev()
        # --
        _lrate_warmup_factor, _lrate_warmup_steps = self.lrate_warmup_factor, self.lrate_warmup_steps
        _accu_batch = conf.accu_batch
        self.adjust_scheduled_values()  # once before train
        train_yielder = d_center.yield_train_yielder()
        while not self.train_finished():  # loop over and over
            # sample batch
            cur_yielder = next(train_yielder)
            # fb for accu_batch/ddp steps (use the same dataset/yielder)
            cur_dname = None
            for _i0 in range(self.ddp_world_size):
                # if 1:
                for _i1 in range(_accu_batch):
                    cur_ibatch = next(cur_yielder)
                    cur_dname = cur_ibatch.dataset.name
                    if _i0 == self.ddp_rank:  # only for current rank!!
                        # if 1:
                        self.fb_batch(cur_ibatch, 1. / _accu_batch)
                        self.tp.update_iidx(len(cur_ibatch), cur_dname)
            # update
            self.tp.update_uidx(1, cur_dname)
            cur_uidx = self.tp.uidx
            # get the effective lrate and update
            act_lrate = float(self.lrate.value)  # start with the lrate.value
            if cur_uidx < _lrate_warmup_steps:  # linear increase
                act_lrate *= (cur_uidx / _lrate_warmup_steps)
            else:  # decrease
                act_lrate *= _lrate_warmup_factor * (cur_uidx**
                                                     conf.lrate_decrease_alpha)
            with self.train_recorder.go('update'):  # also record this!
                self.model.update(act_lrate, 1.)
            # valid?
            if (cur_uidx - last_dev_uidx) >= conf.valid_ufreq:
                last_dev_uidx = cur_uidx
                self.do_dev()
                # todo(+N): do we need to adjust sv at a finer grained?
                self.adjust_scheduled_values()  # adjust after validation
        # --
        zlog(
            f"zzzzzfinal: After training, the best point is: {self.tp.info_best()}.",
            func="report")

    def do_dev(self):
        conf = self.conf
        # report & reset training stat
        if self.tp.uidx > 0:
            train_result = self.run_train_report(
            )  # first report training stat
            self.train_recorder.reset()  # reset training stat
        else:  # for validate_first
            train_result = ResultRecord.get_nil()
        # dev
        ss, cur_cidx = self.current_name(), self.tp.cidx
        zlog("", func="plain")  # empty line
        with Timer(info=f"Valid {ss}", print_date=True):
            # no validation if specified
            if self.tp.uidx < conf.valid_start_uidx:
                zlog("No validation since not the time yet!\n", func="plain")
                return
            # validate
            if len(self.d_center.get_datasets(
                    wset="dev")) == 0:  # simply use train if there are no dev
                zlog(
                    "Use training results for dev since there are no dev set provided!",
                    func="warn")
                dev_result = train_result
            else:
                dev_result = self.do_test("dev")
            # record
            cur_record_best = (self.tp.cidx >= conf.record_best_start_cidx)
            if_overall_best, if_best, if_anneal = self.tp.update_checkpoint(
                train_result, dev_result, record_best=cur_record_best)
            # save curr & best
            self.save(conf.model_save_prefix + conf.model_save_suffix_curr)
            if if_overall_best:
                zlog("Curr is overall best " +
                     str(self.tp.info_overall_best()),
                     func="result")
            else:
                zlog("Curr not overall best, the overall best is " +
                     str(self.tp.info_overall_best()),
                     func="result")
            if if_best:
                self.save(conf.model_save_prefix + conf.model_save_suffix_best)
                zlog("Curr is best: " + str(self.tp.info_best()),
                     func="result")
            else:
                zlog("Curr not best, the best is " + str(self.tp.info_best()),
                     func="result")
            if cur_cidx >= conf.save_special_start_cidx and cur_cidx % conf.save_special_cfreq == 0:
                self.save(conf.model_save_prefix + ss)  # speical save
        # --
        zlog("", func="plain")  # empty line

    def do_test(self, wset="test"):
        model, t_center, d_center = self.model, self.t_center, self.d_center
        conf = self.conf
        # --
        to_test_datasets = d_center.get_datasets(wset=wset)
        t_center.prepare_datasets(to_test_datasets)  # re-prepare!!
        aggr = ResultAggregator()
        for one_ii, one_dataset in enumerate(to_test_datasets):
            with Timer(
                    info=
                    f"Test({one_ii+1}/{len(to_test_datasets)}): {one_dataset}",
                    print_date=True):
                one_res = self.run_test_dataset(one_dataset)
                aggr.add(one_dataset.name, one_res,
                         one_dataset.conf.group_eval_weight)
        ret = aggr.get_res()
        return ret