示例#1
0
    def train(self,
              dataflow: DataLoader,
              *,
              num_epochs: int = 9999999,
              callbacks: Optional[List[Callback]] = None) -> None:
        self.dataflow = dataflow
        self.steps_per_epoch = len(self.dataflow)
        self.num_epochs = num_epochs

        if callbacks is None:
            callbacks = []
        self.callbacks = Callbacks(callbacks)
        self.summary = Summary()

        try:
            self.callbacks.set_trainer(self)
            self.summary.set_trainer(self)

            self.epoch_num = 0
            self.global_step = 0

            train_time = time.perf_counter()
            self.before_train()

            while self.epoch_num < self.num_epochs:
                self.epoch_num += 1
                self.local_step = 0

                logger.info('Epoch {}/{} started.'.format(
                    self.epoch_num, self.num_epochs))
                epoch_time = time.perf_counter()
                self.before_epoch()

                for feed_dict in self.dataflow:
                    self.local_step += 1
                    self.global_step += 1

                    self.before_step(feed_dict)
                    output_dict = self.run_step(feed_dict)
                    self.after_step(output_dict)

                    self.trigger_step()

                self.after_epoch()
                logger.info('Training finished in {}.'.format(
                    humanize.naturaldelta(time.perf_counter() - epoch_time)))

                self.trigger_epoch()
                logger.info('Epoch finished in {}.'.format(
                    humanize.naturaldelta(time.perf_counter() - epoch_time)))

            logger.success('{} epochs of training finished in {}.'.format(
                self.num_epochs,
                humanize.naturaldelta(time.perf_counter() - train_time)))
        except StopTraining as e:
            logger.info('Training was stopped by {}.'.format(str(e)))
        finally:
            self.after_train()
示例#2
0
    def _trigger(self) -> None:
        start_time = time.perf_counter()
        self.callbacks.before_epoch()

        with torch.no_grad():
            orig_is_add_noise = self.trainer.model.nodes[
                0].noise_model_tq.is_add_noise
            orig_noise_total_prob = self.trainer.model.nodes[
                0].noise_model_tq.noise_total_prob
            orig_mode = self.trainer.model.nodes[0].noise_model_tq.mode

            for node in self.trainer.model.nodes:
                node.noise_model_tq.noise_total_prob = self.noise_total_prob
                node.noise_model_tq.is_add_noise = True
                node.noise_model_tq.mode = 'train'

            for feed_dict in tqdm.tqdm(self.dataflow, ncols=0):
                self.callbacks.before_step(feed_dict)
                output_dict = self.trainer.run_step(feed_dict)
                self.callbacks.after_step(output_dict)

            for node in self.trainer.model.nodes:
                node.noise_model_tq.is_add_noise = orig_is_add_noise
                node.noise_model_tq.noise_total_prob = orig_noise_total_prob
                node.noise_model_tq.mode = orig_mode

        self.callbacks.after_epoch()
        logger.info('Inference finished in {}.'.format(
            humanize.naturaldelta(time.perf_counter() - start_time)))
示例#3
0
    def _trigger_epoch(self) -> None:
        if self.trainer.epoch_num < self.trainer.num_epochs:
            self.times.append(time.perf_counter() - self.last_time)
            self.last_time = time.perf_counter()

            estimated_time = (self.trainer.num_epochs -
                              self.trainer.epoch_num) * np.mean(self.times)
            logger.info('Estimated time left: {}.'.format(
                humanize.naturaldelta(estimated_time)))
示例#4
0
    def _trigger(self) -> None:
        start_time = time.perf_counter()
        self.callbacks.before_epoch()

        with torch.no_grad():
            for feed_dict in tqdm.tqdm(self.dataflow, ncols=0):
                self.callbacks.before_step(feed_dict)
                output_dict = self.trainer.run_step(feed_dict)
                self.callbacks.after_step(output_dict)

        self.callbacks.after_epoch()
        logger.info('Inference finished in {}.'.format(
            humanize.naturaldelta(time.perf_counter() - start_time)))
示例#5
0
    def train(self,
              dataflow: DataLoader,
              *,
              num_epochs: int = 9999999,
              eval_interval: int = None,
              splits: List[str] = None,
              callbacks: Optional[List[Callback]] = None) -> None:
        self.dataflow = dataflow
        self.steps_per_epoch = len(self.dataflow)
        self.num_epochs = num_epochs

        if callbacks is None:
            callbacks = []
        self.callbacks = Callbacks(callbacks)
        if splits is None:
            self.summary = {"0": Summary()}
        else:
            self.summary = {s: Summary(split=s) for s in splits}

        try:
            self.callbacks.set_trainer(self)
            for s in self.summary.values():
                s.set_trainer(self)

            self.epoch_num = 0
            self.global_step = 0

            train_time = time.perf_counter()
            self.before_train()

            while self.epoch_num < self.num_epochs:
                self.epoch_num += 1
                self.local_step = 0

                logger.info("Epoch {}/{} started.".format(
                    self.epoch_num, self.num_epochs))
                epoch_time = time.perf_counter()
                self.before_epoch()

                for feed_dict in self.dataflow:
                    self.local_step += 1
                    self.global_step += 1

                    self.before_step(feed_dict)
                    output_dict = self.run_step(feed_dict)
                    self.after_step(output_dict)

                    self.trigger_step()

                self.after_epoch()
                logger.info("Training finished in {}.".format(
                    humanize.naturaldelta(time.perf_counter() - epoch_time)))

                if eval_interval is not None:
                    if self.epoch_num % eval_interval == 0:
                        self.trigger_epoch()
                else:
                    self.trigger_epoch()
                logger.info("Epoch finished in {}.".format(
                    humanize.naturaldelta(time.perf_counter() - epoch_time)))

            logger.success("{} epochs of training finished in {}.".format(
                self.num_epochs,
                humanize.naturaldelta(time.perf_counter() - train_time),
            ))
        except StopTraining as e:
            logger.info("Training was stopped by {}.".format(str(e)))
        finally:
            self.after_train()