Beispiel #1
0
    rmse = validate()
    print('RMSE:', rmse, 'at epoch', AdamW_scheduler.last_epoch + 1)

    learning_rate = AdamW.param_groups[0]['lr']

    if learning_rate < early_stopping_learning_rate:
        break

    # checkpoint
    if AdamW_scheduler.is_better(rmse, AdamW_scheduler.best):
        torch.save(nn.state_dict(), best_model_checkpoint)

    AdamW_scheduler.step(rmse)
    SGD_scheduler.step(rmse)

    tensorboard.add_scalar('validation_rmse', rmse, AdamW_scheduler.last_epoch)
    tensorboard.add_scalar('best_validation_rmse', AdamW_scheduler.best,
                           AdamW_scheduler.last_epoch)
    tensorboard.add_scalar('learning_rate', learning_rate,
                           AdamW_scheduler.last_epoch)

    for i, properties in tqdm.tqdm(enumerate(training),
                                   total=len(training),
                                   desc="epoch {}".format(
                                       AdamW_scheduler.last_epoch)):
        species = properties['species'].to(device)
        coordinates = properties['coordinates'].to(device).float()
        true_energies = properties['energies'].to(device).float()
        num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
        _, predicted_energies = model((species, coordinates))
    def train(
        self,
        *,
        num_epochs: int,
        batch_size: int,
        lr: float,
        loss_lambda: float,
        embedding_dim: int,
        decoder_dim: int,
        attention_dim: int,
        dropout: float,
        maxlen: int,
        patience: int,
        checkpoint_path: Optional[str] = None,
    ) -> float:
        checkpoint = torch.load(
            checkpoint_path) if checkpoint_path is not None else None
        best_bleu = 0.0
        epochs_without_improvement = 0

        if checkpoint is not None:
            self.coco_train.shuffle(subset_len=1000)
            best_bleu = checkpoint["bleu_4"]

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        logging.info(
            f"Training on device {torch.cuda.get_device_name(device.index)}")

        with ImageCaptioningTrainer.tensorboard(
                comment=
                f"_batch={batch_size}_lr={lr}_lambda={loss_lambda}_dropout={dropout}"
        ) as tb:
            data_loader = dl.CocoLoader(self.coco_train, batch_size,
                                        math.ceil(os.cpu_count() / 2))

            decoder = model.LSTMDecoder(self.num_embeddings, embedding_dim,
                                        self.encoder_dim, decoder_dim,
                                        attention_dim, dropout)
            if checkpoint is not None:
                decoder.load_state_dict(checkpoint["decoder"])

            self.encoder.to(device)
            decoder.to(device)

            optimizer = torch.optim.Adam(params=decoder.parameters(), lr=lr)
            if checkpoint is not None:
                optimizer.load_state_dict(checkpoint["optimizer"])

            criterion = DoublyStochasticAttentionLoss(
                hyperparameter_lambda=loss_lambda,
                ignore_index=self.coco_train.target_transform.vocabulary.
                word2idx("<PAD>"),
            ).to(device)

            # Training loop
            start_epoch = 1 if checkpoint is None else checkpoint["epoch"] + 1
            for epoch in range(start_epoch, start_epoch + num_epochs):
                decoder.train()
                cost = 0.0
                running_loss = 0.0

                # Train steps
                for step, batch in enumerate(data_loader, 1):
                    images, captions = batch[0].to(device), batch[1].to(device)

                    optimizer.zero_grad()

                    predictions, attentions = decoder(*self.encoder(images),
                                                      captions)

                    num_samples, num_timesteps = captions.shape[
                        0], captions.shape[1] - 1

                    loss = criterion(
                        predictions.permute(1, 0, 2).reshape(
                            num_timesteps * num_samples, self.num_embeddings),
                        captions[:, 1:].reshape(num_timesteps * num_samples),
                        attentions,
                    )

                    loss.backward()
                    optimizer.step()

                    cost += loss.item()
                    running_loss += loss.item()

                    # Print loss every n-th step
                    every_step = 100
                    if not step % every_step:
                        avg_loss = running_loss / every_step
                        running_loss = 0.0
                        logging.info(
                            f"Epoch {epoch} Step {step}/{len(data_loader)} => {avg_loss: .4f}"
                        )
                        tb.add_scalar(f"loss_lambda={loss_lambda}", avg_loss,
                                      step + (epoch - 1) * len(data_loader))

                current_bleu4 = self.validator.validate(
                    self.encoder, decoder, device)
                logging.info(f"After Epoch {epoch} BLEU-4 => {current_bleu4}")

                self._save_checkpoint(epoch, decoder.state_dict(),
                                      optimizer.state_dict(), lr, dropout,
                                      loss_lambda, current_bleu4)
                self._save_run(epoch, cost / len(data_loader), current_bleu4,
                               loss_lambda, decoder, tb)

                # Early stopping on BLEU-4 metric
                if current_bleu4 > best_bleu:
                    best_bleu = current_bleu4
                    epochs_without_improvement = 0

                    self._save_best_model(embedding_dim, decoder_dim,
                                          attention_dim, decoder.state_dict(),
                                          optimizer.state_dict(), epoch,
                                          best_bleu, lr)
                else:
                    epochs_without_improvement += 1
                    if epochs_without_improvement == patience:
                        break

                self.coco_train.shuffle(subset_len=1000)

        return best_bleu