Exemplo n.º 1
0
    def train(self):
        device = self.device
        print('Running on device: {}'.format(device), 'start training...')
        print(
            f'Setting - Epochs: {self.num_epochs}, Learning rate: {self.learning_rate} '
        )

        train_loader = self.train_loader
        valid_loader = self.valid_loader

        model = self.model.to(device)
        if self.optimizer == 0:
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=self.learning_rate,
                                         weight_decay=1e-5)
        elif self.optimizer == 1:
            optimizer = torch.optim.AdamW(model.parameters(),
                                          lr=self.learning_rate,
                                          weight_decay=1e-5)
        elif self.optimizer == 2:
            optimizer = MADGRAD(model.parameters(),
                                lr=self.learning_rate,
                                weight_decay=1e-5)
        elif self.optimizer == 3:
            optimizer = AdamP(model.parameters(),
                              lr=self.learning_rate,
                              weight_decay=1e-5)
        criterion = torch.nn.CrossEntropyLoss().to(device)

        if self.use_swa:
            optimizer = SWA(optimizer, swa_start=2, swa_freq=2, swa_lr=1e-5)

        # scheduler #
        scheduler_dct = {
            0:
            None,
            1:
            StepLR(optimizer, 10, gamma=0.5),
            2:
            ReduceLROnPlateau(optimizer,
                              'min',
                              factor=0.4,
                              patience=int(0.3 *
                                           self.early_stopping_patience)),
            3:
            CosineAnnealingLR(optimizer, T_max=5, eta_min=0.)
        }
        scheduler = scheduler_dct[self.scheduler]

        # early stopping
        early_stopping = EarlyStopping(patience=self.early_stopping_patience,
                                       verbose=True,
                                       path=f'checkpoint_{self.job}.pt')

        # training
        self.train_loss_lst = list()
        self.train_acc_lst = list()
        self.val_loss_lst = list()
        self.val_acc_lst = list()
        for epoch in range(1, self.num_epochs + 1):
            with tqdm(train_loader, unit='batch') as tepoch:
                avg_val_loss, avg_val_acc = None, None

                for idx, (img, label) in enumerate(tepoch):
                    tepoch.set_description(f"Epoch {epoch}")

                    model.train()
                    optimizer.zero_grad()

                    img, label = img.float().to(device), label.long().to(
                        device)

                    output = model(img)
                    loss = criterion(output, label)
                    predictions = output.argmax(dim=1, keepdim=True).squeeze()
                    correct = (predictions == label).sum().item()
                    accuracy = correct / len(img)

                    loss.backward()
                    optimizer.step()

                    if idx == len(train_loader) - 1:

                        val_loss_lst, val_acc_lst = list(), list()

                        model.eval()
                        with torch.no_grad():
                            for val_img, val_label in valid_loader:
                                val_img, val_label = val_img.float().to(
                                    device), val_label.long().to(device)

                                val_out = model(val_img)
                                val_loss = criterion(val_out, val_label)
                                val_pred = val_out.argmax(
                                    dim=1, keepdim=True).squeeze()
                                val_acc = (val_pred == val_label
                                           ).sum().item() / len(val_img)

                                val_loss_lst.append(val_loss.item())
                                val_acc_lst.append(val_acc)

                        avg_val_loss = np.mean(val_loss_lst)
                        avg_val_acc = np.mean(val_acc_lst) * 100.

                        self.train_loss_lst.append(loss)
                        self.train_acc_lst.append(accuracy)
                        self.val_loss_lst.append(avg_val_loss)
                        self.val_acc_lst.append(avg_val_acc)

                    if scheduler is not None:
                        current_lr = optimizer.param_groups[0]['lr']
                    else:
                        current_lr = self.learning_rate

                    # log
                    tepoch.set_postfix(loss=loss.item(),
                                       accuracy=100. * accuracy,
                                       val_loss=avg_val_loss,
                                       val_acc=avg_val_acc,
                                       current_lr=current_lr)

                # early stopping check
                early_stopping(avg_val_loss, model)
                if early_stopping.early_stop:
                    print("Early stopping")
                    break

                # scheduler update
                if scheduler is not None:
                    if self.scheduler == 2:
                        scheduler.step(avg_val_loss)
                    else:
                        scheduler.step()
        if self.use_swa:
            optimizer.swap_swa_sgd()
        self.model.load_state_dict(torch.load(f'checkpoint_{self.job}.pt'))
Exemplo n.º 2
0
def train(cfg, train_loader, val_loader, val_labels, k):
    # Set Config
    MODEL_ARC = cfg.values.model_arc
    OUTPUT_DIR = cfg.values.output_dir
    NUM_CLASSES = cfg.values.num_classes

    SAVE_PATH = os.path.join(OUTPUT_DIR, MODEL_ARC)

    best_score = 0.

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    os.makedirs(SAVE_PATH, exist_ok=True)

    if k > 0:
        os.makedirs(SAVE_PATH + f'/{k}_fold', exist_ok=True)

    num_epochs = cfg.values.train_args.num_epochs
    max_lr = cfg.values.train_args.max_lr
    min_lr = cfg.values.train_args.min_lr
    weight_decay = cfg.values.train_args.weight_decay
    log_intervals = cfg.values.train_args.log_intervals

    model = CNNModel(model_arc=MODEL_ARC, num_classes=NUM_CLASSES)
    model.to(device)

    # base_optimizer = SGDP
    # optimizer = SAM(model.parameters(), base_optimizer, lr=max_lr, momentum=momentum)
    optimizer = MADGRAD(model.parameters(),
                        lr=max_lr,
                        weight_decay=weight_decay)
    first_cycle_steps = len(train_loader) * num_epochs // 2

    scheduler = CosineAnnealingWarmupRestarts(
        optimizer,
        first_cycle_steps=first_cycle_steps,
        cycle_mult=1.0,
        max_lr=max_lr,
        min_lr=min_lr,
        warmup_steps=int(first_cycle_steps * 0.2),
        gamma=0.5)

    criterion = nn.BCEWithLogitsLoss()

    wandb.watch(model)

    for epoch in range(num_epochs):
        model.train()

        loss_values = AverageMeter()

        scaler = GradScaler()

        for step, (images, labels) in enumerate(tqdm(train_loader)):
            images = images.to(device)
            labels = labels.to(device)
            batch_size = labels.size(0)

            with autocast():
                logits = model(images)
                loss = criterion(logits.view(-1), labels)

            loss_values.update(loss.item(), batch_size)

            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            wandb.log({
                'Learning rate': get_learning_rate(optimizer)[0],
                'Train Loss': loss_values.val
            })

            if step % log_intervals == 0:
                tqdm.write(
                    f'Epoch : [{epoch + 1}/{num_epochs}][{step}/{len(train_loader)}] || '
                    f'LR : {get_learning_rate(optimizer)[0]:.6e} || '
                    f'Train Loss : {loss_values.val:.4f} ({loss_values.avg:.4f}) ||'
                )

        with torch.no_grad():
            model.eval()

            loss_values = AverageMeter()
            preds = []

            for step, (images, labels) in enumerate(tqdm(val_loader)):
                images = images.to(device)
                labels = labels.to(device)
                batch_size = labels.size(0)

                logits = model(images)
                loss = criterion(logits.view(-1), labels)

                preds.append(logits.sigmoid().to('cpu').numpy())

                loss_values.update(loss.item(), batch_size)

        predictions = np.concatenate(preds)

        # f1, roc_auc = get_score(val_labels, predictions)
        roc_auc = get_score(val_labels, predictions)
        is_best = roc_auc >= best_score
        best_score = max(roc_auc, best_score)

        if is_best:
            if k > 0:
                remove_all_file(SAVE_PATH + f'/{k}_fold')
                print(
                    f"Save checkpoints {SAVE_PATH + f'/{k}_fold/{epoch + 1}_epoch_{best_score * 100.0:.2f}%.pth'}..."
                )
                torch.save(
                    model.state_dict(), SAVE_PATH +
                    f'/{k}_fold/{epoch + 1}_epoch_{best_score * 100.0:.2f}%.pth'
                )
            else:
                remove_all_file(SAVE_PATH)
                print(
                    f"Save checkpoints {SAVE_PATH + f'/{epoch + 1}_epoch_{best_score * 100.0:.2f}%.pth'}..."
                )
                torch.save(
                    model.state_dict(), SAVE_PATH +
                    f'/{epoch + 1}_epoch_{best_score * 100.0:.2f}%.pth')

        wandb.log({
            'Validation Loss average': loss_values.avg,
            'ROC AUC Score': roc_auc,
            # 'F1 Score' : f1
        })

        tqdm.write(f'Epoch : [{epoch + 1}/{num_epochs}] || '
                   f'Val Loss : {loss_values.avg:.4f} || '
                   f'ROC AUC score : {roc_auc:.4f} ||')