Beispiel #1
0
def main( gpu,cfg,args):
    # Network Builders

    load_gpu = gpu+args.start_gpu
    rank = gpu
    torch.cuda.set_device(load_gpu)
    dist.init_process_group(
        backend='nccl',
        init_method='tcp://127.0.0.1:{}'.format(args.port),
        world_size=args.gpu_num,
        rank=rank,
        timeout=datetime.timedelta(seconds=300))
            # self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model).cuda(self.gpu)


    if args.use_float16:
        from torch.cuda.amp import autocast as autocast, GradScaler
        scaler = GradScaler()
    else:
        scaler = None
        autocast = None

    label_num_=args.num_class
    net_encoder = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_encoder)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        num_class=label_num_,
        weights=cfg.MODEL.weights_decoder)

    crit = nn.NLLLoss(ignore_index=255)

    if cfg.MODEL.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(
            net_encoder, net_decoder, crit, cfg.TRAIN.deep_sup_scale)
    else:
        segmentation_module = SegmentationModule(
            net_encoder, net_decoder, crit)

    if args.use_clipdataset:
        dataset_train = BaseDataset_longclip(args,'train')
    else:
        dataset_train = BaseDataset(
            args,
            'train'
            )

    sampler_train =torch.utils.data.distributed.DistributedSampler(dataset_train)
    loader_train = torch.utils.data.DataLoader(dataset_train,  batch_size=args.batchsize,  shuffle=False,sampler=sampler_train,   pin_memory=True,
                                    num_workers=args.workers)


    print('1 Epoch = {} iters'.format(cfg.TRAIN.epoch_iters))

    dataset_val = BaseDataset(
        args,
        'val'
        )
    sampler_val =torch.utils.data.distributed.DistributedSampler(dataset_val)
    loader_val = torch.utils.data.DataLoader(dataset_val,  batch_size=args.batchsize,  shuffle=False,sampler=sampler_val,   pin_memory=True,
                                    num_workers=args.workers)
#    loader_val = torch.utils.data.DataLoader(dataset_val,batch_size=args.batchsize,shuffle=False,num_workers=args.workers)
    # create loader iterator
    

    # load nets into gpu

    segmentation_module = segmentation_module.cuda(load_gpu)

    segmentation_module= nn.SyncBatchNorm.convert_sync_batchnorm(segmentation_module)

    if args.resume_epoch!=0:
#        if dist.get_rank() == 0:
        to_load = torch.load(os.path.join('./resume','model_epoch_{}.pth'.format(args.resume_epoch)),map_location=torch.device("cuda:"+str(load_gpu)))
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in to_load.items():
            name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module.
            new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。
        cfg.TRAIN.start_epoch=args.resume_epoch
        segmentation_module.load_state_dict(new_state_dict)


    segmentation_module= torch.nn.parallel.DistributedDataParallel(
                    segmentation_module,
                device_ids=[load_gpu],
                find_unused_parameters=True)

    # Set up optimizers
#    nets = (net_encoder, net_decoder, crit)
    nets = segmentation_module
    optimizers = create_optimizers(segmentation_module, cfg)
    if args.resume_epoch!=0:
#        if dist.get_rank() == 0:
        optimizers.load_state_dict(torch.load(os.path.join('./resume','opt_epoch_{}.pth'.format(args.resume_epoch)),map_location=torch.device("cuda:"+str(load_gpu))))
        print('resume from epoch {}'.format(args.resume_epoch))

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': []}}

#    test(segmentation_module,loader_val,args)
    for epoch in range(cfg.TRAIN.start_epoch, cfg.TRAIN.num_epoch):
        if dist.get_rank() == 0 and epoch==0:
            checkpoint(nets,optimizers, history, args, epoch+1)
        print('Epoch {}'.format(epoch))
        train(segmentation_module, loader_train, optimizers, history, epoch+1, cfg,args,load_gpu,scaler=scaler,autocast=autocast)

###################        # checkpointing
        if dist.get_rank() == 0 and (epoch+1)%10==0:
            checkpoint(segmentation_module,optimizers, history, args, epoch+1)
        if args.validation:
            test(segmentation_module,loader_val,args)

    print('Training Done!')
Beispiel #2
0
def main(cfg, gpu, args):

    num_class = args.num_class
    torch.cuda.set_device(gpu)

    # Network Builders
    net_encoder = ModelBuilder.build_encoder(arch=cfg.MODEL.arch_encoder,
                                             fc_dim=cfg.MODEL.fc_dim,
                                             weights=cfg.MODEL.weights_encoder)
    net_decoder = ModelBuilder.build_decoder(arch=cfg.MODEL.arch_decoder,
                                             fc_dim=cfg.MODEL.fc_dim,
                                             num_class=num_class,
                                             weights=cfg.MODEL.weights_decoder,
                                             use_softmax=True)

    crit = nn.NLLLoss(ignore_index=-1)

    segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)

    to_load = torch.load(args.load,
                         map_location=torch.device("cuda:" +
                                                   str(args.start_gpu)))
    new_state_dict = OrderedDict()
    for k, v in to_load.items():
        name = k[7:]  # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module.
        new_state_dict[name] = v  #新字典的key值对应的value为一一对应的值。

    segmentation_module.load_state_dict(new_state_dict)
    print('load model parameters')

    segmentation_module.cuda(args.start_gpu)
    with open(os.path.join(args.dataroot, args.split + '.txt')) as f:
        lines = f.readlines()
        videolists = [line[:-1] for line in lines]

    # Dataset and Loader
    evaluator = Evaluator(num_class)
    eval_video = Evaluator(num_class)
    evaluator.reset()
    eval_video.reset()
    total_vmIOU = 0.0
    total_vfwIOU = 0.0
    total_video = len(videolists)
    v = []
    n = []
    for video in videolists:
        eval_video.reset()
        dataset_test = TestDataset(args.dataroot, video, args)
        loader_test = torch.utils.data.DataLoader(dataset_test,
                                                  batch_size=args.batchsize,
                                                  shuffle=False,
                                                  num_workers=5,
                                                  drop_last=False)
        # Main loop
        test(segmentation_module, loader_test, gpu, args, evaluator,
             eval_video, video)
        if args.split != 'test':
            v_mIOU = eval_video.Mean_Intersection_over_Union()
            v.append(v)
            n.append(video)
            print(video, v_mIOU)
            total_vmIOU += v_mIOU
            v_fwIOU = eval_video.Frequency_Weighted_Intersection_over_Union()

            total_vfwIOU += v_fwIOU
    if args.split != 'test':
        total_vmIOU = total_vmIOU / total_video
        total_vfwIOU = total_vfwIOU / total_video

        Acc = evaluator.Pixel_Accuracy()
        Acc_class = evaluator.Pixel_Accuracy_Class()
        mIoU = evaluator.Mean_Intersection_over_Union()
        FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union()
        print(
            "Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}, video mIOU: {}, video fwIOU: {}"
            .format(Acc, Acc_class, mIoU, FWIoU, total_vmIOU, total_vfwIOU))

    print('Inference done!')