def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams

        self.fpn = smp.FPN(encoder_name=hparams.encoder_name)

        self.iou = smp.utils.metrics.IoU(activation='sigmoid')
        self.mixed_loss = L.JointLoss(L.BinaryFocalLoss(),
                                      L.BinaryLovaszLoss(), 0.7, 0.3)
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams

        net = getattr(smp, self.hparams.architecture)

        self.net = net(encoder_name=self.hparams.encoder_name, classes=1)

        self.iou = smp.utils.metrics.IoU(activation='sigmoid')

        self.loss = L.JointLoss(L.BinaryFocalLoss(), L.BinaryLovaszLoss(), 0.7,
                                0.3)
def main():
    losses = {
        "bce": BCEWithLogitsLoss(),
        # "focal": L.BinaryFocalLoss(),
        # "jaccard": L.BinaryJaccardLoss(),
        # "jaccard_log": L.BinaryJaccardLogLoss(),
        # "dice": L.BinaryDiceLoss(),
        # "dice_log": L.BinaryDiceLogLoss(),
        # "sdice": L.BinarySymmetricDiceLoss(),
        # "sdice_log": L.BinarySymmetricDiceLoss(log_loss=True),

        "bce+lovasz": L.JointLoss(BCEWithLogitsLoss(), L.BinaryLovaszLoss()),
        # "lovasz": L.BinaryLovaszLoss(),
        # "bce+jaccard": L.JointLoss(BCEWithLogitsLoss(),
        #                            L.BinaryJaccardLoss(), 1, 0.5),

        # "bce+log_jaccard": L.JointLoss(BCEWithLogitsLoss(),
        #                            L.BinaryJaccardLogLoss(), 1, 0.5),

        # "bce+log_dice": L.JointLoss(BCEWithLogitsLoss(),
        #                                L.BinaryDiceLogLoss(), 1, 0.5)

        # "reduced_focal": L.BinaryFocalLoss(reduced=True)
    }

    dx = 0.01
    x_vec = torch.arange(-5, 5, dx).view(-1, 1).expand((-1, 100))

    f, ax = plt.subplots(3, figsize=(16, 16))

    for name, loss in losses.items():
        x_arr = []
        y_arr = []
        target = torch.tensor(1.0).view(1).expand((100))

        for x in x_vec:
            y = loss(x, target).item()

            x_arr.append(float(x[0]))
            y_arr.append(float(y))

        ax[0].plot(x_arr, y_arr, label=name)
        ax[1].plot(x_arr, np.gradient(y_arr, dx))
        ax[2].plot(x_arr, np.gradient(np.gradient(y_arr, dx), dx))

    f.legend()
    f.show()
Beispiel #4
0
def train_main(cfg):
    '''
    训练的主函数
    :param cfg: 配置
    :return:
    '''

    # config
    train_cfg = cfg.train_cfg
    dataset_cfg = cfg.dataset_cfg
    model_cfg = cfg.model_cfg
    is_parallel = cfg.setdefault(key='is_parallel', default=False)
    device = cfg.device
    is_online_train = cfg.setdefault(key='is_online_train', default=False)

    # 配置logger
    logging.basicConfig(filename=cfg.logfile,
                        filemode='a',
                        level=logging.INFO,
                        format='%(asctime)s\n%(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S')
    logger = logging.getLogger()

    #
    # 构建数据集
    train_dataset = LandDataset(DIR_list=dataset_cfg.train_dir_list,
                                mode='train',
                                input_channel=dataset_cfg.input_channel,
                                transform=dataset_cfg.train_transform)
    split_val_from_train_ratio = dataset_cfg.setdefault(
        key='split_val_from_train_ratio', default=None)
    if split_val_from_train_ratio is None:
        val_dataset = LandDataset(DIR_list=dataset_cfg.val_dir_list,
                                  mode='val',
                                  input_channel=dataset_cfg.input_channel,
                                  transform=dataset_cfg.val_transform)
    else:
        val_size = int(len(train_dataset) * split_val_from_train_ratio)
        train_size = len(train_dataset) - val_size
        train_dataset, val_dataset = random_split(
            train_dataset, [train_size, val_size],
            generator=torch.manual_seed(cfg.random_seed))
        # val_dataset.dataset.transform = dataset_cfg.val_transform # 要配置一下val的transform
        print(f"按照{split_val_from_train_ratio}切分训练集...")

    # 构建dataloader
    def _init_fn():
        np.random.seed(cfg.random_seed)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=train_cfg.batch_size,
                                  shuffle=True,
                                  num_workers=train_cfg.num_workers,
                                  drop_last=True,
                                  worker_init_fn=_init_fn())
    val_dataloader = DataLoader(val_dataset,
                                batch_size=train_cfg.batch_size,
                                num_workers=train_cfg.num_workers,
                                shuffle=False,
                                drop_last=True,
                                worker_init_fn=_init_fn())

    # 构建模型
    if train_cfg.is_swa:
        model = torch.load(train_cfg.check_point_file, map_location=device).to(
            device)  # device参数传在里面,不然默认是先加载到cuda:0,to之后再加载到相应的device上
        swa_model = torch.load(
            train_cfg.check_point_file, map_location=device).to(
                device)  # device参数传在里面,不然默认是先加载到cuda:0,to之后再加载到相应的device上
        if is_parallel:
            model = torch.nn.DataParallel(model)
            swa_model = torch.nn.DataParallel(swa_model)
        swa_n = 0
        parameters = swa_model.parameters()
    else:
        model = build_model(model_cfg).to(device)
        if is_parallel:
            model = torch.nn.DataParallel(model)
        parameters = model.parameters()

    # 定义优化器
    optimizer_cfg = train_cfg.optimizer_cfg
    lr_scheduler_cfg = train_cfg.lr_scheduler_cfg
    if optimizer_cfg.type == 'adam':
        optimizer = optim.Adam(params=parameters,
                               lr=optimizer_cfg.lr,
                               weight_decay=optimizer_cfg.weight_decay)
    elif optimizer_cfg.type == 'adamw':
        optimizer = optim.AdamW(params=parameters,
                                lr=optimizer_cfg.lr,
                                weight_decay=optimizer_cfg.weight_decay)
    elif optimizer_cfg.type == 'sgd':
        optimizer = optim.SGD(params=parameters,
                              lr=optimizer_cfg.lr,
                              momentum=optimizer_cfg.momentum,
                              weight_decay=optimizer_cfg.weight_decay)
    elif optimizer_cfg.type == 'RMS':
        optimizer = optim.RMSprop(params=parameters,
                                  lr=optimizer_cfg.lr,
                                  weight_decay=optimizer_cfg.weight_decay)
    else:
        raise Exception('没有该优化器!')

    if not lr_scheduler_cfg:
        lr_scheduler = None
    elif lr_scheduler_cfg.policy == 'cos':
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            lr_scheduler_cfg.T_0,
            lr_scheduler_cfg.T_mult,
            lr_scheduler_cfg.eta_min,
            last_epoch=lr_scheduler_cfg.last_epoch)
    elif lr_scheduler_cfg.policy == 'LambdaLR':
        import math
        lf = lambda x: (((1 + math.cos(x * math.pi / train_cfg.num_epochs)) / 2
                         )**1.0) * 0.95 + 0.05  # cosine
        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                         lr_lambda=lf)
        lr_scheduler.last_epoch = 0
    else:
        lr_scheduler = None

    # 定义损失函数
    DiceLoss_fn = DiceLoss(mode='multiclass')
    SoftCrossEntropy_fn = SoftCrossEntropyLoss(smooth_factor=0.1)
    loss_func = L.JointLoss(first=DiceLoss_fn,
                            second=SoftCrossEntropy_fn,
                            first_weight=0.5,
                            second_weight=0.5).cuda()
    # loss_cls_func = torch.nn.BCEWithLogitsLoss()

    # 创建保存模型的文件夹
    check_point_dir = '/'.join(model_cfg.check_point_file.split('/')[:-1])
    if not os.path.exists(check_point_dir):  # 如果文件夹不存在就创建
        os.mkdir(check_point_dir)

    # 开始训练
    auto_save_epoch_list = train_cfg.setdefault(key='auto_save_epoch_list',
                                                default=5)  # 每隔几轮保存一次模型,默认为5
    train_loss_list = []
    val_loss_list = []
    val_loss_min = 999999
    best_epoch = 0
    best_miou = 0
    train_loss = 10  # 设置一个初始值
    logger.info('开始在{}上训练{}模型...'.format(device, model_cfg.type))
    logger.info('补充信息:{}\n'.format(cfg.setdefault(key='info', default='None')))
    for epoch in range(train_cfg.num_epochs):
        print()
        print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        start_time = time.time()
        print(f"正在进行第{epoch}轮训练...")
        logger.info('*' * 10 + f"第{epoch}轮" + '*' * 10)
        #
        # 训练一轮
        if train_cfg.is_swa:  # swa训练方式
            train_loss = train_epoch(swa_model, optimizer, lr_scheduler,
                                     loss_func, train_dataloader, epoch,
                                     device)
            moving_average(model, swa_model, 1.0 / (swa_n + 1))
            swa_n += 1
            bn_update(train_dataloader, model, device)
        else:
            train_loss = train_epoch(model, optimizer, lr_scheduler, loss_func,
                                     train_dataloader, epoch, device)
            # train_loss = train_unet3p_epoch(model, optimizer, lr_scheduler, loss_func, train_dataloader, epoch, device)

        #
        # 在训练集上评估模型
        # val_loss, val_miou = evaluate_unet3p_model(model, val_dataset, loss_func, device,
        #                                     cfg.num_classes, train_cfg.num_workers, batch_size=train_cfg.batch_size)
        if not is_online_train:  # 只有在线下训练的时候才需要评估模型
            val_loss, val_miou = evaluate_model(model, val_dataloader,
                                                loss_func, device,
                                                cfg.num_classes)
        else:
            val_loss = 0
            val_miou = 0

        train_loss_list.append(train_loss)
        val_loss_list.append(val_loss)

        # 保存模型
        if not is_online_train:  # 非线上训练时需要保存best model
            if val_loss < val_loss_min:
                val_loss_min = val_loss
                best_epoch = epoch
                best_miou = val_miou
                if is_parallel:
                    torch.save(model.module, model_cfg.check_point_file)
                else:
                    torch.save(model, model_cfg.check_point_file)

        if epoch in auto_save_epoch_list:  # 如果再需要保存的轮次中,则保存
            model_file = model_cfg.check_point_file.split(
                '.pth')[0] + '-epoch{}.pth'.format(epoch)
            if is_parallel:
                torch.save(model.module, model_file)
            else:
                torch.save(model, model_file)

        # 打印中间结果
        end_time = time.time()
        run_time = int(end_time - start_time)
        m, s = divmod(run_time, 60)
        time_str = "{:02d}分{:02d}秒".format(m, s)
        print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        out_str = "第{}轮训练完成,耗时{},\t训练集上的loss={:.6f};\t验证集上的loss={:.4f},mIoU={:.6f}\t最好的结果是第{}轮,mIoU={:.6f}" \
            .format(epoch, time_str, train_loss, val_loss, val_miou, best_epoch, best_miou)
        # out_str = "第{}轮训练完成,耗时{},\n训练集上的segm_loss={:.6f},cls_loss{:.6f}\n验证集上的segm_loss={:.4f},cls_loss={:.4f},mIoU={:.6f}\n最好的结果是第{}轮,mIoU={:.6f}" \
        #     .format(epoch, time_str, train_loss, train_cls_loss, val_loss, val_cls_loss, val_miou, best_epoch,
        #             best_miou)
        print(out_str)
        logger.info(out_str + '\n')
Beispiel #5
0
model = smp.UnetPlusPlus(encoder_name="efficientnet-b6",
                         encoder_weights='imagenet',
                         in_channels=3,
                         classes=10)
model.train()
model.to(DEVICE)

# optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)

# loss
# DiceLoss, JaccardLoss, SoftBCEWithLogitsLoss, SoftCrossEntropyLoss
DiceLoss_fn = DiceLoss(mode='multiclass')
SoftCrossEntropy_fn = SoftCrossEntropyLoss(smooth_factor=0.1)
loss_fn = L.JointLoss(first=DiceLoss_fn,
                      second=SoftCrossEntropy_fn,
                      first_weight=0.5,
                      second_weight=0.5).to(DEVICE)

best_iou = 0

for epoch in (range(1, EPOCHES + 1)):
    losses = []
    start_time = time.time()

    for image, target in tqdm(loader):
        image, target = image.to(DEVICE), target.to(DEVICE)
        optimizer.zero_grad()
        output = model(image)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
def train_net(param, model, train_data, valid_data, plot=False,device='cuda'):
    # 初始化参数
    model_name      = param['model_name']
    epochs          = param['epochs']
    batch_size      = param['batch_size']
    lr              = param['lr']
    gamma           = param['gamma']
    step_size       = param['step_size']
    momentum        = param['momentum']
    weight_decay    = param['weight_decay']

    disp_inter      = param['disp_inter']
    save_inter      = param['save_inter']
    min_inter       = param['min_inter']
    iter_inter      = param['iter_inter']

    save_log_dir    = param['save_log_dir']
    save_ckpt_dir   = param['save_ckpt_dir']
    load_ckpt_dir   = param['load_ckpt_dir']

    #
    scaler = GradScaler() 

    # 网络参数
    train_data_size = train_data.__len__()
    valid_data_size = valid_data.__len__()
    c, y, x = train_data.__getitem__(0)['image'].shape
    train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=1)
    valid_loader = DataLoader(dataset=valid_data, batch_size=batch_size, shuffle=False, num_workers=1)
    optimizer = optim.AdamW(model.parameters(), lr=3e-4 ,weight_decay=weight_decay)
    #optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=momentum, weight_decay=weight_decay)
    #scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=3, T_mult=2, eta_min=1e-5, last_epoch=-1)
    #criterion = nn.CrossEntropyLoss(reduction='mean').to(device)
    DiceLoss_fn=DiceLoss(mode='multiclass')
    SoftCrossEntropy_fn=SoftCrossEntropyLoss(smooth_factor=0.1)
    criterion = L.JointLoss(first=DiceLoss_fn, second=SoftCrossEntropy_fn,
                              first_weight=0.5, second_weight=0.5).cuda()
    logger = inial_logger(os.path.join(save_log_dir, time.strftime("%m-%d %H:%M:%S", time.localtime()) +'_'+model_name+ '.log'))

    # 主循环
    train_loss_total_epochs, valid_loss_total_epochs, epoch_lr = [], [], []
    train_loader_size = train_loader.__len__()
    valid_loader_size = valid_loader.__len__()
    best_iou = 0
    best_epoch=0
    best_mode = copy.deepcopy(model)
    epoch_start = 0
    if load_ckpt_dir is not None:
        ckpt = torch.load(load_ckpt_dir)
        epoch_start = ckpt['epoch']
        model.load_state_dict(ckpt['state_dict'])
        optimizer.load_state_dict(ckpt['optimizer'])

    logger.info('Total Epoch:{} Image_size:({}, {}) Training num:{}  Validation num:{}'.format(epochs, x, y, train_data_size, valid_data_size))
    #
    for epoch in range(epoch_start, epochs):
        epoch_start = time.time()
        # 训练阶段
        model.train()
        train_epoch_loss = AverageMeter()
        train_iter_loss = AverageMeter()
        for batch_idx, batch_samples in enumerate(train_loader):
            data, target = batch_samples['image'], batch_samples['label']
            data, target = Variable(data.to(device)), Variable(target.to(device))
            with autocast(): #need pytorch>1.6
                pred = model(data)
                loss = criterion(pred, target)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            scheduler.step(epoch + batch_idx / train_loader_size) 
            image_loss = loss.item()
            train_epoch_loss.update(image_loss)
            train_iter_loss.update(image_loss)
            if batch_idx % iter_inter == 0:
                spend_time = time.time() - epoch_start
                logger.info('[train] epoch:{} iter:{}/{} {:.2f}% lr:{:.6f} loss:{:.6f} ETA:{}min'.format(
                    epoch, batch_idx, train_loader_size, batch_idx/train_loader_size*100,
                    optimizer.param_groups[-1]['lr'],
                    train_iter_loss.avg,spend_time / (batch_idx+1) * train_loader_size // 60 - spend_time // 60))
                train_iter_loss.reset()

        # 验证阶段
        model.eval()
        valid_epoch_loss = AverageMeter()
        valid_iter_loss = AverageMeter()
        iou=IOUMetric(10)
        with torch.no_grad():
            for batch_idx, batch_samples in enumerate(valid_loader):
                data, target = batch_samples['image'], batch_samples['label']
                data, target = Variable(data.to(device)), Variable(target.to(device))
                pred = model(data)
                loss = criterion(pred, target)
                pred=pred.cpu().data.numpy()
                pred= np.argmax(pred,axis=1)
                iou.add_batch(pred,target.cpu().data.numpy())
                #
                image_loss = loss.item()
                valid_epoch_loss.update(image_loss)
                valid_iter_loss.update(image_loss)
                # if batch_idx % iter_inter == 0:
                #     logger.info('[val] epoch:{} iter:{}/{} {:.2f}% loss:{:.6f}'.format(
                #         epoch, batch_idx, valid_loader_size, batch_idx / valid_loader_size * 100, valid_iter_loss.avg))
            val_loss=valid_iter_loss.avg
            acc, acc_cls, iu, mean_iu, fwavacc=iou.evaluate()
            logger.info('[val] epoch:{} miou:{:.2f}'.format(epoch,mean_iu))
                

        # 保存loss、lr
        train_loss_total_epochs.append(train_epoch_loss.avg)
        valid_loss_total_epochs.append(valid_epoch_loss.avg)
        epoch_lr.append(optimizer.param_groups[0]['lr'])
        # 保存模型
        if epoch % save_inter == 0 and epoch > min_inter:
            state = {'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}
            filename = os.path.join(save_ckpt_dir, 'checkpoint-epoch{}.pth'.format(epoch))
            torch.save(state, filename)  # pytorch1.6会压缩模型,低版本无法加载
        # 保存最优模型
        if mean_iu > best_iou:  # train_loss_per_epoch valid_loss_per_epoch
            state = {'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}
            filename = os.path.join(save_ckpt_dir, 'checkpoint-best.pth')
            torch.save(state, filename)
            best_iou = mean_iu
            best_mode = copy.deepcopy(model)
            logger.info('[save] Best Model saved at epoch:{} ============================='.format(epoch))
        #scheduler.step()
        # 显示loss
    # 训练loss曲线
    if plot:
        x = [i for i in range(epochs)]
        fig = plt.figure(figsize=(12, 4))
        ax = fig.add_subplot(1, 2, 1)
        ax.plot(x, smooth(train_loss_total_epochs, 0.6), label='train loss')
        ax.plot(x, smooth(valid_loss_total_epochs, 0.6), label='val loss')
        ax.set_xlabel('Epoch', fontsize=15)
        ax.set_ylabel('CrossEntropy', fontsize=15)
        ax.set_title('train curve', fontsize=15)
        ax.grid(True)
        plt.legend(loc='upper right', fontsize=15)
        ax = fig.add_subplot(1, 2, 2)
        ax.plot(x, epoch_lr,  label='Learning Rate')
        ax.set_xlabel('Epoch', fontsize=15)
        ax.set_ylabel('Learning Rate', fontsize=15)
        ax.set_title('lr curve', fontsize=15)
        ax.grid(True)
        plt.legend(loc='upper right', fontsize=15)
        plt.show()
            
    return best_mode, model
Beispiel #7
0
    file_path = "dataset/train_data.pkl"
    with open(file_path, 'rb') as fr:
        train_data = pickle.load(fr)
    train_all = train_data['data']
    kind = train_data['label']
    n_splits = 5
    sk = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=20)
    fold = 0
    for train_idx, test_idx in sk.split(train_all, kind):
        # 每一个fold都重新初始化一个新的模型
        #
        #weights_class = torch.FloatTensor([0.1, 0.9])
        # criterion = torch.nn.CrossEntropyLoss(weight=weights_class)
        if use_multi_loss:
            criterion = L.JointLoss(first=torch.nn.CrossEntropyLoss().cuda(),
                                    second=LabelSmoothSoftmaxCE().cuda(),
                                    first_weight=0.5,
                                    second_weight=0.5)
        #
        tokenizer = BertTokenizer.from_pretrained(pretrained)
        model = BertModel.from_pretrained(pretrained)
        config = BertConfig.from_pretrained(pretrained)

        albertBertClassifier = AlbertClassfier(model, config, 20)
        device = torch.device("cuda") if torch.cuda.is_available() else 'cpu'
        albertBertClassifier = albertBertClassifier.to(device)
        if n_gpu > 1:
            albertBertClassifier = torch.nn.DataParallel(albertBertClassifier)
        if use_optm == 'sgd':
            optimizer = torch.optim.SGD(albertBertClassifier.parameters(),
                                        lr=init_lr,
                                        momentum=0.9,
Beispiel #8
0
def train(config, model, train_data, valid_data, plot=False, device='cuda'):
    """

    :config config:
    :config model:
    :config train_data:
    :config valid_data:
    :config plot:
    :config device:
    :return:
    """
    # 初始化参数
    model_name = config['model_name']
    epochs = config['epochs']
    batch_size = config['batch_size']

    class_weights = config['class_weights']
    disp_inter = config['disp_inter']
    save_inter = config['save_inter']
    min_inter = config['min_inter']
    iter_inter = config['iter_inter']

    save_log_dir = config['save_log_dir']
    save_ckpt_dir = config['save_ckpt_dir']
    load_ckpt_dir = config['load_ckpt_dir']
    accumulation_steps = config['accumulation_steps']
    # automatic mixed precision
    scaler = GradScaler()
    # 网络参数
    train_data_size = train_data.__len__()
    valid_data_size = valid_data.__len__()
    c, y, x = train_data.__getitem__(0)['image'].shape
    train_loader = DataLoader(dataset=train_data,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=1)
    valid_loader = DataLoader(dataset=valid_data,
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=1)
    #
    if config['optimizer'].lower() == 'adamw':
        optimizer = optim.AdamW(model.parameters(),
                                lr=config['lr'],
                                weight_decay=config['weight_decay'])
    elif config['optimizer'].lower() == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=config['lr'],
                              momentum=config['momentum'],
                              weight_decay=config['weight_decay'])
    # SWA
    if config['swa']:
        swa_opt = SWA(optimizer,
                      swa_start=config['swa_start'],
                      swa_freq=config['swa_freq'],
                      swa_lr=config['swa_lr'])

    # warm_up_with_multistep_lr = lambda \
    #     epoch: epoch / warmup_epochs if epoch <= warmup_epochs else gamma ** len(
    #     [m for m in milestones if m <= epoch])
    # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_up_with_multistep_lr)
    if config['scheduler'] == 'CosineAnnealingLR':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=config['epochs'], eta_min=config['min_lr'])
    elif config['scheduler'] == 'CosineAnnealingWarmRestarts':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=3, T_mult=2, eta_min=1e-5, last_epoch=-1)
    elif config['scheduler'] == 'ReduceLROnPlateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            factor=config['factor'],
            patience=config['patience'],
            verbose=1,
            min_lr=config['min_lr'])
    elif config['scheduler'] == 'MultiStepLR':
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[int(e) for e in config['milestones'].split(',')],
            gamma=config['gamma'])
    elif config['scheduler'] == 'NoamLR':
        scheduler = NoamLR(optimizer, warmup_steps=config['warmup_steps'])
    else:
        raise NotImplementedError

    # criterion = nn.CrossEntropyLoss(reduction='mean').to(device)
    DiceLoss_fn = DiceLoss(mode='multiclass')
    SoftCrossEntropy_fn = SoftCrossEntropyLoss(smooth_factor=0.1)
    CrossEntropyLoss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)
    Lovasz_fn = LovaszLoss(mode='multiclass')
    criterion = L.JointLoss(first=DiceLoss_fn,
                            second=CrossEntropyLoss_fn,
                            first_weight=0.5,
                            second_weight=0.5).cuda()

    logger = initial_logger(
        os.path.join(
            save_log_dir,
            time.strftime("%m-%d %H:%M:%S", time.localtime()) + '_' +
            model_name + '.log'))

    # 主循环
    train_loss_total_epochs, valid_loss_total_epochs, epoch_lr = [], [], []
    train_loader_size = len(train_loader)
    valid_loader_size = len(valid_loader)
    best_iou = 0
    best_epoch = 0
    best_mode = copy.deepcopy(model)
    start_epoch = 0
    if load_ckpt_dir is not None:
        ckpt = torch.load(load_ckpt_dir)
        start_epoch = ckpt['epoch']
        model.load_state_dict(ckpt['state_dict'])
        optimizer.load_state_dict(ckpt['optimizer'])

    logger.info(
        'Total Epoch:{} Image_size:({}, {}) Training num:{}  Validation num:{}'
        .format(epochs, x, y, train_data_size, valid_data_size))
    # execute train
    for epoch in range(start_epoch, epochs):
        start_time = time.time()
        # 训练阶段
        model.train()
        train_epoch_loss = AverageMeter()
        train_iter_loss = AverageMeter()
        for batch_idx, batch_samples in enumerate(train_loader):
            data, target = batch_samples['image'], batch_samples['label']
            data, target = Variable(data.to(device)), Variable(
                target.to(device))
            with autocast():  # need pytorch>1.6
                pred = model(data)
                loss = criterion(pred, target)
                # 2.1 loss regularization
                regular_loss = loss / accumulation_steps
                # 2.2 back propagation
                scaler.scale(regular_loss).backward()
                # 2.3 update parameters of net
                if (batch_idx + 1) % accumulation_steps == 0:
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
                    scheduler.step(epoch + batch_idx / train_loader_size)
            #     scaler.scale(loss).backward()
            #     scaler.step(optimizer)
            #     scaler.update()
            #     optimizer.zero_grad()
            # scheduler.step(epoch + batch_idx / train_loader_size)
            image_loss = loss.item()
            train_epoch_loss.update(image_loss)
            train_iter_loss.update(image_loss)
            if batch_idx % iter_inter == 0:
                spend_time = time.time() - start_time
                logger.info(
                    '[train] epoch:{} iter:{}/{} {:.2f}% lr:{:.6f} loss:{:.6f} ETA:{}min'
                    .format(
                        epoch, batch_idx, train_loader_size,
                        batch_idx / train_loader_size * 100,
                        optimizer.param_groups[-1]['lr'], train_iter_loss.avg,
                        spend_time /
                        (batch_idx + 1) * train_loader_size // 60 -
                        spend_time // 60))
                train_iter_loss.reset()

        # validation
        valid_epoch_loss, mean_iou = eval(model, valid_loader, criterion,
                                          epoch, logger)
        # save loss and lr
        train_loss_total_epochs.append(train_epoch_loss.avg)
        valid_loss_total_epochs.append(valid_epoch_loss.avg)
        epoch_lr.append(optimizer.param_groups[0]['lr'])

        # save checkpoint
        if (epoch + 1) % save_inter == 0 and epoch > min_inter:
            state = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            filename = os.path.join(save_ckpt_dir,
                                    'checkpoint-epoch{}.pth'.format(epoch))
            torch.save(state, filename)  # pytorch1.6会压缩模型,低版本无法加载

        # save best model
        if mean_iou > best_iou:  # train_loss_per_epoch valid_loss_per_epoch
            state = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            filename = os.path.join(save_ckpt_dir, 'checkpoint-best.pth')
            torch.save(state, filename)
            best_iou = mean_iou
            best_mode = copy.deepcopy(model)
            logger.info(
                '[save] Best Model saved at epoch:{} ============================='
                .format(epoch))
        # scheduler.step()

    # show loss curve
    if plot:
        x = [i for i in range(epochs)]
        fig = plt.figure(figsize=(12, 4))
        ax = fig.add_subplot(1, 2, 1)
        ax.plot(x, smooth(train_loss_total_epochs, 0.6), label='train loss')
        ax.plot(x, smooth(valid_loss_total_epochs, 0.6), label='val loss')
        ax.set_xlabel('Epoch', fontsize=15)
        ax.set_ylabel('CrossEntropy', fontsize=15)
        ax.set_title('train curve', fontsize=15)
        ax.grid(True)
        plt.legend(loc='upper right', fontsize=15)
        ax = fig.add_subplot(1, 2, 2)
        ax.plot(x, epoch_lr, label='Learning Rate')
        ax.set_xlabel('Epoch', fontsize=15)
        ax.set_ylabel('Learning Rate', fontsize=15)
        ax.set_title('lr curve', fontsize=15)
        ax.grid(True)
        plt.legend(loc='upper right', fontsize=15)
        plt.show()

    return best_mode, model
def joint(first, second, first_weight=1.0, second_weight=1.0):
    return L.JointLoss(first, second, first_weight, second_weight)
Beispiel #10
0
def mask_rcnn_loss(pred_mask_logits, pred_boundary_logits, instances,
                   pred_mask_bo_logits, pred_boundary_logits_bo):
    """
    Compute the mask prediction loss defined in the Mask R-CNN paper.

    Args:
        pred_mask_logits (Tensor): A tensor of shape (B, C, Hmask, Wmask) or (B, 1, Hmask, Wmask)
            for class-specific or class-agnostic, where B is the total number of predicted masks
            in all images, C is the number of foreground classes, and Hmask, Wmask are the height
            and width of the mask predictions. The values are logits.
        instances (list[Instances]): A list of N Instances, where N is the number of images
            in the batch. These instances are in 1:1
            correspondence with the pred_mask_logits. The ground-truth labels (class, box, mask,
            ...) associated with each instance are stored in fields.

    Returns:
        mask_loss (Tensor): A scalar tensor containing the loss.
    """
    cls_agnostic_mask = pred_mask_logits.size(1) == 1
    total_num_masks = pred_mask_logits.size(0)
    mask_side_len = pred_mask_logits.size(2)
    assert pred_mask_logits.size(2) == pred_mask_logits.size(
        3), "Mask prediction must be square!"

    gt_classes = []
    gt_masks = []
    gt_bo_masks = []
    gt_boundary_bo = []
    gt_boundary = []

    for instances_per_image in instances:
        if len(instances_per_image) == 0:
            continue
        if not cls_agnostic_mask:
            gt_classes_per_image = instances_per_image.gt_classes.to(
                dtype=torch.int64)
            gt_classes.append(gt_classes_per_image)

        #print('mask_head.py L59 instances_per_image.gt_masks:', instances_per_image.gt_masks)
        gt_masks_per_image = instances_per_image.gt_masks.crop_and_resize(
            instances_per_image.proposal_boxes.tensor,
            mask_side_len).to(device=pred_mask_logits.device)
        # A tensor of shape (N, M, M), N=#instances in the image; M=mask_side_len
        gt_masks.append(gt_masks_per_image)

        boundary_ls = []
        for mask in gt_masks_per_image:
            mask_b = mask.data.cpu().numpy()
            boundary, inside_mask, weight = get_instances_contour_interior(
                mask_b)
            boundary = torch.from_numpy(boundary).to(
                device=mask.device).unsqueeze(0)

            boundary_ls.append(boundary)

        gt_boundary.append(cat(boundary_ls, dim=0))

        gt_bo_masks_per_image = instances_per_image.gt_bo_masks.crop_and_resize(
            instances_per_image.proposal_boxes.tensor,
            mask_side_len).to(device=pred_mask_logits.device)
        # A tensor of shape (N, M, M), N=#instances in the image; M=mask_side_len
        gt_bo_masks.append(gt_bo_masks_per_image)

        boundary_ls_bo = []
        for mask_bo in gt_bo_masks_per_image:
            mask_b_bo = mask_bo.data.cpu().numpy()
            boundary_bo, inside_mask_bo, weight_bo = get_instances_contour_interior(
                mask_b_bo)
            boundary_bo = torch.from_numpy(boundary_bo).to(
                device=mask_bo.device).unsqueeze(0)

            boundary_ls_bo.append(boundary_bo)

        gt_boundary_bo.append(cat(boundary_ls_bo, dim=0))

    if len(gt_masks) == 0:
        return pred_mask_logits.sum() * 0, pred_boundary_logits.sum() * 0

    gt_masks = cat(gt_masks, dim=0)
    gt_bo_masks = cat(gt_bo_masks, dim=0)

    gt_boundary_bo = cat(gt_boundary_bo, dim=0)
    gt_boundary = cat(gt_boundary, dim=0)

    if cls_agnostic_mask:
        pred_mask_logits_gt = pred_mask_logits[:, 0]
        pred_bo_mask_logits = pred_mask_bo_logits[:, 0]
        pred_boundary_logits_bo = pred_boundary_logits_bo[:, 0]
        pred_boundary_logits = pred_boundary_logits[:, 0]
    else:
        indices = torch.arange(total_num_masks)
        gt_classes = cat(gt_classes, dim=0)
        pred_mask_logits = pred_mask_logits[indices, gt_classes]

    if gt_masks.dtype == torch.bool:
        gt_masks_bool = gt_masks
    else:
        gt_masks_bool = gt_masks > 0.5

    mask_incorrect = (pred_mask_logits_gt > 0.0) != gt_masks_bool
    mask_accuracy = 1 - (mask_incorrect.sum().item() /
                         max(mask_incorrect.numel(), 1.0))
    num_positive = gt_masks_bool.sum().item()
    false_positive = (mask_incorrect & ~gt_masks_bool).sum().item() / max(
        gt_masks_bool.numel() - num_positive, 1.0)
    false_negative = (mask_incorrect & gt_masks_bool).sum().item() / max(
        num_positive, 1.0)

    indexs2 = torch.nonzero(
        torch.sum(gt_bo_masks.to(dtype=torch.float32), (1, 2)))

    new_gt_bo_masks1 = gt_bo_masks[indexs2, :, :].squeeze()
    new_gt_bo_masks2 = gt_bo_masks[:indexs2.shape[0]]
    if new_gt_bo_masks1.shape != new_gt_bo_masks2.shape:
        new_gt_bo_masks1 = new_gt_bo_masks1.unsqueeze(0)

    new_gt_bo_masks = torch.cat((new_gt_bo_masks1, new_gt_bo_masks2), 0)

    pred_bo_mask_logits1 = pred_bo_mask_logits[indexs2, :, :].squeeze()
    pred_bo_mask_logits2 = pred_bo_mask_logits[:indexs2.shape[0]]
    if pred_bo_mask_logits1.shape != pred_bo_mask_logits2.shape:
        pred_bo_mask_logits1 = pred_bo_mask_logits1.unsqueeze(0)

    new_pred_bo_mask_logits = torch.cat(
        (pred_bo_mask_logits1, pred_bo_mask_logits2), 0)

    new_gt_bo_bounds1 = gt_boundary_bo[indexs2, :, :].squeeze()
    new_gt_bo_bounds2 = gt_boundary_bo[:indexs2.shape[0]]
    if new_gt_bo_bounds1.shape != new_gt_bo_bounds2.shape:
        new_gt_bo_bounds1 = new_gt_bo_bounds1.unsqueeze(0)

    new_gt_bo_bounds = torch.cat((new_gt_bo_bounds1, new_gt_bo_bounds2), 0)

    pred_bo_bounds_logits1 = pred_boundary_logits_bo[indexs2, :, :].squeeze()
    pred_bo_bounds_logits2 = pred_boundary_logits_bo[:indexs2.shape[0]]
    if pred_bo_bounds_logits1.shape != pred_bo_bounds_logits2.shape:
        pred_bo_bounds_logits1 = pred_bo_bounds_logits1.unsqueeze(0)

    new_pred_bo_bounds_logits = torch.cat(
        (pred_bo_bounds_logits1, pred_bo_bounds_logits2), 0)

    mask_loss = F.binary_cross_entropy_with_logits(
        pred_mask_logits_gt,
        gt_masks.to(dtype=torch.float32),
        reduction="mean")

    bound_loss = L.JointLoss(L.BceLoss(),
                             L.BceLoss())(pred_boundary_logits.unsqueeze(1),
                                          gt_boundary.to(dtype=torch.float32))

    if new_gt_bo_masks.shape[0] > 0:
        bo_mask_loss = F.binary_cross_entropy_with_logits(
            new_pred_bo_mask_logits,
            new_gt_bo_masks.to(dtype=torch.float32),
            reduction="mean")
    else:
        bo_mask_loss = torch.tensor(0.0).cuda(mask_loss.get_device())

    if new_gt_bo_bounds.shape[0] > 0:
        bo_bound_loss = L.JointLoss(L.BceLoss(), L.BceLoss())(
            new_pred_bo_bounds_logits.unsqueeze(1),
            new_gt_bo_bounds.to(dtype=torch.float32))
    else:
        bo_bound_loss = torch.tensor(0.0).cuda(mask_loss.get_device())

    return mask_loss, bo_mask_loss, bound_loss, bo_bound_loss
Beispiel #11
0
def train(EPOCHES, BATCH_SIZE, train_image_paths, train_label_paths,
          val_image_paths, val_label_paths, channels, optimizer_name,
          model_path, swa_model_path, addNDVI, loss, early_stop):

    train_loader = get_dataloader(train_image_paths,
                                  train_label_paths,
                                  "train",
                                  addNDVI,
                                  BATCH_SIZE,
                                  shuffle=True,
                                  num_workers=8)
    valid_loader = get_dataloader(val_image_paths,
                                  val_label_paths,
                                  "val",
                                  addNDVI,
                                  BATCH_SIZE,
                                  shuffle=False,
                                  num_workers=8)

    # 定义模型,优化器,损失函数
    # model = smp.UnetPlusPlus(
    #         encoder_name="efficientnet-b7",
    #         encoder_weights="imagenet",
    #         in_channels=channels,
    #         classes=10,
    # )
    model = smp.UnetPlusPlus(
        encoder_name="timm-resnest101e",
        encoder_weights="imagenet",
        in_channels=channels,
        classes=10,
    )
    # model = seg_hrnet_ocr.get_seg_model()
    model.to(DEVICE)
    # model.load_state_dict(torch.load(model_path))
    # 采用SGD优化器
    if (optimizer_name == "sgd"):
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=1e-4,
                                    weight_decay=1e-3,
                                    momentum=0.9)
    # 采用AdamM优化器
    else:
        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=1e-4,
                                      weight_decay=1e-3)
    # 余弦退火调整学习率
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=2,  # T_0就是初始restart的epoch数目
        T_mult=2,  # T_mult就是重启之后因子,即每个restart后,T_0 = T_0 * T_mult
        eta_min=1e-5  # 最低学习率
    )
    # # 使用SWA的初始epoch
    # swa_start = 80
    # # 随机权重平均SWA,以几乎不增加任何成本的方式实现更好的泛化
    # swa_model = AveragedModel(model).to(DEVICE)
    # # SWA调整学习率
    # swa_scheduler = SWALR(optimizer, swa_lr=1e-5)

    if (loss == "SoftCE_dice"):
        # 损失函数采用SoftCrossEntropyLoss+DiceLoss
        # diceloss在一定程度上可以缓解类别不平衡,但是训练容易不稳定
        DiceLoss_fn = DiceLoss(mode='multiclass')
        # 软交叉熵,即使用了标签平滑的交叉熵,会增加泛化性
        SoftCrossEntropy_fn = SoftCrossEntropyLoss(smooth_factor=0.1)
        loss_fn = L.JointLoss(first=DiceLoss_fn,
                              second=SoftCrossEntropy_fn,
                              first_weight=0.5,
                              second_weight=0.5).cuda()
    else:
        # 损失函数采用SoftCrossEntropyLoss+LovaszLoss
        # LovaszLoss是对基于子模块损失凸Lovasz扩展的mIoU损失的直接优化
        LovaszLoss_fn = LovaszLoss(mode='multiclass')
        # 软交叉熵,即使用了标签平滑的交叉熵,会增加泛化性
        SoftCrossEntropy_fn = SoftCrossEntropyLoss(smooth_factor=0.1)
        loss_fn = L.JointLoss(first=LovaszLoss_fn,
                              second=SoftCrossEntropy_fn,
                              first_weight=0.5,
                              second_weight=0.5).cuda()

    header = r'Epoch/EpochNum | TrainLoss | ValidmIoU | Time(m)'
    raw_line = r'{:5d}/{:8d} | {:9.3f} | {:9.3f} | {:9.2f}'
    print(header)

    #    # 在训练最开始之前实例化一个GradScaler对象,使用autocast才需要
    #    scaler = GradScaler()

    # 记录当前验证集最优mIoU,以判定是否保存当前模型
    best_miou = 0
    best_miou_epoch = 0
    train_loss_epochs, val_mIoU_epochs, lr_epochs = [], [], []
    # 开始训练
    for epoch in range(1, EPOCHES + 1):
        # print("Start training the {}st epoch...".format(epoch))
        # 存储训练集每个batch的loss
        losses = []
        start_time = time.time()
        model.train()
        model.to(DEVICE)
        for batch_index, (image, target) in enumerate(train_loader):
            image, target = image.to(DEVICE), target.to(DEVICE)
            # 在反向传播前要手动将梯度清零
            optimizer.zero_grad()
            #            # 使用autocast半精度加速训练,前向过程(model + loss)开启autocast
            #            with autocast(): #need pytorch>1.6
            # 模型推理得到输出
            output = model(image)
            # 求解该batch的loss
            loss = loss_fn(output, target)
            #                scaler.scale(loss).backward()
            #                scaler.step(optimizer)
            #                scaler.update()
            # 反向传播求解梯度
            loss.backward()
            # 更新权重参数
            optimizer.step()
            losses.append(loss.item())
        # if epoch > swa_start:
        #     swa_model.update_parameters(model)
        #     swa_scheduler.step()
        # else:
        # 余弦退火调整学习率
        scheduler.step()
        # 计算验证集IoU
        val_iou = cal_val_iou(model, valid_loader)
        # 输出验证集每类IoU
        # print('\t'.join(np.stack(val_iou).mean(0).round(3).astype(str)))
        # 保存当前epoch的train_loss.val_mIoU.lr_epochs
        train_loss_epochs.append(np.array(losses).mean())
        val_mIoU_epochs.append(np.mean(val_iou))
        lr_epochs.append(optimizer.param_groups[0]['lr'])
        # 输出进程
        print(raw_line.format(epoch, EPOCHES,
                              np.array(losses).mean(), np.mean(val_iou),
                              (time.time() - start_time) / 60**1),
              end="")
        if best_miou < np.stack(val_iou).mean(0).mean():
            best_miou = np.stack(val_iou).mean(0).mean()
            best_miou_epoch = epoch
            torch.save(model.state_dict(), model_path)
            print("  valid mIoU is improved. the model is saved.")
        else:
            print("")
            if (epoch - best_miou_epoch) >= early_stop:
                break
    # # 最后更新BN层参数
    # torch.optim.swa_utils.update_bn(train_loader, swa_model, device= DEVICE)
    # # 计算验证集IoU
    # val_iou = cal_val_iou(model, valid_loader)
    # print("swa_model'mIoU is {}".format(np.mean(val_iou)))
    # torch.save(swa_model.state_dict(), swa_model_path)
    return train_loss_epochs, val_mIoU_epochs, lr_epochs