Ejemplo n.º 1
0
def inference_net(cfg):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    # Set up data loader
    dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
        cfg.DATASET.TEST_DATASET](cfg)
    test_data_loader = torch.utils.data.DataLoader(
        dataset=dataset_loader.get_dataset(
            utils.data_loaders.DatasetSubset.TEST),
        batch_size=1,
        num_workers=cfg.CONST.NUM_WORKERS,
        collate_fn=utils.data_loaders.collate_fn,
        pin_memory=True,
        shuffle=False)

    # Setup networks and initialize networks
    grnet = GRNet(cfg)

    if torch.cuda.is_available():
        grnet = torch.nn.DataParallel(grnet).cuda()

    # Load the pretrained model from a checkpoint
    logging.info('Recovering from %s ...' % (cfg.CONST.WEIGHTS))
    checkpoint = torch.load(cfg.CONST.WEIGHTS)
    grnet.load_state_dict(checkpoint['grnet'])

    # Switch models to evaluation mode
    grnet.eval()

    # The inference loop
    n_samples = len(test_data_loader)
    for model_idx, (taxonomy_id, model_id,
                    data) in enumerate(test_data_loader):
        taxonomy_id = taxonomy_id[0] if isinstance(
            taxonomy_id[0], str) else taxonomy_id[0].item()
        model_id = model_id[0]

        with torch.no_grad():
            for k, v in data.items():
                data[k] = utils.helpers.var_or_cuda(v)

            sparse_ptcloud, dense_ptcloud = grnet(data)
            output_folder = os.path.join(cfg.DIR.OUT_PATH, 'benchmark',
                                         taxonomy_id)
            if not os.path.exists(output_folder):
                os.makedirs(output_folder)

            output_file_path = os.path.join(output_folder, '%s.h5' % model_id)
            utils.io.IO.put(output_file_path,
                            dense_ptcloud.squeeze().cpu().numpy())

            logging.info('Test[%d/%d] Taxonomy = %s Sample = %s File = %s' %
                         (model_idx + 1, n_samples, taxonomy_id, model_id,
                          output_file_path))
        now_points = now_points * 2
    if now_points > num_points:
        idx_selected = np.arange(now_points)
        np.random.shuffle(idx_selected)
        full_part_pc = full_part_pc[idx_selected[:num_points]]
    return full_part_pc


chamferLoss = ChamferDistance()
chamfer_dists = []
avg_chamfer_dist = []

n_points = 2048
n_shape = 1

grnet = GRNet(cfg)
grnet.eval()

# ShapeNet: pred & gt 都是 16384, 不用rescale
# 但是16384的维度 可能会导致cd值大?

for view in range(1):
    # print("------------------- view: %d ---------------------" % view)
    for root, dirs, files in os.walk(pred_path):
        len_files = len(files)
        all_gt = np.zeros((len_files, n_points, 3))
        all_pred = np.zeros((len_files, n_points, 3))
        idx = -1
        tot = 0

        pred_batch = np.zeros((1, n_points, 3))
Ejemplo n.º 3
0
def train_net(cfg):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    # Set up data loader
    # choose ShapeNet
    train_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
        cfg.DATASET.TRAIN_DATASET](cfg)
    test_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
        cfg.DATASET.TEST_DATASET](cfg)
    # get_dataset's para: subdataset(train0, test1, val2)
    train_data_loader = torch.utils.data.DataLoader(
        dataset=train_dataset_loader.get_dataset(
            utils.data_loaders.DatasetSubset.TRAIN),  # train/test/val
        batch_size=cfg.TRAIN.BATCH_SIZE,
        num_workers=cfg.CONST.NUM_WORKERS,
        collate_fn=utils.data_loaders.collate_fn,
        pin_memory=True,
        shuffle=True,
        drop_last=True)
    val_data_loader = torch.utils.data.DataLoader(
        dataset=test_dataset_loader.get_dataset(
            utils.data_loaders.DatasetSubset.VAL),
        batch_size=1,
        num_workers=cfg.CONST.NUM_WORKERS,
        collate_fn=utils.data_loaders.collate_fn,
        pin_memory=True,
        shuffle=False)

    # Set up folders for logs and checkpoints
    output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s',
                              datetime.now().isoformat())  # output_dir
    cfg.DIR.CHECKPOINTS = output_dir % 'checkpoints'
    cfg.DIR.LOGS = output_dir % 'logs'
    txt_dir = output_dir % 'txt'
    if not os.path.exists(txt_dir):
        os.makedirs(txt_dir)
    f_record = open(txt_dir + '/record.txt', 'w')
    if not os.path.exists(cfg.DIR.CHECKPOINTS):
        os.makedirs(cfg.DIR.CHECKPOINTS)

    # Create tensorboard writers
    train_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'train'))
    val_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'test'))

    # Create the networks
    grnet = GRNet(cfg)
    grnet.apply(utils.helpers.init_weights)
    logging.debug('Parameters in GRNet: %d.' %
                  utils.helpers.count_parameters(grnet))

    # Move the network to GPU if possible
    if torch.cuda.is_available():
        grnet = torch.nn.DataParallel(grnet).cuda()

    # Create the optimizers
    grnet_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              grnet.parameters()),
                                       lr=cfg.TRAIN.LEARNING_RATE,
                                       weight_decay=cfg.TRAIN.WEIGHT_DECAY,
                                       betas=cfg.TRAIN.BETAS)
    grnet_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        grnet_optimizer,
        milestones=cfg.TRAIN.LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA)

    # Set up loss functions
    chamfer_dist = ChamferDistance()
    gridding_loss = GriddingLoss(  # lgtm [py/unused-local-variable]
        scales=cfg.NETWORK.GRIDDING_LOSS_SCALES,
        alphas=cfg.NETWORK.GRIDDING_LOSS_ALPHAS)

    # Load pretrained model if exists
    init_epoch = 0  # 断点续跑
    best_metrics = None
    if 'WEIGHTS' in cfg.CONST:
        logging.info('Recovering from %s ...' % (cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)
        best_metrics = Metrics(cfg.TEST.METRIC_NAME,
                               checkpoint['best_metrics'])
        grnet.load_state_dict(checkpoint['grnet'])
        logging.info(
            'Recover complete. Current epoch = #%d; best metrics = %s.' %
            (init_epoch, best_metrics))

    # Training/Testing the network
    first_epoch = True
    for epoch_idx in range(init_epoch + 1, cfg.TRAIN.N_EPOCHS + 1):
        epoch_start_time = time()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter(['SparseLoss', 'DenseLoss'])
        # losses = AverageMeter(['GridLoss', 'DenseLoss'])

        grnet.train()

        batch_end_time = time()
        n_batches = len(train_data_loader)
        for batch_idx, (taxonomy_ids, model_ids,
                        data) in enumerate(train_data_loader):
            # print('batch_size: ', data['partial_cloud'].shape)
            data_time.update(time() - batch_end_time)
            for k, v in data.items():
                data[k] = utils.helpers.var_or_cuda(v)
            sparse_ptcloud, dense_ptcloud = grnet(data)
            sparse_loss = chamfer_dist(sparse_ptcloud, data['gtcloud'])
            # grid_loss = gridding_loss(dense_ptcloud, data['gtcloud'])
            dense_loss = chamfer_dist(dense_ptcloud, data['gtcloud'])
            _loss = sparse_loss + dense_loss
            losses.update(
                [sparse_loss.item() * 1000,
                 dense_loss.item() * 1000])
            # _loss = grid_loss + dense_loss
            # losses.update([grid_loss.item() * 1000, dense_loss.item() * 1000])

            grnet.zero_grad()
            _loss.backward()
            grnet_optimizer.step()

            n_itr = (epoch_idx - 1) * n_batches + batch_idx
            train_writer.add_scalar('Loss/Batch/Sparse',
                                    sparse_loss.item() * 1000, n_itr)
            # train_writer.add_scalar('Loss/Batch/Grid', grid_loss.item() * 1000, n_itr)
            train_writer.add_scalar('Loss/Batch/Dense',
                                    dense_loss.item() * 1000, n_itr)

            batch_time.update(time() - batch_end_time)
            batch_end_time = time()
            ###

            f_record.write(
                '\n[Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) Losses = %s'
                % (epoch_idx, cfg.TRAIN.N_EPOCHS, batch_idx + 1, n_batches,
                   batch_time.val(), data_time.val(),
                   ['%.4f' % l for l in losses.val()]))
            logging.info(
                '[Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) Losses = %s'
                % (epoch_idx, cfg.TRAIN.N_EPOCHS, batch_idx + 1, n_batches,
                   batch_time.val(), data_time.val(),
                   ['%.4f' % l for l in losses.val()]))

        grnet_lr_scheduler.step()
        epoch_end_time = time()
        train_writer.add_scalar('Loss/Epoch/Sparse', losses.avg(0), epoch_idx)
        # train_writer.add_scalar('Loss/Epoch/Grid', losses.avg(0), epoch_idx)
        train_writer.add_scalar('Loss/Epoch/Dense', losses.avg(1), epoch_idx)
        f_record.write('\n[Epoch %d/%d] EpochTime = %.3f (s) Losses = %s' %
                       (epoch_idx, cfg.TRAIN.N_EPOCHS, epoch_end_time -
                        epoch_start_time, ['%.4f' % l for l in losses.avg()]))
        logging.info('[Epoch %d/%d] EpochTime = %.3f (s) Losses = %s' %
                     (epoch_idx, cfg.TRAIN.N_EPOCHS, epoch_end_time -
                      epoch_start_time, ['%.4f' % l for l in losses.avg()]))

        # Validate the current model
        # if epoch_idx % cfg.TRAIN.SAVE_FREQ == 0:
        # metrics = test_net(cfg, epoch_idx, val_data_loader, val_writer, grnet)

        # Save ckeckpoints
        # if epoch_idx % cfg.TRAIN.SAVE_FREQ == 0 or metrics.better_than(best_metrics):

        if first_epoch:
            metrics = test_net(cfg, epoch_idx, val_data_loader, val_writer,
                               grnet)
            best_metrics = metrics
            first_epoch = False

        if epoch_idx % cfg.TRAIN.SAVE_FREQ == 0:
            metrics = test_net(cfg, epoch_idx, val_data_loader, val_writer,
                               grnet)
            file_name = 'best-ckpt.pth' if metrics.better_than(
                best_metrics) else 'epoch-%03d.pth' % (epoch_idx + 1)
            output_path = os.path.join(cfg.DIR.CHECKPOINTS, file_name)
            torch.save({
                'epoch_index': epoch_idx,
                'best_metrics': metrics.state_dict(),
                'grnet': grnet.state_dict()
            }, output_path)  # yapf: disable

            logging.info('Saved checkpoint to %s ...' % output_path)
            if metrics.better_than(best_metrics):
                best_metrics = metrics

    train_writer.close()
    val_writer.close()
Ejemplo n.º 4
0
def test_net(cfg,
             epoch_idx=-1,
             test_data_loader=None,
             test_writer=None,
             grnet=None):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    if test_data_loader is None:
        # Set up data loader
        dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
            cfg.DATASET.TEST_DATASET](cfg)
        # 在data_loader.py中修改这里的dataset值
        test_data_loader = torch.utils.data.DataLoader(
            dataset=dataset_loader.get_dataset(
                utils.data_loaders.DatasetSubset.TEST),
            batch_size=1,
            num_workers=cfg.CONST.NUM_WORKERS,
            collate_fn=utils.data_loaders.collate_fn,
            pin_memory=True,
            shuffle=False)

    # Setup networks and initialize networks
    if grnet is None:
        grnet = GRNet(cfg)

        if torch.cuda.is_available():
            grnet = torch.nn.DataParallel(grnet).cuda()

        logging.info('Recovering from %s ...' % (cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)
        grnet.load_state_dict(checkpoint['grnet'])

    # Switch models to evaluation mode
    grnet.eval()

    # Set up loss functions
    chamfer_dist = ChamferDistance()
    gridding_loss = GriddingLoss(
        scales=cfg.NETWORK.GRIDDING_LOSS_SCALES,
        alphas=cfg.NETWORK.GRIDDING_LOSS_ALPHAS)  # lgtm [py/unused-import]

    # Testing loop
    n_samples = len(test_data_loader)
    test_losses = AverageMeter(['SparseLoss', 'DenseLoss'])
    # test_losses = AverageMeter(['GridLoss', 'DenseLoss'])
    test_metrics = AverageMeter(Metrics.names())  # 'F-score, CD
    category_metrics = dict()

    # Testing loop
    # 通过data得到sparse_pucloud,  data from test_data_loader

    tot_recall, tot_precision, tot_emd = 0.0, 0.0, 0.0
    tot_shapes = 0

    score_dict = {}

    for model_idx, (taxonomy_id, model_id,
                    data) in enumerate(test_data_loader):
        taxonomy_id = taxonomy_id[0] if isinstance(
            taxonomy_id[0], str) else taxonomy_id[0].item()
        model_id = model_id[0]

        with torch.no_grad():
            for k, v in data.items():
                data[k] = utils.helpers.var_or_cuda(v)

            sparse_ptcloud, dense_ptcloud = grnet(data)
            # print('--------dense: ', type(dense_ptcloud), dense_ptcloud.shape)
            # print('--------gt: ', type(data['gtcloud']), data['gtcloud'.shape])
            sparse_loss = chamfer_dist(sparse_ptcloud, data['gtcloud'])
            # grid_loss = gridding_loss(dense_ptcloud, data['gtcloud'])
            dense_loss = chamfer_dist(dense_ptcloud, data['gtcloud'])

            # Fsore
            fscore_pred = o3d.geometry.PointCloud()
            # print(type(dense_ptcloud))
            # print(dense_ptcloud.shape)
            # print(data['gtcloud'].shape)
            # print(type(data['gtcloud']))
            fscore_pred.points = o3d.utility.Vector3dVector(
                np.array(dense_ptcloud.squeeze().cpu().detach().numpy()))
            fscore_gt = o3d.geometry.PointCloud()
            fscore_gt.points = o3d.utility.Vector3dVector(
                data['gtcloud'].squeeze().cpu().detach().numpy())

            dist1 = fscore_pred.compute_point_cloud_distance(fscore_gt)
            dist2 = fscore_gt.compute_point_cloud_distance(fscore_pred)

            th = 0.01
            recall = float(sum(d < th for d in dist2)) / float(len(dist2))
            precision = float(sum(d < th for d in dist1)) / float(len(dist1))
            tot_recall += recall
            tot_precision += precision

            # 计算EMD
            # dense_pts = np.array(dense_ptcloud.cpu())
            # num_points = dense_pts.shape[1]
            # EMD_loss = earth_mover_distance(dense_ptcloud, data['gtcloud'], transpose=False) / num_points
            # EMD_loss = EMD_loss.mean().item()
            # tot_emd += EMD_loss

            tot_shapes += 1

            # print('dense_pc: ', dense_ptcloud.shape,  type(dense_ptcloud))
            test_losses.update(
                [sparse_loss.item() * 1000,
                 dense_loss.item() * 1000])
            # test_losses.update([grid_loss.item() * 1000, dense_loss.item() * 1000])
            _metrics = Metrics.get(dense_ptcloud,
                                   data['gtcloud'])  # return: values
            test_metrics.update(_metrics)

            if taxonomy_id not in category_metrics:
                category_metrics[taxonomy_id] = AverageMeter(Metrics.names())
            category_metrics[taxonomy_id].update(_metrics)

            # train时不用存数据
            # 存 npz
            '''
            save_path = '/home2/wuruihai/GRNet_FILES/Results/Completion3D_grnet_chair_ep300_npz_16384d/'
            save_path2 = '/home2/wuruihai/GRNet_FILES/Results/Completion3D_grent_chair_ep300_npz_2048d/'

            part_name = 'part_7'

            # 只存了 final results (dense_ptcloud)
            save_npz_path = save_path + part_name + '/'
            save_npz_path2 = save_path2 + part_name + '/'
            if not os.path.exists(save_npz_path):
                os.makedirs(save_npz_path)
            if not os.path.exists(save_npz_path2):
                os.makedirs(save_npz_path2)

            dense_pts = np.array(dense_ptcloud.cpu())
            dense_pts2 = rescale_pc_parts(dense_pts, 2048) # rescale
            dense_pts /= 0.45  # 放大回我们的大小
            dense_pts2 /= 0.45
            np.savez(save_npz_path + '%s.npz' % model_id, pts = dense_pts)
            np.savez(save_npz_path2 + '%s.npz' % model_id, pts = dense_pts2)
            '''

            # 存npz (GRNet's data),  Completion3D, 没有part

            # save_path = '/home2/wuruihai/GRNet_FILES/Results/ShapeNet_grnet_pretrained_model_VAL_npz/'
            # if not os.path.exists(save_path):
            #     os.makedirs(save_path)
            # dense_pts = np.array(dense_ptcloud.cpu())
            # np.savez(save_path + '%s.npz' % model_id, pts=dense_pts)

            # 存scores为txt

            dense_pts = np.array(dense_ptcloud.cpu())
            CD_loss = dense_loss.item()

            num_points = dense_pts.shape[1]
            EMD_loss = earth_mover_distance(
                dense_ptcloud, data['gtcloud'], transpose=False) / num_points
            EMD_loss = EMD_loss.mean().item()

            fscore = 2 * recall * precision / (
                recall + precision) if recall + precision else 0

            score_dict[model_id] = (CD_loss, EMD_loss, precision, recall,
                                    fscore)
            # print(score_dict)

            # 存 png
            '''
            save_path = '/home2/wuruihai/GRNet_FILES/Results/Completion3D_GRNet_1003/'
            if not os.path.exists(save_path):
                os.makedirs(save_path)

            plt.figure()


            pc_ptcloud = data['partial_cloud'].squeeze().cpu().numpy()
            pc_ptcloud_img = utils.helpers.get_ptcloud_img(pc_ptcloud)
            matplotlib.image.imsave(save_path + '%s_1_pc.png' % model_id,
                                    pc_ptcloud_img)
        
            
            # sparse_ptcloud = sparse_ptcloud.squeeze().cpu().numpy()
            # sparse_ptcloud_img = utils.helpers.get_ptcloud_img(sparse_ptcloud)
            # matplotlib.image.imsave(save_path+'%s_sps.png' % model_id,
            #                         sparse_ptcloud_img)
            

            dense_ptcloud = dense_ptcloud.squeeze().cpu().numpy()
            dense_ptcloud_img = utils.helpers.get_ptcloud_img(dense_ptcloud)
            matplotlib.image.imsave(save_path + '%s_2_dns.png' % model_id,
                                    dense_ptcloud_img)

            
            gt_ptcloud = data['gtcloud'].squeeze().cpu().numpy()
            gt_ptcloud_img = utils.helpers.get_ptcloud_img(gt_ptcloud)
            matplotlib.image.imsave(save_path+'%s_3_gt.png' % model_id,
                                    gt_ptcloud_img)
            '''
            '''
            if model_idx in range(510, 600):

                now_num=model_idx-499
                # if test_writer is not None and model_idx < 3:
                # sparse_ptcloud = sparse_ptcloud.squeeze().cpu().numpy()
                sparse_ptcloud = sparse_ptcloud.squeeze().numpy()
                sparse_ptcloud_img = utils.helpers.get_ptcloud_img(sparse_ptcloud)
                matplotlib.image.imsave('/home2/wuruihai/GRNet_FILES/results2/%s_%s_sps.png'%(model_idx,model_id), sparse_ptcloud_img)

                # dense_ptcloud = dense_ptcloud.squeeze().cpu().numpy()
                dense_ptcloud = dense_ptcloud.squeeze().numpy()
                dense_ptcloud_img = utils.helpers.get_ptcloud_img(dense_ptcloud)
                matplotlib.image.imsave('/home2/wuruihai/GRNet_FILES/results2/%s_%s_dns.png' % (model_idx, model_id),
                                        dense_ptcloud_img)


                # gt_ptcloud = data['gtcloud'].squeeze().cpu().numpy()
                gt_ptcloud = data['gtcloud'].squeeze().numpy()
                gt_ptcloud_img = utils.helpers.get_ptcloud_img(gt_ptcloud)
                matplotlib.image.imsave('/home2/wuruihai/GRNet_FILES/results2/%s_%s_gt.png'%(model_idx,model_id), gt_ptcloud_img)

                cv.imwrite("/home2/wuruihai/GRNet_FILES/out3.png", sparse_ptcloud_img)
                im = Image.fromarray(sparse_ptcloud_img).convert('RGB')
                im.save("/home2/wuruihai/GRNet_FILES/out.jpeg")
            
                test_writer.add_image('Model%02d/SparseReconstruction' % model_idx, sparse_ptcloud_img, epoch_idx)
                dense_ptcloud = dense_ptcloud.squeeze().cpu().numpy()
                dense_ptcloud_img = utils.helpers.get_ptcloud_img(dense_ptcloud)
                test_writer.add_image('Model%02d/DenseReconstruction' % model_idx, dense_ptcloud_img, epoch_idx)

                gt_ptcloud = data['gtcloud'].squeeze().cpu().numpy()
                gt_ptcloud_img = utils.helpers.get_ptcloud_img(gt_ptcloud)
                test_writer.add_image('Model%02d/GroundTruth' % model_idx, gt_ptcloud_img, epoch_idx)
            '''

            logging.info(
                'Test[%d/%d] Taxonomy = %s Sample = %s Losses = %s Metrics = %s'
                %
                (model_idx + 1, n_samples, taxonomy_id, model_id,
                 ['%.4f' % l
                  for l in test_losses.val()], ['%.4f' % m for m in _metrics]))
    plt.show()
    plt.savefig('/raid/wuruihai/GRNet_FILES/results.png')
    # Print testing results
    print(
        '============================ TEST RESULTS ============================'
    )
    print('Taxonomy', end='\t')
    print('#Sample', end='\t')
    for metric in test_metrics.items:
        print(metric, end='\t')
    print()

    # 将CD, EMD存到txt中
    # print(score_dict)
    # fname = '/home2/wuruihai/GRNet_FILES/Results/ShapeNet_grnet_pretrained_model_VAL_scores.txt'
    # fw = open(fname, 'w')
    # # print(score_dict)
    # for idx in score_dict.keys():
    #     fw.write('%s\t%s\t%s\t%s\t%s\t%s\n' % (idx, score_dict[idx][0], score_dict[idx][1], score_dict[idx][2], score_dict[idx][3], score_dict[idx][4]))  # model_id \t CD \t EMD

    for taxonomy_id in category_metrics:
        print(taxonomy_id, end='\t')
        print(category_metrics[taxonomy_id].count(0), end='\t')
        for value in category_metrics[taxonomy_id].avg():
            print('%.4f' % value, end='\t')
        print()

    print('Overall', end='\t\t\t')
    for value in test_metrics.avg():
        print('%.4f' % value, end='\t')
    print('\n')

    print('recall: ', tot_recall / tot_shapes)
    print('precision: ', tot_precision / tot_shapes)
    # print('EMD: ', tot_emd / tot_shapes)

    # Add testing results to TensorBoard
    if test_writer is not None:
        # test_writer.add_scalar('Loss/Epoch/Sparse', test_losses.avg(0), epoch_idx)
        test_writer.add_scalar('Loss/Epoch/Grid', test_losses.avg(0),
                               epoch_idx)
        test_writer.add_scalar('Loss/Epoch/Dense', test_losses.avg(1),
                               epoch_idx)
        for i, metric in enumerate(test_metrics.items):
            test_writer.add_scalar('Metric/%s' % metric, test_metrics.avg(i),
                                   epoch_idx)

    return Metrics(cfg.TEST.METRIC_NAME, test_metrics.avg())
Ejemplo n.º 5
0
def test_net(cfg,
             epoch_idx=-1,
             test_data_loader=None,
             test_writer=None,
             grnet=None):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    if test_data_loader is None:
        # Set up data loader
        dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
            cfg.DATASET.TEST_DATASET](cfg)
        test_data_loader = torch.utils.data.DataLoader(
            dataset=dataset_loader.get_dataset(
                utils.data_loaders.DatasetSubset.TEST),
            batch_size=1,
            num_workers=cfg.CONST.NUM_WORKERS,
            collate_fn=utils.data_loaders.collate_fn,
            pin_memory=True,
            shuffle=False)

    # Setup networks and initialize networks
    if grnet is None:
        grnet = GRNet(cfg)

        if torch.cuda.is_available():
            grnet = torch.nn.DataParallel(grnet).cuda()

        logging.info('Recovering from %s ...' % (cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)
        grnet.load_state_dict(checkpoint['grnet'])

    # Switch models to evaluation mode
    grnet.eval()

    # Set up loss functions
    chamfer_dist = ChamferDistance()
    gridding_loss = GriddingLoss(
        scales=cfg.NETWORK.GRIDDING_LOSS_SCALES,
        alphas=cfg.NETWORK.GRIDDING_LOSS_ALPHAS)  # lgtm [py/unused-import]

    # Testing loop
    n_samples = len(test_data_loader)
    test_losses = AverageMeter(['SparseLoss', 'DenseLoss'])
    test_metrics = AverageMeter(Metrics.names())
    category_metrics = dict()

    # Testing loop
    for model_idx, (taxonomy_id, model_id,
                    data) in enumerate(test_data_loader):
        taxonomy_id = taxonomy_id[0] if isinstance(
            taxonomy_id[0], str) else taxonomy_id[0].item()
        model_id = model_id[0]

        with torch.no_grad():
            for k, v in data.items():
                data[k] = utils.helpers.var_or_cuda(v)

            sparse_ptcloud, dense_ptcloud = grnet(data)
            sparse_loss = chamfer_dist(sparse_ptcloud, data['gtcloud'])
            dense_loss = chamfer_dist(dense_ptcloud, data['gtcloud'])
            test_losses.update(
                [sparse_loss.item() * 1000,
                 dense_loss.item() * 1000])
            _metrics = Metrics.get(dense_ptcloud, data['gtcloud'])
            test_metrics.update(_metrics)

            # save predicted point cloud
            if cfg.TEST.SAVE_PRED:
                if cfg.DATASET.TEST_DATASET == 'FrankaScan':
                    dirname, obj_idx = model_id.split('-')
                    out_ptcloud = dense_ptcloud[0].cpu()
                    IO.put(
                        cfg.DATASETS.FRANKASCAN.PREDICTION_PATH %
                        (dirname, obj_idx), out_ptcloud)

            if taxonomy_id not in category_metrics:
                category_metrics[taxonomy_id] = AverageMeter(Metrics.names())
            category_metrics[taxonomy_id].update(_metrics)

            if test_writer is not None and model_idx < 3:
                sparse_ptcloud = sparse_ptcloud.squeeze().cpu().numpy()
                sparse_ptcloud_img = utils.helpers.get_ptcloud_img(
                    sparse_ptcloud)
                test_writer.add_image(
                    'Model%02d/SparseReconstruction' % model_idx,
                    sparse_ptcloud_img, epoch_idx)
                dense_ptcloud = dense_ptcloud.squeeze().cpu().numpy()
                dense_ptcloud_img = utils.helpers.get_ptcloud_img(
                    dense_ptcloud)
                test_writer.add_image(
                    'Model%02d/DenseReconstruction' % model_idx,
                    dense_ptcloud_img, epoch_idx)
                gt_ptcloud = data['gtcloud'].squeeze().cpu().numpy()
                gt_ptcloud_img = utils.helpers.get_ptcloud_img(gt_ptcloud)
                test_writer.add_image('Model%02d/GroundTruth' % model_idx,
                                      gt_ptcloud_img, epoch_idx)

            logging.info(
                'Test[%d/%d] Taxonomy = %s Sample = %s Losses = %s Metrics = %s'
                %
                (model_idx + 1, n_samples, taxonomy_id, model_id,
                 ['%.4f' % l
                  for l in test_losses.val()], ['%.4f' % m for m in _metrics]))

    # Print testing results
    print(
        '============================ TEST RESULTS ============================'
    )
    print('Taxonomy', end='\t')
    print('#Sample', end='\t')
    for metric in test_metrics.items:
        print(metric, end='\t')
    print()

    for taxonomy_id in category_metrics:
        print(taxonomy_id, end='\t')
        print(category_metrics[taxonomy_id].count(0), end='\t')
        for value in category_metrics[taxonomy_id].avg():
            print('%.4f' % value, end='\t')
        print()

    print('Overall', end='\t\t\t')
    for value in test_metrics.avg():
        print('%.4f' % value, end='\t')
    print('\n')

    # Add testing results to TensorBoard
    if test_writer is not None:
        test_writer.add_scalar('Loss/Epoch/Sparse', test_losses.avg(0),
                               epoch_idx)
        test_writer.add_scalar('Loss/Epoch/Dense', test_losses.avg(1),
                               epoch_idx)
        for i, metric in enumerate(test_metrics.items):
            test_writer.add_scalar('Metric/%s' % metric, test_metrics.avg(i),
                                   epoch_idx)

    return Metrics(cfg.TEST.METRIC_NAME, test_metrics.avg())
Ejemplo n.º 6
0
def train_net_new(cfg):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Set up data loader
    pnum = 2048
    crop_point_num = 512
    workers = 1
    batchSize = 16

    class_name = "Pistol"

    train_dataset_loader = shapenet_part_loader.PartDataset(
        root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/',
        classification=False,
        class_choice=class_name,
        npoints=pnum,
        split='train')
    train_data_loader = torch.utils.data.DataLoader(train_dataset_loader,
                                                    batch_size=batchSize,
                                                    shuffle=True,
                                                    num_workers=int(workers))

    test_dataset_loader = shapenet_part_loader.PartDataset(
        root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/',
        classification=False,
        class_choice=class_name,
        npoints=pnum,
        split='test')
    val_data_loader = torch.utils.data.DataLoader(test_dataset_loader,
                                                  batch_size=batchSize,
                                                  shuffle=True,
                                                  num_workers=int(workers))

    # Set up folders for logs and checkpoints
    output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s',
                              datetime.now().isoformat())
    cfg.DIR.CHECKPOINTS = output_dir % 'checkpoints'
    cfg.DIR.LOGS = output_dir % 'logs'
    if not os.path.exists(cfg.DIR.CHECKPOINTS):
        os.makedirs(cfg.DIR.CHECKPOINTS)

    # Create tensorboard writers
    train_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'train'))
    val_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'test'))

    # Create the networks
    grnet = GRNet(cfg, seg_class_no)
    grnet.apply(utils.helpers.init_weights)
    logging.debug('Parameters in GRNet: %d.' %
                  utils.helpers.count_parameters(grnet))

    # Move the network to GPU if possible
    grnet = grnet.to(device)

    # Create the optimizers
    grnet_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              grnet.parameters()),
                                       lr=cfg.TRAIN.LEARNING_RATE,
                                       weight_decay=cfg.TRAIN.WEIGHT_DECAY,
                                       betas=cfg.TRAIN.BETAS)
    grnet_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        grnet_optimizer,
        milestones=cfg.TRAIN.LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA)

    # Set up loss functions
    chamfer_dist = ChamferDistance()
    gridding_loss = GriddingLoss(  # lgtm [py/unused-local-variable]
        scales=cfg.NETWORK.GRIDDING_LOSS_SCALES,
        alphas=cfg.NETWORK.GRIDDING_LOSS_ALPHAS)
    seg_criterion = torch.nn.CrossEntropyLoss().cuda()

    # Load pretrained model if exists
    init_epoch = 0
    best_metrics = None
    if 'WEIGHTS' in cfg.CONST:
        logging.info('Recovering from %s ...' % (cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)
        grnet.load_state_dict(checkpoint['grnet'])
        logging.info(
            'Recover complete. Current epoch = #%d; best metrics = %s.' %
            (init_epoch, best_metrics))

    train_seg_on_sparse = False
    train_seg_on_dense = False

    miou = 0

    # Training/Testing the network
    for epoch_idx in range(init_epoch + 1, cfg.TRAIN.N_EPOCHS + 1):
        epoch_start_time = time()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter(['SparseLoss', 'DenseLoss'])

        grnet.train()

        if epoch_idx == 5:
            train_seg_on_sparse = True

        if epoch_idx == 7:
            train_seg_on_dense = True

        batch_end_time = time()
        n_batches = len(train_data_loader)
        for batch_idx, (
                data,
                seg,
                model_ids,
        ) in enumerate(train_data_loader):
            data_time.update(time() - batch_end_time)

            input_cropped1 = torch.FloatTensor(data.size()[0], pnum, 3)
            input_cropped1 = input_cropped1.data.copy_(data)

            if batch_idx == 10:
                pass  #break

            data = data.to(device)
            seg = seg.to(device)

            input_cropped1 = input_cropped1.to(device)

            # remove points to make input incomplete
            choice = [
                torch.Tensor([1, 0, 0]),
                torch.Tensor([0, 0, 1]),
                torch.Tensor([1, 0, 1]),
                torch.Tensor([-1, 0, 0]),
                torch.Tensor([-1, 1, 0])
            ]
            for m in range(data.size()[0]):
                index = random.sample(choice, 1)
                p_center = index[0].to(device)
                distances = torch.sum((data[m] - p_center)**2, dim=1)
                order = torch.argsort(distances)

                zero_point = torch.FloatTensor([0, 0, 0]).to(device)
                input_cropped1.data[m, order[:crop_point_num]] = zero_point

            if save_crop_mode:
                np.save(class_name + "_orig", data[0].detach().cpu().numpy())
                np.save(class_name + "_cropped",
                        input_cropped1[0].detach().cpu().numpy())
                sys.exit()

            sparse_ptcloud, dense_ptcloud, sparse_seg, full_seg, dense_seg = grnet(
                input_cropped1)

            data_seg = get_data_seg(data, full_seg)
            seg_loss = seg_criterion(torch.transpose(data_seg, 1, 2), seg)
            if train_seg_on_sparse and train_seg:
                gt_seg = get_seg_gts(seg, data, sparse_ptcloud)
                seg_loss += seg_criterion(torch.transpose(sparse_seg, 1, 2),
                                          gt_seg)
                seg_loss /= 2

            if train_seg_on_dense and train_seg:
                gt_seg = get_seg_gts(seg, data, dense_ptcloud)
                dense_seg_loss = seg_criterion(
                    torch.transpose(dense_seg, 1, 2), gt_seg)
                print(dense_seg_loss.item())

            if draw_mode:
                plot_ptcloud(data[0], seg[0], "orig")
                plot_ptcloud(input_cropped1[0], seg[0], "cropped")
                plot_ptcloud(sparse_ptcloud[0],
                             torch.argmax(sparse_seg[0], dim=1), "sparse_pred")
                if not train_seg_on_sparse:
                    gt_seg = get_seg_gts(seg, data, sparse_ptcloud)
                #plot_ptcloud(sparse_ptcloud[0], gt_seg[0], "sparse_gt")
                #if not train_seg_on_dense:
                #gt_seg = get_seg_gts(seg, data, sparse_ptcloud)
                print(dense_seg.size())
                plot_ptcloud(dense_ptcloud[0], torch.argmax(dense_seg[0],
                                                            dim=1),
                             "dense_pred")
                sys.exit()

            print(seg_loss.item())

            lamb = 0.8
            sparse_loss = chamfer_dist(sparse_ptcloud, data).to(device)
            dense_loss = chamfer_dist(dense_ptcloud, data).to(device)
            grid_loss = gridding_loss(sparse_ptcloud, data).to(device)
            if train_seg:
                _loss = lamb * (sparse_loss + dense_loss +
                                grid_loss) + (1 - lamb) * seg_loss
            else:
                _loss = (sparse_loss + dense_loss + grid_loss)
            if train_seg_on_dense and train_seg:
                _loss += (1 - lamb) * dense_seg_loss
            _loss.to(device)
            losses.update(
                [sparse_loss.item() * 1000,
                 dense_loss.item() * 1000])

            grnet.zero_grad()
            _loss.backward()
            grnet_optimizer.step()

            n_itr = (epoch_idx - 1) * n_batches + batch_idx
            train_writer.add_scalar('Loss/Batch/Sparse',
                                    sparse_loss.item() * 1000, n_itr)
            train_writer.add_scalar('Loss/Batch/Dense',
                                    dense_loss.item() * 1000, n_itr)

            batch_time.update(time() - batch_end_time)
            batch_end_time = time()
            logging.info(
                '[Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) Losses = %s'
                % (epoch_idx, cfg.TRAIN.N_EPOCHS, batch_idx + 1, n_batches,
                   batch_time.val(), data_time.val(),
                   ['%.4f' % l for l in losses.val()]))

        # Validate the current model
        if train_seg:
            miou_new = test_net_new(cfg, epoch_idx, val_data_loader,
                                    val_writer, grnet)
        else:
            miou_new = 0

        grnet_lr_scheduler.step()
        epoch_end_time = time()
        train_writer.add_scalar('Loss/Epoch/Sparse', losses.avg(0), epoch_idx)
        train_writer.add_scalar('Loss/Epoch/Dense', losses.avg(1), epoch_idx)
        logging.info('[Epoch %d/%d] EpochTime = %.3f (s) Losses = %s' %
                     (epoch_idx, cfg.TRAIN.N_EPOCHS, epoch_end_time -
                      epoch_start_time, ['%.4f' % l for l in losses.avg()]))

        if not train_seg or miou_new > miou:
            file_name = class_name + 'noseg-ckpt-epoch.pth'
            output_path = os.path.join(cfg.DIR.CHECKPOINTS, file_name)
            torch.save({
                'epoch_index': epoch_idx,
                'grnet': grnet.state_dict()
            }, output_path)  # yapf: disable

            logging.info('Saved checkpoint to %s ...' % output_path)
            miou = miou_new

    train_writer.close()
    val_writer.close()
def test_net_new(cfg,
                 epoch_idx=-1,
                 test_data_loader=None,
                 test_writer=None,
                 grnet=None):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    pnum = 2048
    crop_point_num = 512
    workers = 1
    batchSize = 16

    if test_data_loader == None:
        test_dataset_loader = shapenet_part_loader.PartDataset(
            root='./dataset/shapenetcore_partanno_segmentation_benchmark_v0/',
            classification=False,
            class_choice=save_name,
            npoints=pnum,
            split='test')
        test_data_loader = torch.utils.data.DataLoader(
            test_dataset_loader,
            batch_size=batchSize,
            shuffle=True,
            num_workers=int(workers))

    # Setup networks and initialize networks
    if grnet is None:
        grnet = GRNet(cfg, 4)

        if torch.cuda.is_available():
            grnet = grnet.to(device)

        logging.info('Recovering from %s ...' % (cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)
        grnet.load_state_dict(checkpoint['grnet'])

    # Switch models to evaluation mode
    grnet.eval()

    # Set up loss functions
    chamfer_dist = ChamferDistance()
    gridding_loss = GriddingLoss(
        scales=cfg.NETWORK.GRIDDING_LOSS_SCALES,
        alphas=cfg.NETWORK.GRIDDING_LOSS_ALPHAS)  # lgtm [py/unused-import]
    seg_criterion = torch.nn.CrossEntropyLoss().cuda()

    total_sparse_cd = 0
    total_dense_cd = 0

    total_sparse_ce = 0
    total_dense_ce = 0

    total_sparse_miou = 0
    total_dense_miou = 0

    total_sparse_acc = 0
    total_dense_acc = 0

    # Testing loop
    for batch_idx, (
            data,
            seg,
            model_ids,
    ) in enumerate(test_data_loader):
        model_id = model_ids[0]

        with torch.no_grad():
            input_cropped1 = torch.FloatTensor(data.size()[0], pnum, 3)
            input_cropped1 = input_cropped1.data.copy_(data)

            if batch_idx == 200:
                pass  # break

            data = data.to(device)
            seg = seg.to(device)

            input_cropped1 = input_cropped1.to(device)

            # remove points to make input incomplete
            choice = [
                torch.Tensor([1, 0, 0]),
                torch.Tensor([0, 0, 1]),
                torch.Tensor([1, 0, 1]),
                torch.Tensor([-1, 0, 0]),
                torch.Tensor([-1, 1, 0])
            ]
            for m in range(data.size()[0]):
                index = random.sample(choice, 1)
                p_center = index[0].to(device)
                distances = torch.sum((data[m] - p_center)**2, dim=1)
                order = torch.argsort(distances)

                zero_point = torch.FloatTensor([0, 0, 0]).to(device)
                input_cropped1.data[m, order[:crop_point_num]] = zero_point

            sparse_ptcloud, dense_ptcloud, sparse_seg, full_seg, dense_seg = grnet(
                input_cropped1)

            if save_mode:
                np.save("./saved_results/original_" + save_name,
                        data.detach().cpu().numpy())
                np.save("./saved_results/original_seg_" + save_name,
                        seg.detach().cpu().numpy())
                np.save("./saved_results/cropped_" + save_name,
                        input_cropped1.detach().cpu().numpy())
                np.save("./saved_results/sparse_" + save_name,
                        sparse_ptcloud.detach().cpu().numpy())
                np.save("./saved_results/sparse_seg_" + save_name,
                        sparse_seg.detach().cpu().numpy())
                np.save("./saved_results/dense_" + save_name,
                        dense_ptcloud.detach().cpu().numpy())
                np.save("./saved_results/dense_seg_" + save_name,
                        dense_seg.detach().cpu().numpy())
                sys.exit()

            total_sparse_cd += chamfer_dist(sparse_ptcloud, data).to(device)
            total_dense_cd += chamfer_dist(dense_ptcloud, data).to(device)

            sparse_seg_gt = get_seg_gts(seg, data, sparse_ptcloud)
            sparse_miou, sparse_acc = miou(torch.argmax(sparse_seg, dim=2),
                                           sparse_seg_gt)
            total_sparse_miou += sparse_miou
            total_sparse_acc += sparse_acc

            print(batch_idx)

            total_sparse_ce += seg_criterion(torch.transpose(sparse_seg, 1, 2),
                                             sparse_seg_gt)

            dense_seg_gt = get_seg_gts(seg, data, dense_ptcloud)
            dense_miou, dense_acc = miou(torch.argmax(dense_seg, dim=2),
                                         dense_seg_gt)
            total_dense_miou += dense_miou
            print(dense_miou)
            total_dense_acc += dense_acc
            total_dense_ce += seg_criterion(torch.transpose(dense_seg, 1, 2),
                                            dense_seg_gt)

    length = len(test_data_loader)
    print("sparse cd: " + str(total_sparse_cd * 1000 / length))
    print("dense cd: " + str(total_dense_cd * 1000 / length))
    print("sparse acc: " + str(total_sparse_acc / length))
    print("dense acc: " + str(total_dense_acc / length))
    print("sparse miou: " + str(total_sparse_miou / length))
    print("dense miou: " + str(total_dense_miou / length))
    print("sparse ce: " + str(total_sparse_ce / length))
    print("dense ce: " + str(total_dense_ce / length))

    return total_dense_miou / length
Ejemplo n.º 8
0
def test_net(cfg, epoch_idx=-1, test_data_loader=None, test_writer=None, grnet=None):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    if test_data_loader is None:
        # Set up data loader
        dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg)
        # 在data_loader.py中修改这里的dataset值
        test_data_loader = torch.utils.data.DataLoader(dataset=dataset_loader.get_dataset(
            utils.data_loaders.DatasetSubset.VAL),
                                                       batch_size=1,
                                                       num_workers=cfg.CONST.NUM_WORKERS,
                                                       collate_fn=utils.data_loaders.collate_fn,
                                                       pin_memory=True,
                                                       shuffle=False)

    # Setup networks and initialize networks
    if grnet is None:
        grnet = GRNet(cfg)

        if torch.cuda.is_available():
            grnet = torch.nn.DataParallel(grnet).cuda()

        logging.info('Recovering from %s ...' % (cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)
        grnet.load_state_dict(checkpoint['grnet'])

    # Switch models to evaluation mode
    grnet.eval()

    # Set up loss functions
    chamfer_dist = ChamferDistance()
    gridding_loss = GriddingLoss(scales=cfg.NETWORK.GRIDDING_LOSS_SCALES,
                                 alphas=cfg.NETWORK.GRIDDING_LOSS_ALPHAS)    # lgtm [py/unused-import]

    # Testing loop
    n_samples = len(test_data_loader)
    test_losses = AverageMeter(['SparseLoss', 'DenseLoss'])
    # test_losses = AverageMeter(['GridLoss', 'DenseLoss'])
    test_metrics = AverageMeter(Metrics.names())  # 'F-score, CD
    category_metrics = dict()


    # Testing loop
    # 通过data得到sparse_pucloud,  data from test_data_loader

    '''
    gt_path = '/raid/wuruihai/GRNet_FILES/xkh/Completion3D/val/gt/03001627/'
    # pred_path = '/raid/wuruihai/GRNet_FILES/Results/Completion3D_grnet_data_ep300_npz_2048d/'  # 0.0033
    pred_path = '/raid/wuruihai/GRNet_FILES/Results/Completion3D_grnet_alldata_ep300_npz_small_2048d/'   # 0.0030
    n_points = 2048
    for root, dirs, files in os.walk(pred_path):
        len_files = len(files)
        pred_batch = np.zeros((1, n_points, 3))
        gt_batch = np.zeros((1, n_points, 3))
        idx = -1
        tot = 0

        for file in files:
            file_id = os.path.splitext(file)[0]
            idx += 1

            pred = np.load(pred_path + file)['pts']
            # pred = rescale_pc_parts(pred, num_points=n_points)
            # pred = pred.reshape(n_points, 3)
            pred_batch[0] = pred

            gt = h5py.File(gt_path + file_id + '.h5', 'r')['data'][:]  # Completion3D
            gt = np.array(gt).astype(np.float32)
            # gt = rescale_pc_parts(gt, num_points=n_points)
            # gt = gt.reshape(n_points, 3)
            gt_batch[0] = gt


            with torch.no_grad():
                cd = chamfer_dist(torch.tensor(pred_batch, dtype=torch.float32).cuda(), torch.tensor(gt_batch, dtype=torch.float32).cuda())
            print(cd)
            tot += cd
        print('avg: ', tot/len_files)

        return
    '''



    for model_idx, (taxonomy_id, model_id, data) in enumerate(test_data_loader):
        taxonomy_id = taxonomy_id[0] if isinstance(taxonomy_id[0], str) else taxonomy_id[0].item()
        model_id = model_id[0]

        with torch.no_grad():
            for k, v in data.items():
                data[k] = utils.helpers.var_or_cuda(v)



            sparse_ptcloud, dense_ptcloud = grnet(data)
            sparse_loss = chamfer_dist(sparse_ptcloud, data['gtcloud'])
            # grid_loss = gridding_loss(dense_ptcloud, data['gtcloud'])
            dense_loss = chamfer_dist(dense_ptcloud, data['gtcloud'])
            print(dense_ptcloud.shape, data['gtcloud'].shape)
            test_losses.update([sparse_loss.item() * 1000, dense_loss.item() * 1000])
            # test_losses.update([grid_loss.item() * 1000, dense_loss.item() * 1000])
            _metrics = Metrics.get(dense_ptcloud, data['gtcloud']) # return: values
            test_metrics.update(_metrics)

            if taxonomy_id not in category_metrics:
                category_metrics[taxonomy_id] = AverageMeter(Metrics.names())
            category_metrics[taxonomy_id].update(_metrics)


            # train时不用存数据
            # 存 npz

            '''
            save_path = '/home2/wuruihai/GRNet_FILES/Results/Completion3D_grnet_chair_ep300_npz_16384d/'
            save_path2 = '/home2/wuruihai/GRNet_FILES/Results/Completion3D_grent_chair_ep300_npz_2048d/'

            part_name = 'part_7'

            # 只存了 final results (dense_ptcloud)
            save_npz_path = save_path + part_name + '/'
            save_npz_path2 = save_path2 + part_name + '/'
            if not os.path.exists(save_npz_path):
                os.makedirs(save_npz_path)
            if not os.path.exists(save_npz_path2):
                os.makedirs(save_npz_path2)

            dense_pts = np.array(dense_ptcloud.cpu())
            dense_pts2 = rescale_pc_parts(dense_pts, 2048) # rescale
            dense_pts /= 0.45  # 放大回我们的大小
            dense_pts2 /= 0.45
            np.savez(save_npz_path + '%s.npz' % model_id, pts = dense_pts)
            np.savez(save_npz_path2 + '%s.npz' % model_id, pts = dense_pts2)
            '''



            # 存npz (GRNet's data),  Completion3D, 没有part
            # 和grnet自己的数据集比较,不需要放大(/0.45)

            '''
            save_path = '/home2/wuruihai/GRNet_FILES/Results/Completion3D_grnet_alldata_ep300_npz_small_16384d/'
            save_path2 = '/home2/wuruihai/GRNet_FILES/Results/Completion3D_grnet_alldata_ep300_npz_small_2048d/'
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            if not os.path.exists(save_path2):
                os.makedirs(save_path2)
            dense_pts = np.array(dense_ptcloud.cpu())
            dense_pts2 = rescale_pc_parts(dense_pts, 2048)  # rescale
            np.savez(save_path + '%s.npz' % model_id, pts=dense_pts)
            np.savez(save_path2 + '%s.npz' % model_id, pts = dense_pts2)
            '''




            # 存 png
            '''
            save_path = '/home2/wuruihai/GRNet_FILES/Results/ShapeNet_zy_chair_ep500_part0_16384d_png/'
            if not os.path.exists(save_path):
                os.makedirs(save_path)

            plt.figure()


            pc_ptcloud = data['partial_cloud'].squeeze().cpu().numpy()
            pc_ptcloud_img = utils.helpers.get_ptcloud_img(pc_ptcloud)
            matplotlib.image.imsave(save_path + '%s_1_pc.png' % model_id,
                                    pc_ptcloud_img)
        
            
            # sparse_ptcloud = sparse_ptcloud.squeeze().cpu().numpy()
            # sparse_ptcloud_img = utils.helpers.get_ptcloud_img(sparse_ptcloud)
            # matplotlib.image.imsave(save_path+'%s_sps.png' % model_id,
            #                         sparse_ptcloud_img)
            

            dense_ptcloud = dense_ptcloud.squeeze().cpu().numpy()
            dense_ptcloud_img = utils.helpers.get_ptcloud_img(dense_ptcloud)
            matplotlib.image.imsave(save_path+'%s_2_dns.png' % model_id,
                                    dense_ptcloud_img)

            
            gt_ptcloud = data['gtcloud'].squeeze().cpu().numpy()
            gt_ptcloud_img = utils.helpers.get_ptcloud_img(gt_ptcloud)
            matplotlib.image.imsave(save_path+'%s_3_gt.png' % model_id,
                                    gt_ptcloud_img)
            '''




            '''
            if model_idx in range(510, 600):

                now_num=model_idx-499
                # if test_writer is not None and model_idx < 3:
                # sparse_ptcloud = sparse_ptcloud.squeeze().cpu().numpy()
                sparse_ptcloud = sparse_ptcloud.squeeze().numpy()
                sparse_ptcloud_img = utils.helpers.get_ptcloud_img(sparse_ptcloud)
                matplotlib.image.imsave('/home2/wuruihai/GRNet_FILES/results2/%s_%s_sps.png'%(model_idx,model_id), sparse_ptcloud_img)

                # dense_ptcloud = dense_ptcloud.squeeze().cpu().numpy()
                dense_ptcloud = dense_ptcloud.squeeze().numpy()
                dense_ptcloud_img = utils.helpers.get_ptcloud_img(dense_ptcloud)
                matplotlib.image.imsave('/home2/wuruihai/GRNet_FILES/results2/%s_%s_dns.png' % (model_idx, model_id),
                                        dense_ptcloud_img)


                # gt_ptcloud = data['gtcloud'].squeeze().cpu().numpy()
                gt_ptcloud = data['gtcloud'].squeeze().numpy()
                gt_ptcloud_img = utils.helpers.get_ptcloud_img(gt_ptcloud)
                matplotlib.image.imsave('/home2/wuruihai/GRNet_FILES/results2/%s_%s_gt.png'%(model_idx,model_id), gt_ptcloud_img)

                cv.imwrite("/home2/wuruihai/GRNet_FILES/out3.png", sparse_ptcloud_img)
                im = Image.fromarray(sparse_ptcloud_img).convert('RGB')
                im.save("/home2/wuruihai/GRNet_FILES/out.jpeg")
            
                test_writer.add_image('Model%02d/SparseReconstruction' % model_idx, sparse_ptcloud_img, epoch_idx)
                dense_ptcloud = dense_ptcloud.squeeze().cpu().numpy()
                dense_ptcloud_img = utils.helpers.get_ptcloud_img(dense_ptcloud)
                test_writer.add_image('Model%02d/DenseReconstruction' % model_idx, dense_ptcloud_img, epoch_idx)

                gt_ptcloud = data['gtcloud'].squeeze().cpu().numpy()
                gt_ptcloud_img = utils.helpers.get_ptcloud_img(gt_ptcloud)
                test_writer.add_image('Model%02d/GroundTruth' % model_idx, gt_ptcloud_img, epoch_idx)
            '''

            logging.info('Test[%d/%d] Taxonomy = %s Sample = %s Losses = %s Metrics = %s' %
                         (model_idx + 1, n_samples, taxonomy_id, model_id, ['%.4f' % l for l in test_losses.val()
                                                                            ], ['%.4f' % m for m in _metrics]))
    plt.show()
    plt.savefig('/raid/wuruihai/GRNet_FILES/results.png')
    # Print testing results
    print('============================ TEST RESULTS ============================')
    print('Taxonomy', end='\t')
    print('#Sample', end='\t')
    for metric in test_metrics.items:
        print(metric, end='\t')
    print()

    for taxonomy_id in category_metrics:
        print(taxonomy_id, end='\t')
        print(category_metrics[taxonomy_id].count(0), end='\t')
        for value in category_metrics[taxonomy_id].avg():
            print('%.4f' % value, end='\t')
        print()

    print('Overall', end='\t\t\t')
    for value in test_metrics.avg():
        print('%.4f' % value, end='\t')
    print('\n')

    # Add testing results to TensorBoard
    if test_writer is not None:
        # test_writer.add_scalar('Loss/Epoch/Sparse', test_losses.avg(0), epoch_idx)
        test_writer.add_scalar('Loss/Epoch/Grid', test_losses.avg(0), epoch_idx)
        test_writer.add_scalar('Loss/Epoch/Dense', test_losses.avg(1), epoch_idx)
        for i, metric in enumerate(test_metrics.items):
            test_writer.add_scalar('Metric/%s' % metric, test_metrics.avg(i), epoch_idx)

    return Metrics(cfg.TEST.METRIC_NAME, test_metrics.avg())