def main_worker(gpu, ngpus_per_node, argss): global args args = argss if args.sync_bn: if args.multiprocessing_distributed: BatchNorm = apex.parallel.SyncBatchNorm else: from lib.sync_bn.modules import BatchNorm2d BatchNorm = BatchNorm2d else: BatchNorm = nn.BatchNorm2d if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label) if args.arch == 'psp': from model.pspnet import PSPNet model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, criterion=criterion, BatchNorm=BatchNorm) modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4] modules_new = [model.ppm, model.cls, model.aux] elif args.arch == 'psa': from model.psanet import PSANet model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, psa_type=args.psa_type, compact=args.compact, shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w, normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax, criterion=criterion, BatchNorm=BatchNorm) modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4] modules_new = [model.psa, model.cls, model.aux] params_list = [] for module in modules_ori: params_list.append(dict(params=module.parameters(), lr=args.base_lr)) for module in modules_new: params_list.append(dict(params=module.parameters(), lr=args.base_lr * 10)) args.index_split = 5 optimizer = torch.optim.SGD(params_list, lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay) if main_process(): global logger, writer logger = get_logger() writer = SummaryWriter(args.save_path) logger.info(args) logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) logger.info(model) if args.distributed: torch.cuda.set_device(gpu) args.batch_size = int(args.batch_size / ngpus_per_node) args.batch_size_val = int(args.batch_size_val / ngpus_per_node) args.workers = int(args.workers / ngpus_per_node) if args.use_apex: model, optimizer = apex.amp.initialize(model.cuda(), optimizer, opt_level=args.opt_level, keep_batchnorm_fp32=args.keep_batchnorm_fp32, loss_scale=args.loss_scale) model = apex.parallel.DistributedDataParallel(model) else: model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[gpu]) else: model = torch.nn.DataParallel(model.cuda()) if args.weight: if os.path.isfile(args.weight): if main_process(): logger.info("=> loading weight '{}'".format(args.weight)) checkpoint = torch.load(args.weight) model.load_state_dict(checkpoint['state_dict']) if main_process(): logger.info("=> loaded weight '{}'".format(args.weight)) else: if main_process(): logger.info("=> no weight found at '{}'".format(args.weight)) if args.resume: if os.path.isfile(args.resume): if main_process(): logger.info("=> loading checkpoint '{}'".format(args.resume)) # checkpoint = torch.load(args.resume) checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda()) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) if main_process(): logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) else: if main_process(): logger.info("=> no checkpoint found at '{}'".format(args.resume)) value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] train_transform = transform.Compose([ transform.RandScale([args.scale_min, args.scale_max]), transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.ignore_label), transform.RandomGaussianBlur(), transform.RandomHorizontalFlip(), transform.Crop([args.train_h, args.train_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label), transform.ToTensor(), transform.Normalize(mean=mean, std=std)]) train_data = dataset.SemData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_data) else: train_sampler = None train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) if args.evaluate: val_transform = transform.Compose([ transform.Crop([args.train_h, args.train_w], crop_type='center', padding=mean, ignore_label=args.ignore_label), transform.ToTensor(), transform.Normalize(mean=mean, std=std)]) val_data = dataset.SemData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform) if args.distributed: val_sampler = torch.utils.data.distributed.DistributedSampler(val_data) else: val_sampler = None val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler) for epoch in range(args.start_epoch, args.epochs): epoch_log = epoch + 1 if args.distributed: train_sampler.set_epoch(epoch) loss_train, mIoU_train, mAcc_train, allAcc_train = train(train_loader, model, optimizer, epoch) if main_process(): writer.add_scalar('loss_train', loss_train, epoch_log) writer.add_scalar('mIoU_train', mIoU_train, epoch_log) writer.add_scalar('mAcc_train', mAcc_train, epoch_log) writer.add_scalar('allAcc_train', allAcc_train, epoch_log) if (epoch_log % args.save_freq == 0) and main_process(): filename = args.save_path + '/train_epoch_' + str(epoch_log) + '.pth' logger.info('Saving checkpoint to: ' + filename) torch.save({'epoch': epoch_log, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, filename) if epoch_log / args.save_freq > 2: deletename = args.save_path + '/train_epoch_' + str(epoch_log - args.save_freq * 2) + '.pth' os.remove(deletename) if args.evaluate: loss_val, mIoU_val, mAcc_val, allAcc_val = validate(val_loader, model, criterion) if main_process(): writer.add_scalar('loss_val', loss_val, epoch_log) writer.add_scalar('mIoU_val', mIoU_val, epoch_log) writer.add_scalar('mAcc_val', mAcc_val, epoch_log) writer.add_scalar('allAcc_val', allAcc_val, epoch_log)
def main_worker(gpu, ngpus_per_node, argss): global args args = argss ## step.1 设置分布式相关参数 # 1.1 分布式初始化 if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) # 分布式初始化 ## step.2 构建网络 # ---------------------------------------------- 根据实际情况自己写 ---------------------------------------------# criterion = nn.CrossEntropyLoss( ignore_index=args.ignore_label) # 交叉熵损失函数, 根据情况自己修改 if args.arch == 'psp': from model.pspnet import PSPNet model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, criterion=criterion) modules_ori = [ model.layer0, model.layer1, model.layer2, model.layer3, model.layer4 ] modules_new = [model.ppm, model.cls, model.aux] elif args.arch == 'psa': from model.psanet import PSANet model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, psa_type=args.psa_type, compact=args.compact, shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w, normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax, criterion=criterion) modules_ori = [ model.layer0, model.layer1, model.layer2, model.layer3, model.layer4 ] modules_new = [model.psa, model.cls, model.aux] # ---------------------------------------------------- END ---------------------------------------------------# ## step.3 设置优化器 params_list = [] # 模型参数列表 for module in modules_ori: params_list.append(dict(params=module.parameters(), lr=args.base_lr)) # 原来backbone网络 学习率 0.01 for module in modules_new: params_list.append( dict(params=module.parameters(), lr=args.base_lr * 10)) # 新加入预测网络 学习率 0.1 args.index_split = 5 optimizer = torch.optim.SGD(params_list, lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay) # SGD优化器 # 3.x 设置sync_bn from torch.nn.SyncBatchNorm if args.sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) ## step.4 多线程分布式工作 # 4.1 判断是否是在主进程中, 如果在进行如下程序 if main_process(): global logger, writer logger = get_logger() # 设置logger writer = SummaryWriter(args.save_path) # 设置writer logger.info(args) # 输出参数列表 logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) logger.info(model) # 输出网络列表 # 4.2 分布式工作 if args.distributed: torch.cuda.set_device(gpu) # 指定编号为gpu的那一张显卡 args.batch_size = int(args.batch_size / ngpus_per_node) # 每张卡的训练的batch size args.batch_size_val = int(args.batch_size_val / ngpus_per_node) # 每张卡的评测的batch size args.workers = int( (args.workers + ngpus_per_node - 1) / ngpus_per_node) # 每张卡工作的数目 model = torch.nn.parallel.DistributedDataParallel( model.cuda(), device_ids=[gpu]) # 加载torch分布式 else: model = torch.nn.DataParallel(model.cuda()) # 数据并行 ## step.5 加载网络权重 # 5.1 直接加载网络预权重 if args.weight: if os.path.isfile(args.weight): if main_process(): logger.info("=> loading weight '{}'".format(args.weight)) checkpoint = torch.load(args.weight) model.load_state_dict(checkpoint['state_dict']) if main_process(): logger.info("=> loaded weight '{}'".format(args.weight)) else: if main_process(): logger.info("=> no weight found at '{}'".format(args.weight)) # 5.2 加载上次没训练完的模型权重 if args.resume: if os.path.isfile(args.resume): if main_process(): logger.info("=> loading checkpoint '{}'".format(args.resume)) # checkpoint = torch.load(args.resume) checkpoint = torch.load( args.resume, map_location=lambda storage, loc: storage.cuda()) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) if main_process(): logger.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: if main_process(): logger.info("=> no checkpoint found at '{}'".format( args.resume)) ## step.7 设置数据loader # 7.1 loader参数设置 value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] train_transform = transform.Compose([ transform.RandScale([args.scale_min, args.scale_max]), transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.ignore_label), transform.RandomGaussianBlur(), transform.RandomHorizontalFlip(), transform.Crop([args.train_h, args.train_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label), transform.ToTensor(), transform.Normalize(mean=mean, std=std) ]) # 组合数据预处理 # 7.2 训练数据, 可以根据需要自己修改或写 # ---------------------------------------------- 根据实际情况自己写 ---------------------------------------------# train_data = dataset.SemData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform) # ---------------------------------------------------- END ---------------------------------------------------# if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_data) # 分布式下数据loader else: train_sampler = None train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) if args.evaluate: # evaluate数据 val_transform = transform.Compose([ transform.Crop([args.train_h, args.train_w], crop_type='center', padding=mean, ignore_label=args.ignore_label), transform.ToTensor(), transform.Normalize(mean=mean, std=std) ]) val_data = dataset.SemData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform) if args.distributed: val_sampler = torch.utils.data.distributed.DistributedSampler( val_data) else: val_sampler = None val_loader = torch.utils.data.DataLoader( val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler) ## step.8 主循环 for epoch in range(args.start_epoch, args.epochs): epoch_log = epoch + 1 if args.distributed: train_sampler.set_epoch(epoch) # 8.1 训练函数 # ---------------------------------------------- 根据实际情况自己写 ---------------------------------------------# loss_train, mIoU_train, mAcc_train, allAcc_train = train( train_loader, model, optimizer, epoch) # ---------------------------------------------------- END ---------------------------------------------------# if main_process(): writer.add_scalar('loss_train', loss_train, epoch_log) writer.add_scalar('mIoU_train', mIoU_train, epoch_log) writer.add_scalar('mAcc_train', mAcc_train, epoch_log) writer.add_scalar('allAcc_train', allAcc_train, epoch_log) # 8.2 保存checkpoint if (epoch_log % args.save_freq == 0) and main_process(): filename = args.save_path + '/train_epoch_' + str( epoch_log) + '.pth' logger.info('Saving checkpoint to: ' + filename) torch.save( { 'epoch': epoch_log, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, filename) if epoch_log / args.save_freq > 2: deletename = args.save_path + '/train_epoch_' + str( epoch_log - args.save_freq * 2) + '.pth' os.remove(deletename) # 训练一个epoch之后evaluate if args.evaluate: loss_val, mIoU_val, mAcc_val, allAcc_val = validate( val_loader, model, criterion) if main_process(): writer.add_scalar('loss_val', loss_val, epoch_log) writer.add_scalar('mIoU_val', mIoU_val, epoch_log) writer.add_scalar('mAcc_val', mAcc_val, epoch_log) writer.add_scalar('allAcc_val', allAcc_val, epoch_log)
def main_worker(local_rank, ngpus_per_node, argss): global args args = argss dist.init_process_group(backend=args.dist_backend) teacher_model = None if args.teacher_model_path: teacher_model = PSPNet(layers=args.teacher_layers, classes=args.classes, zoom_factor=args.zoom_factor) kd_path = 'alpha_' + str(args.alpha) + '_Temp_' + str(args.temperature) args.save_path = os.path.join(args.save_path, kd_path) if not os.path.exists(args.save_path): os.mkdir(args.save_path) if args.arch == 'psp': model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor) modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4] modules_new = [model.ppm, model.cls, model.aux] elif args.arch == 'bise_v1': model = BiseNet(num_classes=args.classes) modules_ori = [model.sp, model.cp] modules_new = [model.ffm, model.conv_out, model.conv_out16, model.conv_out32] params_list = [] for module in modules_ori: params_list.append(dict(params=module.parameters(), lr=args.base_lr)) for module in modules_new: params_list.append(dict(params=module.parameters(), lr=args.base_lr * 10)) args.index_split = 5 optimizer = torch.optim.SGD(params_list, lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) if teacher_model is not None: teacher_model = nn.SyncBatchNorm.convert_sync_batchnorm(teacher_model) if main_process(): global logger, writer logger = get_logger() writer = SummaryWriter(args.save_path) # tensorboardX logger.info(args) logger.info("=> creating model ...") logger.info("Classes: {}".format(args.classes)) logger.info(model) if teacher_model is not None: logger.info(teacher_model) if args.distributed: torch.cuda.set_device(local_rank) args.batch_size = int(args.batch_size / ngpus_per_node) args.batch_size_val = int(args.batch_size_val / ngpus_per_node) args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[local_rank]) if teacher_model is not None: teacher_model = torch.nn.parallel.DistributedDataParallel(teacher_model.cuda(), device_ids=[local_rank]) else: model = torch.nn.DataParallel(model.cuda()) if teacher_model is not None: teacher_model = torch.nn.DataParallel(teacher_model.cuda()) if teacher_model is not None: checkpoint = torch.load(args.teacher_model_path, map_location=lambda storage, loc: storage.cuda()) teacher_model.load_state_dict(checkpoint['state_dict'], strict=False) print("=> loading teacher checkpoint '{}'".format(args.teacher_model_path)) criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda(local_rank) kd_criterion = None if teacher_model is not None: kd_criterion = KDLoss(ignore_index=args.ignore_label).cuda(local_rank) if args.weight: if os.path.isfile(args.weight): if main_process(): logger.info("=> loading weight: '{}'".format(args.weight)) checkpoint = torch.load(args.weight) model.load_state_dict(checkpoint['state_dict']) if main_process(): logger.info("=> loaded weight '{}'".format(args.weight)) else: if main_process(): logger.info("=> mp weight found at '{}'".format(args.weight)) best_mIoU_val = 0.0 if args.resume: if os.path.isfile(args.resume): if main_process(): logger.info("=> loading checkpoint '{}'".format(args.resume)) # Load all tensors onto GPU checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda()) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) best_mIoU_val = checkpoint['best_mIoU_val'] if main_process(): logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['point'])) else: if main_process(): logger.info("=> no checkpoint found at '{}'".format(args.resume)) value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] train_transform = transform.Compose([ transform.RandScale([args.scale_min, args.scale_max]), transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.ignore_label), transform.RandomGaussianBlur(), transform.RandomHorizontalFlip(), transform.Crop([args.train_h, args.train_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label), transform.ToTensor(), transform.Normalize(mean=mean, std=std)]) train_data = dataset.SemData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_data) else: train_sampler = None train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) if args.evaluate: val_transform = transform.Compose([ transform.Crop([args.train_h, args.train_w], crop_type='center', padding=mean, ignore_label=args.ignore_label), transform.ToTensor(), transform.Normalize(mean=mean, std=std)]) val_data = dataset.SemData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform) if args.distributed: val_sampler = torch.utils.data.distributed.DistributedSampler(val_data) else: val_sampler = None val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler) for epoch in range(args.start_epoch, args.epochs): epoch_log = epoch + 1 if args.distributed: # Use .set_epoch() method to reshuffle the dataset partition at every iteration train_sampler.set_epoch(epoch) loss_train, mIoU_train, mAcc_train, allAcc_train = train(local_rank, train_loader, model, teacher_model, criterion, kd_criterion, optimizer, epoch) if main_process(): writer.add_scalar('loss_train', loss_train, epoch_log) writer.add_scalar('mIoU_train', mIoU_train, epoch_log) writer.add_scalar('mAcc_train', mAcc_train, epoch_log) writer.add_scalar('allAcc_train', allAcc_train, epoch_log) is_best = False if args.evaluate: loss_val, mIoU_val, mAcc_val, allAcc_val = validate(local_rank, val_loader, model, criterion) if main_process(): writer.add_scalar('loss_val', loss_val, epoch_log) writer.add_scalar('mIoU_val', mIoU_val, epoch_log) writer.add_scalar('mAcc_val', mAcc_val, epoch_log) writer.add_scalar('allAcc_val', allAcc_val, epoch_log) if best_mIoU_val < mIoU_val: is_best = True best_mIoU_val = mIoU_val logger.info('==>The best val mIoU: %.3f' % (best_mIoU_val)) if (epoch_log % args.save_freq == 0) and main_process(): save_checkpoint( { 'epoch': epoch_log, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_mIoU_val': best_mIoU_val }, is_best, args.save_path ) if is_best: logger.info('Saving checkpoint to:' + args.save_path + '/best.pth with mIoU: ' + str(best_mIoU_val) ) else: logger.info('Saving checkpoint to:' + args.save_path + '/last.pth with mIoU: ' + str(mIoU_val) ) if main_process(): writer.close() # it must close the writer, otherwise it will appear the EOFError! logger.info('==>Training done!\nBest mIoU: %.3f' % (best_mIoU_val))