Ejemplo n.º 1
0
 def _save_train_stats(self, suffix=None):
     fn = self.opt['model_file']
     if suffix:
         fn += suffix
     fn += '.trainstats'
     with open(fn, 'w') as f:
         json.dump({
             'train_time': self.train_time.time(),
             'total_epochs': (
                 self._preempted_epochs +
                 num_workers() * self.world.get_total_epochs()
             ),
             'impatience': self.impatience,
             'valid_reports': self.valid_reports
         }, f)
Ejemplo n.º 2
0
    def train(self):
        if is_distributed():
            warn_once(
                "Distributed training outputs average-per-worker metrics during "
                "training, and may be slightly distorted. Validation/test are "
                "unadulterated.")
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                world.parley()
                self.parleys += 1
                # print(world.display())

                # get the total training examples done, compute epochs
                self._total_epochs = (
                    self._preempted_epochs +
                    num_workers() * self.world.get_total_epochs())
                exs_per_epoch = self.world.num_examples()
                self._total_exs = int(
                    np.round(self._total_epochs * exs_per_epoch))

                # and use the primary worker's timings for everything
                train_time, log_time, validate_time = sync_object(
                    (self.train_time.time(), self.log_time.time(),
                     self.validate_time.time()))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    self.log()
                    print(
                        '[ num_epochs completed:{} time elapsed:{}s ]'.format(
                            self.max_num_epochs, train_time))
                    break
                if train_time > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(train_time))
                    break
                if log_time > self.log_every_n_secs:
                    self.log()
                if (validate_time > self.val_every_n_secs
                        or self._total_epochs - self.last_valid_epoch >=
                        self.val_every_n_epochs):
                    stop_training = self.validate()
                    self.last_valid_epoch = self._total_epochs
                    if stop_training:
                        break
                if (self.save_time.time() > self.save_every_n_secs
                        and opt.get('model_file') and is_primary_worker()):
                    print("[ saving model checkpoint: {}.checkpoint".format(
                        opt['model_file']))
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not self.saved and is_primary_worker():
            # save agent
            self.save_model()
        elif opt.get('model_file'):
            # reload best validation model
            self.agent = create_agent(opt)

        valid_world = _maybe_load_eval_world(self.agent, opt, 'valid')
        v_report = run_eval(valid_world, opt, 'valid', write_log=True)
        test_world = _maybe_load_eval_world(self.agent, opt, 'test')
        t_report = run_eval(test_world, opt, 'test', write_log=True)
        if valid_world:
            valid_world.shutdown()
        if test_world:
            test_world.shutdown()

        return v_report, t_report
    def train(self):
        if is_distributed():
            warn_once(
                "Distributed training outputs average-per-worker metrics during "
                "training, and may be slightly distorted. Validation/test are "
                "unadulterated."
            )
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                world.parley()
                self.parleys += 1

                # get the total training examples done, compute epochs
                self._total_epochs = (
                    self._preempted_epochs +
                    num_workers() * self.world.get_total_epochs()
                )
                exs_per_epoch = self.world.num_examples()
                self._total_exs = int(np.round(self._total_epochs * exs_per_epoch))

                # and use the primary worker's timings for everything
                train_time, log_time, validate_time = sync_object((
                    self.train_time.time(),
                    self.log_time.time(),
                    self.validate_time.time()
                ))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    self.log()
                    print('[ num_epochs completed:{} time elapsed:{}s ]'.format(
                        self.max_num_epochs, train_time))
                    break
                if train_time > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(train_time))
                    break
                if log_time > self.log_every_n_secs:
                    self.log()
                if (
                    validate_time > self.val_every_n_secs or
                    self._total_epochs - self.last_valid_epoch
                        >= self.val_every_n_epochs
                ):
                    stop_training = self.validate()
                    self.last_valid_epoch = self._total_epochs

                    # --------------- change by hengyicai -------------------------
                    # run evaluation on the test data as well
                    test_opt = copy.deepcopy(self.opt)
                    test_opt['display_examples'] = False
                    test_opt['report_freq'] = 0
                    if self.test_world is None:
                        # we need to load the world now
                        self.test_world = _maybe_load_eval_world(self.agent, test_opt, 'test')
                    run_eval(self.test_world, test_opt, 'test', -1, write_log=True)
                    # --------------- change by hengyicai -------------------------
                    if stop_training:
                        break
                if (
                    self.save_time.time() > self.save_every_n_secs and
                    opt.get('model_file') and
                    is_primary_worker()
                ):
                    print("[ saving model checkpoint: {}.checkpoint".format(
                        opt['model_file']
                    ))
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not self.saved and is_primary_worker():
            # save agent
            self.save_model()
        elif opt.get('model_file'):
            # reload best validation model
            self.agent = create_agent(opt)

        valid_world = _maybe_load_eval_world(self.agent, opt, 'valid')
        max_exs = opt['validation_max_exs'] if opt.get('short_final_eval') else -1
        v_report = run_eval(valid_world, opt, 'valid', max_exs, write_log=True)
        test_world = _maybe_load_eval_world(self.agent, opt, 'test')
        t_report = run_eval(test_world, opt, 'test', max_exs, write_log=True)
        if valid_world:
            valid_world.shutdown()
        if test_world:
            test_world.shutdown()

        # --------------- change by hengyicai -------------------------
        last_model = opt.get('model_file') + '.checkpoint'
        if os.path.isfile(last_model):
            print('[ Conducting evaluations on valid and test data using the last model. ]')
            last_model_opt = copy.deepcopy(opt)
            last_model_opt['model_file'] = last_model
            last_agent = create_agent(last_model_opt)
            valid_world = _maybe_load_eval_world(last_agent, last_model_opt, 'valid')
            max_exs = last_model_opt['validation_max_exs'] if last_model_opt.get('short_final_eval') else -1
            run_eval(valid_world, last_model_opt, 'valid', max_exs, write_log=True)
            test_world = _maybe_load_eval_world(last_agent, last_model_opt, 'test')
            run_eval(test_world, last_model_opt, 'test', max_exs, write_log=True)
            if valid_world:
                valid_world.shutdown()
            if test_world:
                test_world.shutdown()
        # --------------- change by hengyicai -------------------------

        print_announcements(opt)

        return v_report, t_report