def main(): if config.gpu and not torch.cuda.is_available(): raise ValueError("GPU not supported or enabled on this system.") use_gpu = config.gpu log.info("Loading train dataset") train_dataset = COVIDxFolder( config.train_imgs, config.train_labels, transforms.train_transforms(config.width, config.height)) train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, drop_last=True, num_workers=config.n_threads, pin_memory=use_gpu) log.info("Number of training examples {}".format(len(train_dataset))) log.info("Loading val dataset") val_dataset = COVIDxFolder( config.val_imgs, config.val_labels, transforms.val_transforms(config.width, config.height)) val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.n_threads, pin_memory=use_gpu) log.info("Number of validation examples {}".format(len(val_dataset))) if config.weights: # state = torch.load(config.weights) state = None log.info("Loaded model weights from: {}".format(config.weights)) else: state = None state_dict = state["state_dict"] if state else None model = architecture.COVIDEfficientnet(n_classes=config.n_classes) if state_dict: model = util.load_model_weights(model=model, state_dict=state_dict) if use_gpu: model.cuda() model = torch.nn.DataParallel(model) optim_layers = filter(lambda p: p.requires_grad, model.parameters()) # optimizer and lr scheduler optimizer = RAdam(optim_layers, lr=config.lr, weight_decay=config.weight_decay) scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=config.lr_reduce_factor, patience=config.lr_reduce_patience, mode='max', min_lr=1e-7) # Load the last global_step from the checkpoint if existing global_step = 0 if state is None else state['global_step'] + 1 class_weights = util.to_device(torch.FloatTensor(config.loss_weights), gpu=use_gpu) loss_fn = CrossEntropyLoss() # Reset the best metric score best_score = -1 # Training for epoch in range(config.epochs): log.info("Started epoch {}/{}".format(epoch + 1, config.epochs)) for data in train_loader: imgs, labels = data imgs = util.to_device(imgs, gpu=use_gpu) labels = util.to_device(labels, gpu=use_gpu) logits = model(imgs) loss = loss_fn(logits, labels) optimizer.zero_grad() loss.backward() optimizer.step() if global_step % config.log_steps == 0 and global_step > 0: probs = model.module.probability(logits) preds = torch.argmax(probs, dim=1).detach().cpu().numpy() labels = labels.cpu().detach().numpy() acc, f1, _, _ = util.clf_metrics(preds, labels) lr = util.get_learning_rate(optimizer) log.info("Step {} | TRAINING batch: Loss {:.4f} | F1 {:.4f} | " "Accuracy {:.4f} | LR {:.2e}".format( global_step, loss.item(), f1, acc, lr)) if global_step % config.eval_steps == 0 and global_step > 0: best_score = validate(val_loader, model, best_score=best_score, global_step=global_step, cfg=config) scheduler.step(best_score) global_step += 1
class Optimizer(nn.Module): def __init__(self, model): super(Optimizer, self).__init__() self.setup_optimizer(model) def setup_optimizer(self, model): params = [] for key, value in model.named_parameters(): if not value.requires_grad: continue lr = cfg.SOLVER.BASE_LR weight_decay = cfg.SOLVER.WEIGHT_DECAY if "bias" in key: lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS params += [{ "params": [value], "lr": lr, "weight_decay": weight_decay }] if cfg.SOLVER.TYPE == 'SGD': self.optimizer = torch.optim.SGD(params, lr=cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.SGD.MOMENTUM) elif cfg.SOLVER.TYPE == 'ADAM': self.optimizer = torch.optim.Adam(params, lr=cfg.SOLVER.BASE_LR, betas=cfg.SOLVER.ADAM.BETAS, eps=cfg.SOLVER.ADAM.EPS) elif cfg.SOLVER.TYPE == 'ADAMAX': self.optimizer = torch.optim.Adamax(params, lr=cfg.SOLVER.BASE_LR, betas=cfg.SOLVER.ADAM.BETAS, eps=cfg.SOLVER.ADAM.EPS) elif cfg.SOLVER.TYPE == 'ADAGRAD': self.optimizer = torch.optim.Adagrad(params, lr=cfg.SOLVER.BASE_LR) elif cfg.SOLVER.TYPE == 'RMSPROP': self.optimizer = torch.optim.RMSprop(params, lr=cfg.SOLVER.BASE_LR) elif cfg.SOLVER.TYPE == 'RADAM': self.optimizer = RAdam(params, lr=cfg.SOLVER.BASE_LR, betas=cfg.SOLVER.ADAM.BETAS, eps=cfg.SOLVER.ADAM.EPS) else: raise NotImplementedError if cfg.SOLVER.LR_POLICY.TYPE == 'Fix': self.scheduler = None elif cfg.SOLVER.LR_POLICY.TYPE == 'Step': self.scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, step_size=cfg.SOLVER.LR_POLICY.STEP_SIZE, gamma=cfg.SOLVER.LR_POLICY.GAMMA) elif cfg.SOLVER.LR_POLICY.TYPE == 'Plateau': self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, factor=cfg.SOLVER.LR_POLICY.PLATEAU_FACTOR, patience=cfg.SOLVER.LR_POLICY.PLATEAU_PATIENCE) elif cfg.SOLVER.LR_POLICY.TYPE == 'Noam': self.scheduler = lr_scheduler.create( 'Noam', self.optimizer, model_size=cfg.SOLVER.LR_POLICY.MODEL_SIZE, factor=cfg.SOLVER.LR_POLICY.FACTOR, warmup=cfg.SOLVER.LR_POLICY.WARMUP) elif cfg.SOLVER.LR_POLICY.TYPE == 'MultiStep': self.scheduler = lr_scheduler.create( 'MultiStep', self.optimizer, milestones=cfg.SOLVER.LR_POLICY.STEPS, gamma=cfg.SOLVER.LR_POLICY.GAMMA) else: raise NotImplementedError def zero_grad(self): self.optimizer.zero_grad() def step(self): self.optimizer.step() def scheduler_step(self, lrs_type, val=None): if self.scheduler is None: return if cfg.SOLVER.LR_POLICY.TYPE != 'Plateau': val = None if lrs_type == cfg.SOLVER.LR_POLICY.SETP_TYPE: self.scheduler.step(val) def get_lr(self): lr = [] for param_group in self.optimizer.param_groups: lr.append(param_group['lr']) lr = sorted(list(set(lr))) return lr