Exemplo n.º 1
0
def get_learner(nb_classes, all=False, args=ARGS):

    model = _get_model(nb_classes, args=args)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

    from src.metrics import pytorch_metrics
    criterion = pytorch_metrics.RMSELoss()
    metrics = [pytorch_metrics.RMSELoss(round=True), ]

    from torch.optim.lr_scheduler import ReduceLROnPlateau
    if all:
        from src.pytorch.schedulers import StepLR
        scheduler = StepLR(optimizer, step_size=args.all_train_lr_step, gamma=0.1, min_lr=1e-5)
        from src.pytorch.early_stoppers import EarlyStopping
        early_stopper = None
    else:
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True, threshold=0.001,
                                      min_lr=1e-5)
        from src.pytorch.early_stoppers import EarlyStopping
        early_stopper = EarlyStopping(patience=4, verbose=True, delta=0.001, save_model_path=None, wait=1)

    from src.pytorch.wrappers import PyTorchNN_vA as PyTorchNN
    pytorchmodel = PyTorchNN(model, optimizer, criterion, metrics, scheduler, early_stopper, device=DEVICE,
                             virtual_batch_size=args.virtual_batch_size)

    return pytorchmodel
Exemplo n.º 2
0
def get_learner(nb_classes, args=ARGS):
    model = _get_model(nb_classes, args=args)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

    criterion = nn.CrossEntropyLoss()

    from src.metrics import pytorch_metrics
    metrics = [
        pytorch_metrics.CategoricalAccuracy(),
        pytorch_metrics.TopKAccuracy(top_k=2)
    ]

    from src.pytorch.schedulers import StepLR
    scheduler = StepLR(optimizer, step_size=8, gamma=0.1, min_lr=1e-5)

    from src.pytorch.early_stoppers import EarlyStopping
    early_stopper = EarlyStopping(patience=6,
                                  verbose=True,
                                  delta=0.0001,
                                  save_model_path=None,
                                  wait=12)

    from src.pytorch.wrappers import PyTorchNN_vA as PyTorchNN
    pytorchmodel = PyTorchNN(model,
                             optimizer,
                             criterion,
                             metrics,
                             scheduler,
                             early_stopper,
                             device=DEVICE,
                             virtual_batch_size=args.virtual_batch_size)

    return pytorchmodel
def get_learner(nb_classes, args=ARGS):
    model = _get_model(nb_classes, args=args)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

    if args.outputs_as_onehotencoder:
        criterion = nn.BCEWithLogitsLoss()

        from src.metrics import pytorch_metrics
        metrics = [
            pytorch_metrics.CrossEntropyLoss(one_hot_encoding=True),
        ]
    else:
        criterion = nn.CrossEntropyLoss()

        from src.metrics import pytorch_metrics
        metrics = [
            pytorch_metrics.CategoricalAccuracy(),
            pytorch_metrics.TopKAccuracy(top_k=2)
        ]

    from torch.optim.lr_scheduler import ReduceLROnPlateau
    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='min',
                                  factor=0.1,
                                  patience=3,
                                  verbose=True,
                                  threshold=0.001,
                                  min_lr=1e-5)

    from src.pytorch.early_stoppers import EarlyStopping
    early_stopper = EarlyStopping(patience=6,
                                  verbose=True,
                                  delta=0.0001,
                                  save_model_path=None)

    from src.pytorch.wrappers import PyTorchNN_vA as PyTorchNN
    pytorchmodel = PyTorchNN(model,
                             optimizer,
                             criterion,
                             metrics,
                             scheduler,
                             early_stopper,
                             device=DEVICE,
                             virtual_batch_size=args.virtual_batch_size,
                             mixup=args.mixup,
                             mixup_alpha=args.mixup_alpha,
                             mixup_method=args.mixup_method)

    return pytorchmodel
Exemplo n.º 4
0
def load_model(filename, nb_classes=None, load_learner=False):
    print("-" * 80)
    print("LOAD MODEL")
    # Loading model
    print(f"Loading model from file: {filename}")
    checkpoint = torch.load(filename)
    # Generate model
    model = _get_model(nb_classes=nb_classes)
    model.load_state_dict(checkpoint['model_state_dict'])
    # Load learner
    if load_learner:
        pytorchmodel = get_learner(model=model)
    else:
        from src.pytorch.wrappers import PyTorchNN_vA as PyTorchNN
        pytorchmodel = PyTorchNN(model, device=DEVICE)

    return pytorchmodel