Example #1
0
    def train(self):
        for epoch in range(self.config["optim"]["max_epochs"]):
            self.model.train()
            for i, batch in enumerate(self.train_loader):
                batch = batch.to(self.device)

                # Forward, loss, backward.
                out, metrics = self._forward(batch)
                loss = self._compute_loss(out, batch)
                self._backward(loss)

                # Update meter.
                meter_update_dict = {
                    "epoch": epoch + (i + 1) / len(self.train_loader),
                    "loss": loss.item(),
                }
                meter_update_dict.update(metrics)
                self.meter.update(meter_update_dict)

                # Make plots.
                if self.logger is not None:
                    self.logger.log(
                        meter_update_dict,
                        step=epoch * len(self.train_loader) + i + 1,
                        split="train",
                    )

                # Print metrics.
                if i % self.config["cmd"]["print_every"] == 0:
                    print(self.meter)

            self.scheduler.step()

            if self.val_loader is not None:
                self.validate(split="val", epoch=epoch)

            if self.test_loader is not None:
                self.validate(split="test", epoch=epoch)

            if not self.is_debug:
                save_checkpoint(
                    {
                        "epoch": epoch + 1,
                        "state_dict": self.model.state_dict(),
                        "optimizer": self.optimizer.state_dict(),
                        "normalizers": {
                            key: value.state_dict()
                            for key, value in self.normalizers.items()
                        },
                        "config": self.config,
                    },
                    self.config["cmd"]["checkpoint_dir"],
                )
Example #2
0
 def save(
     self,
     metrics=None,
     checkpoint_file="checkpoint.pt",
     training_state=True,
 ):
     if not self.is_debug and distutils.is_master():
         if training_state:
             save_checkpoint(
                 {
                     "epoch": self.epoch,
                     "step": self.step,
                     "state_dict": self.model.state_dict(),
                     "optimizer": self.optimizer.state_dict(),
                     "scheduler": self.scheduler.scheduler.state_dict()
                     if self.scheduler.scheduler_type != "Null"
                     else None,
                     "normalizers": {
                         key: value.state_dict()
                         for key, value in self.normalizers.items()
                     },
                     "config": self.config,
                     "val_metrics": metrics,
                     "ema": self.ema.state_dict() if self.ema else None,
                     "amp": self.scaler.state_dict()
                     if self.scaler
                     else None,
                 },
                 checkpoint_dir=self.config["cmd"]["checkpoint_dir"],
                 checkpoint_file=checkpoint_file,
             )
         else:
             if self.ema:
                 self.ema.store()
                 self.ema.copy_to()
             save_checkpoint(
                 {
                     "state_dict": self.model.state_dict(),
                     "normalizers": {
                         key: value.state_dict()
                         for key, value in self.normalizers.items()
                     },
                     "config": self.config,
                     "val_metrics": metrics,
                     "amp": self.scaler.state_dict()
                     if self.scaler
                     else None,
                 },
                 checkpoint_dir=self.config["cmd"]["checkpoint_dir"],
                 checkpoint_file=checkpoint_file,
             )
             if self.ema:
                 self.ema.restore()
Example #3
0
 def save(self, epoch, metrics):
     if not self.is_debug and distutils.is_master():
         save_checkpoint(
             {
                 "epoch": epoch,
                 "state_dict": self.model.state_dict(),
                 "optimizer": self.optimizer.state_dict(),
                 "normalizers": {
                     key: value.state_dict()
                     for key, value in self.normalizers.items()
                 },
                 "config": self.config,
                 "val_metrics": metrics,
                 "amp": self.scaler.state_dict() if self.scaler else None,
             },
             self.config["cmd"]["checkpoint_dir"],
         )
Example #4
0
    def train(self, max_epochs=None, return_metrics=False):
        # TODO(abhshkdz): Timers for dataloading and forward pass.
        num_epochs = (max_epochs if max_epochs is not None else
                      self.config["optim"]["max_epochs"])
        for epoch in range(num_epochs):
            self.model.train()

            for i, batch in enumerate(self.train_loader):
                batch = batch.to(self.device)

                # Forward, loss, backward.
                out, metrics = self._forward(batch)
                loss = self._compute_loss(out, batch)
                self._backward(loss)

                # Update meter.
                meter_update_dict = {
                    "epoch": epoch + (i + 1) / len(self.train_loader),
                    "loss": loss.item(),
                }
                meter_update_dict.update(metrics)
                self.meter.update(meter_update_dict)

                # Make plots.
                if self.logger is not None:
                    self.logger.log(
                        meter_update_dict,
                        step=epoch * len(self.train_loader) + i + 1,
                        split="train",
                    )

                # Print metrics.
                if i % self.config["cmd"]["print_every"] == 0:
                    print(self.meter)

            self.scheduler.step()

            with torch.no_grad():
                if self.val_loader is not None:
                    v_loss, v_mae = self.validate(split="val", epoch=epoch)

                if self.test_loader is not None:
                    test_loss, test_mae = self.validate(split="test",
                                                        epoch=epoch)

            if not self.is_debug:
                save_checkpoint(
                    {
                        "epoch": epoch + 1,
                        "state_dict": self.model.state_dict(),
                        "optimizer": self.optimizer.state_dict(),
                        "normalizers": {
                            key: value.state_dict()
                            for key, value in self.normalizers.items()
                        },
                        "config": self.config,
                        "amp":
                        self.scaler.state_dict() if self.scaler else None,
                    },
                    self.config["cmd"]["checkpoint_dir"],
                )
        if return_metrics:
            return {
                "training_loss":
                float(self.meter.loss.global_avg),
                "training_mae":
                float(self.meter.meters[
                    self.config["task"]["labels"][0] + "/" +
                    self.config["task"]["metric"]].global_avg),
                "validation_loss":
                v_loss,
                "validation_mae":
                v_mae,
                "test_loss":
                test_loss,
                "test_mae":
                test_mae,
            }
 def save(self, epoch, step, metrics):
     if not self.is_debug and distutils.is_master() and not self.is_hpo:
         save_checkpoint(
             self.save_state(epoch, step, metrics),
             self.config["cmd"]["checkpoint_dir"],
         )
Example #6
0
    def train(self):
        self.best_val_mae = 1e9
        for epoch in range(self.config["optim"]["max_epochs"]):
            self.model.train()
            for i, batch in enumerate(self.train_loader):
                # Forward, loss, backward.
                with torch.cuda.amp.autocast(enabled=self.scaler is not None):
                    out = self._forward(batch)
                    loss = self._compute_loss(out, batch)
                loss = self.scaler.scale(loss) if self.scaler else loss
                self._backward(loss)
                scale = self.scaler.get_scale() if self.scaler else 1.0

                # Compute metrics.
                self.metrics = self._compute_metrics(
                    out,
                    batch,
                    self.evaluator,
                    metrics={},
                )
                self.metrics = self.evaluator.update("loss",
                                                     loss.item() / scale,
                                                     self.metrics)

                # Print metrics, make plots.
                log_dict = {k: self.metrics[k]["metric"] for k in self.metrics}
                log_dict.update(
                    {"epoch": epoch + (i + 1) / len(self.train_loader)})
                if i % self.config["cmd"]["print_every"] == 0:
                    log_str = [
                        "{}: {:.4f}".format(k, v) for k, v in log_dict.items()
                    ]
                    print(", ".join(log_str))

                if self.logger is not None:
                    self.logger.log(
                        log_dict,
                        step=epoch * len(self.train_loader) + i + 1,
                        split="train",
                    )

            self.scheduler.step()
            torch.cuda.empty_cache()

            if self.val_loader is not None:
                val_metrics = self.validate(split="val", epoch=epoch)
                if (val_metrics[self.evaluator.task_primary_metric["is2re"]]
                    ["metric"] < self.best_val_mae):
                    self.best_val_mae = val_metrics[
                        self.evaluator.task_primary_metric["is2re"]]["metric"]
                    if not self.is_debug:
                        save_checkpoint(
                            {
                                "epoch":
                                epoch + 1,
                                "state_dict":
                                self.model.state_dict(),
                                "optimizer":
                                self.optimizer.state_dict(),
                                "normalizers": {
                                    key: value.state_dict()
                                    for key, value in self.normalizers.items()
                                },
                                "config":
                                self.config,
                                "val_metrics":
                                val_metrics,
                                "amp":
                                self.scaler.state_dict()
                                if self.scaler else None,
                            },
                            self.config["cmd"]["checkpoint_dir"],
                        )

            if self.test_loader is not None:
                self.validate(split="test", epoch=epoch)
Example #7
0
    def train(self):
        self.best_val_mae = 1e9
        eval_every = self.config["optim"].get("eval_every", -1)
        iters = 0
        self.metrics = {}
        for epoch in range(self.config["optim"]["max_epochs"]):
            self.model.train()
            for i, batch in enumerate(self.train_loader):
                # Forward, loss, backward.
                out = self._forward(batch)
                loss = self._compute_loss(out, batch)
                self._backward(loss)

                # Compute metrics.
                self.metrics = self._compute_metrics(
                    out, batch, self.evaluator, self.metrics,
                )
                self.metrics = self.evaluator.update(
                    "loss", loss.item(), self.metrics
                )

                # Print metrics, make plots.
                log_dict = {k: self.metrics[k]["metric"] for k in self.metrics}
                log_dict.update(
                    {"epoch": epoch + (i + 1) / len(self.train_loader)}
                )
                if i % self.config["cmd"]["print_every"] == 0:
                    log_str = [
                        "{}: {:.4f}".format(k, v) for k, v in log_dict.items()
                    ]
                    print(", ".join(log_str))
                    self.metrics = {}

                if self.logger is not None:
                    self.logger.log(
                        log_dict,
                        step=epoch * len(self.train_loader) + i + 1,
                        split="train",
                    )

                iters += 1

                # Evaluate on val set every `eval_every` iterations.
                if eval_every != -1 and iters % eval_every == 0:
                    if self.val_loader is not None:
                        val_metrics = self.validate(split="val", epoch=epoch)
                        if (
                            val_metrics[
                                self.evaluator.task_primary_metric["s2ef"]
                            ]["metric"]
                            < self.best_val_mae
                        ):
                            self.best_val_mae = val_metrics[
                                self.evaluator.task_primary_metric["s2ef"]
                            ]["metric"]
                            if not self.is_debug:
                                save_checkpoint(
                                    {
                                        "epoch": epoch
                                        + (i + 1) / len(self.train_loader),
                                        "state_dict": self.model.state_dict(),
                                        "optimizer": self.optimizer.state_dict(),
                                        "normalizers": {
                                            key: value.state_dict()
                                            for key, value in self.normalizers.items()
                                        },
                                        "config": self.config,
                                        "val_metrics": val_metrics,
                                    },
                                    self.config["cmd"]["checkpoint_dir"],
                                )

            self.scheduler.step()
            torch.cuda.empty_cache()

            if eval_every == -1:
                if self.val_loader is not None:
                    val_metrics = self.validate(split="val", epoch=epoch)
                    if (
                        val_metrics[
                            self.evaluator.task_primary_metric["s2ef"]
                        ]["metric"]
                        < self.best_val_mae
                    ):
                        self.best_val_mae = val_metrics[
                            self.evaluator.task_primary_metric["s2ef"]
                        ]["metric"]
                        if not self.is_debug:
                            save_checkpoint(
                                {
                                    "epoch": epoch + 1,
                                    "state_dict": self.model.state_dict(),
                                    "optimizer": self.optimizer.state_dict(),
                                    "normalizers": {
                                        key: value.state_dict()
                                        for key, value in self.normalizers.items()
                                    },
                                    "config": self.config,
                                    "val_metrics": val_metrics,
                                },
                                self.config["cmd"]["checkpoint_dir"],
                            )

            if self.test_loader is not None:
                self.validate(split="test", epoch=epoch)

            if (
                "relaxation_dir" in self.config["task"]
                and self.config["task"].get("ml_relax", "end") == "train"
            ):
                self.validate_relaxation(
                    split="val", epoch=epoch,
                )

        if (
            "relaxation_dir" in self.config["task"]
            and self.config["task"].get("ml_relax", "end") == "end"
        ):
            self.validate_relaxation(
                split="val", epoch=epoch,
            )