def get_learner(nb_classes, all=False, args=ARGS): model = _get_model(nb_classes, args=args) optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) from src.metrics import pytorch_metrics criterion = pytorch_metrics.RMSELoss() metrics = [pytorch_metrics.RMSELoss(round=True), ] from torch.optim.lr_scheduler import ReduceLROnPlateau if all: from src.pytorch.schedulers import StepLR scheduler = StepLR(optimizer, step_size=args.all_train_lr_step, gamma=0.1, min_lr=1e-5) from src.pytorch.early_stoppers import EarlyStopping early_stopper = None else: scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True, threshold=0.001, min_lr=1e-5) from src.pytorch.early_stoppers import EarlyStopping early_stopper = EarlyStopping(patience=4, verbose=True, delta=0.001, save_model_path=None, wait=1) from src.pytorch.wrappers import PyTorchNN_vA as PyTorchNN pytorchmodel = PyTorchNN(model, optimizer, criterion, metrics, scheduler, early_stopper, device=DEVICE, virtual_batch_size=args.virtual_batch_size) return pytorchmodel
def get_learner(nb_classes, args=ARGS): model = _get_model(nb_classes, args=args) optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) criterion = nn.CrossEntropyLoss() from src.metrics import pytorch_metrics metrics = [ pytorch_metrics.CategoricalAccuracy(), pytorch_metrics.TopKAccuracy(top_k=2) ] from src.pytorch.schedulers import StepLR scheduler = StepLR(optimizer, step_size=8, gamma=0.1, min_lr=1e-5) from src.pytorch.early_stoppers import EarlyStopping early_stopper = EarlyStopping(patience=6, verbose=True, delta=0.0001, save_model_path=None, wait=12) from src.pytorch.wrappers import PyTorchNN_vA as PyTorchNN pytorchmodel = PyTorchNN(model, optimizer, criterion, metrics, scheduler, early_stopper, device=DEVICE, virtual_batch_size=args.virtual_batch_size) return pytorchmodel
def get_learner(nb_classes, args=ARGS): model = _get_model(nb_classes, args=args) optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) if args.outputs_as_onehotencoder: criterion = nn.BCEWithLogitsLoss() from src.metrics import pytorch_metrics metrics = [ pytorch_metrics.CrossEntropyLoss(one_hot_encoding=True), ] else: criterion = nn.CrossEntropyLoss() from src.metrics import pytorch_metrics metrics = [ pytorch_metrics.CategoricalAccuracy(), pytorch_metrics.TopKAccuracy(top_k=2) ] from torch.optim.lr_scheduler import ReduceLROnPlateau scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True, threshold=0.001, min_lr=1e-5) from src.pytorch.early_stoppers import EarlyStopping early_stopper = EarlyStopping(patience=6, verbose=True, delta=0.0001, save_model_path=None) from src.pytorch.wrappers import PyTorchNN_vA as PyTorchNN pytorchmodel = PyTorchNN(model, optimizer, criterion, metrics, scheduler, early_stopper, device=DEVICE, virtual_batch_size=args.virtual_batch_size, mixup=args.mixup, mixup_alpha=args.mixup_alpha, mixup_method=args.mixup_method) return pytorchmodel