Esempio n. 1
0
    def _train(self) -> Optional[float]:
        record_keeper, _, _ = logging_presets.get_record_keeper(
            "example_logs", "example_tensorboard")
        hooks = logging_presets.get_hook_container(record_keeper)
        dataset_dict = {"val": self.val_dataset}
        model_folder = "example_saved_models"

        def visualizer_hook(umapper, umap_embeddings, labels, split_name,
                            keyname, *args):
            logging.info("UMAP plot for the {} split and label set {}".format(
                split_name, keyname))
            label_set = np.unique(labels)
            num_classes = len(label_set)
            fig = plt.figure(figsize=(20, 15))
            plt.gca().set_prop_cycle(
                cycler("color", [
                    plt.cm.nipy_spectral(i)
                    for i in np.linspace(0, 0.9, num_classes)
                ]))
            for i in range(num_classes):
                idx = labels == label_set[i]
                plt.plot(umap_embeddings[idx, 0],
                         umap_embeddings[idx, 1],
                         ".",
                         markersize=1)
            #plt.show()
            #plt.show(block=False)
            file_name = './plots/metric_{0}.png'.format(args[0])
            plt.savefig(file_name, dpi=300)
            #

        # Create the tester
        tester = testers.GlobalEmbeddingSpaceTester(
            end_of_testing_hook=hooks.end_of_testing_hook,
            visualizer=umap.UMAP(),
            visualizer_hook=visualizer_hook,
            dataloader_num_workers=32)

        end_of_epoch_hook = hooks.end_of_epoch_hook(tester,
                                                    dataset_dict,
                                                    model_folder,
                                                    test_interval=1,
                                                    patience=200)

        trainer = trainers.MetricLossOnly(
            self.models_dict,
            self.optimizers,
            self._train_cfg.batch_per_gpu,
            self.loss_funcs,
            self.mining_funcs,
            #self._train_loader,
            self.train_set,
            sampler=self.sampler,
            dataloader_num_workers=self._train_cfg.workers - 1,
            end_of_iteration_hook=hooks.end_of_iteration_hook,
            end_of_epoch_hook=end_of_epoch_hook)

        #trainer.train(num_epochs=self._train_cfg.epochs)
        trainer.train(num_epochs=500)
Esempio n. 2
0
def train_eval(args, train_data, dev_data):
    logger = logging.getLogger("main")
    # Create dataset & dataloader
    trans = [
        transforms.Resize((224, 224)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]
    trans = transforms.Compose(trans)

    train_dataset, train_char_idx = \
        create_dataset(args.root, train_data, trans)

    train_sampler = MetricBatchSampler(train_dataset,
                                       train_char_idx,
                                       n_max_per_char=args.n_max_per_char,
                                       n_batch_size=args.n_batch_size,
                                       n_random=args.n_random)
    train_dataloader = DataLoader(train_dataset,
                                  batch_sampler=train_sampler,
                                  collate_fn=collate_fn)
    # number of batches given to trainer
    n_batch = int(len(train_dataloader))

    eval_train_dataloaders = \
        prepare_evaluation_dataloaders(args, args.eval_split*3, train_data, trans)
    eval_dev_dataloaders = \
        prepare_evaluation_dataloaders(args, args.eval_split, dev_data, trans)

    # Construct model & optimizer
    device = "cpu" if args.gpu < 0 else "cuda:{}".format(args.gpu)

    trunk, model = create_models(args.emb_dim, args.dropout)
    trunk.to(device)
    model.to(device)

    if args.metric_loss == "triplet":
        loss_func = losses.TripletMarginLoss(
            margin=args.margin,
            normalize_embeddings=args.normalize,
            smooth_loss=args.smooth)
    elif args.metric_loss == "arcface":
        loss_func = losses.ArcFaceLoss(margin=args.margin,
                                       num_classes=len(train_data),
                                       embedding_size=args.emb_dim)
        loss_func.to(device)

    if args.optimizer == "SGD":
        trunk_optimizer = torch.optim.SGD(trunk.parameters(),
                                          lr=args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.decay)
        model_optimizer = torch.optim.SGD(model.parameters(),
                                          lr=args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.decay)
        optimizers = {
            "trunk_optimizer": trunk_optimizer,
            "embedder_optimizer": model_optimizer
        }
        if args.metric_loss == "arcface":
            loss_optimizer = torch.optim.SGD(loss_func.parameters(),
                                             lr=args.lr,
                                             momentum=args.momentum,
                                             weight_decay=args.decay)
            optimizers["loss_optimizer"] = loss_optimizer
    elif args.optimizer == "Adam":
        trunk_optimizer = torch.optim.Adam(trunk.parameters(),
                                           lr=args.lr,
                                           weight_decay=args.decay)
        model_optimizer = torch.optim.Adam(model.parameters(),
                                           lr=args.lr,
                                           weight_decay=args.decay)
        optimizers = {
            "trunk_optimizer": trunk_optimizer,
            "embedder_optimizer": model_optimizer
        }
        if args.metric_loss == "arcface":
            loss_optimizer = torch.optim.Adam(loss_func.parameters(),
                                              lr=args.lr,
                                              weight_decay=args.decay)
            optimizers["loss_optimizer"] = loss_optimizer
    else:
        raise NotImplementedError

    def lr_func(step):
        if step < args.warmup:
            return (step + 1) / args.warmup
        else:
            steps_decay = step // args.decay_freq
            return 1 / args.decay_factor**steps_decay

    trunk_scheduler = torch.optim.lr_scheduler.LambdaLR(
        trunk_optimizer, lr_func)
    model_scheduler = torch.optim.lr_scheduler.LambdaLR(
        model_optimizer, lr_func)
    schedulers = {
        "trunk_scheduler": trunk_scheduler,
        "model_scheduler": model_scheduler
    }

    if args.miner == "none":
        mining_funcs = {}
    elif args.miner == "batch-hard":
        mining_funcs = {
            "post_gradient_miner": miners.BatchHardMiner(use_similarity=True)
        }

    best_dev_eer = 1.0
    i_epoch = 0

    def end_of_epoch_hook(trainer):
        nonlocal i_epoch, best_dev_eer

        logger.info(f"EPOCH\t{i_epoch}")

        if i_epoch % args.eval_freq == 0:
            train_eer, train_eer_std = evaluate(args, trainer.models["trunk"],
                                                trainer.models["embedder"],
                                                eval_train_dataloaders)
            dev_eer, dev_eer_std = evaluate(args, trainer.models["trunk"],
                                            trainer.models["embedder"],
                                            eval_dev_dataloaders)
            logger.info("Eval EER (mean, std):\t{}\t{}".format(
                train_eer, train_eer_std))
            logger.info("Eval EER (mean, std):\t{}\t{}".format(
                dev_eer, dev_eer_std))
            if dev_eer < best_dev_eer:
                logger.info("New best model!")
                best_dev_eer = dev_eer

        i_epoch += 1

    def end_of_iteration_hook(trainer):
        for scheduler in schedulers.values():
            scheduler.step()

    trainer = trainers.MetricLossOnly(
        models={
            "trunk": trunk,
            "embedder": model
        },
        optimizers=optimizers,
        batch_size=None,
        loss_funcs={"metric_loss": loss_func},
        mining_funcs=mining_funcs,
        iterations_per_epoch=n_batch,
        dataset=train_dataset,
        data_device=None,
        loss_weights=None,
        sampler=train_sampler,
        collate_fn=collate_fn,
        lr_schedulers=None,
        end_of_epoch_hook=end_of_epoch_hook,
        end_of_iteration_hook=end_of_iteration_hook,
        dataloader_num_workers=1)

    trainer.train(num_epochs=args.epoch)

    if args.save_model:
        save_models = {
            "trunk": trainer.models["trunk"].state_dict(),
            "embedder": trainer.models["embedder"].state_dict(),
            "args": [args.emb_dim]
        }
        torch.save(save_models, f"model/{args.suffix}.mdl")

    return best_dev_eer
def train_eval(args, train_data, dev_data):
    logger = logging.getLogger("main")
    # Create dataset & dataloader
    trans = [
        transforms.Resize((224, 224)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]
    trans = transforms.Compose(trans)

    train_dataset, train_char_idx = \
        create_dataset(args.root, train_data, trans)

    train_sampler = MetricBatchSampler(train_dataset,
                                       train_char_idx,
                                       n_max_per_char=args.n_max_per_char,
                                       n_batch_size=args.n_batch_size,
                                       n_random=args.n_random)
    train_dataloader = DataLoader(train_dataset,
                                  batch_sampler=train_sampler,
                                  collate_fn=collate_fn)
    # number of batches given to trainer
    n_batch = int(len(train_dataloader))

    eval_train_dataloaders = \
        prepare_evaluation_dataloaders(args, args.eval_split*3, train_data, trans)
    eval_dev_dataloaders = \
        prepare_evaluation_dataloaders(args, args.eval_split, dev_data, trans)

    # Construct model & optimizer
    device = "cpu" if args.gpu < 0 else "cuda:{}".format(args.gpu)

    trunk = models.resnet18(pretrained=True)
    trunk_output_size = trunk.fc.in_features
    trunk.fc = Identity()
    trunk.to(device)

    model = nn.Sequential(nn.Linear(trunk_output_size, args.emb_dim),
                          Normalize())
    model.to(device)

    if args.optimizer == "SGD":
        trunk_optimizer = torch.optim.SGD(trunk.parameters(),
                                          lr=args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.decay)
        model_optimizer = torch.optim.SGD(model.parameters(),
                                          lr=args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.decay)
    else:
        raise NotImplementedError

    loss_func = losses.TripletMarginLoss(margin=args.margin,
                                         normalize_embeddings=args.normalize)

    best_dev_eer = 1.0
    i_epoch = 0

    def end_of_epoch_hook(trainer):
        nonlocal i_epoch, best_dev_eer

        logger.info(f"EPOCH\t{i_epoch}")

        if i_epoch % args.eval_freq == 0:
            train_eer, train_eer_std = evaluate(args, trainer.models["trunk"],
                                                trainer.models["embedder"],
                                                eval_train_dataloaders)
            dev_eer, dev_eer_std = evaluate(args, trainer.models["trunk"],
                                            trainer.models["embedder"],
                                            eval_dev_dataloaders)
            logger.info("Eval EER (mean, std):\t{}\t{}".format(
                train_eer, train_eer_std))
            logger.info("Eval EER (mean, std):\t{}\t{}".format(
                dev_eer, dev_eer_std))
            if dev_eer < best_dev_eer:
                logger.info("New best model!")
                best_dev_eer = dev_eer

        i_epoch += 1

    trainer = trainers.MetricLossOnly(
        models={
            "trunk": trunk,
            "embedder": model
        },
        optimizers={
            "trunk_optimizer": trunk_optimizer,
            "embedder_optimizer": model_optimizer
        },
        batch_size=None,
        loss_funcs={"metric_loss": loss_func},
        mining_funcs={},
        iterations_per_epoch=n_batch,
        dataset=train_dataset,
        data_device=None,
        loss_weights=None,
        sampler=train_sampler,
        collate_fn=collate_fn,
        lr_schedulers=None,  #TODO: use warm-up,
        end_of_epoch_hook=end_of_epoch_hook,
        dataloader_num_workers=1)

    trainer.train(num_epochs=args.epoch)

    if args.save_model:
        torch.save(trainer.models, f"model/{args.suffix}.mdl")

    return best_dev_eer
Esempio n. 4
0
# tester
tester = testers.GlobalEmbeddingSpaceTester(
    end_of_testing_hook=hooks.end_of_testing_hook, dataloader_num_workers=32)
end_of_epoch_hook = hooks.end_of_epoch_hook(tester, {"val": dev_dataset},
                                            os.path.join(
                                                args.log_dir, f"model"),
                                            test_interval=1,
                                            patience=args.patience)

# train
if args.trainer == "MetricLossOnly":
    trainer = trainers.MetricLossOnly(
        batch_size=batch_size,
        mining_funcs={},
        dataset=train_dataset,
        sampler=train_sampler,
        dataloader_num_workers=32,
        end_of_iteration_hook=hooks.end_of_iteration_hook,
        end_of_epoch_hook=end_of_epoch_hook,
        **trainer_kwargs)
elif args.trainer == "TrainWithClassifier":
    trainer = trainers.TrainWithClassifier(
        batch_size=batch_size,
        mining_funcs={},
        dataset=train_dataset,
        sampler=train_sampler,
        dataloader_num_workers=32,
        end_of_iteration_hook=hooks.end_of_iteration_hook,
        end_of_epoch_hook=end_of_epoch_hook,
        **trainer_kwargs)
trainer.train(num_epochs=args.max_epoch)
Esempio n. 5
0
def main():
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print('Running on device: {}'.format(device))

    # Data transformations
    trans_train = transforms.Compose([
        transforms.RandomApply(transforms=[
            transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET),
            # transforms.RandomPerspective(distortion_scale=0.6, p=1.0),
            transforms.RandomRotation(degrees=(0, 180)),
            transforms.RandomHorizontalFlip(),
        ]),
        np.float32,
        transforms.ToTensor(),
        fixed_image_standardization,
    ])

    trans_val = transforms.Compose([
        # transforms.CenterCrop(120),
        np.float32,
        transforms.ToTensor(),
        fixed_image_standardization,
    ])

    train_dataset = datasets.ImageFolder(os.path.join(data_dir,
                                                      "train_aligned"),
                                         transform=trans_train)
    val_dataset = datasets.ImageFolder(os.path.join(data_dir, "val_aligned"),
                                       transform=trans_val)

    # Prepare the model
    model = InceptionResnetV1(classify=False,
                              pretrained="vggface2",
                              dropout_prob=0.5).to(device)

    # for param in list(model.parameters())[:-8]:
    #     param.requires_grad = False

    trunk_optimizer = torch.optim.SGD(model.parameters(), lr=LR)

    # Set the loss function
    loss = losses.ArcFaceLoss(len(train_dataset.classes), 512)

    # Package the above stuff into dictionaries.
    models = {"trunk": model}
    optimizers = {"trunk_optimizer": trunk_optimizer}
    loss_funcs = {"metric_loss": loss}
    mining_funcs = {}
    lr_scheduler = {
        "trunk_scheduler_by_plateau":
        torch.optim.lr_scheduler.ReduceLROnPlateau(trunk_optimizer)
    }

    # Create the tester
    record_keeper, _, _ = logging_presets.get_record_keeper(
        "logs", "tensorboard")
    hooks = logging_presets.get_hook_container(record_keeper)

    dataset_dict = {"val": val_dataset, "train": train_dataset}
    model_folder = "training_saved_models"

    def visualizer_hook(umapper, umap_embeddings, labels, split_name, keyname,
                        *args):
        logging.info("UMAP plot for the {} split and label set {}".format(
            split_name, keyname))
        label_set = np.unique(labels)
        num_classes = len(label_set)
        fig = plt.figure(figsize=(8, 7))
        plt.gca().set_prop_cycle(
            cycler("color", [
                plt.cm.nipy_spectral(i)
                for i in np.linspace(0, 0.9, num_classes)
            ]))
        for i in range(num_classes):
            idx = labels == label_set[i]
            plt.plot(umap_embeddings[idx, 0],
                     umap_embeddings[idx, 1],
                     ".",
                     markersize=1)
        plt.show()

    tester = testers.GlobalEmbeddingSpaceTester(
        end_of_testing_hook=hooks.end_of_testing_hook,
        dataloader_num_workers=4,
        accuracy_calculator=AccuracyCalculator(
            include=['mean_average_precision_at_r'], k="max_bin_count"))

    end_of_epoch_hook = hooks.end_of_epoch_hook(tester,
                                                dataset_dict,
                                                model_folder,
                                                splits_to_eval=[('val',
                                                                 ['train'])])

    # Create the trainer
    trainer = trainers.MetricLossOnly(
        models,
        optimizers,
        batch_size,
        loss_funcs,
        mining_funcs,
        train_dataset,
        lr_schedulers=lr_scheduler,
        dataloader_num_workers=8,
        end_of_iteration_hook=hooks.end_of_iteration_hook,
        end_of_epoch_hook=end_of_epoch_hook)

    trainer.train(num_epochs=num_epochs)
Esempio n. 6
0
num_epochs = 2
iterations_per_epoch = 100

# Package the above stuff into dictionaries.
models = {"trunk": trunk, "embedder": embedder}
optimizers = {
    "trunk_optimizer": trunk_optimizer,
    "embedder_optimizer": embedder_optimizer
}
loss_funcs = {"metric_loss": loss}
mining_funcs = {"post_gradient_miner": miner}

trainer = trainers.MetricLossOnly(models,
                                  optimizers,
                                  batch_size,
                                  loss_funcs,
                                  mining_funcs,
                                  iterations_per_epoch,
                                  train_dataset,
                                  record_keeper=record_keeper)

trainer.train(num_epochs=num_epochs)

#############################
########## Testing ##########
#############################

# The testing module requires faiss and scikit-learn
# So if you don't have these, then this import will break
from pytorch_metric_learning import testers

tester = testers.GlobalEmbeddingSpaceTester(record_keeper=record_keeper)
Esempio n. 7
0
# Create the tester
tester = testers.GlobalEmbeddingSpaceTester(
    end_of_testing_hook=hooks.end_of_testing_hook, dataloader_num_workers=32)

end_of_epoch_hook = hooks.end_of_epoch_hook(tester,
                                            dataset_dict,
                                            model_folder,
                                            test_interval=1,
                                            patience=1)

trainer = trainers.MetricLossOnly(
    models,
    optimizers,
    batch_size,
    loss_funcs,
    mining_funcs,
    train_dataset,
    sampler=sampler,
    dataloader_num_workers=32,
    end_of_iteration_hook=hooks.end_of_iteration_hook,
    end_of_epoch_hook=end_of_epoch_hook)

#Train the model
trainer.train(num_epochs=num_epochs)

PATH1 = './SentinelNaip_TripletMarginMiner_trunk.pth'
PATH2 = './SentinelNaip_TripletMarginMiner_embed.pth'
torch.save(trunk.state_dict(), PATH1)
torch.save(embedder.state_dict(), PATH2)

# Get a dictionary mapping from loss names to lists
Esempio n. 8
0
def train(train_data, test_data, save_model, num_epochs, lr, embedding_size,
          batch_size):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Set trunk model and replace the softmax layer with an identity function
    trunk = torchvision.models.resnet18(pretrained=True)
    trunk_output_size = trunk.fc.in_features
    trunk.fc = common_functions.Identity()
    trunk = torch.nn.DataParallel(trunk.to(device))

    # Set embedder model. This takes in the output of the trunk and outputs 64 dimensional embeddings
    embedder = torch.nn.DataParallel(
        MLP([trunk_output_size, embedding_size]).to(device))

    # Set optimizers
    trunk_optimizer = torch.optim.Adam(trunk.parameters(),
                                       lr=lr / 10,
                                       weight_decay=0.0001)
    embedder_optimizer = torch.optim.Adam(embedder.parameters(),
                                          lr=lr,
                                          weight_decay=0.0001)

    # Set the loss function
    loss = losses.TripletMarginLoss(margin=0.1)

    # Set the mining function
    miner = miners.MultiSimilarityMiner(epsilon=0.1)

    # Set the dataloader sampler
    sampler = samplers.MPerClassSampler(train_data.targets,
                                        m=4,
                                        length_before_new_iter=len(train_data))

    save_dir = os.path.join(
        save_model, ''.join(str(lr).split('.')) + '_' + str(batch_size) + '_' +
        str(embedding_size))

    os.makedirs(save_dir, exist_ok=True)

    # Package the above stuff into dictionaries.
    models = {"trunk": trunk, "embedder": embedder}
    optimizers = {
        "trunk_optimizer": trunk_optimizer,
        "embedder_optimizer": embedder_optimizer
    }
    loss_funcs = {"metric_loss": loss}
    mining_funcs = {"tuple_miner": miner}

    record_keeper, _, _ = logging_presets.get_record_keeper(
        os.path.join(save_dir, "example_logs"),
        os.path.join(save_dir, "example_tensorboard"))
    hooks = logging_presets.get_hook_container(record_keeper)

    dataset_dict = {"val": test_data, "train": train_data}
    model_folder = "example_saved_models"

    def visualizer_hook(umapper, umap_embeddings, labels, split_name, keyname,
                        *args):
        logging.info("UMAP plot for the {} split and label set {}".format(
            split_name, keyname))
        label_set = np.unique(labels)
        num_classes = len(label_set)
        fig = plt.figure(figsize=(20, 15))
        plt.title(str(split_name) + '_' + str(num_embeddings))
        plt.gca().set_prop_cycle(
            cycler("color", [
                plt.cm.nipy_spectral(i)
                for i in np.linspace(0, 0.9, num_classes)
            ]))
        for i in range(num_classes):
            idx = labels == label_set[i]
            plt.plot(umap_embeddings[idx, 0],
                     umap_embeddings[idx, 1],
                     ".",
                     markersize=1)
        plt.show()

    # Create the tester
    tester = testers.GlobalEmbeddingSpaceTester(
        end_of_testing_hook=hooks.end_of_testing_hook,
        visualizer=umap.UMAP(),
        visualizer_hook=visualizer_hook,
        dataloader_num_workers=32,
        accuracy_calculator=AccuracyCalculator(k="max_bin_count"))

    end_of_epoch_hook = hooks.end_of_epoch_hook(tester,
                                                dataset_dict,
                                                model_folder,
                                                test_interval=1,
                                                patience=1)

    trainer = trainers.MetricLossOnly(
        models,
        optimizers,
        batch_size,
        loss_funcs,
        mining_funcs,
        train_data,
        sampler=sampler,
        dataloader_num_workers=32,
        end_of_iteration_hook=hooks.end_of_iteration_hook,
        end_of_epoch_hook=end_of_epoch_hook)

    trainer.train(num_epochs=num_epochs)

    if save_model is not None:

        torch.save(models["trunk"].state_dict(),
                   os.path.join(save_dir, 'trunk.pth'))
        torch.save(models["embedder"].state_dict(),
                   os.path.join(save_dir, 'embedder.pth'))

        print('Model saved in ', save_dir)
# Package the above stuff into dictionaries.
models = {"trunk": trunk, "embedder": embedder}
optimizers = {
    "trunk_optimizer": trunk_optimizer,
    "embedder_optimizer": embedder_optimizer
}
loss_funcs = {"metric_loss": loss}
mining_funcs = {"post_gradient_miner": miner}

record_keeper = get_record_keeper()

trainer = trainers.MetricLossOnly(models,
                                  optimizers,
                                  batch_size,
                                  loss_funcs,
                                  mining_funcs,
                                  iterations_per_epoch,
                                  train_dataset,
                                  sampler=sampler,
                                  record_keeper=record_keeper,
                                  dataset_labels=train_dataset.targets)

trainer.train(num_epochs=num_epochs)

#############################
########## Testing ##########
#############################

# The testing module requires faiss
# So if you don't have these, then this import will break
from pytorch_metric_learning import testers
Esempio n. 10
0
models = {"trunk": trunk, "embedder": embedder}
optimizers = {
    "trunk_optimizer": trunk_optimizer,
    "embedder_optimizer": embedder_optimizer
}
loss_funcs = {"metric_loss": loss}
mining_funcs = {"post_gradient_miner": miner}

record_keeper, _, _ = logging_presets.get_record_keeper(
    "example_logs", "example_tensorboard")
hooks = logging_presets.get_hook_container(record_keeper)
dataset_dict = {"val": val_dataset}
model_folder = "example_saved_models"

# Create the tester
tester = testers.GlobalEmbeddingSpaceTester(
    end_of_testing_hook=hooks.end_of_testing_hook)
end_of_epoch_hook = hooks.end_of_epoch_hook(tester, dataset_dict, model_folder)
trainer = trainers.MetricLossOnly(
    models,
    optimizers,
    batch_size,
    loss_funcs,
    mining_funcs,
    iterations_per_epoch,
    train_dataset,
    sampler=sampler,
    end_of_iteration_hook=hooks.end_of_iteration_hook,
    end_of_epoch_hook=end_of_epoch_hook)

trainer.train(num_epochs=num_epochs)
Esempio n. 11
0
def objective(trial):
    param_gen = ParameterGenerator(trial, CONF["_fix_params"], logger=logger)

    # Average results of multiple folds.
    print("New parameter.")
    metrics = []
    constructors = MODEL_DEF.get(CONF, trial, param_gen)
    for i_fold, (train_dataset, dev_dataset, train_sampler,
                 batch_size) in enumerate(constructors["fold_generator"]()):
        print(f"Fold {i_fold}")
        trainer_kwargs = constructors["modules"]()

        # logging
        record_keeper, _, _ = logging_presets.get_record_keeper(
            csv_folder=os.path.join(args.log_dir,
                                    f"trial_{trial.number}_{i_fold}_csv"),
            tensorboard_folder=os.path.join(
                args.log_dir, f"trial_{trial.number}_{i_fold}_tensorboard"))
        hooks = logging_presets.get_hook_container(record_keeper)

        # tester
        tester = testers.GlobalEmbeddingSpaceTester(
            end_of_testing_hook=hooks.end_of_testing_hook,
            dataloader_num_workers=args.n_test_loader)
        end_of_epoch_hook = hooks.end_of_epoch_hook(
            tester, {"val": dev_dataset},
            os.path.join(args.log_dir, f"trial_{trial.number}_{i_fold}_model"),
            test_interval=1,
            patience=args.patience)

        CHECKPOINT_FN = os.path.join(
            args.log_dir, f"trial_{trial.number}_{i_fold}_last.pth")

        def actual_end_of_epoch_hook(trainer):
            continue_training = end_of_epoch_hook(trainer)

            torch.save(
                ({k: m.state_dict()
                  for k, m in trainer.models.items()},
                 {k: m.state_dict()
                  for k, m in trainer.optimizers.items()},
                 {k: m.state_dict()
                  for k, m in trainer.loss_funcs.items()}, trainer.epoch),
                CHECKPOINT_FN)

            return continue_training

        # train
        if args.trainer == "MetricLossOnly":
            trainer = trainers.MetricLossOnly(
                batch_size=batch_size,
                mining_funcs={},
                dataset=train_dataset,
                sampler=train_sampler,
                dataloader_num_workers=args.n_train_loader,
                end_of_iteration_hook=hooks.end_of_iteration_hook,
                end_of_epoch_hook=actual_end_of_epoch_hook,
                **trainer_kwargs)
        elif args.trainer == "TrainWithClassifier":
            trainer = trainers.TrainWithClassifier(
                batch_size=batch_size,
                mining_funcs={},
                dataset=train_dataset,
                sampler=train_sampler,
                dataloader_num_workers=args.n_train_loader,
                end_of_iteration_hook=hooks.end_of_iteration_hook,
                end_of_epoch_hook=actual_end_of_epoch_hook,
                **trainer_kwargs)

        while True:
            start_epoch = 1
            if os.path.exists(CHECKPOINT_FN):
                model_dicts, optimizer_dicts, loss_dicts, last_epoch = \
                    torch.load(CHECKPOINT_FN)
                for k, d in model_dicts.items():
                    trainer.models[k].load_state_dict(d)
                for k, d in optimizer_dicts.items():
                    trainer.optimizers[k].load_state_dict(d)
                for k, d in loss_dicts.items():
                    trainer.loss_funcs[k].load_state_dict(d)
                start_epoch = last_epoch + 1

                logger.critical(f"Start from old epoch: {last_epoch + 1}")
            try:
                trainer.train(num_epochs=args.max_epoch,
                              start_epoch=start_epoch)
            except Exception as err:
                logger.critical(f"Error: {err}")
                if not args.ignore_error:
                    break
                else:
                    raise err
            else:
                break

        rslt = hooks.get_accuracy_history(
            tester, "val", metrics=["mean_average_precision_at_r"])

        metrics.append(max(rslt["mean_average_precision_at_r_level0"]))
    return np.mean(metrics)