Ejemplo n.º 1
0
def main_worker(gpu, ngpus_per_node, argss):
    global args
    args = argss
    if args.sync_bn:
        if args.multiprocessing_distributed:
            BatchNorm = apex.parallel.SyncBatchNorm
        else:
            from lib.sync_bn.modules import BatchNorm2d
            BatchNorm = BatchNorm2d
    else:
        BatchNorm = nn.BatchNorm2d
    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank)

    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label)
    if args.arch == 'psp':
        from model.pspnet import PSPNet
        model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, criterion=criterion, BatchNorm=BatchNorm)
        modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4]
        modules_new = [model.ppm, model.cls, model.aux]
    elif args.arch == 'psa':
        from model.psanet import PSANet
        model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, psa_type=args.psa_type,
                       compact=args.compact, shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w,
                       normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax,
                       criterion=criterion,
                       BatchNorm=BatchNorm)
        modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4]
        modules_new = [model.psa, model.cls, model.aux]
    params_list = []
    for module in modules_ori:
        params_list.append(dict(params=module.parameters(), lr=args.base_lr))
    for module in modules_new:
        params_list.append(dict(params=module.parameters(), lr=args.base_lr * 10))
    args.index_split = 5
    optimizer = torch.optim.SGD(params_list, lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)

    if main_process():
        global logger, writer
        logger = get_logger()
        writer = SummaryWriter(args.save_path)
        logger.info(args)
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))
        logger.info(model)
    if args.distributed:
        torch.cuda.set_device(gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
        args.workers = int(args.workers / ngpus_per_node)
        if args.use_apex:
            model, optimizer = apex.amp.initialize(model.cuda(), optimizer, opt_level=args.opt_level, keep_batchnorm_fp32=args.keep_batchnorm_fp32, loss_scale=args.loss_scale)
            model = apex.parallel.DistributedDataParallel(model)
        else:
            model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[gpu])

    else:
        model = torch.nn.DataParallel(model.cuda())

    if args.weight:
        if os.path.isfile(args.weight):
            if main_process():
                logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            if main_process():
                logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            if main_process():
                logger.info("=> no weight found at '{}'".format(args.weight))

    if args.resume:
        if os.path.isfile(args.resume):
            if main_process():
                logger.info("=> loading checkpoint '{}'".format(args.resume))
            # checkpoint = torch.load(args.resume)
            checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            if main_process():
                logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        else:
            if main_process():
                logger.info("=> no checkpoint found at '{}'".format(args.resume))

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transform.Compose([
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.ignore_label),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([args.train_h, args.train_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)])
    train_data = dataset.SemData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
    else:
        train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
    if args.evaluate:
        val_transform = transform.Compose([
            transform.Crop([args.train_h, args.train_w], crop_type='center', padding=mean, ignore_label=args.ignore_label),
            transform.ToTensor(),
            transform.Normalize(mean=mean, std=std)])
        val_data = dataset.SemData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform)
        if args.distributed:
            val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)
        else:
            val_sampler = None
        val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler)

    for epoch in range(args.start_epoch, args.epochs):
        epoch_log = epoch + 1
        if args.distributed:
            train_sampler.set_epoch(epoch)
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(train_loader, model, optimizer, epoch)
        if main_process():
            writer.add_scalar('loss_train', loss_train, epoch_log)
            writer.add_scalar('mIoU_train', mIoU_train, epoch_log)
            writer.add_scalar('mAcc_train', mAcc_train, epoch_log)
            writer.add_scalar('allAcc_train', allAcc_train, epoch_log)

        if (epoch_log % args.save_freq == 0) and main_process():
            filename = args.save_path + '/train_epoch_' + str(epoch_log) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save({'epoch': epoch_log, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, filename)
            if epoch_log / args.save_freq > 2:
                deletename = args.save_path + '/train_epoch_' + str(epoch_log - args.save_freq * 2) + '.pth'
                os.remove(deletename)
        if args.evaluate:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(val_loader, model, criterion)
            if main_process():
                writer.add_scalar('loss_val', loss_val, epoch_log)
                writer.add_scalar('mIoU_val', mIoU_val, epoch_log)
                writer.add_scalar('mAcc_val', mAcc_val, epoch_log)
                writer.add_scalar('allAcc_val', allAcc_val, epoch_log)
Ejemplo n.º 2
0
def main_worker(gpu, ngpus_per_node, argss):
    global args
    args = argss

    ## step.1 设置分布式相关参数
    # 1.1 分布式初始化
    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)  # 分布式初始化

    ## step.2 构建网络
    # ---------------------------------------------- 根据实际情况自己写 ---------------------------------------------#
    criterion = nn.CrossEntropyLoss(
        ignore_index=args.ignore_label)  # 交叉熵损失函数, 根据情况自己修改
    if args.arch == 'psp':
        from model.pspnet import PSPNet
        model = PSPNet(layers=args.layers,
                       classes=args.classes,
                       zoom_factor=args.zoom_factor,
                       criterion=criterion)
        modules_ori = [
            model.layer0, model.layer1, model.layer2, model.layer3,
            model.layer4
        ]
        modules_new = [model.ppm, model.cls, model.aux]
    elif args.arch == 'psa':
        from model.psanet import PSANet
        model = PSANet(layers=args.layers,
                       classes=args.classes,
                       zoom_factor=args.zoom_factor,
                       psa_type=args.psa_type,
                       compact=args.compact,
                       shrink_factor=args.shrink_factor,
                       mask_h=args.mask_h,
                       mask_w=args.mask_w,
                       normalization_factor=args.normalization_factor,
                       psa_softmax=args.psa_softmax,
                       criterion=criterion)
        modules_ori = [
            model.layer0, model.layer1, model.layer2, model.layer3,
            model.layer4
        ]
        modules_new = [model.psa, model.cls, model.aux]
    # ---------------------------------------------------- END ---------------------------------------------------#

    ## step.3 设置优化器
    params_list = []  # 模型参数列表
    for module in modules_ori:
        params_list.append(dict(params=module.parameters(),
                                lr=args.base_lr))  # 原来backbone网络 学习率 0.01
    for module in modules_new:
        params_list.append(
            dict(params=module.parameters(),
                 lr=args.base_lr * 10))  # 新加入预测网络 学习率 0.1
    args.index_split = 5
    optimizer = torch.optim.SGD(params_list,
                                lr=args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)  # SGD优化器
    # 3.x 设置sync_bn from torch.nn.SyncBatchNorm
    if args.sync_bn:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

    ## step.4 多线程分布式工作
    # 4.1 判断是否是在主进程中, 如果在进行如下程序
    if main_process():
        global logger, writer
        logger = get_logger()  # 设置logger
        writer = SummaryWriter(args.save_path)  # 设置writer
        logger.info(args)  # 输出参数列表
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))
        logger.info(model)  # 输出网络列表
    # 4.2 分布式工作
    if args.distributed:
        torch.cuda.set_device(gpu)  # 指定编号为gpu的那一张显卡
        args.batch_size = int(args.batch_size /
                              ngpus_per_node)  # 每张卡的训练的batch size
        args.batch_size_val = int(args.batch_size_val /
                                  ngpus_per_node)  # 每张卡的评测的batch size
        args.workers = int(
            (args.workers + ngpus_per_node - 1) / ngpus_per_node)  # 每张卡工作的数目
        model = torch.nn.parallel.DistributedDataParallel(
            model.cuda(), device_ids=[gpu])  # 加载torch分布式
    else:
        model = torch.nn.DataParallel(model.cuda())  # 数据并行

    ## step.5 加载网络权重
    # 5.1 直接加载网络预权重
    if args.weight:
        if os.path.isfile(args.weight):
            if main_process():
                logger.info("=> loading weight '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            if main_process():
                logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            if main_process():
                logger.info("=> no weight found at '{}'".format(args.weight))
    # 5.2 加载上次没训练完的模型权重
    if args.resume:
        if os.path.isfile(args.resume):
            if main_process():
                logger.info("=> loading checkpoint '{}'".format(args.resume))
            # checkpoint = torch.load(args.resume)
            checkpoint = torch.load(
                args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            if main_process():
                logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
        else:
            if main_process():
                logger.info("=> no checkpoint found at '{}'".format(
                    args.resume))

    ## step.7 设置数据loader
    # 7.1 loader参数设置
    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transform.Compose([
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max],
                             padding=mean,
                             ignore_label=args.ignore_label),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([args.train_h, args.train_w],
                       crop_type='rand',
                       padding=mean,
                       ignore_label=args.ignore_label),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ])  # 组合数据预处理

    # 7.2 训练数据, 可以根据需要自己修改或写
    # ---------------------------------------------- 根据实际情况自己写 ---------------------------------------------#
    train_data = dataset.SemData(split='train',
                                 data_root=args.data_root,
                                 data_list=args.train_list,
                                 transform=train_transform)
    # ---------------------------------------------------- END ---------------------------------------------------#
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_data)  # 分布式下数据loader
    else:
        train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)
    if args.evaluate:  # evaluate数据
        val_transform = transform.Compose([
            transform.Crop([args.train_h, args.train_w],
                           crop_type='center',
                           padding=mean,
                           ignore_label=args.ignore_label),
            transform.ToTensor(),
            transform.Normalize(mean=mean, std=std)
        ])
        val_data = dataset.SemData(split='val',
                                   data_root=args.data_root,
                                   data_list=args.val_list,
                                   transform=val_transform)
        if args.distributed:
            val_sampler = torch.utils.data.distributed.DistributedSampler(
                val_data)
        else:
            val_sampler = None
        val_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=args.batch_size_val,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
            sampler=val_sampler)

    ## step.8 主循环
    for epoch in range(args.start_epoch, args.epochs):
        epoch_log = epoch + 1
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # 8.1 训练函数
        # ---------------------------------------------- 根据实际情况自己写 ---------------------------------------------#
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(
            train_loader, model, optimizer, epoch)
        # ---------------------------------------------------- END ---------------------------------------------------#

        if main_process():
            writer.add_scalar('loss_train', loss_train, epoch_log)
            writer.add_scalar('mIoU_train', mIoU_train, epoch_log)
            writer.add_scalar('mAcc_train', mAcc_train, epoch_log)
            writer.add_scalar('allAcc_train', allAcc_train, epoch_log)
        # 8.2 保存checkpoint
        if (epoch_log % args.save_freq == 0) and main_process():
            filename = args.save_path + '/train_epoch_' + str(
                epoch_log) + '.pth'
            logger.info('Saving checkpoint to: ' + filename)
            torch.save(
                {
                    'epoch': epoch_log,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, filename)
            if epoch_log / args.save_freq > 2:
                deletename = args.save_path + '/train_epoch_' + str(
                    epoch_log - args.save_freq * 2) + '.pth'
                os.remove(deletename)
        # 训练一个epoch之后evaluate
        if args.evaluate:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(
                val_loader, model, criterion)
            if main_process():
                writer.add_scalar('loss_val', loss_val, epoch_log)
                writer.add_scalar('mIoU_val', mIoU_val, epoch_log)
                writer.add_scalar('mAcc_val', mAcc_val, epoch_log)
                writer.add_scalar('allAcc_val', allAcc_val, epoch_log)
Ejemplo n.º 3
0
def main_worker(local_rank, ngpus_per_node, argss):
    global args
    args = argss

    dist.init_process_group(backend=args.dist_backend)

    teacher_model = None
    if args.teacher_model_path:
        teacher_model = PSPNet(layers=args.teacher_layers, classes=args.classes, zoom_factor=args.zoom_factor)
        kd_path = 'alpha_' + str(args.alpha) + '_Temp_' + str(args.temperature)
        args.save_path = os.path.join(args.save_path, kd_path)
        if not os.path.exists(args.save_path):
            os.mkdir(args.save_path)
    if args.arch == 'psp':
        model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor)
        modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4]
        modules_new = [model.ppm, model.cls, model.aux]
    elif args.arch == 'bise_v1':
        model = BiseNet(num_classes=args.classes)
        modules_ori = [model.sp, model.cp]
        modules_new = [model.ffm, model.conv_out, model.conv_out16, model.conv_out32]
    params_list = []
    for module in modules_ori:
        params_list.append(dict(params=module.parameters(), lr=args.base_lr))
    for module in modules_new:
        params_list.append(dict(params=module.parameters(), lr=args.base_lr * 10))
    args.index_split = 5
    optimizer = torch.optim.SGD(params_list, lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    if args.sync_bn:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if teacher_model is not None:
            teacher_model = nn.SyncBatchNorm.convert_sync_batchnorm(teacher_model)

    if main_process():
        global logger, writer
        logger = get_logger()
        writer = SummaryWriter(args.save_path) # tensorboardX
        logger.info(args)
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))
        logger.info(model)
        if teacher_model is not None:
            logger.info(teacher_model)
    if args.distributed:
        torch.cuda.set_device(local_rank)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
        args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
        model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[local_rank])
        if teacher_model is not None:
            teacher_model = torch.nn.parallel.DistributedDataParallel(teacher_model.cuda(), device_ids=[local_rank])

    else:
        model = torch.nn.DataParallel(model.cuda())
        if teacher_model is not None:
            teacher_model = torch.nn.DataParallel(teacher_model.cuda())
    
    if teacher_model is not None:
        checkpoint = torch.load(args.teacher_model_path, map_location=lambda storage, loc: storage.cuda())
        teacher_model.load_state_dict(checkpoint['state_dict'], strict=False)
        print("=> loading teacher checkpoint '{}'".format(args.teacher_model_path))
    
    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda(local_rank)
    kd_criterion = None
    if teacher_model is not None:
        kd_criterion = KDLoss(ignore_index=args.ignore_label).cuda(local_rank)
            
    if args.weight:
        if os.path.isfile(args.weight):
            if main_process():
                logger.info("=> loading weight: '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            if main_process():
                logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            if main_process():
                logger.info("=> mp weight found at '{}'".format(args.weight))
    
    best_mIoU_val = 0.0
    if args.resume:
        if os.path.isfile(args.resume):
            if main_process():
                logger.info("=> loading checkpoint '{}'".format(args.resume))
            # Load all tensors onto GPU
            checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_mIoU_val = checkpoint['best_mIoU_val']
            if main_process():
                logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['point']))
        else:
            if main_process():
                logger.info("=> no checkpoint found at '{}'".format(args.resume))    
        
    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]
        
    train_transform = transform.Compose([
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.ignore_label),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([args.train_h, args.train_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)])

    train_data = dataset.SemData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
    else:
        train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
    if args.evaluate:
        val_transform = transform.Compose([
            transform.Crop([args.train_h, args.train_w], crop_type='center', padding=mean, ignore_label=args.ignore_label),
            transform.ToTensor(),
            transform.Normalize(mean=mean, std=std)])
        val_data = dataset.SemData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform)
        if args.distributed:
            val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)
        else:
            val_sampler = None
        val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler)

    for epoch in range(args.start_epoch, args.epochs):
        epoch_log = epoch + 1
        if args.distributed:
            # Use .set_epoch() method to reshuffle the dataset partition at every iteration
            train_sampler.set_epoch(epoch)
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(local_rank, train_loader, model, teacher_model, criterion, kd_criterion, optimizer, epoch)
        if main_process():
            writer.add_scalar('loss_train', loss_train, epoch_log)
            writer.add_scalar('mIoU_train', mIoU_train, epoch_log)
            writer.add_scalar('mAcc_train', mAcc_train, epoch_log)
            writer.add_scalar('allAcc_train', allAcc_train, epoch_log)
        
        is_best = False
        if args.evaluate:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(local_rank, val_loader, model, criterion)
            if main_process():
                writer.add_scalar('loss_val', loss_val, epoch_log)
                writer.add_scalar('mIoU_val', mIoU_val, epoch_log)
                writer.add_scalar('mAcc_val', mAcc_val, epoch_log)
                writer.add_scalar('allAcc_val', allAcc_val, epoch_log)

                if best_mIoU_val < mIoU_val:
                    is_best = True
                    best_mIoU_val = mIoU_val
                    logger.info('==>The best val mIoU: %.3f' % (best_mIoU_val))

        
        if (epoch_log % args.save_freq == 0) and main_process():
            save_checkpoint(
                {
                    'epoch': epoch_log, 
                    'state_dict': model.state_dict(), 
                    'optimizer': optimizer.state_dict(),
                    'best_mIoU_val': best_mIoU_val
                }, 
                is_best, 
                args.save_path
            )
            if is_best:
                logger.info('Saving checkpoint to:' + args.save_path + '/best.pth with mIoU: ' + str(best_mIoU_val) )
            else:
                logger.info('Saving checkpoint to:' + args.save_path + '/last.pth with mIoU: ' + str(mIoU_val) )

    if main_process():  
        writer.close() # it must close the writer, otherwise it will appear the EOFError!
        logger.info('==>Training done!\nBest mIoU: %.3f' % (best_mIoU_val))