Exemple #1
0
    def __call__(self, start_epoch, experiment, verbose=0):
        ''' Execute training '''
        with ExitStack() as stack:
            stack.enter_context(chunked_scattering())
            stack.enter_context(experiment.train())

            if start_epoch > 0 or experiment.curr_step > 0:
                # TODO: Hacky approach to decide if the metric store should be loaded. Revisit later
                self.metric_store = self.metric_store.load()

            epoch = start_epoch
            experiment.log_current_epoch(epoch)
            while not self.is_done(experiment, epoch):
                experiment.log_current_epoch(epoch)
                self.train_epoch(epoch, experiment, verbose)
                experiment.log_epoch_end(epoch)
                epoch += 1

            if self.stopped_early:
                print('Stopping early!')
            else:
                new_best = False
                if self.config.early_stopping:
                    new_best = self.evaluate(experiment, epoch, verbose)

                self.checkpoint(epoch, experiment.curr_step, new_best)
Exemple #2
0
    def __call__(self, start_epoch, experiment, verbose=0):
        ''' Execute training '''
        with ExitStack() as stack:
            stack.enter_context(chunked_scattering())
            stack.enter_context(experiment.train())

            if start_epoch > 0 or experiment.curr_step > 0:
                self.metric_store = self.metric_store.load()

            epoch = start_epoch
            experiment.log_current_epoch(epoch)

            stats_filename = self.config.stats_filename or f'train_stats.pickle'
            stats_path = os.path.join(self.config.stats_directory,
                                      stats_filename)
            stats_file = stack.enter_context(open(stats_path, 'wb'))

            while not self.is_done(experiment, epoch):
                experiment.log_current_epoch(epoch)
                self.train_epoch(epoch, experiment, verbose)
                experiment.log_epoch_end(epoch)
                epoch += 1

            self.save_stats(stats_file)

            if self.stopped_early:
                print('Stopping early!')
            else:
                new_best = False
                if self.config.early_stopping:
                    new_best = self.evaluate(experiment, epoch, verbose)

                self.checkpoint(epoch, experiment.curr_step, new_best)
Exemple #3
0
    def __call__(self) -> float:
        """
        Run the evaluation!
        """
        dataloader = get_dataloader(self.args.data,
                                    self.dataset,
                                    num_devices=len(self.model.device_ids))

        def get_description():
            return f"Eval {self.metric_store}"

        batch_iterator = tqdm(
            dataloader,
            unit="batch",
            initial=1,
            dynamic_ncols=True,
            desc=get_description(),
            file=sys.stdout,  # needed to make tqdm_wrap_stdout work
        )

        with ExitStack() as stack:
            # pylint:disable=no-member
            stack.enter_context(tqdm_wrap_stdout())
            stack.enter_context(chunked_scattering())
            # pylint:enable=no-member

            for batch in batch_iterator:
                try:
                    self.eval_step(batch)
                except RuntimeError as rte:
                    if "out of memory" in str(rte):
                        self.metric_store["oom"].update(1)
                        logging.warning(str(rte))
                    else:
                        batch_iterator.close()
                        raise rte

                batch_iterator.set_description_str(get_description())

            batch_iterator.close()

        return self.metric_store["nll"].average
Exemple #4
0
    def __call__(self):
        """
        Run the training!
        """
        # Must be called first
        self.try_init_amp()

        model = self.modules["model"]
        optimizer = self.modules["optimizer"]
        scheduler = self.modules["scheduler"]

        if self.args.optim.use_gradient_checkpointing:
            model.enable_gradient_checkpointing()

        model = nn.DataParallel(model)
        dataloader = get_dataloader(
            self.args.data,
            self.dataset,
            num_devices=len(model.device_ids),
            shuffle=True,
        )

        def get_description():
            return f"Train {self.metric_store}"

        max_steps = self.args.optim.max_steps
        accumulation_steps = self.args.optim.gradient_accumulation_steps
        progress = tqdm(
            unit="step",
            initial=self.step,
            dynamic_ncols=True,
            desc=get_description(),
            total=max_steps,
            file=sys.stdout,  # needed to make tqdm_wrap_stdout work
        )

        with ExitStack() as stack:
            # pylint:disable=no-member
            stack.enter_context(tqdm_wrap_stdout())
            stack.enter_context(chunked_scattering())
            stack.enter_context(self.experiment.train())
            # pylint:enable=no-member

            if self.args.optim.early_stopping:
                # If using early stopping, must evaluate regularly to determine
                # if training should stop early, so setup an Evaluator
                eval_args = copy.deepcopy(self.args)
                eval_args.data.batch_size = self.args.optim.eval_batch_size

                evaluator = Evaluator(eval_args)
                evaluator.model = model
                evaluator.load_dataset("validation")
                evaluator.initialize_experiment(experiment=self.experiment)

                # Make sure we are tracking validation nll
                self.metric_store.add(
                    metrics.Metric("vnll", "format_float", "g(m)"))

                # And store a local variable for easy access
                vnll_metric = self.metric_store["vnll"]

            loss = 0
            num_tokens = 0
            for step, batch in enumerate(cycle(dataloader), 1):
                try:
                    step_loss = self.compute_gradients_and_loss(
                        batch, model, optimizer)
                    run_optimizer = (step % accumulation_steps) == 0

                    if run_optimizer:
                        # Run an optimization step
                        optimizer.step()
                        scheduler.step()  # Update learning rate schedule
                        model.zero_grad()

                    # Update loss and num tokens after running an optimization
                    # step, in case it results in an out of memory error
                    loss += step_loss
                    num_tokens += batch["num_tokens"]

                    if run_optimizer:
                        # Since we ran the optimizer, increment current step
                        self.step += 1
                        self.experiment.set_step(self.step)
                        progress.update()

                        # update our metrics as well
                        self.update_metrics(
                            loss / accumulation_steps,
                            num_tokens,
                            scheduler.get_lr()[0],
                        )
                        num_tokens = 0
                        loss = 0

                        # and finally check if we should save
                        if (self.args.save_steps > 0
                                and self.step % self.args.save_steps == 0):
                            # First save the current checkpoint
                            self.save()

                            # Then if we are implementing early stopping, see
                            # if we achieved a new best
                            if self.args.optim.early_stopping:
                                evaluator.reset_metrics()
                                with ExitStack() as eval_stack:
                                    # pylint:disable=no-member
                                    eval_stack.enter_context(
                                        tqdm_unwrap_stdout())
                                    eval_stack.enter_context(
                                        release_cuda_memory(
                                            collect_tensors(optimizer.state)))
                                    # pylint:enable=no-member

                                    vnll = evaluator()
                                    vnll_metric.update(vnll)

                                    # Save the updated metrics
                                    self.save_metrics()

                                    if vnll == vnll_metric.min:
                                        self.on_new_best()

                                # Try to combat OOM errors caused by doing evaluation
                                # in the same loop with training. This manifests in out
                                # of memory errors after the first or second evaluation
                                # run.
                                refresh_cuda_memory()

                            if not self.prune_checkpoints():
                                logging.info("Stopping early")
                                break

                            if self.step >= max_steps:
                                logging.info("Finished training")
                                break

                except RuntimeError as rte:
                    if "out of memory" in str(rte):
                        self.metric_store["oom"].update(1)
                        logging.warning(str(rte))
                    else:
                        progress.close()
                        raise rte

                progress.set_description_str(get_description())

            progress.close()