Esempio n. 1
0
    def refresh_best_n_epoch_models(
        self,
        directory,
        filename,
        metric_name,
        n,
        bigger_is_better,
        current_epoch_idx,
        current_epoch_state,
    ):
        if len(self.metrics["epochs_to_rank"].keys()) < n:
            self.metrics["epochs_to_rank"][current_epoch_idx] = len(
                self.metrics["epochs_to_rank"]
            )

            save_checkpoint(
                state=current_epoch_state,
                is_best=True,
                directory=directory,
                filename=filename,
                epoch_idx=current_epoch_idx,
            )
        else:

            previous_top_n = list(self.metrics["epochs_to_rank"].keys())
            current_top_n = self.get_best_n_epochs_for_metric(
                metric_name=metric_name, n=n, bigger_is_better=bigger_is_better
            )

            if current_top_n != previous_top_n:
                self.metrics["epochs_to_rank"] = {}

                for rank_idx, epoch_idx in enumerate(current_top_n):
                    self.metrics["epochs_to_rank"][epoch_idx] = rank_idx

                epoch_idx_models_to_remove = [
                    idx for idx in previous_top_n if idx not in current_top_n
                ]

                for idx_to_remove in epoch_idx_models_to_remove:
                    os.remove(
                        f"{directory}/epoch_{idx_to_remove}_model_{filename}.ckpt"
                    )

                save_checkpoint(
                    state=current_epoch_state,
                    is_best=True,
                    directory=directory,
                    filename=filename,
                    epoch_idx=current_epoch_idx,
                )
Esempio n. 2
0
def train(encoder, decoder, data, val_data, args):
    encoder = encoder.train()
    decoder = decoder.train()
    params = (list(decoder.parameters()) +
              list(encoder.project_global.parameters()) +
              list(encoder.project_spatial.parameters()))
    optimizer = torch.optim.Adam(params, lr=args.learning_rate)
    total_step = len(data)
    total_epochs = args.base_epoch + args.num_epochs
    scorer = metrics.Scorer(use_bleu=True, use_cider=True, use_meteor=True)
    for i in range(args.num_epochs):
        epoch = args.base_epoch + i
        for step, (images, padded_inputs, padded_targets, lengths,
                   refs) in enumerate(data):
            images = images.to(settings.device)
            padded_inputs = padded_inputs.to(settings.device)
            padded_targets = padded_targets.to(settings.device)
            loss = 0
            if args.is_xe_loss:
                xe_loss = calculate_xe_loss(encoder, decoder, images,
                                            padded_inputs, padded_targets,
                                            lengths)
                loss += xe_loss
                perplexity = np.exp(xe_loss.item())
            decoder.zero_grad()
            encoder.zero_grad()
            loss.backward()
            optimizer.step()
            if (step + 1) % args.log_step == 0:
                update = (f'Epoch [{epoch}/{total_epochs}],'
                          f' Step [{step+1}/{total_step}],')
                if args.is_xe_loss:
                    perplexity = np.exp(xe_loss.item())
                    update += (f' XE Loss [{xe_loss.item():.4f}],'
                               f' Perplexity [{perplexity:5.4f}]')
                print(update)
            if (step + 1) % args.save_step == 0:
                storage.save_checkpoint(args.save_dir, epoch, 0, encoder,
                                        decoder)
            if (step + 1) % args.metric_step == 0:
                scorer.score(encoder, decoder, val_data)
        storage.save_checkpoint(args.save_dir, epoch, 0, encoder, decoder)
                    best_epoch = epoch
                    is_best = True
                state = {
                    'epoch': epoch,
                    'net': net.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }
                epoch_pbar.set_description(
                    'Saving at {}/{}_checkpoint.pth.tar'.format(
                        saved_models_filepath, epoch))
                filename = '{}_checkpoint.pth.tar'.format(epoch)

                previous_save = '{}/{}_checkpoint.pth.tar'.format(
                    saved_models_filepath, epoch - 1)
                if os.path.isfile(previous_save):
                    os.remove(previous_save)

                previous_best_save = '{}/best_{}_checkpoint.pth.tar'.format(
                    saved_models_filepath, previous_best_epoch)
                if os.path.isfile(previous_best_save) and is_best:
                    os.remove(previous_best_save)

                save_checkpoint(state=state,
                                directory=saved_models_filepath,
                                filename=filename,
                                is_best=is_best)
            ############################################################################################################

            epoch_pbar.set_description('')
            epoch_pbar.update(1)
                    'Reconstruction':
                    train_loss_reconstruction / n_batches_since_visualization,
                    'Envmap':
                    train_loss_envmap / n_batches_since_visualization,
                }, train_step, 'Train')

            train_loss_reconstruction, train_loss_envmap = 0.0, 0.0
            n_batches_since_visualization = 0
            train_step += 1

    # Clean up memory (see: https://repl.it/@nickangtc/python-del-multiple-variables-in-single-line)
    del x, x_envmap, target, target_envmap, groundtruth, relit, pred_image_envmap, pred_target_envmap

    # Saving checkpoint
    if epoch % CHECKPOINT_EVERY == 0:
        save_checkpoint(model.state_dict(), optimizer.state_dict(),
                        NAME + '_' + str(epoch))

    # Evaluate
    with no_grad():
        model.eval()
        test_loss_reconstruction = 0.0
        test_loss_image_envmap, test_loss_target_envmap = 0.0, 0.0
        test_psnr = 0.0
        random_batch_id = randint(0, TEST_BATCHES, (1, ))
        for test_batch_idx, test_batch in enumerate(test_dataloader):
            test_x = test_batch[0][0]['image'].to(device)
            test_x_envmap = test_batch[0][1].to(device)
            test_target = test_batch[1][0]['image'].to(device)
            test_target_envmap = test_batch[1][1].to(device)
            test_groundtruth = test_batch[2]['image'].to(device)
Esempio n. 5
0
                scheduler.state_dict(),
            }

            metric_tracker_val.refresh_best_n_epoch_models(
                directory=args.saved_models_filepath,
                filename=args.experiment_name,
                metric_name="accuracy",
                n=args.save_top_n_val_models,
                bigger_is_better=True,
                current_epoch_idx=epoch,
                current_epoch_state=state,
            )

            save_checkpoint(
                state=state,
                directory=args.saved_models_filepath,
                filename=args.experiment_name,
                is_best=False,
            )

        #############################################TESTING############################

        if args.test:
            if args.val_set_percentage >= 0.0:
                top_n_model_idx = metric_tracker_val.get_best_n_epochs_for_metric(
                    metric_name="accuracy", n=1, bigger_is_better=True)[0]
                resume_epoch = restore_model(
                    restore_fields,
                    filename=args.experiment_name,
                    directory=args.saved_models_filepath,
                    epoch_idx=top_n_model_idx,
                )