Ejemplo n.º 1
0
def mean_iou(input: torch.Tensor,
             target: torch.Tensor,
             num_classes: int,
             eps: Optional[float] = 1e-6) -> torch.Tensor:
    r"""Calculate mean Intersection-Over-Union (mIOU).

    The function internally computes the confusion matrix.

    Args:
        input (torch.Tensor) : tensor with estimated targets returned by a
          classifier. The shape can be :math:`(B, *)` and must contain integer
          values between 0 and K-1.
        target (torch.Tensor) : tensor with ground truth (correct) target
          values. The shape can be :math:`(B, *)` and must contain integer
          values between 0 and K-1, whete targets are assumed to be provided as
          one-hot vectors.
        num_classes (int): total possible number of classes in target.

    Returns:
        torch.Tensor: a tensor representing the mean intersection-over union
        with shape :math:`(B, K)` where K is the number of classes.
    """
    if not torch.is_tensor(input) and input.dtype is not torch.int64:
        raise TypeError("Input input type is not a torch.Tensor with "
                        "torch.int64 dtype. Got {}".format(type(input)))
    if not torch.is_tensor(target) and target.dtype is not torch.int64:
        raise TypeError("Input target type is not a torch.Tensor with "
                        "torch.int64 dtype. Got {}".format(type(target)))
    if not input.shape == target.shape:
        raise ValueError("Inputs input and target must have the same shape. "
                         "Got: {}".format(input.shape, target.shape))
    if not input.device == target.device:
        raise ValueError("Inputs must be in the same device. "
                         "Got: {} - {}".format(input.device, target.device))
    if not isinstance(num_classes, int) or num_classes < 2:
        raise ValueError("The number of classes must be an intenger bigger "
                         "than two. Got: {}".format(num_classes))
    # we first compute the confusion matrix
    conf_mat: torch.Tensor = confusion_matrix(input, target, num_classes)

    # compute the actual intersection over union
    sum_over_row = torch.sum(conf_mat, dim=1)
    sum_over_col = torch.sum(conf_mat, dim=2)
    conf_mat_diag = torch.diagonal(conf_mat, dim1=-2, dim2=-1)
    denominator = sum_over_row + sum_over_col - conf_mat_diag

    # NOTE: we add epsilon so that samples that are neither in the
    # prediction or ground truth are taken into account.
    ious = (conf_mat_diag + eps) / (denominator + eps)
    return ious
def main():
    setSeed(10)
    opt = opt_global_inti()

    num_gpu = torch.cuda.device_count()
    assert num_gpu == opt.num_gpu,"opt.num_gpu NOT equals torch.cuda.device_count()" 

    gpu_name_list = []
    for i in range(num_gpu):
        gpu_name_list.append(torch.cuda.get_device_name(i))

    opt.gpu_list = gpu_name_list

    if(opt.load_pretrain!=''):
        opt,model,f_loss,optimizer,scheduler,opt_deepgcn = load_pretrained(opt)
    else:
        opt,model,f_loss,optimizer,scheduler,opt_deepgcn = creating_new_model(opt)
    


    print('----------------------Load Dataset----------------------')
    print('Root of dataset: ', opt.dataset_root)
    print('Phase: ', opt.phase)
    print('debug: ', opt.debug)

    #pdb.set_trace()

    if(opt.model!='deepgcn'):
        train_dataset = BigredDataSet(
            root=opt.dataset_root,
            is_train=True,
            is_validation=False,
            is_test=False,
            num_channel = opt.num_channel,
            test_code = opt.debug,
            including_ring = opt.including_ring
            )

        f_loss.load_weight(train_dataset.labelweights)

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=opt.batch_size,
            shuffle=True,
            pin_memory=True,
            drop_last=True,
            num_workers=int(opt.num_workers))

        validation_dataset = BigredDataSet(
            root=opt.dataset_root,
            is_train=False,
            is_validation=True,
            is_test=False,
            num_channel = opt.num_channel,
            test_code = opt.debug,
            including_ring = opt.including_ring)
        validation_loader = torch.utils.data.DataLoader(
            validation_dataset,
            batch_size=opt.batch_size,
            shuffle=False,
            pin_memory=True,
            drop_last=True,
            num_workers=int(opt.num_workers))
    else:
        train_dataset = BigredDataSetPTG(root = opt.dataset_root,
                                 is_train=True,
                                 is_validation=False,
                                 is_test=False,
                                 num_channel=opt.num_channel,
                                 new_dataset = False,
                                 test_code = opt.debug,
                                 pre_transform=torch_geometric.transforms.NormalizeScale()
                                 )
        train_loader = DenseDataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
        validation_dataset = BigredDataSetPTG(root = opt.dataset_root,
                                    is_train=False,
                                    is_validation=True,
                                    is_test=False,
                                    new_dataset = False,
                                    test_code = opt.debug,
                                    num_channel=opt.num_channel,
                                    pre_transform=torch_geometric.transforms.NormalizeScale()
                                    )
        validation_loader = DenseDataLoader(validation_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers)

        labelweights = np.zeros(2)
        labelweights, _ = np.histogram(train_dataset.data.y.numpy(), range(3))
        labelweights = labelweights.astype(np.float32)
        labelweights = labelweights / np.sum(labelweights)
        labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0)
        weights = torch.Tensor(labelweights).cuda()
        f_loss.load_weight(weights)







    print('train dataset num_frame: ',len(train_dataset))
    print('num_batch: ', int(len(train_loader) / opt.batch_size))


    print('validation dataset num_frame: ',len(validation_dataset))
    print('num_batch: ', int(len(validation_loader) / opt.batch_size))

    print('Batch_size: ', opt.batch_size)

    print('----------------------Prepareing Training----------------------')
    metrics_list = ['Miou','Biou','Fiou','loss','OA','time_complexicity','storage_complexicity']
    manager_test = metrics_manager(metrics_list)

    metrics_list_train = ['Miou','Biou',
                            'Fiou','loss',
                            'storage_complexicity',
                            'time_complexicity']
    manager_train = metrics_manager(metrics_list_train)


    wandb.init(project=opt.wd_project,name=opt.model_name,resume=False)
    if(opt.wandb_history == False):
        best_value = 0
    else:
        temp = wandb.restore('best_model.pth',run_path = opt.wandb_id)
        best_value = torch.load(temp.name)['Miou_validation_ave']

    wandb.config.update(opt)

    if opt.epoch_ckpt == 0:
        opt.unsave_epoch = 0
    else:
        opt.epoch_ckpt = opt.epoch_ckpt+1

    for epoch in range(opt.epoch_ckpt,opt.epoch_max):
        manager_train.reset()
        model.train()
        tic_epoch = time.perf_counter()
        print('---------------------Training----------------------')
        print("Epoch: ",epoch)
        for i, data in tqdm(enumerate(train_loader), total=len(train_loader), smoothing=0.9):
            
            if(opt.model == 'deepgcn'):
                points = torch.cat((data.pos.transpose(2, 1).unsqueeze(3), data.x.transpose(2, 1).unsqueeze(3)), 1)
                points = points[:, :opt.num_channel, :, :]
                target = data.y.cuda()
            else:
                points, target = data
                #target.shape [B,N]
                #points.shape [B,N,C]
                points, target = points.cuda(non_blocking=True), target.cuda(non_blocking=True)

            # pdb.set_trace()
            #training...
            optimizer.zero_grad()
            tic = time.perf_counter()
            pred_mics = model(points)                
            toc = time.perf_counter()
            #compute loss

            #For loss
            #target.shape [B,N] ->[B*N]
            #pred.shape [B,N,2]->[B*N,2]
            #pdb.set_trace()
            

            #pdb.set_trace()
            loss = f_loss(pred_mics, target)   

            if(opt.apex):
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            optimizer.step()

            #pred.shape [B,N,2] since pred returned pass F.log_softmax
            pred, target = pred_mics[0].cpu(), target.cpu()

            #pred:[B,N,2]->[B,N]
            #pdb.set_trace()
            pred = pred.data.max(dim=2)[1]
            
            #compute iou
            Biou,Fiou = mean_iou(pred,target,num_classes =2).mean(dim=0)
            miou = (Biou+Fiou)/2

            #compute Training time complexity
            time_complexity = toc - tic


            #compute Training storage complexsity
            num_device = torch.cuda.device_count()
            assert num_device == opt.num_gpu,"opt.num_gpu NOT equals torch.cuda.device_count()" 
            temp = []
            for k in range(num_device):
                temp.append(torch.cuda.memory_allocated(k))
            RAM_usagePeak = torch.tensor(temp).float().mean()

            #print(loss.item())
            #print(miou.item())
            #writeup logger
            manager_train.update('loss',loss.item())
            manager_train.update('Biou',Biou.item())
            manager_train.update('Fiou',Fiou.item())
            manager_train.update('Miou',miou.item())
            manager_train.update('time_complexicity',float(1/time_complexity))
            manager_train.update('storage_complexicity',RAM_usagePeak.item())

            log_dict = {'loss_online':loss.item(),
                        'Biou_online':Biou.item(),
                        'Fiou_online':Fiou.item(),
                        'Miou_online':miou.item(),
                        'time_complexicity_online':float(1/time_complexity),
                        'storage_complexicity_online':RAM_usagePeak.item()
                        }
            if(epoch - opt.unsave_epoch>=0):
                wandb.log(log_dict)

        toc_epoch = time.perf_counter()
        time_tensor = toc_epoch-tic_epoch


        summery_dict = manager_train.summary()
        log_train_end = {}
        for key in summery_dict:
            log_train_end[key+'_train_ave'] = summery_dict[key]
            print(key+'_train_ave: ',summery_dict[key])
        
        log_train_end['Time_PerEpoch'] = time_tensor
        if(epoch - opt.unsave_epoch>=0):
            wandb.log(log_train_end)
        else:
            print('No data upload to wandb. Start upload: Epoch[%d] Current: Epoch[%d]'%(opt.unsave_epoch,epoch))

        scheduler.step()
        if(epoch % 10 == 1):
            print('---------------------Validation----------------------')
            manager_test.reset()
            model.eval()
            print("Epoch: ",epoch)
            with torch.no_grad():
                for j, data in tqdm(enumerate(validation_loader), total=len(validation_loader), smoothing=0.9):


                    if(opt.model == 'deepgcn'):
                        points = torch.cat((data.pos.transpose(2, 1).unsqueeze(3), data.x.transpose(2, 1).unsqueeze(3)), 1)
                        points = points[:, :opt.num_channel, :, :]
                        target = data.y.cuda()
                    else:
                        points, target = data
                        #target.shape [B,N]
                        #points.shape [B,N,C]
                        points, target = points.cuda(non_blocking=True), target.cuda(non_blocking=True)


                    tic = time.perf_counter()
                    pred_mics = model(points)                
                    toc = time.perf_counter()
                    
                    #pred.shape [B,N,2] since pred returned pass F.log_softmax
                    pred, target = pred_mics[0].cpu(), target.cpu()

                    #compute loss
                    test_loss = 0

                    #pred:[B,N,2]->[B,N]
                    pred = pred.data.max(dim=2)[1]
                    #compute confusion matrix
                    cm = confusion_matrix(pred,target,num_classes =2).sum(dim=0)
                    #compute OA
                    overall_correct_site = torch.diag(cm).sum()
                    overall_reference_site = cm.sum()
                    # if(overall_reference_site != opt.batch_size * opt.num_points):
                    #pdb.set_trace()
                    #assert overall_reference_site == opt.batch_size * opt.num_points,"Confusion_matrix computing error"
                    oa = float(overall_correct_site/overall_reference_site)
                    
                    #compute iou
                    Biou,Fiou = mean_iou(pred,target,num_classes =2).mean(dim=0)
                    miou = (Biou+Fiou)/2

                    #compute inference time complexity
                    time_complexity = toc - tic
                    
                    #compute inference storage complexsity
                    num_device = torch.cuda.device_count()
                    assert num_device == opt.num_gpu,"opt.num_gpu NOT equals torch.cuda.device_count()" 
                    temp = []
                    for k in range(num_device):
                        temp.append(torch.cuda.memory_allocated(k))
                    RAM_usagePeak = torch.tensor(temp).float().mean()
                    #writeup logger
                    # metrics_list = ['test_loss','OA','Biou','Fiou','Miou','time_complexicity','storage_complexicity']
                    manager_test.update('loss',test_loss)
                    manager_test.update('OA',oa)
                    manager_test.update('Biou',Biou.item())
                    manager_test.update('Fiou',Fiou.item())
                    manager_test.update('Miou',miou.item())
                    manager_test.update('time_complexicity',float(1/time_complexity))
                    manager_test.update('storage_complexicity',RAM_usagePeak.item())

            
            summery_dict = manager_test.summary()

            log_val_end = {}
            for key in summery_dict:
                log_val_end[key+'_validation_ave'] = summery_dict[key]
                print(key+'_validation_ave: ',summery_dict[key])

            package = dict()
            package['state_dict'] = model.state_dict()
            package['scheduler'] = scheduler
            package['optimizer'] = optimizer
            package['epoch'] = epoch

            opt_temp = vars(opt)
            for k in opt_temp:
                package[k] = opt_temp[k]
            if(opt_deepgcn is not None):
                opt_temp = vars(opt_deepgcn)
                for k in opt_temp:
                    package[k+'_opt2'] = opt_temp[k]


            for k in log_val_end:
                package[k] = log_val_end[k]

            save_root = opt.save_root+'/val_miou%.4f_Epoch%s.pth'%(package['Miou_validation_ave'],package['epoch'])
            torch.save(package,save_root)

            print('Is Best?: ',(package['Miou_validation_ave']>best_value))
            if(package['Miou_validation_ave']>best_value):
                best_value = package['Miou_validation_ave']
                save_root = opt.save_root+'/best_model.pth'
                torch.save(package,save_root)
            if(epoch - opt.unsave_epoch>=0):
                wandb.log(log_val_end)
            else:
                print('No data upload to wandb. Start upload: Epoch[%d] Current: Epoch[%d]'%(opt.unsave_epoch,epoch))
            if(opt.debug == True):
                pdb.set_trace()
Ejemplo n.º 3
0
def main():
    setSeed(10)
    opt = opt_global_inti()
    print('----------------------Load ckpt----------------------')
    pretrained_model_path = os.path.join(opt.load_pretrain,
                                         'saves/best_model.pth')
    package = torch.load(pretrained_model_path)
    para_state_dict = package['state_dict']
    opt.num_channel = package['num_channel']
    opt.time = package['time']
    opt.epoch_ckpt = package['epoch']
    try:
        state_dict = convert_state_dict(para_state_dict)
    except:
        para_state_dict = para_state_dict.state_dict()
        state_dict = convert_state_dict(para_state_dict)

    # state_dict = para_state_dict
    ckpt_, ckpt_file_name = opt.load_pretrain.split("/")
    module_name = ckpt_ + '.' + ckpt_file_name + '.' + 'model'
    MODEL = importlib.import_module(module_name)

    opt_deepgcn = []
    print(opt.model)
    if (opt.model == 'deepgcn'):
        opt_deepgcn = OptInit_deepgcn().initialize()
        model = MODEL.get_model(opt2=opt_deepgcn,
                                input_channel=opt.num_channel)
    else:
        # print('opt.num_channel: ',opt.num_channel)
        model = MODEL.get_model(input_channel=opt.num_channel,
                                is_synchoization='Instance')
    Model_Specification = MODEL.get_model_name(input_channel=opt.num_channel)
    print('----------------------Test Model----------------------')
    print('Root of prestrain model: ', pretrained_model_path)
    print('Model: ', opt.model)
    print('Pretrained model name: ', Model_Specification)
    print('Trained Date: ', opt.time)
    print('num_channel: ', opt.num_channel)
    name = input("Edit the name or press ENTER to skip: ")
    if (name != ''):
        opt.model_name = name
    else:
        opt.model_name = Model_Specification
    print('Pretrained model name: ', opt.model_name)
    package['name'] = opt.model_name
    try:
        package["Miou_validation_ave"] = package.pop("Validation_ave_miou")
    except:
        pass

    save_model(package, pretrained_model_path)
    #pdb.set_trace()

    #pdb.set_trace()
    # save_model(package,root,name)

    # if(model == 'pointnet'):
    #     #add args
    #     model = pointnet.Pointnet_sem_seg(k=2,num_channel=opt.num_channel)
    # elif(model == 'pointnetpp'):
    #     print()
    # elif(model == 'deepgcn'):
    #     print()
    # elif(model == 'dgcnn'):
    #     print()
    #pdb.set_trace()
    model.load_state_dict(state_dict)
    model.cuda()

    print('----------------------Load Dataset----------------------')
    print('Root of dataset: ', opt.dataset_root)
    print('Phase: ', opt.phase)
    print('debug: ', opt.debug)

    print('opt.model', opt.model)
    print(opt.model == 'deepgcn')
    if (opt.model != 'deepgcn'):
        test_dataset = BigredDataSet(root=opt.dataset_root,
                                     is_train=False,
                                     is_validation=False,
                                     is_test=True,
                                     num_channel=opt.num_channel,
                                     test_code=opt.debug,
                                     including_ring=opt.including_ring,
                                     file_name=opt.file_name)
        result_sheet = test_dataset.result_sheet
        file_dict = test_dataset.file_dict
        tag_Getter = tag_getter(file_dict)
        testloader = torch.utils.data.DataLoader(test_dataset,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 drop_last=True,
                                                 num_workers=int(
                                                     opt.num_workers))
    else:
        test_dataset = BigredDataSetPTG(
            root=opt.dataset_root,
            is_train=False,
            is_validation=False,
            is_test=True,
            num_channel=opt.num_channel,
            new_dataset=True,
            test_code=opt.debug,
            pre_transform=torch_geometric.transforms.NormalizeScale(),
            file_name=opt.file_name)
        result_sheet = test_dataset.result_sheet
        file_dict = test_dataset.file_dict
        print(file_dict)
        tag_Getter = tag_getter(file_dict)

        testloader = DenseDataLoader(test_dataset,
                                     batch_size=1,
                                     shuffle=False,
                                     num_workers=opt.num_workers)

    print('num_frame: ', len(test_dataset))
    print('batch_size: ', opt.batch_size)
    print('num_batch: ', int(len(testloader) / opt.batch_size))

    print('----------------------Testing----------------------')
    metrics_list = [
        'Miou', 'Biou', 'Fiou', 'test_loss', 'OA', 'time_complexicity',
        'storage_complexicity'
    ]
    print(result_sheet)
    for name in result_sheet:
        metrics_list.append(name)
    print(metrics_list)
    manager = metrics_manager(metrics_list)

    model.eval()
    wandb.init(project="Test", name=package['name'])
    wandb.config.update(opt)

    prediction_set = []

    with torch.no_grad():
        for j, data in tqdm(enumerate(testloader),
                            total=len(testloader),
                            smoothing=0.9):
            #pdb.set_trace()
            if (opt.model == 'deepgcn'):
                points = torch.cat((data.pos.transpose(
                    2, 1).unsqueeze(3), data.x.transpose(2, 1).unsqueeze(3)),
                                   1)
                points = points[:, :opt.num_channel, :, :].cuda()
                target = data.y.cuda()
            else:
                points, target = data
                #target.shape [B,N]
                #points.shape [B,N,C]
                points, target = points.cuda(), target.cuda()

            torch.cuda.synchronize()
            since = int(round(time.time() * 1000))
            pred_mics = model(points)
            torch.cuda.synchronize()
            #compute inference time complexity
            time_complexity = int(round(time.time() * 1000)) - since

            #print(time_complexity)

            #pred_mics[0] is pred
            #pred_mics[1] is feat [only pointnet and pointnetpp has it]

            #compute loss
            test_loss = 0

            #pred.shape [B,N,2] since pred returned pass F.log_softmax
            pred, target, points = pred_mics[0].cpu(), target.cpu(
            ), points.cpu()

            #pred:[B,N,2]->[B,N]
            # pdb.set_trace()
            pred = pred.data.max(dim=2)[1]
            prediction_set.append(pred)
            #compute confusion matrix
            cm = confusion_matrix(pred, target, num_classes=2).sum(dim=0)
            #compute OA
            overall_correct_site = torch.diag(cm).sum()
            overall_reference_site = cm.sum()
            assert overall_reference_site == opt.batch_size * opt.num_points, "Confusion_matrix computing error"
            oa = float(overall_correct_site / overall_reference_site)

            #compute iou
            Biou, Fiou = mean_iou(pred, target, num_classes=2).mean(dim=0)
            miou = (Biou + Fiou) / 2

            #compute inference storage complexsity
            num_device = torch.cuda.device_count()
            assert num_device == opt.num_gpu, "opt.num_gpu NOT equals torch.cuda.device_count()"
            temp = []
            for k in range(num_device):
                temp.append(torch.cuda.memory_allocated(k))
            RAM_usagePeak = torch.tensor(temp).float().mean()
            #writeup logger
            # metrics_list = ['test_loss','OA','Biou','Fiou','Miou','time_complexicity','storage_complexicity']
            manager.update('test_loss', test_loss)
            manager.update('OA', oa)
            manager.update('Biou', Biou.item())
            manager.update('Fiou', Fiou.item())
            manager.update('Miou', miou.item())
            manager.update('time_complexicity', time_complexity)
            manager.update('storage_complexicity', RAM_usagePeak.item())
            #get tags,compute the save miou for corresponding class
            difficulty, location, isSingle, file_name = tag_Getter.get_difficulty_location_isSingle(
                j)
            manager.update(file_name, miou.item())
            manager.update(difficulty, miou.item())
            manager.update(isSingle, miou.item())

    prediction_set = np.concatenate(prediction_set, axis=0)
    point_set, label_set, ermgering_set = test_dataset.getVis(prediction_set)
    #pdb.set_trace()

    experiment_dir = Path('visulization_data/' + opt.model)
    experiment_dir.mkdir(exist_ok=True)

    root = 'visulization_data/' + opt.model

    with open(root + '/point_set.npy', 'wb') as f:
        np.save(f, point_set)
    with open(root + '/label_set.npy', 'wb') as f:
        np.save(f, label_set)
    with open(root + '/ermgering_set.npy', 'wb') as f:
        np.save(f, ermgering_set)
    with open(root + '/prediction_set.npy', 'wb') as f:
        np.save(f, prediction_set)

    summery_dict = manager.summary()
    generate_report(summery_dict, package)
    wandb.log(summery_dict)
Ejemplo n.º 4
0
def main():
    setSeed(10)
    opt = opt_global_inti()
    print('----------------------Load ckpt----------------------')
    pretrained_model_path = os.path.join(opt.load_pretrain, 'best_model.pth')
    package = torch.load(pretrained_model_path)
    para_state_dict = package['state_dict']
    opt.num_channel = package['num_channel']
    opt.time = package['time']
    opt.epoch_ckpt = package['epoch']

    # opt.val_miou = package['validation_mIoU']
    # package.pop('validation_mIoU')
    # package['Validation_ave_miou'] = opt.val_miou

    # num_gpu = package['gpuNum']
    # package.pop('gpuNum')
    # package['num_gpu'] = num_gpu

    # save_model(package,pretrained_model_path)
    state_dict = convert_state_dict(para_state_dict)

    ckpt_, ckpt_file_name = opt.load_pretrain.split("/")
    module_name = ckpt_ + '.' + ckpt_file_name + '.' + 'model'
    MODEL = importlib.import_module(module_name)
    # print('opt.num_channel: ',opt.num_channel)
    model = MODEL.get_model(input_channel=opt.num_channel)
    Model_Specification = MODEL.get_model_name(input_channel=opt.num_channel)
    print('----------------------Test Model----------------------')
    print('Root of prestrain model: ', pretrained_model_path)
    print('Model: ', opt.model)
    print('Pretrained model name: ', Model_Specification)
    print('Trained Date: ', opt.time)
    print('num_channel: ', opt.num_channel)
    name = input("Edit the name or press ENTER to skip: ")
    if (name != ''):
        opt.model_name = name
    else:
        opt.model_name = Model_Specification
    print('Pretrained model name: ', opt.model_name)
    package['name'] = opt.model_name
    save_model(package, pretrained_model_path)
    # pdb.set_trace()
    # save_model(package,root,name)

    # if(model == 'pointnet'):
    #     #add args
    #     model = pointnet.Pointnet_sem_seg(k=2,num_channel=opt.num_channel)
    # elif(model == 'pointnetpp'):
    #     print()
    # elif(model == 'deepgcn'):
    #     print()
    # elif(model == 'dgcnn'):
    #     print()

    model.load_state_dict(state_dict)
    model.cuda()

    print('----------------------Load Dataset----------------------')
    print('Root of dataset: ', opt.dataset_root)
    print('Phase: ', opt.phase)
    print('debug: ', opt.debug)

    test_dataset = BigredDataSet(root=opt.dataset_root,
                                 is_train=False,
                                 is_validation=False,
                                 is_test=True,
                                 num_channel=opt.num_channel,
                                 test_code=opt.debug,
                                 including_ring=opt.including_ring)
    result_sheet = test_dataset.result_sheet
    file_dict = test_dataset.file_dict
    tag_Getter = tag_getter(file_dict)

    testloader = torch.utils.data.DataLoader(test_dataset,
                                             batch_size=opt.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             drop_last=True,
                                             num_workers=int(opt.num_workers))

    print('num_frame: ', len(test_dataset))
    print('batch_size: ', opt.batch_size)
    print('num_batch: ', int(len(testloader) / opt.batch_size))

    print('----------------------Testing----------------------')
    metrics_list = [
        'Miou', 'Biou', 'Fiou', 'test_loss', 'OA', 'time_complexicity',
        'storage_complexicity'
    ]
    for name in result_sheet:
        metrics_list.append(name)

    manager = metrics_manager(metrics_list)

    model.eval()
    wandb.init(project="Test", name=package['name'])
    wandb.config.update(opt)

    points_gt_list = []
    points_pd_list = []
    points_important_list = []

    with torch.no_grad():
        for j, data in tqdm(enumerate(testloader),
                            total=len(testloader),
                            smoothing=0.9):
            points, target = data
            #target.shape [B,N]
            #points.shape [B,N,C]
            points, target = points.cuda(), target.cuda()
            tic = time.perf_counter()
            pred_mics = model(points)
            toc = time.perf_counter()
            #pred_mics[0] is pred
            #pred_mics[1] is feat [only pointnet and pointnetpp has it]

            #compute loss
            test_loss = 0

            #pred.shape [B,N,2] since pred returned pass F.log_softmax
            pred, target, points = pred_mics[0].cpu(), target.cpu(
            ), points.cpu()
            imp_glob = pred_mics[2].cpu()

            #pred:[B,N,2]->[B,N]
            # pdb.set_trace()
            pred = pred.data.max(dim=2)[1]
            #compute confusion matrix
            cm = confusion_matrix(pred, target, num_classes=2).sum(dim=0)
            #compute OA
            overall_correct_site = torch.diag(cm).sum()
            overall_reference_site = cm.sum()
            assert overall_reference_site == opt.batch_size * opt.num_points, "Confusion_matrix computing error"
            oa = float(overall_correct_site / overall_reference_site)

            #compute iou
            Biou, Fiou = mean_iou(pred, target, num_classes=2).mean(dim=0)
            miou = (Biou + Fiou) / 2

            #compute inference time complexity
            time_complexity = toc - tic

            #compute inference storage complexsity
            num_device = torch.cuda.device_count()
            assert num_device == opt.num_gpu, "opt.num_gpu NOT equals torch.cuda.device_count()"
            temp = []
            for k in range(num_device):
                temp.append(torch.cuda.memory_allocated(k))
            RAM_usagePeak = torch.tensor(temp).float().mean()
            #writeup logger
            # metrics_list = ['test_loss','OA','Biou','Fiou','Miou','time_complexicity','storage_complexicity']
            manager.update('test_loss', test_loss)
            manager.update('OA', oa)
            manager.update('Biou', Biou.item())
            manager.update('Fiou', Fiou.item())
            manager.update('Miou', miou.item())
            manager.update('time_complexicity', float(1 / time_complexity))
            manager.update('storage_complexicity', RAM_usagePeak.item())
            #get tags,compute the save miou for corresponding class
            difficulty, location, isSingle, file_name = tag_Getter.get_difficulty_location_isSingle(
                j)
            manager.update(file_name, miou.item())
            manager.update(difficulty, miou.item())
            manager.update(isSingle, miou.item())

            dim_num = points.shape[2]
            points = points.view(-1, dim_num).numpy()
            pred = pred.view(-1, 1).numpy()
            target = target.view(-1, 1).numpy()
            imp_glob = imp_glob.view(-1, )

            number_sheet, _, bin_sheet = torch.unique(imp_glob,
                                                      sorted=True,
                                                      return_inverse=True,
                                                      return_counts=True,
                                                      dim=None)

            temp_arr = np.zeros(len(target))

            temp_arr[number_sheet] = bin_sheet

            temp_arr = temp_arr.reshape(-1, 1)
            points_gt = np.concatenate((points[:, [0, 1, 2]], target), axis=1)
            points_pd = np.concatenate((points[:, [0, 1, 2]], pred), axis=1)
            points_important = np.concatenate((points[:, [0, 1, 2]], temp_arr),
                                              axis=1)

            if (opt.including_ring):
                temp_arr2 = np.zeros(len(target))
                imp_ring = pred_mics[3].cpu()
                imp_ring = imp_ring.view(-1, )
                number_sheet, _, bin_sheet = torch.unique(imp_ring,
                                                          sorted=True,
                                                          return_inverse=True,
                                                          return_counts=True,
                                                          dim=None)
                temp_arr2[number_sheet] = bin_sheet
                temp_arr2 = temp_arr2.reshape(-1, 1)
                points_important = np.concatenate(
                    (points_important, temp_arr2), axis=1)

            points_gt_list.append(points_gt)
            points_pd_list.append(points_pd)
            points_important_list.append(points_important)

            # visualize_wandb(points,pred,target,index_important)
            # pdb.set_trace()

    f = h5py.File('resluts.h5', 'w')
    f.create_dataset('points_gt_list', data=np.array(points_gt_list))
    f.create_dataset('points_pd_list', data=np.array(points_pd_list))
    f.create_dataset('points_important_list',
                     data=np.array(points_important_list))
    f.close()

    summery_dict = manager.summary()
    generate_report(summery_dict, package)
    wandb.log(summery_dict)
Ejemplo n.º 5
0
def main():
    setSeed(10)
    opt = opt_global_inti()
    print('----------------------Load ckpt----------------------')
    pretrained_model_path = os.path.join(opt.load_pretrain, 'best_model.pth')
    package = torch.load(pretrained_model_path)
    para_state_dict = package['state_dict']
    opt.num_channel = package['num_channel']
    opt.time = package['time']
    opt.epoch_ckpt = package['epoch']
    #pdb.set_trace()
    state_dict = convert_state_dict(para_state_dict)

    ckpt_, ckpt_file_name = opt.load_pretrain.split("/")
    module_name = ckpt_ + '.' + ckpt_file_name + '.' + 'model'
    MODEL = importlib.import_module(module_name)
    opt_deepgcn = []
    print(opt.model)
    if (opt.model == 'deepgcn'):
        opt_deepgcn = OptInit_deepgcn().initialize()
        model = MODEL.get_model(opt2=opt_deepgcn,
                                input_channel=opt.num_channel)
    else:
        # print('opt.num_channel: ',opt.num_channel)
        model = MODEL.get_model(input_channel=opt.num_channel)
    Model_Specification = MODEL.get_model_name(input_channel=opt.num_channel)
    f_loss = MODEL.get_loss(input_channel=opt.num_channel)

    print('----------------------Test Model----------------------')
    print('Root of prestrain model: ', pretrained_model_path)
    print('Model: ', opt.model)
    print('Pretrained model name: ', Model_Specification)
    print('Trained Date: ', opt.time)
    print('num_channel: ', opt.num_channel)
    name = input("Edit the name or press ENTER to skip: ")
    if (name != ''):
        opt.model_name = name
    else:
        opt.model_name = Model_Specification
    print('Pretrained model name: ', opt.model_name)
    package['name'] = opt.model_name
    save_model(package, pretrained_model_path)

    print(
        '----------------------Configure optimizer and scheduler----------------------'
    )
    experiment_dir = Path('ckpt/')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath(opt.model_name)
    experiment_dir.mkdir(exist_ok=True)

    experiment_dir = experiment_dir.joinpath('saves')
    experiment_dir.mkdir(exist_ok=True)
    opt.save_root = str(experiment_dir)

    model.ini_ft()
    model.frozen_ft()

    if (opt.apex == True):
        # model = apex.parallel.convert_syncbn_model(model)
        model.cuda()
        f_loss.cuda()

        optimizer = optim.Adam(model.parameters(),
                               lr=0.001,
                               betas=(0.9, 0.999))
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=1,
                                              gamma=0.1)
        model, optimizer = amp.initialize(model, optimizer, opt_level="O2")
        model = torch.nn.DataParallel(model, device_ids=[0, 1])
    else:
        # model = apex.parallel.convert_syncbn_model(model)
        model = torch.nn.DataParallel(model)
        model.cuda()
        f_loss.cuda()
        # optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
        # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
        # optimizer = package['optimizer']
        # scheduler = package['scheduler']

    # optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
    # optimizer_dict = package['optimizer'].state_dict()
    # optimizer.load_state_dict(optimizer_dict)
    # scheduler = package['scheduler']

    optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

    print('----------------------Load Dataset----------------------')
    print('Root of dataset: ', opt.dataset_root)
    print('Phase: ', opt.phase)
    print('debug: ', opt.debug)

    if (opt.model != 'deepgcn'):
        train_dataset = BigredDataSet_finetune(
            root=opt.dataset_root,
            is_train=True,
            is_validation=False,
            is_test=False,
            num_channel=opt.num_channel,
            test_code=opt.debug,
            including_ring=opt.including_ring)

        f_loss.load_weight(train_dataset.labelweights)

        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=opt.batch_size,
                                                   shuffle=True,
                                                   pin_memory=True,
                                                   drop_last=True,
                                                   num_workers=int(
                                                       opt.num_workers))

        validation_dataset = BigredDataSet_finetune(
            root=opt.dataset_root,
            is_train=False,
            is_validation=True,
            is_test=False,
            num_channel=opt.num_channel,
            test_code=opt.debug,
            including_ring=opt.including_ring)
        validation_loader = torch.utils.data.DataLoader(
            validation_dataset,
            batch_size=opt.batch_size,
            shuffle=False,
            pin_memory=True,
            drop_last=True,
            num_workers=int(opt.num_workers))
    else:
        train_dataset = BigredDataSetPTG(
            root=opt.dataset_root,
            is_train=True,
            is_validation=False,
            is_test=False,
            num_channel=opt.num_channel,
            new_dataset=False,
            test_code=opt.debug,
            pre_transform=torch_geometric.transforms.NormalizeScale())
        train_loader = DenseDataLoader(train_dataset,
                                       batch_size=opt.batch_size,
                                       shuffle=True,
                                       num_workers=opt.num_workers)
        validation_dataset = BigredDataSetPTG(
            root=opt.dataset_root,
            is_train=False,
            is_validation=True,
            is_test=False,
            new_dataset=False,
            test_code=opt.debug,
            num_channel=opt.num_channel,
            pre_transform=torch_geometric.transforms.NormalizeScale())
        validation_loader = DenseDataLoader(validation_dataset,
                                            batch_size=opt.batch_size,
                                            shuffle=False,
                                            num_workers=opt.num_workers)

        labelweights = np.zeros(2)
        labelweights, _ = np.histogram(train_dataset.data.y.numpy(), range(3))
        labelweights = labelweights.astype(np.float32)
        labelweights = labelweights / np.sum(labelweights)
        labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0)
        weights = torch.Tensor(labelweights).cuda()
        f_loss.load_weight(weights)

    print('train dataset num_frame: ', len(train_dataset))
    print('num_batch: ', int(len(train_loader) / opt.batch_size))

    print('validation dataset num_frame: ', len(validation_dataset))
    print('num_batch: ', int(len(validation_loader) / opt.batch_size))

    print('Batch_size: ', opt.batch_size)

    print('----------------------Prepareing Training----------------------')
    metrics_list = [
        'Miou', 'Biou', 'Fiou', 'loss', 'OA', 'time_complexicity',
        'storage_complexicity'
    ]
    manager_test = metrics_manager(metrics_list)

    metrics_list_train = [
        'Miou', 'Biou', 'Fiou', 'loss', 'storage_complexicity',
        'time_complexicity'
    ]
    manager_train = metrics_manager(metrics_list_train)

    wandb.init(project=opt.wd_project, name=opt.model_name, resume=False)
    if (opt.wandb_history == False):
        best_value = 0
    else:
        temp = wandb.restore('best_model.pth', run_path=opt.wandb_id)
        best_value = torch.load(temp.name)['Miou_validation_ave']

    best_value = 0
    wandb.config.update(opt)

    if opt.epoch_ckpt == 0:
        opt.unsave_epoch = 0
    else:
        opt.epoch_ckpt = opt.epoch_ckpt + 1

    # pdb.set_trace()
    for epoch in range(opt.epoch_ckpt, opt.epoch_max):
        manager_train.reset()
        model.train()
        tic_epoch = time.perf_counter()
        print('---------------------Training----------------------')
        print("Epoch: ", epoch)
        for i, data in tqdm(enumerate(train_loader),
                            total=len(train_loader),
                            smoothing=0.9):

            if (opt.model == 'deepgcn'):
                points = torch.cat((data.pos.transpose(
                    2, 1).unsqueeze(3), data.x.transpose(2, 1).unsqueeze(3)),
                                   1)
                points = points[:, :opt.num_channel, :, :]
                target = data.y.cuda()
            else:
                points, target = data
                #target.shape [B,N]
                #points.shape [B,N,C]
                points, target = points.cuda(non_blocking=True), target.cuda(
                    non_blocking=True)

            # pdb.set_trace()
            #training...
            optimizer.zero_grad()
            tic = time.perf_counter()
            pred_mics = model(points)
            toc = time.perf_counter()
            #compute loss

            #For loss
            #target.shape [B,N] ->[B*N]
            #pred.shape [B,N,2]->[B*N,2]
            #pdb.set_trace()

            #pdb.set_trace()
            loss = f_loss(pred_mics, target)

            if (opt.apex):
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            optimizer.step()

            #pred.shape [B,N,2] since pred returned pass F.log_softmax
            pred, target = pred_mics[0].cpu(), target.cpu()

            #pred:[B,N,2]->[B,N]
            #pdb.set_trace()
            pred = pred.data.max(dim=2)[1]

            #compute iou
            Biou, Fiou = mean_iou(pred, target, num_classes=2).mean(dim=0)
            miou = (Biou + Fiou) / 2

            #compute Training time complexity
            time_complexity = toc - tic

            #compute Training storage complexsity
            num_device = torch.cuda.device_count()
            assert num_device == opt.num_gpu, "opt.num_gpu NOT equals torch.cuda.device_count()"
            temp = []
            for k in range(num_device):
                temp.append(torch.cuda.memory_allocated(k))
            RAM_usagePeak = torch.tensor(temp).float().mean()

            #print(loss.item())
            #print(miou.item())
            #writeup logger
            manager_train.update('loss', loss.item())
            manager_train.update('Biou', Biou.item())
            manager_train.update('Fiou', Fiou.item())
            manager_train.update('Miou', miou.item())
            manager_train.update('time_complexicity',
                                 float(1 / time_complexity))
            manager_train.update('storage_complexicity', RAM_usagePeak.item())

            log_dict = {
                'loss_online': loss.item(),
                'Biou_online': Biou.item(),
                'Fiou_online': Fiou.item(),
                'Miou_online': miou.item(),
                'time_complexicity_online': float(1 / time_complexity),
                'storage_complexicity_online': RAM_usagePeak.item()
            }
            if (epoch - opt.unsave_epoch >= 0):
                wandb.log(log_dict)

        toc_epoch = time.perf_counter()
        time_tensor = toc_epoch - tic_epoch

        summery_dict = manager_train.summary()
        log_train_end = {}
        for key in summery_dict:
            log_train_end[key + '_train_ave'] = summery_dict[key]
            print(key + '_train_ave: ', summery_dict[key])

        log_train_end['Time_PerEpoch'] = time_tensor
        if (epoch - opt.unsave_epoch >= 0):
            wandb.log(log_train_end)
        else:
            print(
                'No data upload to wandb. Start upload: Epoch[%d] Current: Epoch[%d]'
                % (opt.unsave_epoch, epoch))

        scheduler.step()
        if (epoch % 10 == 1):
            print('---------------------Validation----------------------')
            manager_test.reset()
            model.eval()
            print("Epoch: ", epoch)
            with torch.no_grad():
                for j, data in tqdm(enumerate(validation_loader),
                                    total=len(validation_loader),
                                    smoothing=0.9):

                    if (opt.model == 'deepgcn'):
                        points = torch.cat(
                            (data.pos.transpose(2, 1).unsqueeze(3),
                             data.x.transpose(2, 1).unsqueeze(3)), 1)
                        points = points[:, :opt.num_channel, :, :]
                        target = data.y.cuda()
                    else:
                        points, target = data
                        #target.shape [B,N]
                        #points.shape [B,N,C]
                        points, target = points.cuda(
                            non_blocking=True), target.cuda(non_blocking=True)

                    tic = time.perf_counter()
                    pred_mics = model(points)
                    toc = time.perf_counter()

                    #pred.shape [B,N,2] since pred returned pass F.log_softmax
                    pred, target = pred_mics[0].cpu(), target.cpu()

                    #compute loss
                    test_loss = 0

                    #pred:[B,N,2]->[B,N]
                    pred = pred.data.max(dim=2)[1]
                    #compute confusion matrix
                    cm = confusion_matrix(pred, target,
                                          num_classes=2).sum(dim=0)
                    #compute OA
                    overall_correct_site = torch.diag(cm).sum()
                    overall_reference_site = cm.sum()
                    # if(overall_reference_site != opt.batch_size * opt.num_points):
                    #pdb.set_trace()
                    #assert overall_reference_site == opt.batch_size * opt.num_points,"Confusion_matrix computing error"
                    oa = float(overall_correct_site / overall_reference_site)

                    #compute iou
                    Biou, Fiou = mean_iou(pred, target,
                                          num_classes=2).mean(dim=0)
                    miou = (Biou + Fiou) / 2

                    #compute inference time complexity
                    time_complexity = toc - tic

                    #compute inference storage complexsity
                    num_device = torch.cuda.device_count()
                    assert num_device == opt.num_gpu, "opt.num_gpu NOT equals torch.cuda.device_count()"
                    temp = []
                    for k in range(num_device):
                        temp.append(torch.cuda.memory_allocated(k))
                    RAM_usagePeak = torch.tensor(temp).float().mean()
                    #writeup logger
                    # metrics_list = ['test_loss','OA','Biou','Fiou','Miou','time_complexicity','storage_complexicity']
                    manager_test.update('loss', test_loss)
                    manager_test.update('OA', oa)
                    manager_test.update('Biou', Biou.item())
                    manager_test.update('Fiou', Fiou.item())
                    manager_test.update('Miou', miou.item())
                    manager_test.update('time_complexicity',
                                        float(1 / time_complexity))
                    manager_test.update('storage_complexicity',
                                        RAM_usagePeak.item())

            summery_dict = manager_test.summary()

            log_val_end = {}
            for key in summery_dict:
                log_val_end[key + '_validation_ave'] = summery_dict[key]
                print(key + '_validation_ave: ', summery_dict[key])

            package = dict()
            package['state_dict'] = model.state_dict()
            package['scheduler'] = scheduler
            package['optimizer'] = optimizer
            package['epoch'] = epoch

            opt_temp = vars(opt)
            for k in opt_temp:
                package[k] = opt_temp[k]
            if (opt_deepgcn is None):
                opt_temp = vars(opt_deepgcn)
                for k in opt_temp:
                    package[k + '_opt2'] = opt_temp[k]

            for k in log_val_end:
                package[k] = log_val_end[k]

            save_root = opt.save_root + '/val_miou%.4f_Epoch%s.pth' % (
                package['Miou_validation_ave'], package['epoch'])
            torch.save(package, save_root)

            print('Is Best?: ', (package['Miou_validation_ave'] > best_value))
            if (package['Miou_validation_ave'] > best_value):
                best_value = package['Miou_validation_ave']
                save_root = opt.save_root + '/best_model.pth'
                torch.save(package, save_root)
            if (epoch - opt.unsave_epoch >= 0):
                wandb.log(log_val_end)
            else:
                print(
                    'No data upload to wandb. Start upload: Epoch[%d] Current: Epoch[%d]'
                    % (opt.unsave_epoch, epoch))
            if (opt.debug == True):
                pdb.set_trace()