def main( gpu,cfg,args): # Network Builders load_gpu = gpu+args.start_gpu rank = gpu torch.cuda.set_device(load_gpu) dist.init_process_group( backend='nccl', init_method='tcp://127.0.0.1:{}'.format(args.port), world_size=args.gpu_num, rank=rank, timeout=datetime.timedelta(seconds=300)) # self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model).cuda(self.gpu) if args.use_float16: from torch.cuda.amp import autocast as autocast, GradScaler scaler = GradScaler() else: scaler = None autocast = None label_num_=args.num_class net_encoder = ModelBuilder.build_encoder( arch=cfg.MODEL.arch_encoder.lower(), fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_encoder) net_decoder = ModelBuilder.build_decoder( arch=cfg.MODEL.arch_decoder.lower(), fc_dim=cfg.MODEL.fc_dim, num_class=label_num_, weights=cfg.MODEL.weights_decoder) crit = nn.NLLLoss(ignore_index=255) if cfg.MODEL.arch_decoder.endswith('deepsup'): segmentation_module = SegmentationModule( net_encoder, net_decoder, crit, cfg.TRAIN.deep_sup_scale) else: segmentation_module = SegmentationModule( net_encoder, net_decoder, crit) if args.use_clipdataset: dataset_train = BaseDataset_longclip(args,'train') else: dataset_train = BaseDataset( args, 'train' ) sampler_train =torch.utils.data.distributed.DistributedSampler(dataset_train) loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.batchsize, shuffle=False,sampler=sampler_train, pin_memory=True, num_workers=args.workers) print('1 Epoch = {} iters'.format(cfg.TRAIN.epoch_iters)) dataset_val = BaseDataset( args, 'val' ) sampler_val =torch.utils.data.distributed.DistributedSampler(dataset_val) loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=args.batchsize, shuffle=False,sampler=sampler_val, pin_memory=True, num_workers=args.workers) # loader_val = torch.utils.data.DataLoader(dataset_val,batch_size=args.batchsize,shuffle=False,num_workers=args.workers) # create loader iterator # load nets into gpu segmentation_module = segmentation_module.cuda(load_gpu) segmentation_module= nn.SyncBatchNorm.convert_sync_batchnorm(segmentation_module) if args.resume_epoch!=0: # if dist.get_rank() == 0: to_load = torch.load(os.path.join('./resume','model_epoch_{}.pth'.format(args.resume_epoch)),map_location=torch.device("cuda:"+str(load_gpu))) from collections import OrderedDict new_state_dict = OrderedDict() for k, v in to_load.items(): name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module. new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。 cfg.TRAIN.start_epoch=args.resume_epoch segmentation_module.load_state_dict(new_state_dict) segmentation_module= torch.nn.parallel.DistributedDataParallel( segmentation_module, device_ids=[load_gpu], find_unused_parameters=True) # Set up optimizers # nets = (net_encoder, net_decoder, crit) nets = segmentation_module optimizers = create_optimizers(segmentation_module, cfg) if args.resume_epoch!=0: # if dist.get_rank() == 0: optimizers.load_state_dict(torch.load(os.path.join('./resume','opt_epoch_{}.pth'.format(args.resume_epoch)),map_location=torch.device("cuda:"+str(load_gpu)))) print('resume from epoch {}'.format(args.resume_epoch)) # Main loop history = {'train': {'epoch': [], 'loss': [], 'acc': []}} # test(segmentation_module,loader_val,args) for epoch in range(cfg.TRAIN.start_epoch, cfg.TRAIN.num_epoch): if dist.get_rank() == 0 and epoch==0: checkpoint(nets,optimizers, history, args, epoch+1) print('Epoch {}'.format(epoch)) train(segmentation_module, loader_train, optimizers, history, epoch+1, cfg,args,load_gpu,scaler=scaler,autocast=autocast) ################### # checkpointing if dist.get_rank() == 0 and (epoch+1)%10==0: checkpoint(segmentation_module,optimizers, history, args, epoch+1) if args.validation: test(segmentation_module,loader_val,args) print('Training Done!')
def main(cfg, gpu, args): num_class = args.num_class torch.cuda.set_device(gpu) # Network Builders net_encoder = ModelBuilder.build_encoder(arch=cfg.MODEL.arch_encoder, fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_encoder) net_decoder = ModelBuilder.build_decoder(arch=cfg.MODEL.arch_decoder, fc_dim=cfg.MODEL.fc_dim, num_class=num_class, weights=cfg.MODEL.weights_decoder, use_softmax=True) crit = nn.NLLLoss(ignore_index=-1) segmentation_module = SegmentationModule(net_encoder, net_decoder, crit) to_load = torch.load(args.load, map_location=torch.device("cuda:" + str(args.start_gpu))) new_state_dict = OrderedDict() for k, v in to_load.items(): name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module. new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。 segmentation_module.load_state_dict(new_state_dict) print('load model parameters') segmentation_module.cuda(args.start_gpu) with open(os.path.join(args.dataroot, args.split + '.txt')) as f: lines = f.readlines() videolists = [line[:-1] for line in lines] # Dataset and Loader evaluator = Evaluator(num_class) eval_video = Evaluator(num_class) evaluator.reset() eval_video.reset() total_vmIOU = 0.0 total_vfwIOU = 0.0 total_video = len(videolists) v = [] n = [] for video in videolists: eval_video.reset() dataset_test = TestDataset(args.dataroot, video, args) loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batchsize, shuffle=False, num_workers=5, drop_last=False) # Main loop test(segmentation_module, loader_test, gpu, args, evaluator, eval_video, video) if args.split != 'test': v_mIOU = eval_video.Mean_Intersection_over_Union() v.append(v) n.append(video) print(video, v_mIOU) total_vmIOU += v_mIOU v_fwIOU = eval_video.Frequency_Weighted_Intersection_over_Union() total_vfwIOU += v_fwIOU if args.split != 'test': total_vmIOU = total_vmIOU / total_video total_vfwIOU = total_vfwIOU / total_video Acc = evaluator.Pixel_Accuracy() Acc_class = evaluator.Pixel_Accuracy_Class() mIoU = evaluator.Mean_Intersection_over_Union() FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union() print( "Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}, video mIOU: {}, video fwIOU: {}" .format(Acc, Acc_class, mIoU, FWIoU, total_vmIOU, total_vfwIOU)) print('Inference done!')