def main():
    if config.gpu and not torch.cuda.is_available():
        raise ValueError("GPU not supported or enabled on this system.")
    use_gpu = config.gpu

    log.info("Loading train dataset")
    train_dataset = COVIDxFolder(
        config.train_imgs, config.train_labels,
        transforms.train_transforms(config.width, config.height))
    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              shuffle=True,
                              drop_last=True,
                              num_workers=config.n_threads,
                              pin_memory=use_gpu)
    log.info("Number of training examples {}".format(len(train_dataset)))

    log.info("Loading val dataset")
    val_dataset = COVIDxFolder(
        config.val_imgs, config.val_labels,
        transforms.val_transforms(config.width, config.height))
    val_loader = DataLoader(val_dataset,
                            batch_size=config.batch_size,
                            shuffle=False,
                            num_workers=config.n_threads,
                            pin_memory=use_gpu)
    log.info("Number of validation examples {}".format(len(val_dataset)))

    if config.weights:
        # state = torch.load(config.weights)
        state = None
        log.info("Loaded model weights from: {}".format(config.weights))
    else:
        state = None

    state_dict = state["state_dict"] if state else None
    model = architecture.COVIDEfficientnet(n_classes=config.n_classes)
    if state_dict:
        model = util.load_model_weights(model=model, state_dict=state_dict)

    if use_gpu:
        model.cuda()
        model = torch.nn.DataParallel(model)
    optim_layers = filter(lambda p: p.requires_grad, model.parameters())

    # optimizer and lr scheduler
    optimizer = RAdam(optim_layers,
                      lr=config.lr,
                      weight_decay=config.weight_decay)
    scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                  factor=config.lr_reduce_factor,
                                  patience=config.lr_reduce_patience,
                                  mode='max',
                                  min_lr=1e-7)

    # Load the last global_step from the checkpoint if existing
    global_step = 0 if state is None else state['global_step'] + 1

    class_weights = util.to_device(torch.FloatTensor(config.loss_weights),
                                   gpu=use_gpu)
    loss_fn = CrossEntropyLoss()

    # Reset the best metric score
    best_score = -1

    # Training
    for epoch in range(config.epochs):
        log.info("Started epoch {}/{}".format(epoch + 1, config.epochs))
        for data in train_loader:
            imgs, labels = data
            imgs = util.to_device(imgs, gpu=use_gpu)
            labels = util.to_device(labels, gpu=use_gpu)

            logits = model(imgs)
            loss = loss_fn(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if global_step % config.log_steps == 0 and global_step > 0:
                probs = model.module.probability(logits)
                preds = torch.argmax(probs, dim=1).detach().cpu().numpy()
                labels = labels.cpu().detach().numpy()
                acc, f1, _, _ = util.clf_metrics(preds, labels)
                lr = util.get_learning_rate(optimizer)

                log.info("Step {} | TRAINING batch: Loss {:.4f} | F1 {:.4f} | "
                         "Accuracy {:.4f} | LR {:.2e}".format(
                             global_step, loss.item(), f1, acc, lr))

            if global_step % config.eval_steps == 0 and global_step > 0:
                best_score = validate(val_loader,
                                      model,
                                      best_score=best_score,
                                      global_step=global_step,
                                      cfg=config)
                scheduler.step(best_score)
            global_step += 1
class Optimizer(nn.Module):
    def __init__(self, model):
        super(Optimizer, self).__init__()
        self.setup_optimizer(model)

    def setup_optimizer(self, model):
        params = []
        for key, value in model.named_parameters():
            if not value.requires_grad:
                continue
            lr = cfg.SOLVER.BASE_LR
            weight_decay = cfg.SOLVER.WEIGHT_DECAY
            if "bias" in key:
                lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
                weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
            params += [{
                "params": [value],
                "lr": lr,
                "weight_decay": weight_decay
            }]

        if cfg.SOLVER.TYPE == 'SGD':
            self.optimizer = torch.optim.SGD(params,
                                             lr=cfg.SOLVER.BASE_LR,
                                             momentum=cfg.SOLVER.SGD.MOMENTUM)
        elif cfg.SOLVER.TYPE == 'ADAM':
            self.optimizer = torch.optim.Adam(params,
                                              lr=cfg.SOLVER.BASE_LR,
                                              betas=cfg.SOLVER.ADAM.BETAS,
                                              eps=cfg.SOLVER.ADAM.EPS)
        elif cfg.SOLVER.TYPE == 'ADAMAX':
            self.optimizer = torch.optim.Adamax(params,
                                                lr=cfg.SOLVER.BASE_LR,
                                                betas=cfg.SOLVER.ADAM.BETAS,
                                                eps=cfg.SOLVER.ADAM.EPS)
        elif cfg.SOLVER.TYPE == 'ADAGRAD':
            self.optimizer = torch.optim.Adagrad(params, lr=cfg.SOLVER.BASE_LR)
        elif cfg.SOLVER.TYPE == 'RMSPROP':
            self.optimizer = torch.optim.RMSprop(params, lr=cfg.SOLVER.BASE_LR)
        elif cfg.SOLVER.TYPE == 'RADAM':
            self.optimizer = RAdam(params,
                                   lr=cfg.SOLVER.BASE_LR,
                                   betas=cfg.SOLVER.ADAM.BETAS,
                                   eps=cfg.SOLVER.ADAM.EPS)
        else:
            raise NotImplementedError

        if cfg.SOLVER.LR_POLICY.TYPE == 'Fix':
            self.scheduler = None
        elif cfg.SOLVER.LR_POLICY.TYPE == 'Step':
            self.scheduler = torch.optim.lr_scheduler.StepLR(
                self.optimizer,
                step_size=cfg.SOLVER.LR_POLICY.STEP_SIZE,
                gamma=cfg.SOLVER.LR_POLICY.GAMMA)
        elif cfg.SOLVER.LR_POLICY.TYPE == 'Plateau':
            self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer,
                factor=cfg.SOLVER.LR_POLICY.PLATEAU_FACTOR,
                patience=cfg.SOLVER.LR_POLICY.PLATEAU_PATIENCE)
        elif cfg.SOLVER.LR_POLICY.TYPE == 'Noam':
            self.scheduler = lr_scheduler.create(
                'Noam',
                self.optimizer,
                model_size=cfg.SOLVER.LR_POLICY.MODEL_SIZE,
                factor=cfg.SOLVER.LR_POLICY.FACTOR,
                warmup=cfg.SOLVER.LR_POLICY.WARMUP)
        elif cfg.SOLVER.LR_POLICY.TYPE == 'MultiStep':
            self.scheduler = lr_scheduler.create(
                'MultiStep',
                self.optimizer,
                milestones=cfg.SOLVER.LR_POLICY.STEPS,
                gamma=cfg.SOLVER.LR_POLICY.GAMMA)
        else:
            raise NotImplementedError

    def zero_grad(self):
        self.optimizer.zero_grad()

    def step(self):
        self.optimizer.step()

    def scheduler_step(self, lrs_type, val=None):
        if self.scheduler is None:
            return

        if cfg.SOLVER.LR_POLICY.TYPE != 'Plateau':
            val = None

        if lrs_type == cfg.SOLVER.LR_POLICY.SETP_TYPE:
            self.scheduler.step(val)

    def get_lr(self):
        lr = []
        for param_group in self.optimizer.param_groups:
            lr.append(param_group['lr'])
        lr = sorted(list(set(lr)))
        return lr