예제 #1
0
def main(*args):
    conf = MainConf()
    conf.update_from_args(args)
    # --
    if conf.R.input_path:
        # stat mode
        # --
        reader = conf.R.get_reader()
        insts = list(reader)
        stat = StatRecorder()
        if len(insts) > 0:
            if isinstance(insts[0], Doc):
                stat_docs(insts, stat)
            else:
                stat_sents(insts, stat)
        # --
        key = conf.R.input_path
        res = {}
        for ss in [stat.plain_values, stat.special_values]:
            res.update(ss)
        show_res = [f"{kk}: {str(res[kk])}\n" for kk in sorted(res.keys())]
        zlog(f"# -- Stat Mode, Read from {key} and updating {conf.result_center}:\n{''.join(show_res)}")
        if conf.result_center:
            if os.path.isfile(conf.result_center):
                d0 = default_json_serializer.from_file(conf.result_center)
            else:
                d0 = {}
            d0[key] = res
            default_json_serializer.to_file(d0, conf.result_center)
            # breakpoint()
    else:
        # query mode: query across datasets (key)
        data = default_json_serializer.from_file(conf.result_center)
        pattern = re.compile(conf.key_re_pattern)
        hit_keys = sorted(k for k in data.keys() if re.fullmatch(pattern, k))
        zlog(f"Query for {hit_keys}")
        # breakpoint()
        # --
        while True:
            try:
                code = input(">> ")
            except EOFError:
                break
            except KeyboardInterrupt:
                continue
            code = code.strip()
            if len(code) == 0: continue
            # --
            zlog(f"Eval `{code}':")
            for k in hit_keys:
                d = data[k]
                try:
                    one_res = eval(code)
                except:
                    one_res = traceback.format_exc()
                zlog(f"#--{k}:\n{one_res}")
예제 #2
0
 def __init__(self, conf: TRConf, model: ZModel, train_stream: Streamer,
              train_batch_f: Callable, train_discard_batch_f: Callable,
              dev_runners: List[TestingRunner], **kwargs):
     self.conf = conf
     self.kwargs = kwargs
     # --
     self.model = model
     self.train_stream = train_stream
     self.train_batch_f = train_batch_f if train_batch_f is not None else lambda x: 1  # inst -> int
     self.train_discard_batch_f = train_discard_batch_f if train_discard_batch_f is not None else lambda x: False  # List[inst] -> bool
     self.dev_runners = dev_runners
     # some records
     self.tp = TrainingProgressRecord()
     self.train_recorder = StatRecorder(timing=True)
     # -- special!!
     # store all insts for future usage (do not use if input is large)
     self.stored_all_insts = None
     if conf.store_all_insts:
         _all_insts = []
         for _one_insts in self.train_stream:
             _all_insts.extend(_one_insts)
         self.stored_all_insts = _all_insts
     # --
     # scheduled values
     self.lrate = ScheduledValue("lrate", conf.lrate)
     self.lrate_warmup_steps = self._determine_warmup_steps(
         conf.lrate_warmup_eidx, conf.lrate_warmup_uidx, train_stream)
     if self.lrate_warmup_steps > 0:
         self.lrate_warmup_factor = 1. / (self.lrate_warmup_steps**
                                          conf.lrate_decrease_alpha)
         zlog(
             f"For lrate-warmup, first {self.lrate_warmup_steps} steps up to {self.lrate.value}, "
             f"then decrease with lrate*{self.lrate_warmup_factor}*step^{conf.lrate_decrease_alpha}",
             func="plain")
     else:
         self.lrate_warmup_factor = 1.
     self.scheduled_values = OrderedDict([("_lrate", self.lrate)])
     DictHelper.update_dict(self.scheduled_values,
                            model.get_scheduled_values()
                            )  # values to be scheduled as training goes
예제 #3
0
 def run_test_dataset(self, dataset: ZDataset):
     test_recorder = StatRecorder(timing=True)
     for ibatch in dataset.yield_batches(loop=False):
         with test_recorder.go():
             one_res = self.model.predict_on_batch(ibatch)
             test_recorder.record(one_res)
     # --
     # write output
     if self.is_main_process:  # note: do saving only with main process
         dataset.write_insts(None)  # use dataset's conf
     # --
     # eval
     x = test_recorder.summary()
     zlog(f"Test-Info: {OtherHelper.printd_str(x, sep=' ')}")
     aggr = ResultAggregator()
     for task in self.t_center.tasks.values():
         if task.name not in dataset.tasks:
             continue
         tn_res: ResultRecord = task.eval_insts(dataset.gold_insts,
                                                dataset.insts,
                                                quite=False)
         if tn_res is None:
             continue
         aggr.add(task.name, tn_res, task.conf.eval_weight)
     ret = aggr.get_res()
     return ret
예제 #4
0
 def __init__(self, conf: RunCenterConf, model, t_center: TaskCenter,
              d_center: DataCenter):
     self.conf = conf
     self.model = model
     self.t_center = t_center
     self.d_center = d_center
     # ==
     # for train
     self.tp = TrainingProgressRecord()
     self.train_recorder = StatRecorder(timing=True)
     # --
     self.lrate = ScheduledValue("lrate", conf.lrate)
     self.scheduled_values = OrderedDict([("_lrate", self.lrate)
                                          ])  # add all scheduled values
     DictHelper.update_dict(self.scheduled_values,
                            model.get_scheduled_values())
     DictHelper.update_dict(self.scheduled_values,
                            d_center.get_scheduled_values())
     # --
     self.lrate_warmup_steps = conf.lrate_warmup_uidx
     if self.lrate_warmup_steps > 0:
         self.lrate_warmup_factor = 1. / (self.lrate_warmup_steps**
                                          conf.lrate_decrease_alpha)
         zlog(
             f"For lrate-warmup, first {self.lrate_warmup_steps} steps up to {self.lrate.value}, "
             f"then decrease with lrate*{self.lrate_warmup_factor}*step^{conf.lrate_decrease_alpha}",
             func="plain")
     else:
         self.lrate_warmup_factor = 1.
     # ==
     # for ddp
     self.ddp_world_size = BK.ddp_world_size()
     self.ddp_rank = BK.ddp_rank()
     self.is_main_process = BK.is_main_process()  # handle the savings!
     if self.ddp_world_size > 1:
         assert conf.accu_batch == 1, "accu_batch and ddp may conflict!!"
예제 #5
0
def stat_docs(docs: List[Doc], stat: StatRecorder):
    for doc in docs:
        stat.record_kv("doc", 1)
        stat.srecord_kv("doc_nsent_d10", len(doc.sents)//10)
        stat_sents(doc.sents, stat)
예제 #6
0
def stat_sents(sents: List[Sent], stat: StatRecorder):
    # --
    def _has_overlap(_f1, _f2):
        start1, end1 = _f1.mention.widx, _f1.mention.wridx
        start2, end2 = _f2.mention.widx, _f2.mention.wridx
        return not (start1>=end2 or start2>=end1)
    # --
    for sent in sents:
        stat.record_kv("sent", 1)
        stat.record_kv("tok", len(sent))
        stat.srecord_kv("sent_ntok_d10", len(sent)//10)
        stat.srecord_kv("sent_nframe", len(sent.events))
        cur_pos_list = sent.seq_upos.vals if sent.seq_upos is not None else None
        # frame
        for frame in sent.events:
            widx, wlen = frame.mention.get_span()
            # --
            stat.record_kv("frame", 1)
            # frame target length
            stat.srecord_kv("frame_wlen", wlen)
            # frame trigger upos
            stat.srecord_kv("frame_trigger_pos", ",".join([] if cur_pos_list is None else cur_pos_list[widx:widx+wlen]))
            # frame target overlap with others?
            stat.record_kv("frame_overlapped", int(any(_has_overlap(frame, f2) for f2 in sent.events if f2 is not frame)))
            # frame type
            stat.srecord_kv("frame_type", frame.type)
            stat.srecord_kv("frame_type0", frame.type.split(".")[0])  # in case of PB
            # args
            all_args = Counter()
            stat.srecord_kv("frame_narg", len(frame.args))
            for alink in frame.args:
                rank = alink.info.get("rank", 1)
                # --
                stat.record_kv("arg", 1)
                stat.record_kv(f"arg_R{rank}", 1)
                # arg target length
                stat.srecord_kv("arg_wlen_m30", min(30, alink.mention.wlen))
                # arg overlap with others?
                stat.record_kv("arg_overlapped", int(any(_has_overlap(alink, a2) for a2 in frame.args if a2 is not alink)))
                stat.record_kv(f"arg_overlapped_R{rank}",
                               int(any(_has_overlap(alink, a2) for a2 in frame.args if
                                       a2 is not alink and a2.info.get("rank", 1) == rank)))
                # arg role
                stat.srecord_kv("arg_role", alink.role)
                # --
                all_args[alink.role] += 1
            # check repeat
            for rr, cc in all_args.items():
                stat.srecord_kv("arg_repeat", cc, c=cc)
                if cc>1:
                    stat.srecord_kv("arg_repeatR", f"{cc}*{rr}")
예제 #7
0
class TrainingRunner:
    def __init__(self, conf: TRConf, model: ZModel, train_stream: Streamer,
                 train_batch_f: Callable, train_discard_batch_f: Callable,
                 dev_runners: List[TestingRunner], **kwargs):
        self.conf = conf
        self.kwargs = kwargs
        # --
        self.model = model
        self.train_stream = train_stream
        self.train_batch_f = train_batch_f if train_batch_f is not None else lambda x: 1  # inst -> int
        self.train_discard_batch_f = train_discard_batch_f if train_discard_batch_f is not None else lambda x: False  # List[inst] -> bool
        self.dev_runners = dev_runners
        # some records
        self.tp = TrainingProgressRecord()
        self.train_recorder = StatRecorder(timing=True)
        # -- special!!
        # store all insts for future usage (do not use if input is large)
        self.stored_all_insts = None
        if conf.store_all_insts:
            _all_insts = []
            for _one_insts in self.train_stream:
                _all_insts.extend(_one_insts)
            self.stored_all_insts = _all_insts
        # --
        # scheduled values
        self.lrate = ScheduledValue("lrate", conf.lrate)
        self.lrate_warmup_steps = self._determine_warmup_steps(
            conf.lrate_warmup_eidx, conf.lrate_warmup_uidx, train_stream)
        if self.lrate_warmup_steps > 0:
            self.lrate_warmup_factor = 1. / (self.lrate_warmup_steps**
                                             conf.lrate_decrease_alpha)
            zlog(
                f"For lrate-warmup, first {self.lrate_warmup_steps} steps up to {self.lrate.value}, "
                f"then decrease with lrate*{self.lrate_warmup_factor}*step^{conf.lrate_decrease_alpha}",
                func="plain")
        else:
            self.lrate_warmup_factor = 1.
        self.scheduled_values = OrderedDict([("_lrate", self.lrate)])
        DictHelper.update_dict(self.scheduled_values,
                               model.get_scheduled_values()
                               )  # values to be scheduled as training goes

    def _determine_warmup_steps(self, eidx: int, uidx: int, stream: Streamer):
        # set warmup steps if using eidx
        if eidx > 0:
            _epoch_step_size = len(list(stream))
            _warmup_steps = _epoch_step_size * eidx
            zlog(
                f"Determine warmup steps: {eidx} x {_epoch_step_size} = {_warmup_steps}",
                func="plain")
        else:
            _warmup_steps = 0
        _warmup_steps = max(_warmup_steps, uidx)
        return _warmup_steps

    def current_name(self):
        return self.tp.current_suffix()

    def add_scheduled_value(self, v: ScheduledValue, key_prefix=''):
        key = key_prefix + v.name
        assert key not in self.scheduled_values
        self.scheduled_values[key] = v

    def adjust_scheduled_values(self):
        # adjust schedule values
        ss = self.current_name()
        for one_name, one_sv in self.scheduled_values.items():
            if one_sv.changeable:
                one_sv.adjust_at_ckp(ss, self.tp, extra_info=one_name)

    # -----
    # saving and loading related
    def save_progress(self, file: str):
        default_json_serializer.to_file(self.tp, file)
        zlog(f"Save training progress to {file}", func="io")

    def load_progress(self, file: str, forward_stream=False):
        old_uidx = self.tp.uidx
        d = default_json_serializer.from_file(file)
        self.tp.from_json(d)
        if forward_stream:
            if old_uidx > self.tp.uidx:
                zwarn(
                    f"Cannot go to the past: {old_uidx} -> {self.tp.uidx}, skip this!"
                )
            else:
                _s = self.train_stream
                for _ in range(self.tp.uidx - old_uidx):
                    _, _eos = _s.next_and_check()
                    if _eos:  # restart and get one
                        _s.restart()
                        _s.next()
                zlog(f"Forward to the future: {old_uidx} -> {self.tp.uidx}!",
                     func="io")
        zlog(f"Load training progress from {file}", func="io")
        self.adjust_scheduled_values()  # also adjust values!

    def save(self, prefix: str):
        self.save_progress(prefix + ".tp.json")
        self.model.save(prefix + ".m")

    def load(self,
             prefix: str,
             load_progress=False,
             forward_stream=False,
             load_strict=None):
        if prefix.endswith(".m"):
            prefix = prefix[:-2]
        if load_progress:
            self.load_progress(prefix + ".tp.json", forward_stream)
        self.model.load(prefix + ".m", strict=load_strict)

    def get_train_stream(self):
        return self.train_stream

    # =====
    # run until the end of training
    def run(self):
        conf = self.conf
        last_report_uidx, last_dev_uidx = 0, 0
        # --
        if conf.valid_first:  # valid before training
            self.validate()
        # --
        _lrate_warmup_factor, _lrate_warmup_steps = self.lrate_warmup_factor, self.lrate_warmup_steps
        _skip_batch = conf.skip_batch
        _gen0 = Random.get_generator("train")
        _gen = Random.stream(_gen0.random_sample)
        # --
        _accu_checker = 0
        _accu_batch = conf.accu_batch
        # --
        # start before loop
        self.adjust_scheduled_values()
        # loop
        act_lrate = None
        while True:  # loop over and over
            _train_stream = self.get_train_stream(
            )  # we may change train_stream!!
            # --
            if _train_stream.is_inactive(
            ):  # check to avoid restart after load_progress
                _train_stream.restart()
            insts, _eos = _train_stream.next_and_check()
            if _eos:  # end of epoch
                zlog(
                    f"End of epoch at {self.tp.current_suffix(False)}: Current act_lrate is {act_lrate}.",
                    func="plain",
                    timed=True)
                if conf.valid_epoch:
                    last_dev_uidx = self.tp.uidx
                    self.validate()
                    # todo(+N): do we need to adjust sv at a finer grained?
                    self.adjust_scheduled_values()  # adjust after validation
                if self._finished():
                    break
                self.tp.update_eidx(1)
                continue
            # skip batch?
            if _skip_batch > 0 and next(_gen) < _skip_batch:
                continue
            if self.train_discard_batch_f(insts):
                continue  # discard this batch due to some specific reasons (like noevt)
            # run fb (possibly split batch)
            self.fb_batch(insts, 1. / _accu_batch)
            self.tp.update_iidx(len(insts))
            # ==
            # only update for certain accu fb runs
            _accu_checker += 1
            if _accu_checker % _accu_batch == 0:
                self.tp.update_uidx(1)
                cur_uidx = self.tp.uidx
                # get the effective lrate and update
                act_lrate = float(
                    self.lrate.value)  # start with the lrate.value
                if cur_uidx < _lrate_warmup_steps:  # linear increase
                    act_lrate *= (cur_uidx / _lrate_warmup_steps)
                else:  # decrease
                    act_lrate *= _lrate_warmup_factor * (
                        cur_uidx**conf.lrate_decrease_alpha)
                self._run_update(act_lrate, 1.)
                # --
                # report on training process
                if conf.flag_verbose and (
                        cur_uidx - last_report_uidx) >= conf.report_ufreq:
                    zlog(
                        f"Report at {self.tp.current_suffix(False)}: Current act_lrate is {act_lrate}.",
                        func="plain",
                        timed=True)
                    self._run_train_report()
                    last_report_uidx = cur_uidx
                # valid?
                if (cur_uidx - last_dev_uidx) >= conf.valid_ufreq:
                    last_dev_uidx = self.tp.uidx
                    self.validate()
                    # todo(+N): do we need to adjust sv at a finer grained?
                    self.adjust_scheduled_values()  # adjust after validation
                    if self._finished():
                        break
            # =====
        zlog(f"Finish training because of: {self._reach_ends()}", func="plain")
        zlog(
            f"zzzzzfinal: After training, the best point is: {self.tp.info_best()}.",
            func="report")

    # only finish when reaching any of the endings
    def _reach_ends(self):
        conf = self.conf
        cur_eidx, cur_uidx, cur_aidx = self.tp.eidx, self.tp.uidx, self.tp.aidx
        return cur_eidx >= conf.max_eidx, cur_uidx >= conf.max_uidx, cur_aidx >= conf.anneal_times

    def _finished(self):
        conf = self.conf
        cur_eidx, cur_uidx, cur_aidx = self.tp.eidx, self.tp.uidx, self.tp.aidx
        return (cur_eidx >= conf.min_eidx) and (
            cur_uidx >= conf.min_uidx) and any(self._reach_ends())

    # training for one batch
    def fb_batch(self, insts: List, loss_factor: float):
        with self.train_recorder.go():
            res = self._run_fb(insts, loss_factor)
            self.train_recorder.record(res)
        # --

    # do validation and record checkpoints
    def validate(self):
        conf = self.conf
        # report & reset training stat
        if self.tp.uidx > 0:
            train_result = self._run_train_report(
            )  # first report training stat
            self.train_recorder.reset()  # reset training stat
        else:  # for validate_first
            train_result = None
        # dev
        ss, cur_cidx = self.current_name(), self.tp.cidx
        zlog("", func="plain")  # empty line
        with Timer(info=f"Valid {ss}",
                   print_date=True), self.model.ema_wrap_dev():
            # no validation if specified
            if (self.tp.eidx < conf.valid_start_eidx) or (
                    self.tp.uidx < conf.valid_start_uidx):
                zlog("No validation since not the time yet!\n", func="plain")
                return
            # validate
            if len(self.dev_runners
                   ) == 0:  # simply use train if there are no dev
                zlog(
                    "Use training results for dev since there are no dev set provided!",
                    func="warn")
                dev_result = train_result
            else:
                dev_result = self._run_validate(self.dev_runners)
            # record
            cur_no_bad = (self.tp.eidx < conf.bad_start_eidx) or (
                self.tp.uidx < conf.bad_start_uidx)
            cur_record_best = (self.tp.cidx >= conf.record_best_cidx)
            if_overall_best, if_best, if_anneal = self.tp.update_checkpoint(
                train_result, dev_result, cur_no_bad, cur_record_best,
                conf.anneal_patience)
            # save curr & best
            self.save(conf.model_prefix + conf.model_suffix_curr)
            if if_overall_best:
                zlog("Curr is overall best " +
                     str(self.tp.info_overall_best()),
                     func="result")
            else:
                zlog("Curr not overall best, the overall best is " +
                     str(self.tp.info_overall_best()),
                     func="result")
            if if_best:
                self.save(conf.model_prefix + conf.model_suffix_best)
                zlog("Curr is best: " + str(self.tp.info_best()),
                     func="result")
            else:
                zlog("Curr not best, the best is " + str(self.tp.info_best()),
                     func="result")
            if cur_cidx >= conf.save_start_cidx and cur_cidx % conf.save_cfreq == 0:
                self.save(conf.model_prefix + ss)  # speical save
            if if_anneal and conf.anneal_restore:
                zlog("Restore from previous best model!!", func="plain")
                self.load(conf.model_prefix + conf.model_suffix_best, False)
        zlog("", func="plain")  # empty line

    # =====
    # template methods which can be overridden

    # return one train result
    def _run_fb(self, insts: List, loss_factor: float):
        res = self.model.loss_on_batch(insts, loss_factor=loss_factor)
        return res

    # return None
    def _run_update(self, lrate: float, grad_factor: float):
        self.model.update(lrate, grad_factor)

    # print and return train summary
    def _run_train_report(self) -> ResultRecord:
        x = self.train_recorder.summary()
        zlog(f"Train-Info: {OtherHelper.printd_str(x, sep=' ')}")
        return ResultRecord(results=x, description=None)

    # run and return dev results
    def _run_validate(self, dev_runners: List[TestingRunner]) -> ResultRecord:
        if len(dev_runners) == 0: return ResultRecord.get_nil()
        all_records: List[ResultRecord] = [r.run() for r in dev_runners]
        # note: use devs[0] as the criterion, assuming that is the dev itself!!
        r = ResultRecord(results={
            f"res{ii}": v.results
            for ii, v in enumerate(all_records)
        },
                         description={
                             f"res{ii}": str(v)
                             for ii, v in enumerate(all_records)
                         },
                         score=all_records[0].score)
        return r
예제 #8
0
class RunCenter:
    def __init__(self, conf: RunCenterConf, model, t_center: TaskCenter,
                 d_center: DataCenter):
        self.conf = conf
        self.model = model
        self.t_center = t_center
        self.d_center = d_center
        # ==
        # for train
        self.tp = TrainingProgressRecord()
        self.train_recorder = StatRecorder(timing=True)
        # --
        self.lrate = ScheduledValue("lrate", conf.lrate)
        self.scheduled_values = OrderedDict([("_lrate", self.lrate)
                                             ])  # add all scheduled values
        DictHelper.update_dict(self.scheduled_values,
                               model.get_scheduled_values())
        DictHelper.update_dict(self.scheduled_values,
                               d_center.get_scheduled_values())
        # --
        self.lrate_warmup_steps = conf.lrate_warmup_uidx
        if self.lrate_warmup_steps > 0:
            self.lrate_warmup_factor = 1. / (self.lrate_warmup_steps**
                                             conf.lrate_decrease_alpha)
            zlog(
                f"For lrate-warmup, first {self.lrate_warmup_steps} steps up to {self.lrate.value}, "
                f"then decrease with lrate*{self.lrate_warmup_factor}*step^{conf.lrate_decrease_alpha}",
                func="plain")
        else:
            self.lrate_warmup_factor = 1.
        # ==
        # for ddp
        self.ddp_world_size = BK.ddp_world_size()
        self.ddp_rank = BK.ddp_rank()
        self.is_main_process = BK.is_main_process()  # handle the savings!
        if self.ddp_world_size > 1:
            assert conf.accu_batch == 1, "accu_batch and ddp may conflict!!"
        # --

    # --
    # helpers

    def current_name(self):
        return self.tp.current_suffix()

    def adjust_scheduled_values(self):
        # adjust schedule values
        ss = self.current_name()
        for one_name, one_sv in self.scheduled_values.items():
            if one_sv.changeable:
                one_sv.adjust_at_ckp(ss, self.tp, extra_info=one_name)

    # saving and loading related
    def save_progress(self, file: str):
        default_json_serializer.to_file(self.tp, file)
        zlog(f"Save training progress to {file}", func="io")

    def load_progress(self, file: str, forward_stream=False):
        d = default_json_serializer.from_file(file)
        self.tp.from_json(d)
        assert not forward_stream, "Error: 'forward_stream' not supported in this mode!!"
        zlog(f"Load training progress from {file}", func="io")
        self.adjust_scheduled_values()  # also adjust values!

    def save(self, prefix: str):
        if self.is_main_process:  # note: do saving only with main process
            self.save_progress(prefix + ".tp.json")
            self.model.save(prefix + ".m")

    def load(self, prefix: str, load_progress=False, forward_stream=False):
        if prefix.endswith(".m"):
            prefix = prefix[:-2]
        if load_progress:
            self.load_progress(prefix + ".tp.json", forward_stream)
        self.model.load(prefix + ".m")

    # go
    # training for one batch
    def fb_batch(self, ibatch, loss_factor: float):
        with self.train_recorder.go():
            res = self.model.loss_on_batch(ibatch, loss_factor)
            self.train_recorder.record(res)
        # --

    def train_finished(self):
        conf = self.conf
        return self.tp.uidx >= conf.max_uidx

    # print and return train summary
    def run_train_report(self) -> ResultRecord:
        x = self.train_recorder.summary()
        zlog(f"Train-Info: {OtherHelper.printd_str(x, sep=' ')}")
        # also report uidx_counter/iidx_counter
        zlog(f"UidxCounter: {self.tp.uidx_counter}")
        zlog(f"IidxCounter: {self.tp.iidx_counter}")
        return ResultRecord(results=x, description=None)

    # run test on one dataset
    def run_test_dataset(self, dataset: ZDataset):
        test_recorder = StatRecorder(timing=True)
        for ibatch in dataset.yield_batches(loop=False):
            with test_recorder.go():
                one_res = self.model.predict_on_batch(ibatch)
                test_recorder.record(one_res)
        # --
        # write output
        if self.is_main_process:  # note: do saving only with main process
            dataset.write_insts(None)  # use dataset's conf
        # --
        # eval
        x = test_recorder.summary()
        zlog(f"Test-Info: {OtherHelper.printd_str(x, sep=' ')}")
        aggr = ResultAggregator()
        for task in self.t_center.tasks.values():
            if task.name not in dataset.tasks:
                continue
            tn_res: ResultRecord = task.eval_insts(dataset.gold_insts,
                                                   dataset.insts,
                                                   quite=False)
            if tn_res is None:
                continue
            aggr.add(task.name, tn_res, task.conf.eval_weight)
        ret = aggr.get_res()
        return ret

    # --
    # main runs

    def do_train(self):
        model, t_center, d_center = self.model, self.t_center, self.d_center
        conf = self.conf
        last_dev_uidx = 0
        # --
        if conf.valid_first:  # valid before training
            self.do_dev()
        # --
        _lrate_warmup_factor, _lrate_warmup_steps = self.lrate_warmup_factor, self.lrate_warmup_steps
        _accu_batch = conf.accu_batch
        self.adjust_scheduled_values()  # once before train
        train_yielder = d_center.yield_train_yielder()
        while not self.train_finished():  # loop over and over
            # sample batch
            cur_yielder = next(train_yielder)
            # fb for accu_batch/ddp steps (use the same dataset/yielder)
            cur_dname = None
            for _i0 in range(self.ddp_world_size):
                # if 1:
                for _i1 in range(_accu_batch):
                    cur_ibatch = next(cur_yielder)
                    cur_dname = cur_ibatch.dataset.name
                    if _i0 == self.ddp_rank:  # only for current rank!!
                        # if 1:
                        self.fb_batch(cur_ibatch, 1. / _accu_batch)
                        self.tp.update_iidx(len(cur_ibatch), cur_dname)
            # update
            self.tp.update_uidx(1, cur_dname)
            cur_uidx = self.tp.uidx
            # get the effective lrate and update
            act_lrate = float(self.lrate.value)  # start with the lrate.value
            if cur_uidx < _lrate_warmup_steps:  # linear increase
                act_lrate *= (cur_uidx / _lrate_warmup_steps)
            else:  # decrease
                act_lrate *= _lrate_warmup_factor * (cur_uidx**
                                                     conf.lrate_decrease_alpha)
            with self.train_recorder.go('update'):  # also record this!
                self.model.update(act_lrate, 1.)
            # valid?
            if (cur_uidx - last_dev_uidx) >= conf.valid_ufreq:
                last_dev_uidx = cur_uidx
                self.do_dev()
                # todo(+N): do we need to adjust sv at a finer grained?
                self.adjust_scheduled_values()  # adjust after validation
        # --
        zlog(
            f"zzzzzfinal: After training, the best point is: {self.tp.info_best()}.",
            func="report")

    def do_dev(self):
        conf = self.conf
        # report & reset training stat
        if self.tp.uidx > 0:
            train_result = self.run_train_report(
            )  # first report training stat
            self.train_recorder.reset()  # reset training stat
        else:  # for validate_first
            train_result = ResultRecord.get_nil()
        # dev
        ss, cur_cidx = self.current_name(), self.tp.cidx
        zlog("", func="plain")  # empty line
        with Timer(info=f"Valid {ss}", print_date=True):
            # no validation if specified
            if self.tp.uidx < conf.valid_start_uidx:
                zlog("No validation since not the time yet!\n", func="plain")
                return
            # validate
            if len(self.d_center.get_datasets(
                    wset="dev")) == 0:  # simply use train if there are no dev
                zlog(
                    "Use training results for dev since there are no dev set provided!",
                    func="warn")
                dev_result = train_result
            else:
                dev_result = self.do_test("dev")
            # record
            cur_record_best = (self.tp.cidx >= conf.record_best_start_cidx)
            if_overall_best, if_best, if_anneal = self.tp.update_checkpoint(
                train_result, dev_result, record_best=cur_record_best)
            # save curr & best
            self.save(conf.model_save_prefix + conf.model_save_suffix_curr)
            if if_overall_best:
                zlog("Curr is overall best " +
                     str(self.tp.info_overall_best()),
                     func="result")
            else:
                zlog("Curr not overall best, the overall best is " +
                     str(self.tp.info_overall_best()),
                     func="result")
            if if_best:
                self.save(conf.model_save_prefix + conf.model_save_suffix_best)
                zlog("Curr is best: " + str(self.tp.info_best()),
                     func="result")
            else:
                zlog("Curr not best, the best is " + str(self.tp.info_best()),
                     func="result")
            if cur_cidx >= conf.save_special_start_cidx and cur_cidx % conf.save_special_cfreq == 0:
                self.save(conf.model_save_prefix + ss)  # speical save
        # --
        zlog("", func="plain")  # empty line

    def do_test(self, wset="test"):
        model, t_center, d_center = self.model, self.t_center, self.d_center
        conf = self.conf
        # --
        to_test_datasets = d_center.get_datasets(wset=wset)
        t_center.prepare_datasets(to_test_datasets)  # re-prepare!!
        aggr = ResultAggregator()
        for one_ii, one_dataset in enumerate(to_test_datasets):
            with Timer(
                    info=
                    f"Test({one_ii+1}/{len(to_test_datasets)}): {one_dataset}",
                    print_date=True):
                one_res = self.run_test_dataset(one_dataset)
                aggr.add(one_dataset.name, one_res,
                         one_dataset.conf.group_eval_weight)
        ret = aggr.get_res()
        return ret
예제 #9
0
파일: run_test.py 프로젝트: zzsfornlp/zmsp
 def __init__(self, model: ZModel, test_stream: Streamer):
     self.model = model
     self.test_recorder = StatRecorder(timing=True)
     self.test_stream = test_stream