Beispiel #1
0
 def __init__(self, dice_weight=0.5, focal_weight=0.5):
     super(FocalDiceLoss, self).__init__()
     self.dice_weight = dice_weight
     self.focal_weight = focal_weight
     self.dice = DiceLoss(mode='multiclass')
     #self.focal = FocalLoss(mode="multiclass")
     self.focal = SoftCrossEntropyLoss(smooth_factor=0.1)
Beispiel #2
0
def test_soft_ce_loss():
    criterion = SoftCrossEntropyLoss(smooth_factor=0.1, ignore_index=-100)

    y_pred = torch.tensor([[+9, -9, -9, -9], [-9, +9, -9, -9],
                           [-9, -9, +9, -9], [-9, -9, -9, +9]]).float()
    y_true = torch.tensor([0, 1, -100, 3]).long()

    loss = criterion(y_pred, y_true)
    assert float(loss) == pytest.approx(1.0125, abs=0.0001)
Beispiel #3
0
def test_soft_ce_loss():
    criterion = SoftCrossEntropyLoss(smooth_factor=0.1, ignore_index=-100)

    # Ideal case
    y_pred = torch.tensor([[+9, -9, -9, -9], [-9, +9, -9, -9], [-9, -9, +9, -9], [-9, -9, -9, +9]]).float()
    y_true = torch.tensor([0, 1, -100, 3]).long()

    loss = criterion(y_pred, y_true)
    print(loss)
Beispiel #4
0
def train(num_epochs, model, data_loader, val_loader, val_every, device, file_name):
    learning_rate = 0.0001
    from torch.optim.swa_utils import AveragedModel, SWALR
    from torch.optim.lr_scheduler import CosineAnnealingLR
    from segmentation_models_pytorch.losses import SoftCrossEntropyLoss, JaccardLoss
    from adamp import AdamP

    criterion = [SoftCrossEntropyLoss(smooth_factor=0.1), JaccardLoss('multiclass', classes=12)]
    optimizer = AdamP(params=model.parameters(), lr=learning_rate, weight_decay=1e-6)
    swa_scheduler = SWALR(optimizer, swa_lr=learning_rate)
    swa_model = AveragedModel(model)
    look = Lookahead(optimizer, la_alpha=0.5)

    print('Start training..')
    best_miou = 0
    for epoch in range(num_epochs):
        hist = np.zeros((12, 12))
        model.train()
        for step, (images, masks, _) in enumerate(data_loader):
            loss = 0
            images = torch.stack(images)  # (batch, channel, height, width)
            masks = torch.stack(masks).long()  # (batch, channel, height, width)

            # gpu 연산을 위해 device 할당
            images, masks = images.to(device), masks.to(device)

            # inference
            outputs = model(images)
            for i in criterion:
                loss += i(outputs, masks)
            # loss 계산 (cross entropy loss)

            look.zero_grad()
            loss.backward()
            look.step()

            outputs = torch.argmax(outputs.squeeze(), dim=1).detach().cpu().numpy()
            hist = add_hist(hist, masks.detach().cpu().numpy(), outputs, n_class=12)
            acc, acc_cls, mIoU, fwavacc = label_accuracy_score(hist)
            # step 주기에 따른 loss, mIoU 출력
            if (step + 1) % 25 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, mIoU: {:.4f}'.format(
                    epoch + 1, num_epochs, step + 1, len(data_loader), loss.item(), mIoU))

        # validation 주기에 따른 loss 출력 및 best model 저장
        if (epoch + 1) % val_every == 0:
            avrg_loss, val_miou = validation(epoch + 1, model, val_loader, criterion, device)
            if val_miou > best_miou:
                print('Best performance at epoch: {}'.format(epoch + 1))
                print('Save model in', saved_dir)
                best_miou = val_miou
                save_model(model, file_name = file_name)

        if epoch > 3:
            swa_model.update_parameters(model)
            swa_scheduler.step()
Beispiel #5
0
]

model = smp.UnetPlusPlus(encoder_name="efficientnet-b6",
                         encoder_weights='imagenet',
                         in_channels=3,
                         classes=10)
model.train()
model.to(DEVICE)

# optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)

# loss
# DiceLoss, JaccardLoss, SoftBCEWithLogitsLoss, SoftCrossEntropyLoss
DiceLoss_fn = DiceLoss(mode='multiclass')
SoftCrossEntropy_fn = SoftCrossEntropyLoss(smooth_factor=0.1)
loss_fn = L.JointLoss(first=DiceLoss_fn,
                      second=SoftCrossEntropy_fn,
                      first_weight=0.5,
                      second_weight=0.5).to(DEVICE)

best_iou = 0

for epoch in (range(1, EPOCHES + 1)):
    losses = []
    start_time = time.time()

    for image, target in tqdm(loader):
        image, target = image.to(DEVICE), target.to(DEVICE)
        optimizer.zero_grad()
        output = model(image)
def train_net(param, model, train_data, valid_data, plot=False,device='cuda'):
    # 初始化参数
    model_name      = param['model_name']
    epochs          = param['epochs']
    batch_size      = param['batch_size']
    lr              = param['lr']
    gamma           = param['gamma']
    step_size       = param['step_size']
    momentum        = param['momentum']
    weight_decay    = param['weight_decay']

    disp_inter      = param['disp_inter']
    save_inter      = param['save_inter']
    min_inter       = param['min_inter']
    iter_inter      = param['iter_inter']

    save_log_dir    = param['save_log_dir']
    save_ckpt_dir   = param['save_ckpt_dir']
    load_ckpt_dir   = param['load_ckpt_dir']

    #
    scaler = GradScaler() 

    # 网络参数
    train_data_size = train_data.__len__()
    valid_data_size = valid_data.__len__()
    c, y, x = train_data.__getitem__(0)['image'].shape
    train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=1)
    valid_loader = DataLoader(dataset=valid_data, batch_size=batch_size, shuffle=False, num_workers=1)
    optimizer = optim.AdamW(model.parameters(), lr=3e-4 ,weight_decay=weight_decay)
    #optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=momentum, weight_decay=weight_decay)
    #scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=3, T_mult=2, eta_min=1e-5, last_epoch=-1)
    #criterion = nn.CrossEntropyLoss(reduction='mean').to(device)
    DiceLoss_fn=DiceLoss(mode='multiclass')
    SoftCrossEntropy_fn=SoftCrossEntropyLoss(smooth_factor=0.1)
    criterion = L.JointLoss(first=DiceLoss_fn, second=SoftCrossEntropy_fn,
                              first_weight=0.5, second_weight=0.5).cuda()
    logger = inial_logger(os.path.join(save_log_dir, time.strftime("%m-%d %H:%M:%S", time.localtime()) +'_'+model_name+ '.log'))

    # 主循环
    train_loss_total_epochs, valid_loss_total_epochs, epoch_lr = [], [], []
    train_loader_size = train_loader.__len__()
    valid_loader_size = valid_loader.__len__()
    best_iou = 0
    best_epoch=0
    best_mode = copy.deepcopy(model)
    epoch_start = 0
    if load_ckpt_dir is not None:
        ckpt = torch.load(load_ckpt_dir)
        epoch_start = ckpt['epoch']
        model.load_state_dict(ckpt['state_dict'])
        optimizer.load_state_dict(ckpt['optimizer'])

    logger.info('Total Epoch:{} Image_size:({}, {}) Training num:{}  Validation num:{}'.format(epochs, x, y, train_data_size, valid_data_size))
    #
    for epoch in range(epoch_start, epochs):
        epoch_start = time.time()
        # 训练阶段
        model.train()
        train_epoch_loss = AverageMeter()
        train_iter_loss = AverageMeter()
        for batch_idx, batch_samples in enumerate(train_loader):
            data, target = batch_samples['image'], batch_samples['label']
            data, target = Variable(data.to(device)), Variable(target.to(device))
            with autocast(): #need pytorch>1.6
                pred = model(data)
                loss = criterion(pred, target)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            scheduler.step(epoch + batch_idx / train_loader_size) 
            image_loss = loss.item()
            train_epoch_loss.update(image_loss)
            train_iter_loss.update(image_loss)
            if batch_idx % iter_inter == 0:
                spend_time = time.time() - epoch_start
                logger.info('[train] epoch:{} iter:{}/{} {:.2f}% lr:{:.6f} loss:{:.6f} ETA:{}min'.format(
                    epoch, batch_idx, train_loader_size, batch_idx/train_loader_size*100,
                    optimizer.param_groups[-1]['lr'],
                    train_iter_loss.avg,spend_time / (batch_idx+1) * train_loader_size // 60 - spend_time // 60))
                train_iter_loss.reset()

        # 验证阶段
        model.eval()
        valid_epoch_loss = AverageMeter()
        valid_iter_loss = AverageMeter()
        iou=IOUMetric(10)
        with torch.no_grad():
            for batch_idx, batch_samples in enumerate(valid_loader):
                data, target = batch_samples['image'], batch_samples['label']
                data, target = Variable(data.to(device)), Variable(target.to(device))
                pred = model(data)
                loss = criterion(pred, target)
                pred=pred.cpu().data.numpy()
                pred= np.argmax(pred,axis=1)
                iou.add_batch(pred,target.cpu().data.numpy())
                #
                image_loss = loss.item()
                valid_epoch_loss.update(image_loss)
                valid_iter_loss.update(image_loss)
                # if batch_idx % iter_inter == 0:
                #     logger.info('[val] epoch:{} iter:{}/{} {:.2f}% loss:{:.6f}'.format(
                #         epoch, batch_idx, valid_loader_size, batch_idx / valid_loader_size * 100, valid_iter_loss.avg))
            val_loss=valid_iter_loss.avg
            acc, acc_cls, iu, mean_iu, fwavacc=iou.evaluate()
            logger.info('[val] epoch:{} miou:{:.2f}'.format(epoch,mean_iu))
                

        # 保存loss、lr
        train_loss_total_epochs.append(train_epoch_loss.avg)
        valid_loss_total_epochs.append(valid_epoch_loss.avg)
        epoch_lr.append(optimizer.param_groups[0]['lr'])
        # 保存模型
        if epoch % save_inter == 0 and epoch > min_inter:
            state = {'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}
            filename = os.path.join(save_ckpt_dir, 'checkpoint-epoch{}.pth'.format(epoch))
            torch.save(state, filename)  # pytorch1.6会压缩模型,低版本无法加载
        # 保存最优模型
        if mean_iu > best_iou:  # train_loss_per_epoch valid_loss_per_epoch
            state = {'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}
            filename = os.path.join(save_ckpt_dir, 'checkpoint-best.pth')
            torch.save(state, filename)
            best_iou = mean_iu
            best_mode = copy.deepcopy(model)
            logger.info('[save] Best Model saved at epoch:{} ============================='.format(epoch))
        #scheduler.step()
        # 显示loss
    # 训练loss曲线
    if plot:
        x = [i for i in range(epochs)]
        fig = plt.figure(figsize=(12, 4))
        ax = fig.add_subplot(1, 2, 1)
        ax.plot(x, smooth(train_loss_total_epochs, 0.6), label='train loss')
        ax.plot(x, smooth(valid_loss_total_epochs, 0.6), label='val loss')
        ax.set_xlabel('Epoch', fontsize=15)
        ax.set_ylabel('CrossEntropy', fontsize=15)
        ax.set_title('train curve', fontsize=15)
        ax.grid(True)
        plt.legend(loc='upper right', fontsize=15)
        ax = fig.add_subplot(1, 2, 2)
        ax.plot(x, epoch_lr,  label='Learning Rate')
        ax.set_xlabel('Epoch', fontsize=15)
        ax.set_ylabel('Learning Rate', fontsize=15)
        ax.set_title('lr curve', fontsize=15)
        ax.grid(True)
        plt.legend(loc='upper right', fontsize=15)
        plt.show()
            
    return best_mode, model
Beispiel #7
0
def pseudo_labeling(num_epochs, model, data_loader, val_loader,
                    unlabeled_loader, device, val_every, file_name):
    # Instead of using current epoch we use a "step" variable to calculate alpha_weight
    # This helps the model converge faster
    from torch.optim.swa_utils import AveragedModel, SWALR
    from segmentation_models_pytorch.losses import SoftCrossEntropyLoss, JaccardLoss
    from adamp import AdamP

    criterion = [
        SoftCrossEntropyLoss(smooth_factor=0.1),
        JaccardLoss('multiclass', classes=12)
    ]
    optimizer = AdamP(params=model.parameters(), lr=0.0001, weight_decay=1e-6)
    swa_scheduler = SWALR(optimizer, swa_lr=0.0001)
    swa_model = AveragedModel(model)
    optimizer = Lookahead(optimizer, la_alpha=0.5)

    step = 100
    size = 256
    best_mIoU = 0
    model.train()
    print('Start Pseudo-Labeling..')
    for epoch in range(num_epochs):
        hist = np.zeros((12, 12))
        for batch_idx, (imgs, image_infos) in enumerate(unlabeled_loader):

            # Forward Pass to get the pseudo labels
            # --------------------------------------------- test(unlabelse)를 모델에 통과
            model.eval()
            outs = model(torch.stack(imgs).to(device))
            oms = torch.argmax(outs.squeeze(), dim=1).detach().cpu().numpy()
            oms = torch.Tensor(oms)
            oms = oms.long()
            oms = oms.to(device)

            # --------------------------------------------- 학습

            model.train()
            # Now calculate the unlabeled loss using the pseudo label
            imgs = torch.stack(imgs)
            imgs = imgs.to(device)
            # preds_array = preds_array.to(device)

            output = model(imgs)
            loss = 0
            for each in criterion:
                loss += each(output, oms)

            unlabeled_loss = alpha_weight(step) * loss

            # Backpropogate
            optimizer.zero_grad()
            unlabeled_loss.backward()
            optimizer.step()
            output = torch.argmax(output.squeeze(),
                                  dim=1).detach().cpu().numpy()
            hist = add_hist(hist,
                            oms.detach().cpu().numpy(),
                            output,
                            n_class=12)

            if (batch_idx + 1) % 25 == 0:
                acc, acc_cls, mIoU, fwavacc = label_accuracy_score(hist)
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, mIoU:{:.4f}'.
                      format(epoch + 1, num_epochs, batch_idx + 1,
                             len(unlabeled_loader), unlabeled_loss.item(),
                             mIoU))
            # For every 50 batches train one epoch on labeled data
            # 50배치마다 라벨데이터를 1 epoch학습
            if batch_idx % 50 == 0:

                # Normal training procedure
                for batch_idx, (images, masks, _) in enumerate(data_loader):
                    labeled_loss = 0
                    images = torch.stack(images)
                    # (batch, channel, height, width)
                    masks = torch.stack(masks).long()

                    # gpu 연산을 위해 device 할당
                    images, masks = images.to(device), masks.to(device)

                    output = model(images)

                    for each in criterion:
                        labeled_loss += each(output, masks)

                    optimizer.zero_grad()
                    labeled_loss.backward()
                    optimizer.step()

                # Now we increment step by 1
                step += 1

        if (epoch + 1) % val_every == 0:
            avrg_loss, val_mIoU = validation(epoch + 1, model, val_loader,
                                             criterion, device)
            if val_mIoU > best_mIoU:
                print('Best performance at epoch: {}'.format(epoch + 1))
                print('Save model in', saved_dir)
                best_mIoU = val_mIoU
                save_model(model, file_name=file_name)

        model.train()

        if epoch > 3:
            swa_model.update_parameters(model)
            swa_scheduler.step()
Beispiel #8
0
def train(config, model, train_data, valid_data, plot=False, device='cuda'):
    """

    :config config:
    :config model:
    :config train_data:
    :config valid_data:
    :config plot:
    :config device:
    :return:
    """
    # 初始化参数
    model_name = config['model_name']
    epochs = config['epochs']
    batch_size = config['batch_size']

    class_weights = config['class_weights']
    disp_inter = config['disp_inter']
    save_inter = config['save_inter']
    min_inter = config['min_inter']
    iter_inter = config['iter_inter']

    save_log_dir = config['save_log_dir']
    save_ckpt_dir = config['save_ckpt_dir']
    load_ckpt_dir = config['load_ckpt_dir']
    accumulation_steps = config['accumulation_steps']
    # automatic mixed precision
    scaler = GradScaler()
    # 网络参数
    train_data_size = train_data.__len__()
    valid_data_size = valid_data.__len__()
    c, y, x = train_data.__getitem__(0)['image'].shape
    train_loader = DataLoader(dataset=train_data,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=1)
    valid_loader = DataLoader(dataset=valid_data,
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=1)
    #
    if config['optimizer'].lower() == 'adamw':
        optimizer = optim.AdamW(model.parameters(),
                                lr=config['lr'],
                                weight_decay=config['weight_decay'])
    elif config['optimizer'].lower() == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=config['lr'],
                              momentum=config['momentum'],
                              weight_decay=config['weight_decay'])
    # SWA
    if config['swa']:
        swa_opt = SWA(optimizer,
                      swa_start=config['swa_start'],
                      swa_freq=config['swa_freq'],
                      swa_lr=config['swa_lr'])

    # warm_up_with_multistep_lr = lambda \
    #     epoch: epoch / warmup_epochs if epoch <= warmup_epochs else gamma ** len(
    #     [m for m in milestones if m <= epoch])
    # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_up_with_multistep_lr)
    if config['scheduler'] == 'CosineAnnealingLR':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=config['epochs'], eta_min=config['min_lr'])
    elif config['scheduler'] == 'CosineAnnealingWarmRestarts':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=3, T_mult=2, eta_min=1e-5, last_epoch=-1)
    elif config['scheduler'] == 'ReduceLROnPlateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            factor=config['factor'],
            patience=config['patience'],
            verbose=1,
            min_lr=config['min_lr'])
    elif config['scheduler'] == 'MultiStepLR':
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[int(e) for e in config['milestones'].split(',')],
            gamma=config['gamma'])
    elif config['scheduler'] == 'NoamLR':
        scheduler = NoamLR(optimizer, warmup_steps=config['warmup_steps'])
    else:
        raise NotImplementedError

    # criterion = nn.CrossEntropyLoss(reduction='mean').to(device)
    DiceLoss_fn = DiceLoss(mode='multiclass')
    SoftCrossEntropy_fn = SoftCrossEntropyLoss(smooth_factor=0.1)
    CrossEntropyLoss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)
    Lovasz_fn = LovaszLoss(mode='multiclass')
    criterion = L.JointLoss(first=DiceLoss_fn,
                            second=CrossEntropyLoss_fn,
                            first_weight=0.5,
                            second_weight=0.5).cuda()

    logger = initial_logger(
        os.path.join(
            save_log_dir,
            time.strftime("%m-%d %H:%M:%S", time.localtime()) + '_' +
            model_name + '.log'))

    # 主循环
    train_loss_total_epochs, valid_loss_total_epochs, epoch_lr = [], [], []
    train_loader_size = len(train_loader)
    valid_loader_size = len(valid_loader)
    best_iou = 0
    best_epoch = 0
    best_mode = copy.deepcopy(model)
    start_epoch = 0
    if load_ckpt_dir is not None:
        ckpt = torch.load(load_ckpt_dir)
        start_epoch = ckpt['epoch']
        model.load_state_dict(ckpt['state_dict'])
        optimizer.load_state_dict(ckpt['optimizer'])

    logger.info(
        'Total Epoch:{} Image_size:({}, {}) Training num:{}  Validation num:{}'
        .format(epochs, x, y, train_data_size, valid_data_size))
    # execute train
    for epoch in range(start_epoch, epochs):
        start_time = time.time()
        # 训练阶段
        model.train()
        train_epoch_loss = AverageMeter()
        train_iter_loss = AverageMeter()
        for batch_idx, batch_samples in enumerate(train_loader):
            data, target = batch_samples['image'], batch_samples['label']
            data, target = Variable(data.to(device)), Variable(
                target.to(device))
            with autocast():  # need pytorch>1.6
                pred = model(data)
                loss = criterion(pred, target)
                # 2.1 loss regularization
                regular_loss = loss / accumulation_steps
                # 2.2 back propagation
                scaler.scale(regular_loss).backward()
                # 2.3 update parameters of net
                if (batch_idx + 1) % accumulation_steps == 0:
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
                    scheduler.step(epoch + batch_idx / train_loader_size)
            #     scaler.scale(loss).backward()
            #     scaler.step(optimizer)
            #     scaler.update()
            #     optimizer.zero_grad()
            # scheduler.step(epoch + batch_idx / train_loader_size)
            image_loss = loss.item()
            train_epoch_loss.update(image_loss)
            train_iter_loss.update(image_loss)
            if batch_idx % iter_inter == 0:
                spend_time = time.time() - start_time
                logger.info(
                    '[train] epoch:{} iter:{}/{} {:.2f}% lr:{:.6f} loss:{:.6f} ETA:{}min'
                    .format(
                        epoch, batch_idx, train_loader_size,
                        batch_idx / train_loader_size * 100,
                        optimizer.param_groups[-1]['lr'], train_iter_loss.avg,
                        spend_time /
                        (batch_idx + 1) * train_loader_size // 60 -
                        spend_time // 60))
                train_iter_loss.reset()

        # validation
        valid_epoch_loss, mean_iou = eval(model, valid_loader, criterion,
                                          epoch, logger)
        # save loss and lr
        train_loss_total_epochs.append(train_epoch_loss.avg)
        valid_loss_total_epochs.append(valid_epoch_loss.avg)
        epoch_lr.append(optimizer.param_groups[0]['lr'])

        # save checkpoint
        if (epoch + 1) % save_inter == 0 and epoch > min_inter:
            state = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            filename = os.path.join(save_ckpt_dir,
                                    'checkpoint-epoch{}.pth'.format(epoch))
            torch.save(state, filename)  # pytorch1.6会压缩模型,低版本无法加载

        # save best model
        if mean_iou > best_iou:  # train_loss_per_epoch valid_loss_per_epoch
            state = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            filename = os.path.join(save_ckpt_dir, 'checkpoint-best.pth')
            torch.save(state, filename)
            best_iou = mean_iou
            best_mode = copy.deepcopy(model)
            logger.info(
                '[save] Best Model saved at epoch:{} ============================='
                .format(epoch))
        # scheduler.step()

    # show loss curve
    if plot:
        x = [i for i in range(epochs)]
        fig = plt.figure(figsize=(12, 4))
        ax = fig.add_subplot(1, 2, 1)
        ax.plot(x, smooth(train_loss_total_epochs, 0.6), label='train loss')
        ax.plot(x, smooth(valid_loss_total_epochs, 0.6), label='val loss')
        ax.set_xlabel('Epoch', fontsize=15)
        ax.set_ylabel('CrossEntropy', fontsize=15)
        ax.set_title('train curve', fontsize=15)
        ax.grid(True)
        plt.legend(loc='upper right', fontsize=15)
        ax = fig.add_subplot(1, 2, 2)
        ax.plot(x, epoch_lr, label='Learning Rate')
        ax.set_xlabel('Epoch', fontsize=15)
        ax.set_ylabel('Learning Rate', fontsize=15)
        ax.set_title('lr curve', fontsize=15)
        ax.grid(True)
        plt.legend(loc='upper right', fontsize=15)
        plt.show()

    return best_mode, model
Beispiel #9
0
def train(EPOCHES, BATCH_SIZE, train_image_paths, train_label_paths,
          val_image_paths, val_label_paths, channels, optimizer_name,
          model_path, swa_model_path, addNDVI, loss, early_stop):

    train_loader = get_dataloader(train_image_paths,
                                  train_label_paths,
                                  "train",
                                  addNDVI,
                                  BATCH_SIZE,
                                  shuffle=True,
                                  num_workers=8)
    valid_loader = get_dataloader(val_image_paths,
                                  val_label_paths,
                                  "val",
                                  addNDVI,
                                  BATCH_SIZE,
                                  shuffle=False,
                                  num_workers=8)

    # 定义模型,优化器,损失函数
    # model = smp.UnetPlusPlus(
    #         encoder_name="efficientnet-b7",
    #         encoder_weights="imagenet",
    #         in_channels=channels,
    #         classes=10,
    # )
    model = smp.UnetPlusPlus(
        encoder_name="timm-resnest101e",
        encoder_weights="imagenet",
        in_channels=channels,
        classes=10,
    )
    # model = seg_hrnet_ocr.get_seg_model()
    model.to(DEVICE)
    # model.load_state_dict(torch.load(model_path))
    # 采用SGD优化器
    if (optimizer_name == "sgd"):
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=1e-4,
                                    weight_decay=1e-3,
                                    momentum=0.9)
    # 采用AdamM优化器
    else:
        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=1e-4,
                                      weight_decay=1e-3)
    # 余弦退火调整学习率
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=2,  # T_0就是初始restart的epoch数目
        T_mult=2,  # T_mult就是重启之后因子,即每个restart后,T_0 = T_0 * T_mult
        eta_min=1e-5  # 最低学习率
    )
    # # 使用SWA的初始epoch
    # swa_start = 80
    # # 随机权重平均SWA,以几乎不增加任何成本的方式实现更好的泛化
    # swa_model = AveragedModel(model).to(DEVICE)
    # # SWA调整学习率
    # swa_scheduler = SWALR(optimizer, swa_lr=1e-5)

    if (loss == "SoftCE_dice"):
        # 损失函数采用SoftCrossEntropyLoss+DiceLoss
        # diceloss在一定程度上可以缓解类别不平衡,但是训练容易不稳定
        DiceLoss_fn = DiceLoss(mode='multiclass')
        # 软交叉熵,即使用了标签平滑的交叉熵,会增加泛化性
        SoftCrossEntropy_fn = SoftCrossEntropyLoss(smooth_factor=0.1)
        loss_fn = L.JointLoss(first=DiceLoss_fn,
                              second=SoftCrossEntropy_fn,
                              first_weight=0.5,
                              second_weight=0.5).cuda()
    else:
        # 损失函数采用SoftCrossEntropyLoss+LovaszLoss
        # LovaszLoss是对基于子模块损失凸Lovasz扩展的mIoU损失的直接优化
        LovaszLoss_fn = LovaszLoss(mode='multiclass')
        # 软交叉熵,即使用了标签平滑的交叉熵,会增加泛化性
        SoftCrossEntropy_fn = SoftCrossEntropyLoss(smooth_factor=0.1)
        loss_fn = L.JointLoss(first=LovaszLoss_fn,
                              second=SoftCrossEntropy_fn,
                              first_weight=0.5,
                              second_weight=0.5).cuda()

    header = r'Epoch/EpochNum | TrainLoss | ValidmIoU | Time(m)'
    raw_line = r'{:5d}/{:8d} | {:9.3f} | {:9.3f} | {:9.2f}'
    print(header)

    #    # 在训练最开始之前实例化一个GradScaler对象,使用autocast才需要
    #    scaler = GradScaler()

    # 记录当前验证集最优mIoU,以判定是否保存当前模型
    best_miou = 0
    best_miou_epoch = 0
    train_loss_epochs, val_mIoU_epochs, lr_epochs = [], [], []
    # 开始训练
    for epoch in range(1, EPOCHES + 1):
        # print("Start training the {}st epoch...".format(epoch))
        # 存储训练集每个batch的loss
        losses = []
        start_time = time.time()
        model.train()
        model.to(DEVICE)
        for batch_index, (image, target) in enumerate(train_loader):
            image, target = image.to(DEVICE), target.to(DEVICE)
            # 在反向传播前要手动将梯度清零
            optimizer.zero_grad()
            #            # 使用autocast半精度加速训练,前向过程(model + loss)开启autocast
            #            with autocast(): #need pytorch>1.6
            # 模型推理得到输出
            output = model(image)
            # 求解该batch的loss
            loss = loss_fn(output, target)
            #                scaler.scale(loss).backward()
            #                scaler.step(optimizer)
            #                scaler.update()
            # 反向传播求解梯度
            loss.backward()
            # 更新权重参数
            optimizer.step()
            losses.append(loss.item())
        # if epoch > swa_start:
        #     swa_model.update_parameters(model)
        #     swa_scheduler.step()
        # else:
        # 余弦退火调整学习率
        scheduler.step()
        # 计算验证集IoU
        val_iou = cal_val_iou(model, valid_loader)
        # 输出验证集每类IoU
        # print('\t'.join(np.stack(val_iou).mean(0).round(3).astype(str)))
        # 保存当前epoch的train_loss.val_mIoU.lr_epochs
        train_loss_epochs.append(np.array(losses).mean())
        val_mIoU_epochs.append(np.mean(val_iou))
        lr_epochs.append(optimizer.param_groups[0]['lr'])
        # 输出进程
        print(raw_line.format(epoch, EPOCHES,
                              np.array(losses).mean(), np.mean(val_iou),
                              (time.time() - start_time) / 60**1),
              end="")
        if best_miou < np.stack(val_iou).mean(0).mean():
            best_miou = np.stack(val_iou).mean(0).mean()
            best_miou_epoch = epoch
            torch.save(model.state_dict(), model_path)
            print("  valid mIoU is improved. the model is saved.")
        else:
            print("")
            if (epoch - best_miou_epoch) >= early_stop:
                break
    # # 最后更新BN层参数
    # torch.optim.swa_utils.update_bn(train_loader, swa_model, device= DEVICE)
    # # 计算验证集IoU
    # val_iou = cal_val_iou(model, valid_loader)
    # print("swa_model'mIoU is {}".format(np.mean(val_iou)))
    # torch.save(swa_model.state_dict(), swa_model_path)
    return train_loss_epochs, val_mIoU_epochs, lr_epochs