示例#1
0
def train(train_loader, model, optimizer, epoch):
    ## step.1 设置评价参数,随时更新
    batch_time = AverageMeter()
    data_time = AverageMeter()
    main_loss_meter = AverageMeter()
    aux_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    model.train()
    end = time.time()
    max_iter = args.epochs * len(train_loader)

    ## step.2 epoch内部循环
    for i, (input, target) in enumerate(train_loader):
        data_time.update(time.time() - end)
        if args.zoom_factor != 8:
            h = int((target.size()[1] - 1) / 8 * args.zoom_factor + 1)
            w = int((target.size()[2] - 1) / 8 * args.zoom_factor + 1)
            # 'nearest' mode doesn't support align_corners mode and 'bilinear' mode is fine for downsampling
            target = F.interpolate(target.unsqueeze(1).float(),
                                   size=(h, w),
                                   mode='bilinear',
                                   align_corners=True).squeeze(1).long()
        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)
        output, main_loss, aux_loss = model(input, target)  # 输出, 损失函数
        if not args.multiprocessing_distributed:
            main_loss, aux_loss = torch.mean(main_loss), torch.mean(aux_loss)
        loss = main_loss + args.aux_weight * aux_loss

        ## step.3 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        n = input.size(0)  # 一张卡的batch
        if args.multiprocessing_distributed:
            main_loss, aux_loss, loss = main_loss.detach(
            ) * n, aux_loss * n, loss * n  # not considering ignore pixels
            count = target.new_tensor([n], dtype=torch.long)
            dist.all_reduce(main_loss), dist.all_reduce(
                aux_loss), dist.all_reduce(loss), dist.all_reduce(count)
            n = count.item()
            main_loss, aux_loss, loss = main_loss / n, aux_loss / n, loss / n

        ## step.4 更新评价数据
        intersection, union, target = intersectionAndUnionGPU(
            output, target, args.classes, args.ignore_label)
        if args.multiprocessing_distributed:
            dist.all_reduce(intersection), dist.all_reduce(
                union), dist.all_reduce(target)
        intersection, union, target = intersection.cpu().numpy(), union.cpu(
        ).numpy(), target.cpu().numpy()
        intersection_meter.update(intersection), union_meter.update(
            union), target_meter.update(target)

        accuracy = sum(
            intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        main_loss_meter.update(main_loss.item(), n)
        aux_loss_meter.update(aux_loss.item(), n)
        loss_meter.update(loss.item(), n)
        batch_time.update(time.time() - end)
        end = time.time()

        ## step.5 调整学习率
        current_iter = epoch * len(train_loader) + i + 1
        current_lr = poly_learning_rate(args.base_lr,
                                        current_iter,
                                        max_iter,
                                        power=args.power)
        for index in range(0, args.index_split):
            optimizer.param_groups[index]['lr'] = current_lr  # 原backbone学习率调整
        for index in range(args.index_split, len(optimizer.param_groups)):
            optimizer.param_groups[index]['lr'] = current_lr * 10  # 后面预测网络学习调整
        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m),
                                                    int(t_s))  # 计算剩余时间

        ## step.6 打印日志
        if (i + 1) % args.print_freq == 0 and main_process():
            logger.info('Epoch: [{}/{}][{}/{}] '
                        'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                        'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                        'Remain {remain_time} '
                        'MainLoss {main_loss_meter.val:.4f} '
                        'AuxLoss {aux_loss_meter.val:.4f} '
                        'Loss {loss_meter.val:.4f} '
                        'Accuracy {accuracy:.4f}.'.format(
                            epoch + 1,
                            args.epochs,
                            i + 1,
                            len(train_loader),
                            batch_time=batch_time,
                            data_time=data_time,
                            remain_time=remain_time,
                            main_loss_meter=main_loss_meter,
                            aux_loss_meter=aux_loss_meter,
                            loss_meter=loss_meter,
                            accuracy=accuracy))
        if main_process():
            writer.add_scalar('loss_train_batch', main_loss_meter.val,
                              current_iter)
            writer.add_scalar('mIoU_train_batch',
                              np.mean(intersection / (union + 1e-10)),
                              current_iter)
            writer.add_scalar('mAcc_train_batch',
                              np.mean(intersection / (target + 1e-10)),
                              current_iter)
            writer.add_scalar('allAcc_train_batch', accuracy, current_iter)

    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIoU = np.mean(iou_class)
    mAcc = np.mean(accuracy_class)
    allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
    if main_process():
        logger.info(
            'Train result at epoch [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'
            .format(epoch + 1, args.epochs, mIoU, mAcc, allAcc))
    return main_loss_meter.avg, mIoU, mAcc, allAcc
示例#2
0
def train(train_loader, model, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    loss_meter = AverageMeter()
    accuracy_meter = AverageMeter()
    fscore_meter = AverageMeter()

    model.train()
    end = time.time()
    max_iter = args.epochs * len(train_loader)
    for i, (input, target) in enumerate(train_loader):
        data_time.update(time.time() - end)

        # data
        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # loss
        logits, output, loss = model(input, target)
        if len(args.train_gpu) > 1:
            loss = torch.mean(loss)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        n = input.size(0)
        loss_meter.update(loss.item(), n)

        # metric
        accuracy, precision, recall, f_score = accuracy_metrics(
            output.detach().cpu().numpy(),
            target.detach().cpu().numpy(),
            threshold=args.binary_threshold,
            training=True)
        accuracy_meter.update(accuracy)
        fscore_meter.update(f_score)

        batch_time.update(time.time() - end)
        end = time.time()

        # learning rate
        current_iter = epoch * len(train_loader) + i + 1
        current_lr = poly_learning_rate(args.base_lr,
                                        current_iter,
                                        max_iter,
                                        power=args.power)
        writer.add_scalar('learning_rate', current_lr, current_iter)
        for param_group in optimizer.param_groups:
            param_group['lr'] = current_lr
        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m),
                                                    int(t_s))

        if (i + 1) % args.print_freq == 0:
            logger.info('Epoch: [{}/{}][{}/{}] '
                        'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                        'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                        'Remain {remain_time} '
                        'Loss {loss_meter.val:.4f} '
                        'Accuracy {accuracy_meter.val:.4f}.'
                        'f-score {fscore_meter.val:.4f}.'.format(
                            epoch + 1,
                            args.epochs,
                            i + 1,
                            len(train_loader),
                            batch_time=batch_time,
                            data_time=data_time,
                            remain_time=remain_time,
                            loss_meter=loss_meter,
                            accuracy_meter=accuracy_meter,
                            fscore_meter=fscore_meter))

        writer.add_scalar('loss_train_batch', loss_meter.val, current_iter)
        writer.add_scalar('accuracy_train_batch', accuracy_meter.val,
                          current_iter)
        writer.add_scalar('fscore_train_batch', fscore_meter.val, current_iter)

    mAcc = accuracy_meter.avg
    mFscore = fscore_meter.avg
    logger.info(
        'Train result at epoch [{}/{}]: mAcc/mFscore {:.4f}/{:.4f}.'.format(
            epoch + 1, args.epochs, mAcc, mFscore))
    return loss_meter.avg, mAcc, mFscore
示例#3
0
文件: train.py 项目: zots0127/ASGNet
def train(train_loader, model, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    main_loss_meter = AverageMeter()
    aux_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    model.train()
    end = time.time()
    max_iter = args.epochs * len(train_loader)
    vis_key = 0
    print('Warmup: {}'.format(args.warmup))
    for i, (input, target, s_input, s_mask, s_init_seed, subcls) in enumerate(train_loader):
        data_time.update(time.time() - end)
        current_iter = epoch * len(train_loader) + i + 1
        index_split = -1
        if args.base_lr > 1e-6:
            poly_learning_rate(optimizer, args.base_lr, current_iter, max_iter, power=args.power, index_split=index_split, warmup=args.warmup, warmup_step=len(train_loader)//2)

        s_input = s_input.cuda(non_blocking=True)
        s_mask = s_mask.cuda(non_blocking=True)
        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)
        s_init_seed = s_init_seed.cuda(non_blocking=True)
        
        output, main_loss, aux_loss = model(s_x=s_input, s_y=s_mask, x=input, y=target, s_seed=s_init_seed)

        if not args.multiprocessing_distributed:
            main_loss, aux_loss = torch.mean(main_loss), torch.mean(aux_loss)
        loss = main_loss + args.aux_weight * aux_loss
        optimizer.zero_grad()

        loss.backward()
        optimizer.step()
        n = input.size(0)
        if args.multiprocessing_distributed:
            main_loss, aux_loss, loss = main_loss.detach() * n, aux_loss * n, loss * n 
            count = target.new_tensor([n], dtype=torch.long)
            dist.all_reduce(main_loss), dist.all_reduce(aux_loss), dist.all_reduce(loss), dist.all_reduce(count)
            n = count.item()
            main_loss, aux_loss, loss = main_loss / n, aux_loss / n, loss / n

        intersection, union, target = intersectionAndUnionGPU(output, target, args.classes, args.ignore_label)
        if args.multiprocessing_distributed:
            dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(target)
        intersection, union, target = intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy()
        intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target)

        accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        main_loss_meter.update(main_loss.item(), n)
        aux_loss_meter.update(aux_loss.item(), n)
        loss_meter.update(loss.item(), n)
        batch_time.update(time.time() - end)
        end = time.time()

        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))

        if (i + 1) % args.print_freq == 0 and main_process():
            logger.info('Epoch: [{}/{}][{}/{}] '
                        'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                        'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                        'Remain {remain_time} '
                        'MainLoss {main_loss_meter.val:.4f} '
                        'AuxLoss {aux_loss_meter.val:.4f} '                        
                        'Loss {loss_meter.val:.4f} '
                        'Accuracy {accuracy:.4f}.'.format(epoch+1, args.epochs, i + 1, len(train_loader),
                                                          batch_time=batch_time,
                                                          data_time=data_time,
                                                          remain_time=remain_time,
                                                          main_loss_meter=main_loss_meter,
                                                          aux_loss_meter=aux_loss_meter,
                                                          loss_meter=loss_meter,
                                                          accuracy=accuracy))
        if main_process():
            writer.add_scalar('loss_train_batch', main_loss_meter.val, current_iter)
            writer.add_scalar('aux_loss_train_batch', aux_loss_meter.val, current_iter)
            writer.add_scalar('mIoU_train_batch', np.mean(intersection / (union + 1e-10)), current_iter)
            writer.add_scalar('mAcc_train_batch', np.mean(intersection / (target + 1e-10)), current_iter)
            writer.add_scalar('allAcc_train_batch', accuracy, current_iter)

    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIoU = np.mean(iou_class)
    mAcc = np.mean(accuracy_class)
    allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)

    if main_process():
        logger.info('Train result at epoch [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(epoch, args.epochs, mIoU, mAcc, allAcc))
        for i in range(args.classes):
            logger.info('Class_{} Result: iou/accuracy {:.4f}/{:.4f}.'.format(i, iou_class[i], accuracy_class[i]))        
    return main_loss_meter.avg, aux_loss_meter.avg, mIoU, mAcc, allAcc
def train(train_loader, model, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    main_loss_meter = AverageMeter()
    aux_loss_meter = AverageMeter()
    reg_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()
    final_loss_meter = AverageMeter()
    model.train()

    '''for p in model.module.layer0.parameters():
     p.require_gradient=False
    for p in model.module.layer1.parameters():
     p.require_gradient=False
    for p in model.module.layer2.parameters():
     p.require_gradient=False
    for p in model.module.layer3.parameters():
     p.require_gradient=False
    for p in model.module.layer4.parameters():
     p.require_gradient=False
    for p in model.module.ppm.parameters():
     p.require_gradient=False
    for p in model.module.cls.parameters():
     p.require_gradient=False
    for p in model.module.aux.parameters():
     p.require_gradient=False

    model.module.layer0.eval()
    model.module.layer1.eval()
    model.module.layer2.eval()
    model.module.layer3.eval()
    model.module.ppm.eval()
    model.module.reg.eval()
    model.module.aux.eval()'''
    end = time.time()
    max_iter = args.epochs * len(train_loader)
    for i, (input, target, feat, featidx) in enumerate(train_loader):
        #print (i,flush=True)
        data_time.update(time.time() - end)
        if args.zoom_factor != 8:
            h = int((target.size()[1] - 1) / 8 * args.zoom_factor + 1)
            w = int((target.size()[2] - 1) / 8 * args.zoom_factor + 1)
            # 'nearest' mode doesn't support align_corners mode and 'bilinear' mode is fine for downsampling
            target = F.interpolate(target.unsqueeze(1).float(), size=(h, w), mode='bilinear', align_corners=True).squeeze(1).long()
        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)
        #print ('t1',flush=True)
        feat=feat.cuda(non_blocking=True)
        #print (feat.shape,flush=True)
        featidx=featidx.cuda(non_blocking=True)
        #print ('t2',flush=True)
        output, main_loss, aux_loss, reg_loss, final_loss = model(input, target, feat, featidx)
        #print ('t3',flush=True)
        if not args.multiprocessing_distributed:
            main_loss, aux_loss, reg_loss, final_loss = torch.mean(main_loss), torch.mean(aux_loss), torch.mean(reg_loss),torch.mean(final_loss)
        #print (reg_loss,main_loss,aux_loss,flush=True)
        loss = main_loss + args.aux_weight * aux_loss + reg_loss +final_loss
        #print ('t4',flush=True)
        optimizer.zero_grad()
        #if args.use_apex and args.multiprocessing_distributed:
        #    with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
        #        scaled_loss.backward()
        #print ('apex...',flush=True)
        #else:
        loss.backward()
        #print ('t5',flush=True)
        optimizer.step()
        #print ('apexfinished',flush=True)
        n = input.size(0)
        if args.multiprocessing_distributed:
            #print ('t6',flush=True)
            main_loss, aux_loss, reg_loss,final_loss, loss = main_loss.detach() * n, aux_loss * n, reg_loss*n,final_loss*n, loss * n  # not considering ignore pixels
            #print ('t7',flush=True)
            count = target.new_tensor([n], dtype=torch.long)
            #reg_loss=torch.Tensor(reg_loss).cuda()
            #print ('t8',flush=True)
            dist.all_reduce(main_loss), dist.all_reduce(aux_loss), dist.all_reduce(reg_loss),dist.all_reduce(final_loss), dist.all_reduce(loss), dist.all_reduce(count)
            n = count.item()
            main_loss, aux_loss, reg_loss, final_loss, loss = main_loss / n, aux_loss / n, reg_loss/n, final_loss/n, loss / n
        #print ('2',flush=True)
        #print ('t9',flush=True)
        intersection, union, target = intersectionAndUnionGPU(output, target, args.classes, args.ignore_label)
        if args.multiprocessing_distributed:
            dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(target)
        intersection, union, target = intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy()
        intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target)
        #print ('t10',flush=True)
        accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        #print ('acc',flush=True)
        main_loss_meter.update(main_loss.item(), n)
        aux_loss_meter.update(aux_loss.item(), n)
        reg_loss_meter.update(reg_loss.item(), n)
        final_loss_meter.update(final_loss.item(), n)
        loss_meter.update(loss.item(), n)
        batch_time.update(time.time() - end)
        end = time.time()
        #print ('t11',flush=True)
        current_iter = epoch * len(train_loader) + i + 1
        current_lr = poly_learning_rate(args.base_lr, current_iter, max_iter, power=args.power)
        #print (current_lr,'learningrate',flush=True)
        for index in range(0, args.index_split):
            optimizer.param_groups[index]['lr'] = current_lr
        for index in range(args.index_split, len(optimizer.param_groups)):
            optimizer.param_groups[index]['lr'] = current_lr * 10
        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))

        if (i + 1) % args.print_freq == 0 and main_process():
            logger.info('Epoch: [{}/{}][{}/{}] '
                        'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                        'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                        'Remain {remain_time} '
                        'MainLoss {main_loss_meter.val:.4f} '
                        'AuxLoss {aux_loss_meter.val:.4f} '
                        'RegLoss {reg_loss_meter.val:.4f} '
                        'FinalLoss {final_loss_meter.val:.4f} '
                        'Loss {loss_meter.val:.4f} '
                        'Accuracy {accuracy:.4f}.'.format(epoch+1, args.epochs, i + 1, len(train_loader),
                                                          batch_time=batch_time,
                                                          data_time=data_time,
                                                          remain_time=remain_time,
                                                          main_loss_meter=main_loss_meter,
                                                          aux_loss_meter=aux_loss_meter,
                                                          reg_loss_meter=reg_loss_meter,
                                                          final_loss_meter=final_loss_meter,
                                                          loss_meter=loss_meter,
                                                          accuracy=accuracy))
        if main_process():
            writer.add_scalar('loss_train_batch', main_loss_meter.val, current_iter)
            writer.add_scalar('mIoU_train_batch', np.mean(intersection / (union + 1e-10)), current_iter)
            writer.add_scalar('mAcc_train_batch', np.mean(intersection / (target + 1e-10)), current_iter)
            writer.add_scalar('allAcc_train_batch', accuracy, current_iter)
        #print ('t12',flush=True)
    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIoU = np.mean(iou_class)
    mAcc = np.mean(accuracy_class)
    
    allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
    if main_process():
        logger.info('Train result at epoch [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(epoch+1, args.epochs, mIoU, mAcc, allAcc))
    return main_loss_meter.avg, mIoU, mAcc, allAcc
示例#5
0
def train(train_loader, model, optimizer, epoch):
    """
    No MGDA -- whole iteration takes 0.31 sec.
    0.24 sec to run typical backward pass (with no MGDA)

    With MGDA -- whole iteration takes 1.10 sec.
    1.05 sec to run backward pass w/ MGDA subroutine -- scale_loss_and_gradients() in every iteration.

    TODO: Profile which part of Frank-Wolfe is slow

    """

    from util.avg_meter import AverageMeter, SegmentationAverageMeter
    from util.util import poly_learning_rate

    import torch.distributed as dist
    from multiobjective_opt.dist_mgda_utils import scale_loss_and_gradients



    import torch, os, math, time


    batch_time = AverageMeter()
    data_time = AverageMeter()
    main_loss_meter = AverageMeter()
    aux_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    sam = SegmentationAverageMeter()

    model.train()
    # set bn to be eval() and see the norm
    # def set_bn_eval(m):
    #     classname = m.__class__.__name__
    #     if classname.find('BatchNorm') != -1:
    #         m.eval()
    # model.apply(set_bn_eval)
    end = time.time()
    max_iter = args.max_iters

    for i, (input, target, batch_domain_idxs) in enumerate(train_loader):
        # pass
        # if main_process():
        data_time.update(time.time() - end)
        if args.zoom_factor != 8:
            h = int((target.size()[1] - 1) / 8 * args.zoom_factor + 1)
            w = int((target.size()[2] - 1) / 8 * args.zoom_factor + 1)
            # 'nearest' mode doesn't support align_corners mode and 'bilinear' mode is fine for downsampling
            target = F.interpolate(target.unsqueeze(1).float(), size=(h, w), mode='bilinear', align_corners=True).squeeze(1).long()
        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)
        batch_domain_idxs = batch_domain_idxs.cuda(non_blocking=True)

        if args.use_mgda:
            output, loss, main_loss, aux_loss, scales = forward_backward_mgda(input, target, model, optimizer, args)
        else:
            #print('Batch domain idxs: ', batch_domain_idxs.shape, batch_domain_idxs.device, batch_domain_idxs)
            output, loss, main_loss, aux_loss = forward_backward_full_sync(input, target, model, optimizer, args, batch_domain_idxs)

        optimizer.step()

        n = input.size(0)
        if args.multiprocessing_distributed:
            main_loss, aux_loss, loss = main_loss.detach() * n, aux_loss * n, loss * n  # not considering ignore pixels
            count = target.new_tensor([n], dtype=torch.long)
            dist.all_reduce(main_loss), dist.all_reduce(aux_loss), dist.all_reduce(loss), dist.all_reduce(count)
            n = count.item()
            main_loss, aux_loss, loss = main_loss / n, aux_loss / n, loss / n

        sam.update_metrics_gpu(output, target, args.classes, args.ignore_label, args.multiprocessing_distributed)

        main_loss_meter.update(main_loss.item(), n)
        aux_loss_meter.update(aux_loss.item(), n)
        loss_meter.update(loss.item(), n)
        # if main_process():
        if i > 0:
            batch_time.update(time.time() - end)
        end = time.time()

        current_iter = epoch * len(train_loader) + i + 1 + args.resume_iter
        current_lr = poly_learning_rate(args.base_lr, current_iter, max_iter, power=args.power)
        # current_lr = 0
        for index in range(0, args.index_split):
            optimizer.param_groups[index]['lr'] = current_lr
        for index in range(args.index_split, len(optimizer.param_groups)):
            if args.finetune:
                optimizer.param_groups[index]['lr'] = current_lr 
            else:
                optimizer.param_groups[index]['lr'] = current_lr * 10

        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))

        if (i + 1) % args.print_freq == 0 and main_process():
        # if True:
            logger.info('Epoch: [{}/{}][{}/{}] '
                        'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                        'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                        'Remain {remain_time} '
                        'MainLoss {main_loss_meter.val:.4f} '
                        'AuxLoss {aux_loss_meter.val:.4f} '
                        'Loss {loss_meter.val:.4f} '
                        'Accuracy {accuracy:.4f}.'.format(epoch+1, args.epochs, i + 1, len(train_loader),
                                                          batch_time=batch_time,
                                                          data_time=data_time,
                                                          remain_time=remain_time,
                                                          main_loss_meter=main_loss_meter,
                                                          aux_loss_meter=aux_loss_meter,
                                                          loss_meter=loss_meter,
                                                          accuracy=sam.accuracy) + f'current_iter: {current_iter}' + f' rank: {args.rank} ')
            if args.use_mgda and main_process():
                # Scales identical in each process, so print out only in main process.
                scales_str = [f'{d}: {scale:.2f}' for d,scale in scales.items()]
                scales_str = ' , '.join(scales_str)
                logger.info(f'Scales: {scales_str}')

        if main_process() and current_iter == max_iter - 5: # early exit to prevent iter number not matching between gpus
            break
        # if main_process():
        #     writer.add_scalar('loss_train_batch', main_loss_meter.val, current_iter)
        #     writer.add_scalar('mIoU_train_batch', np.mean(intersection / (union + 1e-10)), current_iter)
        #     writer.add_scalar('mAcc_train_batch', np.mean(intersection / (target + 1e-10)), current_iter)
        #     writer.add_scalar('allAcc_train_batch', accuracy, current_iter)

    iou_class, accuracy_class, mIoU, mAcc, allAcc = sam.get_metrics()
    # if main_process():
    logger.info('Train result at epoch [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(epoch+1, args.epochs, mIoU, mAcc, allAcc))
    return main_loss_meter.avg, mIoU, mAcc, allAcc
示例#6
0
def train(train_loader, model, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    main_loss_meter = AverageMeter()
    aux_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    model.train()
    end = time.time()
    max_iter = args.epochs * len(train_loader)
    for i, (input, target) in enumerate(train_loader):
        data_time.update(time.time() - end)
        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)
        output, main_loss, aux_loss = model(input, target)
        if not args.multiprocessing_distributed:
            main_loss, aux_loss = torch.mean(main_loss), torch.mean(aux_loss)
        loss = main_loss + args.aux_weight * aux_loss

        optimizer.zero_grad()
        if args.use_apex and args.multiprocessing_distributed:
            with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        optimizer.step()

        n = input.size(0)
        if args.multiprocessing_distributed:
            main_loss, aux_loss, loss = main_loss.detach(
            ) * n, aux_loss * n, loss * n  # not considering ignore pixels
            count = target.new_tensor([n], dtype=torch.long)
            dist.all_reduce(main_loss), dist.all_reduce(
                aux_loss), dist.all_reduce(loss), dist.all_reduce(count)
            n = count.item()
            main_loss, aux_loss, loss = main_loss / n, aux_loss / n, loss / n

        intersection, union, target = intersectionAndUnionGPU(
            output, target, args.classes, args.ignore_label)
        if args.multiprocessing_distributed:
            dist.all_reduce(intersection), dist.all_reduce(
                union), dist.all_reduce(target)
        intersection, union, target = intersection.cpu().numpy(), union.cpu(
        ).numpy(), target.cpu().numpy()
        intersection_meter.update(intersection), union_meter.update(
            union), target_meter.update(target)

        accuracy = sum(
            intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        main_loss_meter.update(main_loss.item(), n)
        aux_loss_meter.update(aux_loss.item(), n)
        loss_meter.update(loss.item(), n)
        batch_time.update(time.time() - end)
        end = time.time()

        current_iter = epoch * len(train_loader) + i + 1
        current_lr = poly_learning_rate(args.base_lr,
                                        current_iter,
                                        max_iter,
                                        power=args.power)
        for index in range(0, args.index_split):
            optimizer.param_groups[index]['lr'] = current_lr
        for index in range(args.index_split, len(optimizer.param_groups)):
            optimizer.param_groups[index]['lr'] = current_lr * 10
        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m),
                                                    int(t_s))

        if (i + 1) % args.print_freq == 0 and main_process():
            logger.info('Epoch: [{}/{}][{}/{}] '
                        'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                        'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                        'Remain {remain_time} '
                        'MainLoss {main_loss_meter.val:.4f} '
                        'AuxLoss {aux_loss_meter.val:.4f} '
                        'Loss {loss_meter.val:.4f} '
                        'Accuracy {accuracy:.4f}.'.format(
                            epoch + 1,
                            args.epochs,
                            i + 1,
                            len(train_loader),
                            batch_time=batch_time,
                            data_time=data_time,
                            remain_time=remain_time,
                            main_loss_meter=main_loss_meter,
                            aux_loss_meter=aux_loss_meter,
                            loss_meter=loss_meter,
                            accuracy=accuracy))
        if main_process():
            writer.add_scalar('loss_train_batch', main_loss_meter.val,
                              current_iter)
            writer.add_scalar('mIoU_train_batch',
                              np.mean(intersection / (union + 1e-10)),
                              current_iter)
            writer.add_scalar('mAcc_train_batch',
                              np.mean(intersection / (target + 1e-10)),
                              current_iter)
            writer.add_scalar('allAcc_train_batch', accuracy, current_iter)

    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIoU = np.mean(iou_class)
    mAcc = np.mean(accuracy_class)
    allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
    if main_process():
        logger.info(
            'Train result at epoch [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'
            .format(epoch + 1, args.epochs, mIoU, mAcc, allAcc))
    return main_loss_meter.avg, mIoU, mAcc, allAcc
示例#7
0
def train(train_loader, model, optimizer, epoch, epoch_log, val_loader,
          criterion):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    main_loss_meter = AverageMeter()
    aux_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    model.train()
    end = time.time()
    max_iter = args.epochs * len(train_loader)
    for i, (input, target, _) in enumerate(train_loader):

        current_iter = epoch * len(train_loader) + i

        if args.just_vis or (args.evaluate and args.val_every_iter != -1
                             and current_iter % args.val_every_iter == 0):
            # if True:
            # logger.info('Validating.....')
            loss_val, mIoU_val, mAcc_val, allAcc_val, return_dict = validate(
                val_loader, model, criterion, args)
            if main_process():
                writer.add_scalar('VAL/loss_val', loss_val, current_iter)
                writer.add_scalar('VAL/mIoU_val', mIoU_val, current_iter)
                writer.add_scalar('VAL/mAcc_val', mAcc_val, current_iter)
                writer.add_scalar('VAL/allAcc_val', allAcc_val, current_iter)

                for sample_idx in range(len(return_dict['image_name_list'])):
                    writer.add_text('VAL-image_name/%d' % sample_idx,
                                    return_dict['image_name_list'][sample_idx],
                                    current_iter)
                    writer.add_image('VAL-image/%d' % sample_idx,
                                     return_dict['im_list'][sample_idx],
                                     current_iter,
                                     dataformats='HWC')
                    writer.add_image('VAL-color_label/%d' % sample_idx,
                                     return_dict['color_GT_list'][sample_idx],
                                     current_iter,
                                     dataformats='HWC')
                    writer.add_image(
                        'VAL-color_pred/%d' % sample_idx,
                        return_dict['color_pred_list'][sample_idx],
                        current_iter,
                        dataformats='HWC')

            model.train()
            end = time.time()

        # if (epoch_log % args.save_freq == 0) and main_process():
        if args.save_every_iter != -1 and current_iter % args.save_every_iter == 0 and main_process(
        ):
            model.eval()
            filename = args.save_path + '/train_epoch_' + str(
                epoch_log) + '_tid_' + str(current_iter) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save(
                {
                    'epoch': epoch_log,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, filename)
            # if epoch_log / args.save_freq > 2:
            #     deletename = args.save_path + '/train_epoch_' + str(epoch_log - args.save_freq * 2) + '.pth'
            #     os.remove(deletename)
            model.train()
            end = time.time()

        data_time.update(time.time() - end)
        if args.zoom_factor != 8:
            h = int((target.size()[1] - 1) / 8 * args.zoom_factor + 1)
            w = int((target.size()[2] - 1) / 8 * args.zoom_factor + 1)
            # 'nearest' mode doesn't support align_corners mode and 'bilinear' mode is fine for downsampling
            target = F.interpolate(target.unsqueeze(1).float(),
                                   size=(h, w),
                                   mode='bilinear',
                                   align_corners=True).squeeze(1).long()
        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)
        # if args.test_in_nyu_label_space:
        #     target = map_openrooms_nyu_gpu(target)

        output, main_loss, aux_loss = model(input, target)
        if not args.multiprocessing_distributed:
            main_loss, aux_loss = torch.mean(main_loss), torch.mean(aux_loss)
        loss = main_loss + args.aux_weight * aux_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        n = input.size(0)
        if args.multiprocessing_distributed:
            main_loss, aux_loss, loss = main_loss.detach(
            ) * n, aux_loss * n, loss * n  # not considering ignore pixels
            count = target.new_tensor([n], dtype=torch.long)
            dist.all_reduce(main_loss), dist.all_reduce(
                aux_loss), dist.all_reduce(loss), dist.all_reduce(count)
            n = count.item()
            main_loss, aux_loss, loss = main_loss / n, aux_loss / n, loss / n

        # if args.test_in_nyu_label_space:
        #     intersection, union, target = intersectionAndUnionGPU(map_openrooms_nyu_gpu(output), map_openrooms_nyu_gpu(target), 41, args.ignore_label)
        # else:
        intersection, union, target = intersectionAndUnionGPU(
            output, target, args.classes, args.ignore_label)
        if args.multiprocessing_distributed:
            dist.all_reduce(intersection), dist.all_reduce(
                union), dist.all_reduce(target)
        intersection, union, target = intersection.cpu().numpy(), union.cpu(
        ).numpy(), target.cpu().numpy()
        intersection_meter.update(intersection), union_meter.update(
            union), target_meter.update(target)

        accuracy = sum(
            intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        main_loss_meter.update(main_loss.item(), n)
        aux_loss_meter.update(aux_loss.item(), n)
        loss_meter.update(loss.item(), n)
        batch_time.update(time.time() - end)
        end = time.time()

        current_lr = poly_learning_rate(args.base_lr,
                                        current_iter,
                                        max_iter,
                                        power=args.power)
        for index in range(0, args.index_split):
            optimizer.param_groups[index]['lr'] = current_lr
        for index in range(args.index_split, len(optimizer.param_groups)):
            optimizer.param_groups[index]['lr'] = current_lr * 10
        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m),
                                                    int(t_s))

        if (i + 1) % args.print_freq == 0 and main_process():
            logger.info('Epoch: [{}/{}][{}/{}] '
                        'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                        'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                        'Remain {remain_time} '
                        'MainLoss {main_loss_meter.val:.4f} '
                        'AuxLoss {aux_loss_meter.val:.4f} '
                        'Loss {loss_meter.val:.4f} '
                        'Accuracy {accuracy:.4f}.'.format(
                            epoch + 1,
                            args.epochs,
                            i + 1,
                            len(train_loader),
                            batch_time=batch_time,
                            data_time=data_time,
                            remain_time=remain_time,
                            main_loss_meter=main_loss_meter,
                            aux_loss_meter=aux_loss_meter,
                            loss_meter=loss_meter,
                            accuracy=accuracy))
        if main_process():
            writer.add_scalar('TRAIN/loss_train_batch', main_loss_meter.val,
                              current_iter)
            writer.add_scalar('TRAIN/mIoU_train_batch',
                              np.mean(intersection / (union + 1e-10)),
                              current_iter)
            writer.add_scalar('TRAIN/mAcc_train_batch',
                              np.mean(intersection / (target + 1e-10)),
                              current_iter)
            writer.add_scalar('TRAIN/allAcc_train_batch', accuracy,
                              current_iter)

    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIoU = np.mean(iou_class)
    mAcc = np.mean(accuracy_class)
    allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
    if main_process():
        logger.info(
            'Train result at epoch [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'
            .format(epoch + 1, args.epochs, mIoU, mAcc, allAcc))
    return main_loss_meter.avg, mIoU, mAcc, allAcc
示例#8
0
def train(train_loader, model, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    derain_loss_meter = AverageMeter()
    seg_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    psnr_meter = AverageMeter()
    ssim_meter = AverageMeter()

    list_multiply = lambda x, y: x * y
    assert len(args.seg_loss_step_weight) == args.num_steps

    model.train()
    end = time.time()
    max_iter = args.epochs * len(train_loader)
    for i, (clear_label, rain_input) in enumerate(train_loader):
        data_time.update(time.time() - end)

        clear_label = clear_label.cuda(non_blocking=True)
        rain_input = rain_input.cuda(non_blocking=True)
        derain_output, derain_losses = model(rain_input, clear_label)
        derain_losses = map(list_multiply, derain_losses,
                            args.derain_loss_step_weight)
        derain_sum_loss = sum(derain_losses)
        if not args.multiprocessing_distributed:
            derain_sum_loss = torch.mean(derain_sum_loss)
        loss = args.derain_loss_weight * derain_sum_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        n = rain_input.size(0)
        if args.multiprocessing_distributed:
            derain_sum_loss, loss = derain_sum_loss.detach() * n, \
                                    loss * n  # not considering ignore pixels
            count = clear_label.new_tensor([n], dtype=torch.long)
            dist.all_reduce(derain_sum_loss), dist.all_reduce(
                loss), dist.all_reduce(count)
            n = count.item()
            derain_sum_loss, loss = derain_sum_loss / n, loss / n

        # intersection, union, target = intersectionAndUnionCPU(seg_output, seg_label, args.classes, args.ignore_label)
        psnr, ssim = batchPSNRandSSIMGPU(derain_output, clear_label)
        # if args.multiprocessing_distributed:
        #     dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(target)
        # intersection, union, target = intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy()
        # intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target)
        psnr_meter.update(psnr), ssim_meter.update(ssim)

        # accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        accuracy = 0
        psnr_val = psnr_meter.val
        ssim_val = ssim_meter.val
        derain_loss_meter.update(derain_sum_loss.item(), n)
        loss_meter.update(loss.item(), n)
        batch_time.update(time.time() - end)
        end = time.time()

        current_iter = epoch * len(train_loader) + i + 1
        current_lr = poly_learning_rate(args.base_lr,
                                        current_iter,
                                        max_iter,
                                        power=args.power)
        for index in range(0, args.index_split_1):
            optimizer.param_groups[index]['lr'] = current_lr * 0
        for index in range(args.index_split_1, args.index_split_2):
            optimizer.param_groups[index]['lr'] = current_lr * 10
        for index in range(args.index_split_2, len(optimizer.param_groups)):
            optimizer.param_groups[index]['lr'] = current_lr * 0
        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m),
                                                    int(t_s))

        if (i + 1) % args.print_freq == 0 and main_process():
            logger.info('Epoch: [{}/{}][{}/{}] '
                        'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                        'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                        'Remain {remain_time} '
                        'DerainLoss {derain_loss_meter:.4f} '
                        'SegLoss {seg_loss_meter:.4f} '
                        'Loss {loss_meter:.4f} '
                        'Accuracy {accuracy:.4f}.'
                        'PSNR {psnr_val:.2f}.'
                        'SSIM {ssim_val:.4f}.'.format(
                            epoch + 1,
                            args.epochs,
                            i + 1,
                            len(train_loader),
                            batch_time=batch_time,
                            data_time=data_time,
                            remain_time=remain_time,
                            derain_loss_meter=derain_loss_meter.val,
                            seg_loss_meter=seg_loss_meter.val,
                            loss_meter=loss_meter.val,
                            accuracy=accuracy,
                            psnr_val=psnr_val,
                            ssim_val=ssim_val))

        if main_process():
            writer.add_scalar('derain_loss_train_batch', derain_loss_meter.val,
                              current_iter)
            writer.add_scalar('seg_loss_train_batch', seg_loss_meter.val,
                              current_iter)
            writer.add_scalar('loss_train_batch', loss_meter.val, current_iter)
            # writer.add_scalar('mIoU_train_batch', np.mean(intersection / (union + 1e-10)), current_iter)
            # writer.add_scalar('mAcc_train_batch', np.mean(intersection / (target + 1e-10)), current_iter)
            writer.add_scalar('allAcc_train_batch', accuracy, current_iter)
            writer.add_scalar('psnr_train_batch', psnr_val, current_iter)
            writer.add_scalar('ssim_train_batch', ssim_val, current_iter)

    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    # mIoU = np.mean(iou_class)
    # mAcc = np.mean(accuracy_class)
    # allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
    mIoU = 0
    mAcc = 0
    allAcc = 0
    if main_process():
        logger.info(
            'Train result at epoch [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'
            .format(epoch + 1, args.epochs, mIoU, mAcc, allAcc))
        logger.info(
            'Train result at epoch [{}/{}]: PSNR/SSIM {:.4f}/{:.4f}.'.format(
                epoch + 1, args.epochs, psnr_meter.avg, ssim_meter.avg))
    return loss_meter.avg, mIoU, mAcc, allAcc, psnr_meter.avg, ssim_meter.avg
示例#9
0
def train(train_loader, model, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    main_loss_meter = AverageMeter()
    aux_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    model.train()
    end = time.time()
    max_iter = args.epochs * len(train_loader)
    for i, (input, target) in tqdm(enumerate(train_loader),
                                   total=len(train_loader)):
        data_time.update(time.time() - end)
        if args.zoom_factor != 8:
            h = int((target.size()[1] - 1) / 8 * args.zoom_factor + 1)
            w = int((target.size()[2] - 1) / 8 * args.zoom_factor + 1)
            # 'nearest' mode doesn't support align_corners mode and 'bilinear' mode is fine for downsampling
            target = F.interpolate(target.unsqueeze(1).float(),
                                   size=(h, w),
                                   mode='bilinear',
                                   align_corners=True).squeeze(1).long()
        input = input.cuda()
        target = target.cuda()

        output, main_loss, aux_loss = model(input, target)
        main_loss = main_loss.mean()
        aux_loss = aux_loss.mean()
        loss = main_loss + args.aux_weight * aux_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        n = input.size(0)

        intersection, union, target = intersectionAndUnionGPU(
            output, target, args.classes, args.ignore_label)
        intersection, union, target = intersection.cpu().numpy(), union.cpu(
        ).numpy(), target.cpu().numpy()
        intersection_meter.update(intersection), union_meter.update(
            union), target_meter.update(target)

        accuracy = sum(
            intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        main_loss_meter.update(main_loss.item(), n)
        aux_loss_meter.update(aux_loss.item(), n)
        loss_meter.update(loss.item(), n)
        batch_time.update(time.time() - end)

        # learning rate scheduling
        current_iter = epoch * len(train_loader) + i + 1
        current_lr = poly_learning_rate(args.base_lr,
                                        current_iter,
                                        max_iter,
                                        power=args.power)
        for index in range(0, args.index_split):
            optimizer.param_groups[index]['lr'] = current_lr
        for index in range(args.index_split, len(optimizer.param_groups)):
            optimizer.param_groups[index]['lr'] = current_lr * 10
        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)

        if (i + 1) % args.print_freq == 0:
            writer.add_scalar('loss/train_batch', main_loss_meter.val,
                              current_iter)
            writer.add_scalar('mIoU/train_batch',
                              np.mean(intersection / (union + 1e-10)),
                              current_iter)
            writer.add_scalar('mAcc/train_batch',
                              np.mean(intersection / (target + 1e-10)),
                              current_iter)
            writer.add_scalar('allAcc/train_batch', accuracy, current_iter)

    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIoU = np.mean(iou_class)
    mAcc = np.mean(accuracy_class)
    allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
    logger.info(
        'Train result at epoch [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'
        .format(epoch + 1, args.epochs, mIoU, mAcc, allAcc))
    logger.info(f'remaining time: {int(t_h)}h {int(t_m)}min {int(t_s)}sec')

    return main_loss_meter.avg, mIoU, mAcc, allAcc
def train(train_loader, model, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    main_loss_meter = AverageMeter()
    aux_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    model.train()
    end = time.time()
    max_iter = args.epochs * len(train_loader)  # 最大迭代次数,用于计算poly学习率
    vis_key = 0
    print('Warmup: {}'.format(args.warmup))

    for i, (input, target, nomimg, s_input, s_mask, subcls) in enumerate(train_loader):
        data_time.update(time.time() - end)
        current_iter = epoch * len(train_loader) + i + 1  # 当前iteration

        # poly策略调整学习率
        if args.base_lr > 1e-6:
            poly_learning_rate(optimizer, args.base_lr, current_iter, max_iter, power=args.power,
                               warmup=args.warmup, warmup_step=len(train_loader)//2)

        s_input = s_input.cuda(non_blocking=True)  # [b,1,3,473,473]
        s_mask = s_mask.cuda(non_blocking=True)    # [b,1,473,473]
        input = input.cuda(non_blocking=True)      # [b,3,473,473]
        target = target.cuda(non_blocking=True)    # [b,473,473]
        nomimg = nomimg.cuda(non_blocking=True)
        # predicted mask[b,473,473] loss [1,b]
        output, main_loss = model(s_x=s_input, s_y=s_mask, nom=nomimg, x=input, y=target)

        loss = main_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 计算I和U
        n = input.size(0)
        intersection, union, target = intersectionAndUnionGPU(output, target, args.classes, args.ignore_label)
        intersection, union, target = intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy()
        intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target)

        # 计算acc等指标
        accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        main_loss_meter.update(main_loss.item(), n)
        loss_meter.update(loss.item(), n)
        batch_time.update(time.time() - end)
        end = time.time()

        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))
        # 定期打印训练信息
        if (i + 1) % args.print_freq == 0 and main_process():
            logger.info('Epoch: [{}/{}][{}/{}] '
                        'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                        'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                        'Remain {remain_time} '
                        'MainLoss {main_loss_meter.val:.4f} '
                        'AuxLoss {aux_loss_meter.val:.4f} '                        
                        'Loss {loss_meter.val:.4f} '
                        'Accuracy {accuracy:.4f}.'.format(epoch+1, args.epochs, i + 1, len(train_loader),
                                                          batch_time=batch_time,
                                                          data_time=data_time,
                                                          remain_time=remain_time,
                                                          main_loss_meter=main_loss_meter,
                                                          aux_loss_meter=aux_loss_meter,
                                                          loss_meter=loss_meter,
                                                          accuracy=accuracy))
        if main_process():
            writer.add_scalar('loss_train_batch', main_loss_meter.val, current_iter)
            writer.add_scalar('mIoU_train_batch', np.mean(intersection / (union + 1e-10)), current_iter)
            writer.add_scalar('mAcc_train_batch', np.mean(intersection / (target + 1e-10)), current_iter)
            writer.add_scalar('allAcc_train_batch', accuracy, current_iter)

    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIoU = np.mean(iou_class)
    mAcc = np.mean(accuracy_class)
    allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)

    if main_process():
        logger.info('Train result at epoch [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(epoch, args.epochs, mIoU, mAcc, allAcc))
        for i in range(args.classes):
            logger.info('Class_{} Result: iou/accuracy {:.4f}/{:.4f}.'.format(i, iou_class[i], accuracy_class[i]))        
    return main_loss_meter.avg, mIoU, mAcc, allAcc