def log(model, batch_idx):
      if batch_idx % params["log_interval"] == 0:
        entropy = model.entropy()
        print("logging: {} learning iterations, entropy: {} / {}".format(
          model.getLearningIterations(), float(entropy), model.maxEntropy()))

        if params["create_plots"]:
          plotDutyCycles(model.dutyCycle,
                         self.resultsDir + "/figure_" + str(epoch) + "_" +
                         str(model.getLearningIterations()))
Пример #2
0
    def log(model, batch_idx):
      if batch_idx % params["log_interval"] == 0:
        entropy = model.entropy()
        print("logging: {} learning iterations, entropy: {} / {}".format(
          model.getLearningIterations(), float(entropy), model.maxEntropy()))

        if params["create_plots"]:
          plotDutyCycles(model.dutyCycle,
                         self.resultsDir + "/figure_" + str(epoch) + "_" +
                         str(model.getLearningIterations()))
    def train(self, params, epoch, repetition):
        """
    Train one epoch of this model by iterating through mini batches. An epoch
    ends after one pass through the training set, or if the number of mini
    batches exceeds the parameter "batches_in_epoch".
    """
        # Check for pre-trained model
        modelCheckpoint = os.path.join(
            params["path"], params["name"],
            "model_{}_{}.pt".format(repetition, epoch))
        if os.path.exists(modelCheckpoint):
            self.model = torch.load(modelCheckpoint, map_location=self.device)
            return

        self.model.train()
        for batch_idx, (batch, target) in enumerate(self.train_loader):
            data = batch["input"]
            if params["model_type"] in ["resnet9", "cnn"]:
                data = torch.unsqueeze(data, 1)
            data, target = data.to(self.device), target.to(self.device)
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            self.optimizer.step()

            # Log info every log_interval mini batches
            if batch_idx % params["log_interval"] == 0:
                entropy = self.model.entropy()
                print("logging: ", self.model.getLearningIterations(),
                      " learning iterations, elapsedTime",
                      time.time() - self.startTime, " entropy:",
                      float(entropy), " / ", self.model.maxEntropy(), "loss:",
                      loss.item())
                if params["create_plots"]:
                    plotDutyCycles(
                        self.model.dutyCycle,
                        self.resultsDir + "/figure_" + str(epoch) + "_" +
                        str(self.model.getLearningIterations()))

            if batch_idx >= params["batches_in_epoch"]:
                break

        self.model.postEpoch()

        # Save model checkpoint on every epoch
        if params.get("save_every_epoch", False):
            torch.save(self.model, modelCheckpoint)
Пример #4
0
  def train(self, params, epoch):
    """
    Train one epoch of this model by iterating through mini batches. An epoch
    ends after one pass through the training set, or if the number of mini
    batches exceeds the parameter "batches_in_epoch".
    """
    self.model.train()
    for batch_idx, batch in enumerate(self.train_loader):
      data = batch["input"]
      if params["model_type"] in ["resnet9", "cnn"]:
        data = torch.unsqueeze(data, 1)
      target = batch["target"]
      data, target = data.to(self.device), target.to(self.device)
      self.optimizer.zero_grad()
      output = self.model(data)
      loss = F.nll_loss(output, target)
      loss.backward()
      self.optimizer.step()

      # Log info every log_interval mini batches
      if batch_idx % params["log_interval"] == 0:
        entropy = self.model.entropy()
        print(
          "logging: ",self.model.getLearningIterations(),
          " learning iterations, elapsedTime", time.time() - self.startTime,
          " entropy:", float(entropy)," / ", self.model.maxEntropy(),
          "loss:", loss.item())
        if params["create_plots"]:
          plotDutyCycles(self.model.dutyCycle,
                         self.resultsDir + "/figure_"+str(epoch)+"_"+str(
                           self.model.getLearningIterations()))

      if batch_idx >= params["batches_in_epoch"]:
        break

    self.model.postEpoch()