Ejemplo n.º 1
0
class Visualization_demo():
    def __init__(self, cfg, output_dir):
        self.encoder = Encoder(cfg)
        self.decoder = Decoder(cfg)
        self.refiner = Refiner(cfg)
        self.merger = Merger(cfg)

        checkpoint = torch.load(cfg.CHECKPOINT)
        encoder_state_dict = clean_state_dict(checkpoint['encoder_state_dict'])
        self.encoder.load_state_dict(encoder_state_dict)
        decoder_state_dict = clean_state_dict(checkpoint['decoder_state_dict'])
        self.decoder.load_state_dict(decoder_state_dict)
        if cfg.NETWORK.USE_REFINER:
            refiner_state_dict = clean_state_dict(
                checkpoint['refiner_state_dict'])
            self.refiner.load_state_dict(refiner_state_dict)
        if cfg.NETWORK.USE_MERGER:
            merger_state_dict = clean_state_dict(
                checkpoint['merger_state_dict'])
            self.merger.load_state_dict(merger_state_dict)

        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        self.output_dir = output_dir

    def run_on_images(self, imgs, sid, mid, iid, sampled_idx):
        dir1 = os.path.join(output_dir, str(sid), str(mid))
        if not os.path.exists(dir1):
            os.makedirs(dir1)

        deprocess = imagenet_deprocess(rescale_image=False)
        image_features = self.encoder(imgs)
        raw_features, generated_volume = self.decoder(image_features)
        generated_volume = self.merger(raw_features, generated_volume)
        generated_volume = self.refiner(generated_volume)

        mesh = cubify(generated_volume, 0.3)
        #         mesh = voxel_to_world(meshes)
        save_mesh = os.path.join(dir1, "%s_%s.obj" % (iid, sampled_idx))
        verts, faces = mesh.get_mesh_verts_faces(0)
        save_obj(save_mesh, verts, faces)

        generated_volume = generated_volume.squeeze()
        img = image_to_numpy(deprocess(imgs[0][0]))
        save_img = os.path.join(dir1, "%02d.png" % (iid))
        #         cv2.imwrite(save_img, img[:, :, ::-1])
        cv2.imwrite(save_img, img)
        img1 = image_to_numpy(deprocess(imgs[0][1]))
        save_img1 = os.path.join(dir1, "%02d.png" % (sampled_idx))
        cv2.imwrite(save_img1, img1)
        #         cv2.imwrite(save_img1, img1[:, :, ::-1])
        get_volume_views(generated_volume, dir1, iid, sampled_idx)
class Quantitative_analysis_demo():
    def __init__(self, cfg, output_dir):
        self.encoder = Encoder(cfg)
        self.decoder = Decoder(cfg)
        self.refiner = Refiner(cfg)
        self.merger = Merger(cfg)
        #         self.thresh = cfg.VOXEL_THRESH
        self.th = cfg.TEST.VOXEL_THRESH

        checkpoint = torch.load(cfg.CHECKPOINT)
        encoder_state_dict = clean_state_dict(checkpoint['encoder_state_dict'])
        self.encoder.load_state_dict(encoder_state_dict)
        decoder_state_dict = clean_state_dict(checkpoint['decoder_state_dict'])
        self.decoder.load_state_dict(decoder_state_dict)
        if cfg.NETWORK.USE_REFINER:
            refiner_state_dict = clean_state_dict(
                checkpoint['refiner_state_dict'])
            self.refiner.load_state_dict(refiner_state_dict)
        if cfg.NETWORK.USE_MERGER:
            merger_state_dict = clean_state_dict(
                checkpoint['merger_state_dict'])
            self.merger.load_state_dict(merger_state_dict)

        self.output_dir = output_dir

    def calculate_iou(self, imgs, GT_voxels, sid, mid, iid):
        dir1 = os.path.join(self.output_dir, str(sid), str(mid))
        if not os.path.exists(dir1):
            os.makedirs(dir1)

        image_features = self.encoder(imgs)
        raw_features, generated_volume = self.decoder(image_features)
        generated_volume = self.merger(raw_features, generated_volume)
        generated_volume = self.refiner(generated_volume)
        generated_volume = generated_volume.squeeze()

        sample_iou = []
        for th in self.th:
            _volume = torch.ge(generated_volume, th).float()
            intersection = torch.sum(_volume.mul(GT_voxels)).float()
            union = torch.sum(torch.ge(_volume.add(GT_voxels), 1)).float()
            sample_iou.append((intersection / union).item())
        return sample_iou
    def __init__(self, cfg, output_dir):
        self.encoder = Encoder(cfg)
        self.decoder = Decoder(cfg)
        self.refiner = Refiner(cfg)
        self.merger = Merger(cfg)
#         self.thresh = cfg.VOXEL_THRESH
        self.th = cfg.TEST.VOXEL_THRESH
        
        checkpoint = torch.load(cfg.CHECKPOINT)
        encoder_state_dict = clean_state_dict(checkpoint['encoder_state_dict'])
        self.encoder.load_state_dict(encoder_state_dict)
        decoder_state_dict = clean_state_dict(checkpoint['decoder_state_dict'])
        self.decoder.load_state_dict(decoder_state_dict)
        if cfg.NETWORK.USE_REFINER:
            refiner_state_dict = clean_state_dict(checkpoint['refiner_state_dict'])
            self.refiner.load_state_dict(refiner_state_dict)
        if cfg.NETWORK.USE_MERGER:
            merger_state_dict = clean_state_dict(checkpoint['merger_state_dict'])
            self.merger.load_state_dict(merger_state_dict)
        
        self.output_dir = output_dir
Ejemplo n.º 4
0
    def __init__(self, cfg_network: DictConfig, cfg_tester: DictConfig):
        super().__init__()
        self.cfg_network = cfg_network
        self.cfg_tester = cfg_tester

        # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
        torch.backends.cudnn.benchmark = True

        # Set up networks
        self.encoder = Encoder(cfg_network)
        self.decoder = Decoder(cfg_network)
        self.refiner = Refiner(cfg_network)
        self.merger = Merger(cfg_network)
        
        # Initialize weights of networks
        self.encoder.apply(utils.network_utils.init_weights)
        self.decoder.apply(utils.network_utils.init_weights)
        self.refiner.apply(utils.network_utils.init_weights)
        self.merger.apply(utils.network_utils.init_weights)
        
        self.bce_loss = nn.BCELoss()
Ejemplo n.º 5
0
    def __init__(self, cfg, output_dir):
        self.encoder = Encoder(cfg)
        self.decoder = Decoder(cfg)
        self.refiner = Refiner(cfg)
        self.merger = Merger(cfg)

        checkpoint = torch.load(cfg.CHECKPOINT)
        encoder_state_dict = clean_state_dict(checkpoint['encoder_state_dict'])
        self.encoder.load_state_dict(encoder_state_dict)
        decoder_state_dict = clean_state_dict(checkpoint['decoder_state_dict'])
        self.decoder.load_state_dict(decoder_state_dict)
        if cfg.NETWORK.USE_REFINER:
            refiner_state_dict = clean_state_dict(
                checkpoint['refiner_state_dict'])
            self.refiner.load_state_dict(refiner_state_dict)
        if cfg.NETWORK.USE_MERGER:
            merger_state_dict = clean_state_dict(
                checkpoint['merger_state_dict'])
            self.merger.load_state_dict(merger_state_dict)

        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        self.output_dir = output_dir
Ejemplo n.º 6
0
def train_net(cfg):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    # Set up data augmentation
    IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
    CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
    train_transforms = utils.data_transforms.Compose([
        utils.data_transforms.RandomCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(
            cfg.TRAIN.RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.ColorJitter(cfg.TRAIN.BRIGHTNESS,
                                          cfg.TRAIN.CONTRAST,
                                          cfg.TRAIN.SATURATION),
        utils.data_transforms.RandomNoise(cfg.TRAIN.NOISE_STD),
        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                        std=cfg.DATASET.STD),
        utils.data_transforms.RandomFlip(),
        utils.data_transforms.RandomPermuteRGB(),
        utils.data_transforms.ToTensor(),
    ])
    val_transforms = utils.data_transforms.Compose([
        utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                        std=cfg.DATASET.STD),
        utils.data_transforms.ToTensor(),
    ])

    # Set up data loader
    train_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
        cfg.DATASET.TRAIN_DATASET](cfg)
    val_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
        cfg.DATASET.TEST_DATASET](cfg)
    train_data_loader = torch.utils.data.DataLoader(
        dataset=train_dataset_loader.get_dataset(
            utils.data_loaders.DatasetType.TRAIN, cfg.CONST.N_VIEWS_RENDERING,
            train_transforms),
        batch_size=cfg.CONST.BATCH_SIZE,
        num_workers=cfg.TRAIN.NUM_WORKER,
        pin_memory=True,
        shuffle=True,
        drop_last=True)
    val_data_loader = torch.utils.data.DataLoader(
        dataset=val_dataset_loader.get_dataset(
            utils.data_loaders.DatasetType.VAL, cfg.CONST.N_VIEWS_RENDERING,
            val_transforms),
        batch_size=1,
        num_workers=1,
        pin_memory=True,
        shuffle=False)

    # Set up networks
    encoder = Encoder(cfg)
    decoder = Decoder(cfg)
    refiner = Refiner(cfg)
    merger = Merger(cfg)
    print('[DEBUG] %s Parameters in Encoder: %d.' %
          (dt.now(), utils.network_utils.count_parameters(encoder)))
    print('[DEBUG] %s Parameters in Decoder: %d.' %
          (dt.now(), utils.network_utils.count_parameters(decoder)))
    print('[DEBUG] %s Parameters in Refiner: %d.' %
          (dt.now(), utils.network_utils.count_parameters(refiner)))
    print('[DEBUG] %s Parameters in Merger: %d.' %
          (dt.now(), utils.network_utils.count_parameters(merger)))

    # Initialize weights of networks
    encoder.apply(utils.network_utils.init_weights)
    decoder.apply(utils.network_utils.init_weights)
    refiner.apply(utils.network_utils.init_weights)
    merger.apply(utils.network_utils.init_weights)

    # Set up solver
    if cfg.TRAIN.POLICY == 'adam':
        encoder_solver = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                 encoder.parameters()),
                                          lr=cfg.TRAIN.ENCODER_LEARNING_RATE,
                                          betas=cfg.TRAIN.BETAS)
        decoder_solver = torch.optim.Adam(decoder.parameters(),
                                          lr=cfg.TRAIN.DECODER_LEARNING_RATE,
                                          betas=cfg.TRAIN.BETAS)
        refiner_solver = torch.optim.Adam(refiner.parameters(),
                                          lr=cfg.TRAIN.REFINER_LEARNING_RATE,
                                          betas=cfg.TRAIN.BETAS)
        merger_solver = torch.optim.Adam(merger.parameters(),
                                         lr=cfg.TRAIN.MERGER_LEARNING_RATE,
                                         betas=cfg.TRAIN.BETAS)
    elif cfg.TRAIN.POLICY == 'sgd':
        encoder_solver = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                                encoder.parameters()),
                                         lr=cfg.TRAIN.ENCODER_LEARNING_RATE,
                                         momentum=cfg.TRAIN.MOMENTUM)
        decoder_solver = torch.optim.SGD(decoder.parameters(),
                                         lr=cfg.TRAIN.DECODER_LEARNING_RATE,
                                         momentum=cfg.TRAIN.MOMENTUM)
        refiner_solver = torch.optim.SGD(refiner.parameters(),
                                         lr=cfg.TRAIN.REFINER_LEARNING_RATE,
                                         momentum=cfg.TRAIN.MOMENTUM)
        merger_solver = torch.optim.SGD(merger.parameters(),
                                        lr=cfg.TRAIN.MERGER_LEARNING_RATE,
                                        momentum=cfg.TRAIN.MOMENTUM)
    else:
        raise Exception('[FATAL] %s Unknown optimizer %s.' %
                        (dt.now(), cfg.TRAIN.POLICY))

    # Set up learning rate scheduler to decay learning rates dynamically
    encoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        encoder_solver,
        milestones=cfg.TRAIN.ENCODER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA)
    decoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        decoder_solver,
        milestones=cfg.TRAIN.DECODER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA)
    refiner_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        refiner_solver,
        milestones=cfg.TRAIN.REFINER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA)
    merger_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        merger_solver,
        milestones=cfg.TRAIN.MERGER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA)

    if torch.cuda.is_available():
        encoder = torch.nn.DataParallel(encoder).cuda()
        decoder = torch.nn.DataParallel(decoder).cuda()
        refiner = torch.nn.DataParallel(refiner).cuda()
        merger = torch.nn.DataParallel(merger).cuda()

    # Set up loss functions
    bce_loss = torch.nn.BCELoss()

    # Load pretrained model if exists
    init_epoch = 0
    best_iou = -1
    best_epoch = -1
    if 'WEIGHTS' in cfg.CONST and cfg.TRAIN.RESUME_TRAIN:
        print('[INFO] %s Recovering from %s ...' %
              (dt.now(), cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)
        init_epoch = checkpoint['epoch_idx']
        best_iou = checkpoint['best_iou']
        best_epoch = checkpoint['best_epoch']

        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder.load_state_dict(checkpoint['decoder_state_dict'])
        if cfg.NETWORK.USE_REFINER:
            refiner.load_state_dict(checkpoint['refiner_state_dict'])
        if cfg.NETWORK.USE_MERGER:
            merger.load_state_dict(checkpoint['merger_state_dict'])

        print('[INFO] %s Recover complete. Current epoch #%d, Best IoU = %.4f at epoch #%d.' \
                 % (dt.now(), init_epoch, best_iou, best_epoch))

    # Summary writer for TensorBoard
    output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s', dt.now().isoformat())
    log_dir = output_dir % 'logs'
    ckpt_dir = output_dir % 'checkpoints'
    train_writer = SummaryWriter(os.path.join(log_dir, 'train'))
    val_writer = SummaryWriter(os.path.join(log_dir, 'test'))

    # Training loop
    for epoch_idx in range(init_epoch, cfg.TRAIN.NUM_EPOCHES):
        # Tick / tock
        epoch_start_time = time()

        # Batch average meterics
        batch_time = utils.network_utils.AverageMeter()
        data_time = utils.network_utils.AverageMeter()
        encoder_losses = utils.network_utils.AverageMeter()
        refiner_losses = utils.network_utils.AverageMeter()

        # Adjust learning rate
        encoder_lr_scheduler.step()
        decoder_lr_scheduler.step()
        refiner_lr_scheduler.step()
        merger_lr_scheduler.step()

        # switch models to training mode
        encoder.train()
        decoder.train()
        merger.train()
        refiner.train()

        batch_end_time = time()
        n_batches = len(train_data_loader)
        for batch_idx, (taxonomy_names, sample_names, rendering_images,
                        ground_truth_volumes) in enumerate(train_data_loader):
            # Measure data time
            data_time.update(time() - batch_end_time)

            # Get data from data loader
            rendering_images = utils.network_utils.var_or_cuda(
                rendering_images)
            ground_truth_volumes = utils.network_utils.var_or_cuda(
                ground_truth_volumes)

            # Train the encoder, decoder, refiner, and merger
            image_features = encoder(rendering_images)
            raw_features, generated_volumes = decoder(image_features)

            if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
                generated_volumes = merger(raw_features, generated_volumes)
            else:
                generated_volumes = torch.mean(generated_volumes, dim=1)
            encoder_loss = bce_loss(generated_volumes,
                                    ground_truth_volumes) * 10

            if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
                generated_volumes = refiner(generated_volumes)
                refiner_loss = bce_loss(generated_volumes,
                                        ground_truth_volumes) * 10
            else:
                refiner_loss = encoder_loss

            # Gradient decent
            encoder.zero_grad()
            decoder.zero_grad()
            refiner.zero_grad()
            merger.zero_grad()

            if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
                encoder_loss.backward(retain_graph=True)
                refiner_loss.backward()
            else:
                encoder_loss.backward()

            encoder_solver.step()
            decoder_solver.step()
            refiner_solver.step()
            merger_solver.step()

            # Append loss to average metrics
            encoder_losses.update(encoder_loss.item())
            refiner_losses.update(refiner_loss.item())
            # Append loss to TensorBoard
            n_itr = epoch_idx * n_batches + batch_idx
            train_writer.add_scalar('EncoderDecoder/BatchLoss',
                                    encoder_loss.item(), n_itr)
            train_writer.add_scalar('Refiner/BatchLoss', refiner_loss.item(),
                                    n_itr)

            # Tick / tock
            batch_time.update(time() - batch_end_time)
            batch_end_time = time()
            print('[INFO] %s [Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) EDLoss = %.4f RLoss = %.4f' % \
                (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, batch_idx + 1, n_batches, \
                    batch_time.val, data_time.val, encoder_loss.item(), refiner_loss.item()))

        # Append epoch loss to TensorBoard
        train_writer.add_scalar('EncoderDecoder/EpochLoss', encoder_losses.avg,
                                epoch_idx + 1)
        train_writer.add_scalar('Refiner/EpochLoss', refiner_losses.avg,
                                epoch_idx + 1)

        # Tick / tock
        epoch_end_time = time()
        print('[INFO] %s Epoch [%d/%d] EpochTime = %.3f (s) EDLoss = %.4f RLoss = %.4f' %
            (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, epoch_end_time - epoch_start_time, \
                encoder_losses.avg, refiner_losses.avg))

        # Update Rendering Views
        if cfg.TRAIN.UPDATE_N_VIEWS_RENDERING:
            n_views_rendering = random.randint(1, cfg.CONST.N_VIEWS_RENDERING)
            train_data_loader.dataset.set_n_views_rendering(n_views_rendering)
            print('[INFO] %s Epoch [%d/%d] Update #RenderingViews to %d' % \
                (dt.now(), epoch_idx + 2, cfg.TRAIN.NUM_EPOCHES, n_views_rendering))

        # Validate the training models
        iou = test_net(cfg, epoch_idx + 1, output_dir, val_data_loader,
                       val_writer, encoder, decoder, refiner, merger)

        # Save weights to file
        if (epoch_idx + 1) % cfg.TRAIN.SAVE_FREQ == 0:
            if not os.path.exists(ckpt_dir):
                os.makedirs(ckpt_dir)

            utils.network_utils.save_checkpoints(cfg, \
                    os.path.join(ckpt_dir, 'ckpt-epoch-%04d.pth' % (epoch_idx + 1)), \
                    epoch_idx + 1, encoder, encoder_solver, decoder, decoder_solver, \
                    refiner, refiner_solver, merger, merger_solver, best_iou, best_epoch)
        if iou > best_iou:
            if not os.path.exists(ckpt_dir):
                os.makedirs(ckpt_dir)

            best_iou = iou
            best_epoch = epoch_idx + 1
            utils.network_utils.save_checkpoints(cfg, \
                    os.path.join(ckpt_dir, 'best-ckpt.pth'), \
                    epoch_idx + 1, encoder, encoder_solver, decoder, decoder_solver, \
                    refiner, refiner_solver, merger, merger_solver, best_iou, best_epoch)

    # Close SummaryWriter for TensorBoard
    train_writer.close()
    val_writer.close()
def test_single_img(cfg):
    encoder = Encoder(cfg)
    decoder = Decoder(cfg)
    refiner = Refiner(cfg)
    merger = Merger(cfg)

    cfg.CONST.WEIGHTS = 'D:/Pix2Vox/Pix2Vox/pretrained/Pix2Vox-A-ShapeNet.pth'
    checkpoint = torch.load(cfg.CONST.WEIGHTS, map_location=torch.device('cpu'))

    fix_checkpoint = {}
    fix_checkpoint['encoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['encoder_state_dict'].items())
    fix_checkpoint['decoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['decoder_state_dict'].items())
    fix_checkpoint['refiner_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['refiner_state_dict'].items())
    fix_checkpoint['merger_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['merger_state_dict'].items())

    epoch_idx = checkpoint['epoch_idx']
    encoder.load_state_dict(fix_checkpoint['encoder_state_dict'])
    decoder.load_state_dict(fix_checkpoint['decoder_state_dict'])

    if cfg.NETWORK.USE_REFINER:
        print('Use refiner')
        refiner.load_state_dict(fix_checkpoint['refiner_state_dict'])
    if cfg.NETWORK.USE_MERGER:
        print('Use merger')
        merger.load_state_dict(fix_checkpoint['merger_state_dict'])


    encoder.eval()
    decoder.eval()
    refiner.eval()
    merger.eval()

    img1_path = 'D:/Pix2Vox/Pix2Vox/rand/minecraft.png'
    img1_np = cv2.imread(img1_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.

    sample = np.array([img1_np])

    IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
    CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W

    test_transforms = utils.data_transforms.Compose([
        utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD),
        utils.data_transforms.ToTensor(),
    ])

    rendering_images = test_transforms(rendering_images=sample)
    rendering_images = rendering_images.unsqueeze(0)

    with torch.no_grad():
        image_features = encoder(rendering_images)
        raw_features, generated_volume = decoder(image_features)

        if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
            generated_volume = merger(raw_features, generated_volume)
        else:
            generated_volume = torch.mean(generated_volume, dim=1)

        if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
            generated_volume = refiner(generated_volume)

    generated_volume = generated_volume.squeeze(0)

    img_dir = 'D:/Pix2Vox/Pix2Vox/output'
    gv = generated_volume.cpu().numpy()
    gv_new = np.swapaxes(gv, 2, 1)
    print(gv_new)
    rendering_views = utils.binvox_visualization.get_volume_views(gv_new, os.path.join(img_dir),
                                                                                        epoch_idx)
Ejemplo n.º 8
0
def test_net(cfg, epoch_idx=-1, output_dir=None, test_data_loader=None, \
        test_writer=None, encoder=None, decoder=None, refiner=None, merger=None):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    # Load taxonomies of dataset
    taxonomies = []
    with open(cfg.DATASETS[cfg.DATASET.TEST_DATASET.upper()].TAXONOMY_FILE_PATH, encoding='utf-8') as file:
        taxonomies = json.loads(file.read())
    taxonomies = {t['taxonomy_id']: t for t in taxonomies}

    # Set up data loader
    if test_data_loader is None:
        # Set up data augmentation
        IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
        CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
        test_transforms = utils.data_transforms.Compose([
            utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
            utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
            utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD),
            utils.data_transforms.ToTensor(),
        ])

        dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg)
        test_data_loader = torch.utils.data.DataLoader(
            dataset=dataset_loader.get_dataset(utils.data_loaders.DatasetType.TEST,
                                               cfg.CONST.N_VIEWS_RENDERING, test_transforms),
            batch_size=1,
            num_workers=1,
            pin_memory=True,
            shuffle=False)

    # Set up networks
    if decoder is None or encoder is None:
        encoder = Encoder(cfg)
        decoder = Decoder(cfg)
        refiner = Refiner(cfg)
        merger = Merger(cfg)

        if torch.cuda.is_available():
            encoder = torch.nn.DataParallel(encoder).cuda()
            decoder = torch.nn.DataParallel(decoder).cuda()
            refiner = torch.nn.DataParallel(refiner).cuda()
            merger = torch.nn.DataParallel(merger).cuda()

        print('[INFO] %s Loading weights from %s ...' % (dt.now(), cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)
        epoch_idx = checkpoint['epoch_idx']
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder.load_state_dict(checkpoint['decoder_state_dict'])

        if cfg.NETWORK.USE_REFINER:
            refiner.load_state_dict(checkpoint['refiner_state_dict'])
        if cfg.NETWORK.USE_MERGER:
            merger.load_state_dict(checkpoint['merger_state_dict'])

    # Set up loss functions
    bce_loss = torch.nn.BCELoss()

    # Testing loop
    n_samples = len(test_data_loader)
    test_iou = dict()
    encoder_losses = utils.network_utils.AverageMeter()
    refiner_losses = utils.network_utils.AverageMeter()

    # Switch models to evaluation mode
    encoder.eval()
    decoder.eval()
    refiner.eval()
    merger.eval()

    for sample_idx, (taxonomy_id, sample_name, rendering_images, ground_truth_volume) in enumerate(test_data_loader):
        taxonomy_id = taxonomy_id[0] if isinstance(taxonomy_id[0], str) else taxonomy_id[0].item()
        sample_name = sample_name[0]

        with torch.no_grad():
            # Get data from data loader
            rendering_images = utils.network_utils.var_or_cuda(rendering_images)
            ground_truth_volume = utils.network_utils.var_or_cuda(ground_truth_volume)

            # Test the encoder, decoder, refiner and merger
            image_features = encoder(rendering_images)
            raw_features, generated_volume = decoder(image_features)

            if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
                generated_volume = merger(raw_features, generated_volume)
            else:
                generated_volume = torch.mean(generated_volume, dim=1)
            encoder_loss = bce_loss(generated_volume, ground_truth_volume) * 10

            if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
                generated_volume = refiner(generated_volume)
                refiner_loss = bce_loss(generated_volume, ground_truth_volume) * 10
            else:
                refiner_loss = encoder_loss

            print("vox shape {}".format(generated_volume.shape))

            # Append loss and accuracy to average metrics
            encoder_losses.update(encoder_loss.item())
            refiner_losses.update(refiner_loss.item())

            # IoU per sample
            sample_iou = []
            for th in cfg.TEST.VOXEL_THRESH:
                _volume = torch.ge(generated_volume, th).float()
                intersection = torch.sum(_volume.mul(ground_truth_volume)).float()
                union = torch.sum(torch.ge(_volume.add(ground_truth_volume), 1)).float()
                sample_iou.append((intersection / union).item())

            # IoU per taxonomy
            if not taxonomy_id in test_iou:
                test_iou[taxonomy_id] = {'n_samples': 0, 'iou': []}
            test_iou[taxonomy_id]['n_samples'] += 1
            test_iou[taxonomy_id]['iou'].append(sample_iou)

            # Append generated volumes to TensorBoard
            if output_dir and sample_idx < 3:
                img_dir = output_dir % 'images'
                # Volume Visualization
                gv = generated_volume.cpu().numpy()
                rendering_views = utils.binvox_visualization.get_volume_views(gv, os.path.join(img_dir, 'test'),
                                                                              epoch_idx)
                if not test_writer is None:
                    test_writer.add_image('Test Sample#%02d/Volume Reconstructed' % sample_idx, rendering_views, epoch_idx)
                gtv = ground_truth_volume.cpu().numpy()
                rendering_views = utils.binvox_visualization.get_volume_views(gtv, os.path.join(img_dir, 'test'),
                                                                              epoch_idx)
                if not test_writer is None:
                    test_writer.add_image('Test Sample#%02d/Volume GroundTruth' % sample_idx, rendering_views, epoch_idx)

            # Print sample loss and IoU
            print('[INFO] %s Test[%d/%d] Taxonomy = %s Sample = %s EDLoss = %.4f RLoss = %.4f IoU = %s' % \
                (dt.now(), sample_idx + 1, n_samples, taxonomy_id, sample_name, encoder_loss.item(), \
                    refiner_loss.item(), ['%.4f' % si for si in sample_iou]))

    # Output testing results
    mean_iou = []
    for taxonomy_id in test_iou:
        test_iou[taxonomy_id]['iou'] = np.mean(test_iou[taxonomy_id]['iou'], axis=0)
        mean_iou.append(test_iou[taxonomy_id]['iou'] * test_iou[taxonomy_id]['n_samples'])
    mean_iou = np.sum(mean_iou, axis=0) / n_samples

    # Print header
    print('============================ TEST RESULTS ============================')
    print('Taxonomy', end='\t')
    print('#Sample', end='\t')
    print('Baseline', end='\t')
    for th in cfg.TEST.VOXEL_THRESH:
        print('t=%.2f' % th, end='\t')
    print()
    # Print body
    for taxonomy_id in test_iou:
        print('%s' % taxonomies[taxonomy_id]['taxonomy_name'].ljust(8), end='\t')
        print('%d' % test_iou[taxonomy_id]['n_samples'], end='\t')
        if 'baseline' in taxonomies[taxonomy_id]:
            print('%.4f' % taxonomies[taxonomy_id]['baseline']['%d-view' % cfg.CONST.N_VIEWS_RENDERING], end='\t\t')
        else:
            print('N/a', end='\t\t')

        for ti in test_iou[taxonomy_id]['iou']:
            print('%.4f' % ti, end='\t')
        print()
    # Print mean IoU for each threshold
    print('Overall ', end='\t\t\t\t')
    for mi in mean_iou:
        print('%.4f' % mi, end='\t')
    print('\n')

    # Add testing results to TensorBoard
    max_iou = np.max(mean_iou)
    if not test_writer is None:
        test_writer.add_scalar('EncoderDecoder/EpochLoss', encoder_losses.avg, epoch_idx)
        test_writer.add_scalar('Refiner/EpochLoss', refiner_losses.avg, epoch_idx)
        test_writer.add_scalar('Refiner/IoU', max_iou, epoch_idx)

    return max_iou
Ejemplo n.º 9
0
def train_net(cfg):
    # Set up data augmentation
    IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
    CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
    train_transforms = utils.data_transforms.Compose([
        utils.data_transforms.RandomCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(
            cfg.TRAIN.RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.ColorJitter(cfg.TRAIN.BRIGHTNESS,
                                          cfg.TRAIN.CONTRAST,
                                          cfg.TRAIN.SATURATION),
        utils.data_transforms.RandomNoise(cfg.TRAIN.NOISE_STD),
        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                        std=cfg.DATASET.STD),
        utils.data_transforms.RandomFlip(),
        utils.data_transforms.RandomPermuteRGB(),
        utils.data_transforms.ToTensor(),
    ])
    val_transforms = utils.data_transforms.Compose([
        utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                        std=cfg.DATASET.STD),
        utils.data_transforms.ToTensor(),
    ])

    # Set up data loader
    train_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
        cfg.DATASET.TRAIN_DATASET](cfg)
    val_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
        cfg.DATASET.TEST_DATASET](cfg)
    train_data_loader = paddle.io.DataLoader(
        dataset=train_dataset_loader.get_dataset(
            utils.data_loaders.DatasetType.TRAIN, cfg.CONST.N_VIEWS_RENDERING,
            train_transforms),
        batch_size=cfg.CONST.BATCH_SIZE,
        #num_workers=0  , # cfg.TRAIN.NUM_WORKER>0时报错,因为dev/shm/太小  https://blog.csdn.net/ctypyb2002/article/details/107914643
        #pin_memory=True,
        use_shared_memory=False,
        shuffle=True,
        drop_last=True)
    val_data_loader = paddle.io.DataLoader(
        dataset=val_dataset_loader.get_dataset(
            utils.data_loaders.DatasetType.VAL, cfg.CONST.N_VIEWS_RENDERING,
            val_transforms),
        batch_size=1,
        #num_workers=1,
        #pin_memory=True,
        shuffle=False)

    # Set up networks # paddle.Model prepare fit save
    encoder = Encoder(cfg)
    decoder = Decoder(cfg)
    merger = Merger(cfg)
    refiner = Refiner(cfg)
    print('[DEBUG] %s Parameters in Encoder: %d.' %
          (dt.now(), utils.network_utils.count_parameters(encoder)))
    print('[DEBUG] %s Parameters in Decoder: %d.' %
          (dt.now(), utils.network_utils.count_parameters(decoder)))
    print('[DEBUG] %s Parameters in Merger: %d.' %
          (dt.now(), utils.network_utils.count_parameters(merger)))
    print('[DEBUG] %s Parameters in Refiner: %d.' %
          (dt.now(), utils.network_utils.count_parameters(refiner)))

    # # Initialize weights of networks # paddle的参数化不同,参见API
    # encoder.apply(utils.network_utils.init_weights)
    # decoder.apply(utils.network_utils.init_weights)
    # merger.apply(utils.network_utils.init_weights)

    # Set up learning rate scheduler to decay learning rates dynamically
    encoder_lr_scheduler = paddle.optimizer.lr.MultiStepDecay(
        learning_rate=cfg.TRAIN.ENCODER_LEARNING_RATE,
        milestones=cfg.TRAIN.ENCODER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA,
        verbose=True)
    decoder_lr_scheduler = paddle.optimizer.lr.MultiStepDecay(
        learning_rate=cfg.TRAIN.DECODER_LEARNING_RATE,
        milestones=cfg.TRAIN.DECODER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA,
        verbose=True)
    merger_lr_scheduler = paddle.optimizer.lr.MultiStepDecay(
        learning_rate=cfg.TRAIN.MERGER_LEARNING_RATE,
        milestones=cfg.TRAIN.MERGER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA,
        verbose=True)
    refiner_lr_scheduler = paddle.optimizer.lr.MultiStepDecay(
        learning_rate=cfg.TRAIN.REFINER_LEARNING_RATE,
        milestones=cfg.TRAIN.REFINER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA,
        verbose=True)
    # Set up solver
    # if cfg.TRAIN.POLICY == 'adam':
    encoder_solver = paddle.optimizer.Adam(learning_rate=encoder_lr_scheduler,
                                           parameters=encoder.parameters())
    decoder_solver = paddle.optimizer.Adam(learning_rate=decoder_lr_scheduler,
                                           parameters=decoder.parameters())
    merger_solver = paddle.optimizer.Adam(learning_rate=merger_lr_scheduler,
                                          parameters=merger.parameters())
    refiner_solver = paddle.optimizer.Adam(learning_rate=refiner_lr_scheduler,
                                           parameters=refiner.parameters())

    # if torch.cuda.is_available():
    #     encoder = torch.nn.DataParallel(encoder).cuda()
    #     decoder = torch.nn.DataParallel(decoder).cuda()
    #     merger = torch.nn.DataParallel(merger).cuda()

    # Set up loss functions
    bce_loss = paddle.nn.BCELoss()

    # Load pretrained model if exists
    init_epoch = 0
    best_iou = -1
    best_epoch = -1
    if 'WEIGHTS' in cfg.CONST and cfg.TRAIN.RESUME_TRAIN:
        print('[INFO] %s Recovering from %s ...' %
              (dt.now(), cfg.CONST.WEIGHTS))
        # load
        encoder_state_dict = paddle.load(
            os.path.join(cfg.CONST.WEIGHTS, "encoder.pdparams"))
        encoder_solver_state_dict = paddle.load(
            os.path.join(cfg.CONST.WEIGHTS, "encoder_solver.pdopt"))
        encoder.set_state_dict(encoder_state_dict)
        encoder_solver.set_state_dict(encoder_solver_state_dict)
        decoder_state_dict = paddle.load(
            os.path.join(cfg.CONST.WEIGHTS, "decoder.pdparams"))
        decoder_solver_state_dict = paddle.load(
            os.path.join(cfg.CONST.WEIGHTS, "decoder_solver.pdopt"))
        decoder.set_state_dict(decoder_state_dict)
        decoder_solver.set_state_dict(decoder_solver_state_dict)

        if cfg.NETWORK.USE_MERGER:
            merger_state_dict = paddle.load(
                os.path.join(cfg.CONST.WEIGHTS, "merger.pdparams"))
            merger_solver_state_dict = paddle.load(
                os.path.join(cfg.CONST.WEIGHTS, "merger_solver.pdopt"))
            merger.set_state_dict(merger_state_dict)
            merger_solver.set_state_dict(merger_solver_state_dict)

        if cfg.NETWORK.USE_REFINER:
            refiner_state_dict = paddle.load(
                os.path.join(cfg.CONST.WEIGHTS, "refiner.pdparams"))
            refiner_solver_state_dict = paddle.load(
                os.path.join(cfg.CONST.WEIGHTS, "refiner_solver.pdopt"))
            refiner.set_state_dict(refiner_state_dict)
            refiner_solver.set_state_dict(refiner_solver_state_dict)

        print(
            '[INFO] %s Recover complete. Current epoch #%d, Best IoU = %.4f at epoch #%d.'
            % (dt.now(), init_epoch, best_iou, best_epoch))

    # Summary writer for TensorBoard
    output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s', dt.now().isoformat())
    log_dir = output_dir % 'logs'
    ckpt_dir = output_dir % 'checkpoints'
    # train_writer = SummaryWriter()
    # val_writer = SummaryWriter(os.path.join(log_dir, 'test'))
    train_writer = LogWriter(os.path.join(log_dir, 'train'))
    val_writer = LogWriter(os.path.join(log_dir, 'val'))

    # Training loop
    for epoch_idx in range(init_epoch, cfg.TRAIN.NUM_EPOCHES):
        # Tick / tock
        epoch_start_time = time()

        # Batch average meterics
        batch_time = utils.network_utils.AverageMeter()
        data_time = utils.network_utils.AverageMeter()
        encoder_losses = utils.network_utils.AverageMeter()
        refiner_losses = utils.network_utils.AverageMeter()

        # # switch models to training mode
        encoder.train()
        decoder.train()
        merger.train()
        refiner.train()

        batch_end_time = time()
        n_batches = len(train_data_loader)

        # print("****debug: length of train data loder",n_batches)
        for batch_idx, (rendering_images, ground_truth_volumes) in enumerate(
                train_data_loader()):
            # # debug
            # if batch_idx>1:
            #     break

            # Measure data time
            data_time.update(time() - batch_end_time)
            # print("****debug: batch_idx",batch_idx)
            # print(rendering_images.shape)
            # print(ground_truth_volumes.shape)
            # Get data from data loader
            rendering_images = utils.network_utils.var_or_cuda(
                rendering_images)
            ground_truth_volumes = utils.network_utils.var_or_cuda(
                ground_truth_volumes)

            # Train the encoder, decoder, and merger
            image_features = encoder(rendering_images)
            raw_features, generated_volumes = decoder(image_features)

            if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
                generated_volumes = merger(raw_features, generated_volumes)
            # else:
            #     mergered_volumes = paddle.mean(generated_volumes, aixs=1)

            encoder_loss = bce_loss(generated_volumes,
                                    ground_truth_volumes) * 10

            if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
                generated_volumes = refiner(generated_volumes)
                refiner_loss = bce_loss(generated_volumes,
                                        ground_truth_volumes) * 10
            # else:
            #     refiner_loss = encoder_loss

            # Gradient decent
            encoder_solver.clear_grad()
            decoder_solver.clear_grad()
            merger_solver.clear_grad()
            refiner_solver.clear_grad()

            if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
                encoder_loss.backward(retain_graph=True)
                refiner_loss.backward()
            # else:
            #     encoder_loss.backward()

            encoder_solver.step()
            decoder_solver.step()
            merger_solver.step()
            refiner_solver.step()

            # Append loss to average metrics
            encoder_losses.update(encoder_loss.numpy())
            refiner_losses.update(refiner_loss.numpy())

            # Append loss to TensorBoard
            n_itr = epoch_idx * n_batches + batch_idx
            train_writer.add_scalar(tag='EncoderDecoder/BatchLoss',
                                    step=n_itr,
                                    value=encoder_loss.numpy())
            train_writer.add_scalar('Refiner/BatchLoss',
                                    value=refiner_loss.numpy(),
                                    step=n_itr)

            # Tick / tock
            batch_time.update(time() - batch_end_time)
            batch_end_time = time()
            if (batch_idx % int(cfg.CONST.INFO_BATCH)) == 0:
                print(
                    '[INFO] %s [Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) EDLoss = %.4f RLoss = %.4f'
                    % (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES,
                       batch_idx + 1, n_batches, batch_time.val, data_time.val,
                       encoder_loss.numpy(), refiner_loss.numpy()))

        # Append epoch loss to TensorBoard
        train_writer.add_scalar(tag='EncoderDecoder/EpochLoss',
                                step=epoch_idx + 1,
                                value=encoder_losses.avg)
        train_writer.add_scalar('Refiner/EpochLoss',
                                value=refiner_losses.avg,
                                step=epoch_idx + 1)

        # update scheduler each step
        encoder_lr_scheduler.step()
        decoder_lr_scheduler.step()
        merger_lr_scheduler.step()
        refiner_lr_scheduler.step()

        # Tick / tock
        epoch_end_time = time()
        print(
            '[INFO] %s Epoch [%d/%d] EpochTime = %.3f (s) EDLoss = %.4f RLoss = %.4f'
            % (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, epoch_end_time -
               epoch_start_time, encoder_losses.avg, refiner_losses.avg))

        # Update Rendering Views
        if cfg.TRAIN.UPDATE_N_VIEWS_RENDERING:
            n_views_rendering = random.randint(1, cfg.CONST.N_VIEWS_RENDERING)
            train_data_loader.dataset.set_n_views_rendering(n_views_rendering)
            print('[INFO] %s Epoch [%d/%d] Update #RenderingViews to %d' %
                  (dt.now(), epoch_idx + 2, cfg.TRAIN.NUM_EPOCHES,
                   n_views_rendering))

        # Validate the training models
        iou = test_net(cfg, epoch_idx + 1, output_dir, val_data_loader,
                       val_writer, encoder, decoder, merger, refiner)

        # Save weights to file
        if (epoch_idx + 1) % cfg.TRAIN.SAVE_FREQ == 0:
            if not os.path.exists(ckpt_dir):
                os.makedirs(ckpt_dir)

            utils.network_utils.save_checkpoints(
                cfg, os.path.join(ckpt_dir,
                                  'ckpt-epoch-%04d' % (epoch_idx + 1)),
                epoch_idx + 1, encoder, encoder_solver, decoder,
                decoder_solver, merger, merger_solver, refiner, refiner_solver,
                best_iou, best_epoch)
        if iou > best_iou:
            if not os.path.exists(ckpt_dir):
                os.makedirs(ckpt_dir)

            best_iou = iou
            best_epoch = epoch_idx + 1
            utils.network_utils.save_checkpoints(
                cfg, os.path.join(ckpt_dir, 'best-ckpt'), epoch_idx + 1,
                encoder, encoder_solver, decoder, decoder_solver, merger,
                merger_solver, refiner, refiner_solver, best_iou, best_epoch)
Ejemplo n.º 10
0
def test_single_img_net(cfg):

    encoder = Encoder(cfg)
    decoder = Decoder(cfg)
    refiner = Refiner(cfg)
    merger = Merger(cfg)

    print('[INFO] %s Loading weights from %s ...' %
          (dt.now(), cfg.CONST.WEIGHTS))
    checkpoint = torch.load(cfg.CONST.WEIGHTS,
                            map_location=torch.device('cpu'))

    fix_checkpoint = {}
    fix_checkpoint['encoder_state_dict'] = OrderedDict(
        (k.split('module.')[1:][0], v)
        for k, v in checkpoint['encoder_state_dict'].items())
    fix_checkpoint['decoder_state_dict'] = OrderedDict(
        (k.split('module.')[1:][0], v)
        for k, v in checkpoint['decoder_state_dict'].items())
    fix_checkpoint['refiner_state_dict'] = OrderedDict(
        (k.split('module.')[1:][0], v)
        for k, v in checkpoint['refiner_state_dict'].items())
    fix_checkpoint['merger_state_dict'] = OrderedDict(
        (k.split('module.')[1:][0], v)
        for k, v in checkpoint['merger_state_dict'].items())

    epoch_idx = checkpoint['epoch_idx']
    encoder.load_state_dict(fix_checkpoint['encoder_state_dict'])
    decoder.load_state_dict(fix_checkpoint['decoder_state_dict'])

    if cfg.NETWORK.USE_REFINER:
        print('Use refiner')
        refiner.load_state_dict(fix_checkpoint['refiner_state_dict'])
    if cfg.NETWORK.USE_MERGER:
        print('Use merger')
        merger.load_state_dict(fix_checkpoint['merger_state_dict'])

    encoder.eval()
    decoder.eval()
    refiner.eval()
    merger.eval()

    img1_path = '/media/caig/FECA2C89CA2C406F/dataset/ShapeNetRendering_copy/03001627/1a74a83fa6d24b3cacd67ce2c72c02e/rendering/00.png'
    img1_np = cv2.imread(img1_path, cv2.IMREAD_UNCHANGED).astype(
        np.float32) / 255.

    sample = np.array([img1_np])

    IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
    CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W

    test_transforms = utils.data_transforms.Compose([
        utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                        std=cfg.DATASET.STD),
        utils.data_transforms.ToTensor(),
    ])

    rendering_images = test_transforms(rendering_images=sample)
    rendering_images = rendering_images.unsqueeze(0)

    with torch.no_grad():
        image_features = encoder(rendering_images)
        raw_features, generated_volume = decoder(image_features)

        if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
            generated_volume = merger(raw_features, generated_volume)
        else:
            generated_volume = torch.mean(generated_volume, dim=1)

        if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
            generated_volume = refiner(generated_volume)

    generated_volume = generated_volume.squeeze(0)

    img_dir = '/media/caig/FECA2C89CA2C406F/sketch3D/sketch3D/test_output'
    gv = generated_volume.cpu().numpy()
    gv_new = np.swapaxes(gv, 2, 1)
    rendering_views = utils.binvox_visualization.get_volume_views(
        gv_new, os.path.join(img_dir), epoch_idx)
Ejemplo n.º 11
0
def test_img(cfg):

    encoder = Encoder(cfg)
    decoder = Decoder(cfg)
    refiner = Refiner(cfg)
    merger = Merger(cfg)

    cfg.CONST.WEIGHTS = '/Users/pranavpomalapally/Downloads/new-Pix2Vox-A-ShapeNet.pth'
    checkpoint = torch.load(cfg.CONST.WEIGHTS,
                            map_location=torch.device('cpu'))

    print()
    # fix_checkpoint = {}
    # fix_checkpoint['encoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['encoder_state_dict'].items())
    # fix_checkpoint['decoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['decoder_state_dict'].items())
    # fix_checkpoint['refiner_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['refiner_state_dict'].items())
    # fix_checkpoint['merger_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['merger_state_dict'].items())

    # fix_checkpoint['encoder_state_dict'] = OrderedDict((k.split('module.')[0], v) for k, v in checkpoint['encoder_state_dict'].items())
    # fix_checkpoint['decoder_state_dict'] = OrderedDict((k.split('module.')[0], v) for k, v in checkpoint['decoder_state_dict'].items())
    # fix_checkpoint['refiner_state_dict'] = OrderedDict((k.split('module.')[0], v) for k, v in checkpoint['refiner_state_dict'].items())
    # fix_checkpoint['merger_state_dict'] = OrderedDict((k.split('module.')[0], v) for k, v in checkpoint['merger_state_dict'].items())

    epoch_idx = checkpoint['epoch_idx']
    # encoder.load_state_dict(fix_checkpoint['encoder_state_dict'])
    # decoder.load_state_dict(fix_checkpoint['decoder_state_dict'])
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_state_dict'])

    # if cfg.NETWORK.USE_REFINER:
    #  print('Use refiner')
    #  refiner.load_state_dict(fix_checkpoint['refiner_state_dict'])

    print('Use refiner')
    refiner.load_state_dict(checkpoint['refiner_state_dict'])
    if cfg.NETWORK.USE_MERGER:
        print('Use merger')
        # merger.load_state_dict(fix_checkpoint['merger_state_dict'])
        merger.load_state_dict(checkpoint['merger_state_dict'])

    encoder.eval()
    decoder.eval()
    refiner.eval()
    merger.eval()

    #img1_path = '/Users/pranavpomalapally/Downloads/ShapeNetRendering/02691156/1a04e3eab45ca15dd86060f189eb133/rendering/00.png'
    img1_path = '/Users/pranavpomalapally/Downloads/09 copy.png'
    img1_np = cv2.imread(img1_path, cv2.IMREAD_UNCHANGED).astype(
        np.float32) / 255.

    sample = np.array([img1_np])

    IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
    CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W

    test_transforms = utils.data_transforms.Compose([
        utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                        std=cfg.DATASET.STD),
        utils.data_transforms.ToTensor(),
    ])

    rendering_images = test_transforms(rendering_images=sample)
    rendering_images = rendering_images.unsqueeze(0)

    with torch.no_grad():
        image_features = encoder(rendering_images)
        raw_features, generated_volume = decoder(image_features)

        if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
            generated_volume = merger(raw_features, generated_volume)
        else:
            generated_volume = torch.mean(generated_volume, dim=1)

        # if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
        #     generated_volume = refiner(generated_volume)
    generated_volume = refiner(generated_volume)
    generated_volume = generated_volume.squeeze(0)

    img_dir = '/Users/pranavpomalapally/Downloads/outputs'
    # gv = generated_volume.cpu().numpy()
    gv = generated_volume.cpu().detach().numpy()
    gv_new = np.swapaxes(gv, 2, 1)

    os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
    rendering_views = utils.binvox_visualization.get_volume_views(
        gv_new, img_dir, epoch_idx)
Ejemplo n.º 12
0
def test_net(cfg,
             epoch_idx=-1,
             output_dir=None,
             test_data_loader=None,
             test_writer=None,
             encoder=None,
             decoder=None,
             merger=None):
   
    # Load taxonomies of dataset
    taxonomies = []
    with open(cfg.DATASETS[cfg.DATASET.TEST_DATASET.upper()].TAXONOMY_FILE_PATH, encoding='utf-8') as file:
        taxonomies = json.loads(file.read())
    taxonomies = {t['taxonomy_id']: t for t in taxonomies}

    # # Set up data loader
    if test_data_loader is None:
        # Set up data augmentation
        IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
        CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
        test_transforms = utils.data_transforms.Compose([
            utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
            utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
            utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD),
            utils.data_transforms.ToTensor(),
        ])

        dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg)
        test_data_loader = paddle.io.DataLoader(dataset=dataset_loader.get_dataset(
            utils.data_loaders.DatasetType.TEST, cfg.CONST.N_VIEWS_RENDERING, test_transforms),
                                                       batch_size=1,
                                                    #    num_workers=1,
                                                       shuffle=False)
        mode = 'test'
    else:
        mode = 'val'

    
    # paddle.io.Dataset not support 'str' input
    dataset_taxonomy = None
    rendering_image_path_template = cfg.DATASETS.SHAPENET.RENDERING_PATH
    volume_path_template = cfg.DATASETS.SHAPENET.VOXEL_PATH

    # Load all taxonomies of the dataset
    with open('./datasets/ShapeNet.json', encoding='utf-8') as file:
        dataset_taxonomy = json.loads(file.read())
        # print("[INFO]TEST-- open TAXONOMY_FILE_PATH succeess")

    all_test_taxonomy_id_and_sample_name = []
    # Load data for each category
    for taxonomy in dataset_taxonomy:
        taxonomy_folder_name = taxonomy['taxonomy_id']
        # print('[INFO] %set -- Collecting files of Taxonomy[ID=%s, Name=%s]' %
        #         (mode, taxonomy['taxonomy_id'], taxonomy['taxonomy_name']))
        samples = taxonomy[mode]
        for sample in samples:
            all_test_taxonomy_id_and_sample_name.append([taxonomy_folder_name, sample])
    # print(len(all_test_taxonomy_id_and_sample_name))
    # print(all_test_taxonomy_id_and_sample_name)
    print('[INFO] Collected files of %set' % (mode))   
    # Set up networks
    if decoder is None or encoder is None:
        encoder = Encoder(cfg)
        decoder = Decoder(cfg)
        merger = Merger(cfg)

        # if torch.cuda.is_available():
        #     encoder = paddle.DataParallel(encoder)
        #     decoder = paddle.DataParallel(decoder)
        #     merger = paddle.DataParallel(merger)

        print('[INFO] %s Loading weights from %s ...' % (dt.now(), cfg.CONST.WEIGHTS))
        encoder_state_dict = paddle.load(os.path.join(cfg.CONST.WEIGHTS, "encoder.pdparams"))
        # encoder_solver_state_dict = paddle.load(os.path.join(cfg.CONST.WEIGHTS, "encoder_solver.pdopt"))
        encoder.set_state_dict(encoder_state_dict)
        # encoder_solver.set_state_dict(encoder_solver_state_dict)
        decoder_state_dict = paddle.load(os.path.join(cfg.CONST.WEIGHTS, "decoder.pdparams"))
        # decoder_solver_state_dict = paddle.load(os.path.join(cfg.CONST.WEIGHTS, "decoder_solver.pdopt"))
        decoder.set_state_dict(decoder_state_dict)
        # decoder_solver.set_state_dict(decoder_solver_state_dict)

        if cfg.NETWORK.USE_MERGER:
            merger_state_dict = paddle.load(os.path.join(cfg.CONST.WEIGHTS, "merger.pdparams"))
            # merger_solver_state_dict = paddle.load(os.path.join(cfg.CONST.WEIGHTS, "merger_solver.pdopt"))
            merger.set_state_dict(merger_state_dict)
            # merger_solver.set_state_dict(merger_solver_state_dict)

    # Set up loss functions
    bce_loss = paddle.nn.BCELoss()

    # Testing loop
    n_samples = len(test_data_loader)
    test_iou = dict()
    encoder_losses = utils.network_utils.AverageMeter()

    # Switch models to evaluation mode
    encoder.eval()
    decoder.eval()
    merger.eval()

    for sample_idx, (rendering_images, ground_truth_volume) in enumerate(test_data_loader):
        taxonomy_id = all_test_taxonomy_id_and_sample_name[sample_idx][0]
        sample_name = all_test_taxonomy_id_and_sample_name[sample_idx][1]
        # print("all_test_taxonomy_id_and_sample_name")
        # print(taxonomy_id)
        # print(sample_name)

        with paddle.no_grad():
            # Get data from data loader
            # rendering_images = utils.network_utils.var_or_cuda(rendering_images)
            # ground_truth_volume = utils.network_utils.var_or_cuda(ground_truth_volume)

            # Test the encoder, decoder and merger
            image_features = encoder(rendering_images)
            raw_features, generated_volume = decoder(image_features)

            if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
                generated_volume = merger(raw_features, generated_volume)
            else:
                generated_volume = paddle.mean(generated_volume, axis=1)

            encoder_loss = bce_loss(generated_volume, ground_truth_volume) * 10

            # Append loss and accuracy to average metrics
            encoder_losses.update(encoder_loss)

            # IoU per sample
            sample_iou = []
            for th in cfg.TEST.VOXEL_THRESH:
                # _volume = torch.ge(generated_volume, th).float()
                # intersection = torch.sum(_volume.mul(ground_truth_volume)).float()
                # union = torch.sum(torch.ge(_volume.add(ground_truth_volume), 1)).float()
                # print("#################")
                _volume = paddle.greater_equal(generated_volume, paddle.to_tensor(th)).astype("float32")
                # print(_volume)
                # print("@@@@@@@")
                # print(ground_truth_volume)
                intersection = paddle.sum(paddle.multiply(_volume, ground_truth_volume))
                # print(paddle.greater_equal(paddle.add(_volume, ground_truth_volume).astype("float32"), paddle.to_tensor(1., dtype='float32')).astype("float32"))
                union = paddle.sum(paddle.greater_equal(paddle.add(_volume, ground_truth_volume).astype("float32"), paddle.to_tensor(1., dtype='float32')).astype("float32"))
                # print(union)
                sample_iou.append((intersection / union))

            # IoU per taxonomy
            if taxonomy_id not in test_iou:
                test_iou[taxonomy_id] = {'n_samples': 0, 'iou': []}
            test_iou[taxonomy_id]['n_samples'] += 1
            test_iou[taxonomy_id]['iou'].append(sample_iou)

            # Append generated volumes to TensorBoard
            if output_dir and sample_idx < 1:
                img_dir = output_dir % 'images'
                # Volume Visualization
                gv = generated_volume.cpu().numpy()
                rendering_views = utils.binvox_visualization.get_volume_views(gv, os.path.join(img_dir, 'Reconstructed'),
                                                                              epoch_idx)
                test_writer.add_image(tag='Reconstructed', img=rendering_views, step=epoch_idx)
                gtv = ground_truth_volume.cpu().numpy()
                rendering_views = utils.binvox_visualization.get_volume_views(gtv, os.path.join(img_dir, 'GroundTruth'),
                                                                              epoch_idx)
                test_writer.add_image(tag='GroundTruth', img=rendering_views, step=epoch_idx)

            # Print sample loss and IoU
            print('[INFO] %s Test[%d/%d] Taxonomy = %s Sample = %s EDLoss = %.4f IoU = %s' %
                  (dt.now(), sample_idx + 1, n_samples, taxonomy_id, sample_name, encoder_loss,
                   ['%.4f' % si for si in sample_iou]))

    # Output testing results
    mean_iou = []
    for taxonomy_id in test_iou:
        test_iou[taxonomy_id]['iou'] = np.mean(test_iou[taxonomy_id]['iou'], axis=0)
        mean_iou.append(test_iou[taxonomy_id]['iou'] * test_iou[taxonomy_id]['n_samples'])
    mean_iou = np.sum(mean_iou, axis=0) / n_samples

    # Print header
    print('============================ TEST RESULTS ============================')
    print('Taxonomy', end='\t')
    print('#Sample', end='\t')
    print('Baseline', end='\t')
    for th in cfg.TEST.VOXEL_THRESH:
        print('t=%.2f' % th, end='\t')
    print()
    # Print body
    for taxonomy_id in test_iou:
        print('%s' % taxonomies[taxonomy_id]['taxonomy_name'].ljust(8), end='\t')
        print('%d' % test_iou[taxonomy_id]['n_samples'], end='\t')
        if 'baseline' in taxonomies[taxonomy_id]:
            print('%.4f' % taxonomies[taxonomy_id]['baseline']['%d-view' % cfg.CONST.N_VIEWS_RENDERING], end='\t\t')
        else:
            print('N/a', end='\t\t')

        for ti in test_iou[taxonomy_id]['iou']:
            print('%.4f' % ti, end='\t')
        print()
    # Print mean IoU for each threshold
    print('Overall ', end='\t\t\t\t')
    for mi in mean_iou:
        print('%.4f' % mi, end='\t')
    print('\n')

    # Add testing results to TensorBoard
    max_iou = np.max(mean_iou)
    if test_writer is not None:
        test_writer.add_scalar(tag='EncoderDecoder/EpochLoss', value=encoder_losses.avg, step=epoch_idx)
        test_writer.add_scalar(tag='EncoderDecoder/IoU', value=max_iou, step=epoch_idx)

    return max_iou
Ejemplo n.º 13
0
def demo_net(cfg, imgs_path):
    encoder = Encoder(cfg)
    decoder = Decoder(cfg)
    merger = Merger(cfg)

    print('[INFO] %s Loading weights from %s ...' %
          (dt.now(), cfg.CONST.WEIGHTS))
    encoder_state_dict = paddle.load(
        os.path.join(cfg.CONST.WEIGHTS, "encoder.pdparams"))
    encoder.set_state_dict(encoder_state_dict)
    decoder_state_dict = paddle.load(
        os.path.join(cfg.CONST.WEIGHTS, "decoder.pdparams"))
    decoder.set_state_dict(decoder_state_dict)

    if cfg.NETWORK.USE_MERGER:
        merger_state_dict = paddle.load(
            os.path.join(cfg.CONST.WEIGHTS, "merger.pdparams"))
        merger.set_state_dict(merger_state_dict)

    # Switch models to evaluation mode
    encoder.eval()
    decoder.eval()
    merger.eval()

    rendering_images = []
    if os.path.isfile(imgs_path):
        print("demo img")
        rendering_image = cv2.imread(imgs_path, cv2.IMREAD_UNCHANGED).astype(
            np.float32) / 255.
        rendering_image = np.asarray(rendering_image)[np.newaxis, :, :, :]
        # print(rendering_image.shape)
        IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
        CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
        test_transforms = utils.data_transforms.Compose([
            utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
            utils.data_transforms.RandomBackground(
                cfg.TEST.RANDOM_BG_COLOR_RANGE),
            utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                            std=cfg.DATASET.STD),
            utils.data_transforms.ToTensor(),
        ])

        rendering_image = test_transforms(rendering_image)
        # print(rendering_image)
        rendering_image = paddle.reshape(rendering_image, [1, 1, 3, 224, 224])
        with paddle.no_grad():
            # Get data from data loader
            rendering_image = utils.network_utils.var_or_cuda(rendering_image)

            # Test the encoder, decoder and merger
            image_features = encoder(rendering_image)
            raw_features, generated_volume = decoder(image_features)

            if cfg.NETWORK.USE_MERGER:
                generated_volume = merger(raw_features, generated_volume)
            else:
                generated_volume = paddle.mean(generated_volume, axis=1)

            for th in cfg.TEST.DEMO_VOXEL_THRESH:
                _volume = paddle.greater_equal(
                    generated_volume, paddle.to_tensor(th)).astype("float32")
                _volume = paddle.reshape(_volume, [32, 32, 32])
                # print(_volume.shape)
                # print(_volume)
                # Append generated volumes to TensorBoard
                if cfg.DIR.OUT_PATH:
                    # Volume Visualization
                    pred_file_name = os.path.join(
                        cfg.DIR.OUT_PATH,
                        imgs_path.split('/')[-1].split('.')[0] + '.obj')
                    print("save ", pred_file_name)
                    utils.voxel.voxel2obj(pred_file_name,
                                          _volume.cpu().numpy())

    elif os.path.isdir(imgs_path):
        print("demo dir")
        rendering_files_path = os.listdir(imgs_path)
        for rendering_file_path in rendering_files_path:
            if '.png' not in rendering_file_path:
                continue
            print(os.path.join(imgs_path, rendering_file_path))
            rendering_image = cv2.imread(
                os.path.join(imgs_path, rendering_file_path),
                cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
            rendering_image = np.asarray(rendering_image)[np.newaxis, :, :, :]
            # print(rendering_image.shape)
            IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
            CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
            test_transforms = utils.data_transforms.Compose([
                utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
                utils.data_transforms.RandomBackground(
                    cfg.TEST.RANDOM_BG_COLOR_RANGE),
                utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                                std=cfg.DATASET.STD),
                utils.data_transforms.ToTensor(),
            ])

            rendering_image = test_transforms(rendering_image)
            # print(rendering_image)
            rendering_image = paddle.reshape(rendering_image,
                                             [1, 1, 3, 224, 224])
            with paddle.no_grad():
                # Get data from data loader
                rendering_image = utils.network_utils.var_or_cuda(
                    rendering_image)

                # Test the encoder, decoder and merger
                image_features = encoder(rendering_image)
                raw_features, generated_volume = decoder(image_features)

                if cfg.NETWORK.USE_MERGER:
                    generated_volume = merger(raw_features, generated_volume)
                else:
                    generated_volume = paddle.mean(generated_volume, axis=1)

                # for th in cfg.TEST.VOXEL_THRESH:
                #     _volume = paddle.greater_equal(generated_volume, paddle.to_tensor(th)).astype("float32")
                #     print(_volume.shape)

                # Append generated volumes to TensorBoard
                if cfg.DIR.OUT_PATH:
                    # Volume Visualization
                    gv = generated_volume.detach().cpu().numpy()
                    pred_file_name = os.path.join(
                        cfg.DIR.OUT_PATH, imgs_path,
                        rendering_file_path.split('.')[0] + '.obj')
                    utils.voxel.voxel2obj(
                        pred_file_name, gv[0, 1] > cfg.TEST.DEMO_VOXEL_THRESH)
    else:
        raise Exception("error input path")
Ejemplo n.º 14
0
def test_net(cfg, epoch_idx=-1, output_dir=None, test_data_loader=None, \
        test_writer=None, encoder=None, decoder=None, refiner=None, merger=None):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    # Load taxonomies of dataset
    taxonomies = []
    with open(
            cfg.DATASETS[cfg.DATASET.TEST_DATASET.upper()].TAXONOMY_FILE_PATH,
            encoding='utf-8') as file:
        taxonomies = json.loads(file.read())
    taxonomies = {t['taxonomy_id']: t for t in taxonomies}

    # Set up data loader
    if test_data_loader is None:
        # Set up data augmentation
        IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
        CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W

        test_transforms = utils.data_transforms.Compose([
            utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
            utils.data_transforms.RandomBackground(
                cfg.TEST.RANDOM_BG_COLOR_RANGE),
            utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                            std=cfg.DATASET.STD),
            utils.data_transforms.ToTensor(),
        ])

        dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
            cfg.DATASET.TEST_DATASET](cfg)
        test_data_loader = torch.utils.data.DataLoader(
            dataset=dataset_loader.get_dataset(
                utils.data_loaders.DatasetType.TEST,
                cfg.CONST.N_VIEWS_RENDERING, test_transforms),
            batch_size=1,
            num_workers=1,
            pin_memory=True,
            shuffle=False)

    # Set up networks
    if decoder is None or encoder is None:
        encoder = Encoder(cfg)
        decoder = Decoder(cfg)
        refiner = Refiner(cfg)
        merger = Merger(cfg)

        if torch.cuda.is_available():
            encoder = torch.nn.DataParallel(encoder).cuda()
            decoder = torch.nn.DataParallel(decoder).cuda()
            refiner = torch.nn.DataParallel(refiner).cuda()
            merger = torch.nn.DataParallel(merger).cuda()

        print('[INFO] %s Loading weights from %s ...' %
              (dt.now(), cfg.CONST.WEIGHTS))

        if torch.cuda.is_available():
            checkpoint = torch.load(cfg.CONST.WEIGHTS)
        else:
            map_location = torch.device('cpu')
            checkpoint = torch.load(cfg.CONST.WEIGHTS,
                                    map_location=map_location)

        epoch_idx = checkpoint['epoch_idx']
        print('Epoch ID of the current model is {}'.format(epoch_idx))
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder.load_state_dict(checkpoint['decoder_state_dict'])

        if cfg.NETWORK.USE_REFINER:
            refiner.load_state_dict(checkpoint['refiner_state_dict'])
        if cfg.NETWORK.USE_MERGER:
            merger.load_state_dict(checkpoint['merger_state_dict'])

    # Set up loss functions
    bce_loss = torch.nn.BCELoss()

    # Testing loop
    n_samples = len(test_data_loader)
    test_iou = dict()
    encoder_losses = utils.network_utils.AverageMeter()
    refiner_losses = utils.network_utils.AverageMeter()

    # Switch models to evaluation mode
    encoder.eval()
    decoder.eval()
    refiner.eval()
    merger.eval()

    print("test data loader type is {}".format(type(test_data_loader)))
    for sample_idx, (taxonomy_id, sample_name,
                     rendering_images) in enumerate(test_data_loader):
        taxonomy_id = taxonomy_id[0] if isinstance(
            taxonomy_id[0], str) else taxonomy_id[0].item()
        sample_name = sample_name[0]
        print("sample IDx {}".format(sample_idx))
        print("taxonomy id {}".format(taxonomy_id))
        with torch.no_grad():
            # Get data from data loader
            rendering_images = utils.network_utils.var_or_cuda(
                rendering_images)

            print("Shape of the loaded images {}".format(
                rendering_images.shape))

            # Test the encoder, decoder, refiner and merger
            image_features = encoder(rendering_images)
            raw_features, generated_volume = decoder(image_features)

            if cfg.NETWORK.USE_MERGER:
                generated_volume = merger(raw_features, generated_volume)
            else:
                generated_volume = torch.mean(generated_volume, dim=1)

            if cfg.NETWORK.USE_REFINER:
                generated_volume = refiner(generated_volume)

            print("vox shape {}".format(generated_volume.shape))

            gv = generated_volume.cpu().numpy()

            rendering_views = utils.binvox_visualization.get_volume_views(
                gv,
                os.path.join('./LargeDatasets/inference_images/', 'inference'),
                sample_idx)
    print("gv shape is {}".format(gv.shape))
    return gv, rendering_images
Ejemplo n.º 15
0
def test_net(cfg,
             model_type,
             dataset_type,
             results_file_name,
             epoch_idx=-1,
             test_data_loader=None,
             test_writer=None,
             encoder=None,
             decoder=None,
             refiner=None,
             merger=None,
             save_results_to_file=False,
             show_voxels=False,
             path_to_times_csv=None):
    if model_type == Pix2VoxTypes.Pix2Vox_A or model_type == Pix2VoxTypes.Pix2Vox_Plus_Plus_A:
        use_refiner = True
    else:
        use_refiner = False

    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    # Set up data loader
    if test_data_loader is None:
        # Set up data augmentation
        IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
        CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
        test_transforms = utils.data_transforms.Compose([
            utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
            utils.data_transforms.RandomBackground(
                cfg.TEST.RANDOM_BG_COLOR_RANGE),
            utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                            std=cfg.DATASET.STD),
            utils.data_transforms.ToTensor(),
        ])

        dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
            cfg.DATASET.TEST_DATASET](cfg)
        test_data_loader = torch.utils.data.DataLoader(
            dataset=dataset_loader.get_dataset(dataset_type,
                                               cfg.CONST.N_VIEWS_RENDERING,
                                               test_transforms),
            batch_size=1,
            num_workers=cfg.CONST.NUM_WORKER,
            pin_memory=True,
            shuffle=False)

    # Set up networks
    if decoder is None or encoder is None:
        encoder = Encoder(cfg, model_type)
        decoder = Decoder(cfg, model_type)
        if use_refiner:
            refiner = Refiner(cfg)
        merger = Merger(cfg, model_type)

        if torch.cuda.is_available():
            encoder = torch.nn.DataParallel(encoder).cuda()
            decoder = torch.nn.DataParallel(decoder).cuda()
            if use_refiner:
                refiner = torch.nn.DataParallel(refiner).cuda()
            merger = torch.nn.DataParallel(merger).cuda()

        logging.info('Loading weights from %s ...' % (cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)
        epoch_idx = checkpoint['epoch_idx']
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder.load_state_dict(checkpoint['decoder_state_dict'])

        if use_refiner:
            refiner.load_state_dict(checkpoint['refiner_state_dict'])
        if cfg.NETWORK.USE_MERGER:
            merger.load_state_dict(checkpoint['merger_state_dict'])

    # Set up loss functions
    bce_loss = torch.nn.BCELoss()

    # Testing loop
    n_samples = len(test_data_loader)
    test_iou = dict()
    encoder_losses = AverageMeter()
    if use_refiner:
        refiner_losses = AverageMeter()

    # Switch models to evaluation mode
    encoder.eval()
    decoder.eval()
    if use_refiner:
        refiner.eval()
    merger.eval()

    samples_names = []
    edlosses = []
    rlosses = []
    ious_dict = {}
    for iou_threshold in cfg.TEST.VOXEL_THRESH:
        ious_dict[iou_threshold] = []

    if path_to_times_csv is not None:
        n_view_list = []
        times_list = []

    for sample_idx, (taxonomy_id, sample_name, rendering_images,
                     ground_truth_volume) in enumerate(test_data_loader):
        taxonomy_id = taxonomy_id[0] if isinstance(
            taxonomy_id[0], str) else taxonomy_id[0].item()
        sample_name = sample_name[0]
        with torch.no_grad():
            # Get data from data loader
            rendering_images = utils.helpers.var_or_cuda(rendering_images)
            ground_truth_volume = utils.helpers.var_or_cuda(
                ground_truth_volume)

            if path_to_times_csv is not None:
                start_time = time.time()

            # Test the encoder, decoder, refiner and merger
            image_features = encoder(rendering_images)
            raw_features, generated_volume = decoder(image_features)

            if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
                generated_volume = merger(raw_features, generated_volume)
            else:
                generated_volume = torch.mean(generated_volume, dim=1)
            encoder_loss = bce_loss(generated_volume, ground_truth_volume) * 10

            if use_refiner and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
                generated_volume = refiner(generated_volume)
                refiner_loss = bce_loss(generated_volume,
                                        ground_truth_volume) * 10
            else:
                refiner_loss = encoder_loss

            if path_to_times_csv is not None:
                end_time = time.time()
                n_view_list.append(rendering_images.size()[1])
                times_list.append(end_time - start_time)

            # Append loss and accuracy to average metrics
            encoder_losses.update(encoder_loss.item())
            if use_refiner:
                refiner_losses.update(refiner_loss.item())

            # IoU per sample
            sample_iou = []
            for th in cfg.TEST.VOXEL_THRESH:
                _volume = torch.ge(generated_volume, th).float()
                intersection = torch.sum(
                    _volume.mul(ground_truth_volume)).float()
                union = torch.sum(torch.ge(_volume.add(ground_truth_volume),
                                           1)).float()
                sample_iou.append((intersection / union).item())

                ious_dict[th].append((intersection / union).item())

            # IoU per taxonomy
            if taxonomy_id not in test_iou:
                test_iou[taxonomy_id] = {'n_samples': 0, 'iou': []}
            test_iou[taxonomy_id]['n_samples'] += 1
            test_iou[taxonomy_id]['iou'].append(sample_iou)

            # Append generated volumes to TensorBoard
            if show_voxels:
                with open("model.binvox", "wb") as f:
                    v = br.Voxels(
                        torch.ge(generated_volume,
                                 0.2).float().cpu().numpy()[0], (32, 32, 32),
                        (0, 0, 0), 1, "xyz")
                    v.write(f)

                subprocess.run([VIEWVOX_EXE, "model.binvox"])

                with open("model.binvox", "wb") as f:
                    v = br.Voxels(ground_truth_volume.cpu().numpy()[0],
                                  (32, 32, 32), (0, 0, 0), 1, "xyz")
                    v.write(f)

                subprocess.run([VIEWVOX_EXE, "model.binvox"])

            # Print sample loss and IoU
            logging.info(
                'Test[%d/%d] Taxonomy = %s Sample = %s EDLoss = %.4f RLoss = %.4f IoU = %s'
                % (sample_idx + 1, n_samples, taxonomy_id, sample_name,
                   encoder_loss.item(), refiner_loss.item(),
                   ['%.4f' % si for si in sample_iou]))

            samples_names.append(sample_name)
            edlosses.append(encoder_loss.item())
            if use_refiner:
                rlosses.append(refiner_loss.item())

    if save_results_to_file:
        save_test_results_to_csv(samples_names,
                                 edlosses,
                                 rlosses,
                                 ious_dict,
                                 path_to_csv=results_file_name)

    if path_to_times_csv is not None:
        save_times_to_csv(times_list,
                          n_view_list,
                          path_to_csv=path_to_times_csv)

    # Output testing results
    mean_iou = []
    for taxonomy_id in test_iou:
        test_iou[taxonomy_id]['iou'] = np.mean(test_iou[taxonomy_id]['iou'],
                                               axis=0)
        mean_iou.append(test_iou[taxonomy_id]['iou'] *
                        test_iou[taxonomy_id]['n_samples'])
    mean_iou = np.sum(mean_iou, axis=0) / n_samples

    # Print header
    print(
        '============================ TEST RESULTS ============================'
    )
    print('Taxonomy', end='\t')
    print('#Sample', end='\t')
    print('Baseline', end='\t')
    for th in cfg.TEST.VOXEL_THRESH:
        print('t=%.2f' % th, end='\t')
    print()
    # Print mean IoU for each threshold
    print('Overall ', end='\t\t\t\t')
    for mi in mean_iou:
        print('%.4f' % mi, end='\t')
    print('\n')

    # Add testing results to TensorBoard
    max_iou = np.max(mean_iou)
    if test_writer is not None:
        test_writer.add_scalar('EncoderDecoder/EpochLoss', encoder_losses.avg,
                               epoch_idx)
        if use_refiner:
            test_writer.add_scalar('Refiner/EpochLoss', refiner_losses.avg,
                                   epoch_idx)
            test_writer.add_scalar('Refiner/IoU', max_iou, epoch_idx)

    return max_iou
Ejemplo n.º 16
0
class Model(pl.LightningModule):

    def __init__(self, cfg_network: DictConfig, cfg_tester: DictConfig):
        super().__init__()
        self.cfg_network = cfg_network
        self.cfg_tester = cfg_tester

        # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
        torch.backends.cudnn.benchmark = True

        # Set up networks
        self.encoder = Encoder(cfg_network)
        self.decoder = Decoder(cfg_network)
        self.refiner = Refiner(cfg_network)
        self.merger = Merger(cfg_network)
        
        # Initialize weights of networks
        self.encoder.apply(utils.network_utils.init_weights)
        self.decoder.apply(utils.network_utils.init_weights)
        self.refiner.apply(utils.network_utils.init_weights)
        self.merger.apply(utils.network_utils.init_weights)
        
        self.bce_loss = nn.BCELoss()

    def configure_optimizers(self):
        params = self.cfg_network.optimization
        # Set up solver
        if params.policy == 'adam':
            encoder_solver = optim.Adam(filter(lambda p: p.requires_grad, self.encoder.parameters()),
                                            lr=params.encoder_lr,
                                            betas=params.betas)
            decoder_solver = optim.Adam(self.decoder.parameters(),
                                            lr=params.decoder_lr,
                                            betas=params.betas)
            refiner_solver = optim.Adam(self.refiner.parameters(),
                                            lr=params.refiner_lr,
                                            betas=params.betas)
            merger_solver = optim.Adam(self.merger.parameters(),
                                             lr=params.merger_lr,
                                             betas=params.betas)
        elif params.policy == 'sgd':
            encoder_solver = optim.SGD(filter(lambda p: p.requires_grad, self.encoder.parameters()),
                                            lr=params.encoder_lr,
                                            momentum=params.momentum)
            decoder_solver = optim.SGD(self.decoder.parameters(),
                                            lr=params.decoder_lr,
                                            momentum=params.momentum)
            refiner_solver = optim.SGD(self.refiner.parameters(),
                                            lr=params.refiner_lr,
                                            momentum=params.momentum)
            merger_solver = optim.SGD(self.merger.parameters(),
                                            lr=params.merger_lr,
                                            momentum=params.momentum)
        else:
            raise Exception('[FATAL] %s Unknown optimizer %s.' % (dt.now(), params.policy))
            
            # Set up learning rate scheduler to decay learning rates dynamically
        encoder_lr_scheduler = optim.lr_scheduler.MultiStepLR(encoder_solver,
                                                                    milestones=params.encoder_lr_milestones,
                                                                    gamma=params.gamma)
        decoder_lr_scheduler = optim.lr_scheduler.MultiStepLR(decoder_solver,
                                                                    milestones=params.decoder_lr_milestones,
                                                                    gamma=params.gamma)
        refiner_lr_scheduler = optim.lr_scheduler.MultiStepLR(refiner_solver,
                                                                    milestones=params.refiner_lr_milestones,
                                                                    gamma=params.gamma)
        merger_lr_scheduler = optim.lr_scheduler.MultiStepLR(merger_solver,
                                                                milestones=params.merger_lr_milestones,
                                                                gamma=params.gamma)
        
        return [encoder_solver, decoder_solver, refiner_solver, merger_solver], \
               [encoder_lr_scheduler, decoder_lr_scheduler, refiner_lr_scheduler, merger_lr_scheduler]
    
    def _fwd(self, batch):
        taxonomy_names, sample_names, rendering_images, ground_truth_volumes = batch

        image_features = self.encoder(rendering_images)
        raw_features, generated_volumes = self.decoder(image_features)

        if self.cfg_network.use_merger and self.current_epoch >= self.cfg_network.optimization.epoch_start_use_merger:
            generated_volumes = self.merger(raw_features, generated_volumes)
        else:
            generated_volumes = torch.mean(generated_volumes, dim=1)
        encoder_loss = self.bce_loss(generated_volumes, ground_truth_volumes) * 10
        
        if self.cfg_network.use_refiner and self.current_epoch >= self.cfg_network.optimization.epoch_start_use_refiner:
            generated_volumes = self.refiner(generated_volumes)
            refiner_loss = self.bce_loss(generated_volumes, ground_truth_volumes) * 10
        else:
            refiner_loss = encoder_loss
        
        return generated_volumes, encoder_loss, refiner_loss

    def training_step(self, batch, batch_idx, optimizer_idx):
        (opt_enc, opt_dec, opt_ref, opt_merg) = self.optimizers()
        
        generated_volumes, encoder_loss, refiner_loss = self._fwd(batch)
        
        self.log('loss/EncoderDecoder', encoder_loss, 
                 prog_bar=True, logger=True, on_step=True, on_epoch=True)
        self.log('loss/Refiner', refiner_loss, 
                 prog_bar=True, logger=True, on_step=True, on_epoch=True)

        if self.cfg_network.use_refiner and self.current_epoch >= self.cfg_network.optimization.epoch_start_use_refiner:
            self.manual_backward(encoder_loss, opt_enc, retain_graph=True)
            self.manual_backward(refiner_loss, opt_ref)
        else:
            self.manual_backward(encoder_loss, opt_enc)
            
        for opt in self.optimizers():
            opt.step()
            opt.zero_grad()

    def training_epoch_end(self, outputs) -> None:
        # Update Rendering Views
        if self.cfg_network.update_n_views_rendering:
            n_views_rendering = self.trainer.datamodule.update_n_views_rendering()
            print('[INFO] %s Epoch [%d/%d] Update #RenderingViews to %d' %
                  (dt.now(), self.current_epoch + 2, self.trainer.max_epochs, n_views_rendering))

    def _eval_step(self, batch, batch_idx):
        # SUPPORTS ONLY BATCH_SIZE=1
        taxonomy_names, sample_names, rendering_images, ground_truth_volumes = batch
        taxonomy_id = taxonomy_names[0]
        sample_name = sample_names[0]

        generated_volumes, encoder_loss, refiner_loss = self._fwd(batch)

        self.log('val_loss/EncoderDecoder', encoder_loss, prog_bar=True,
                 logger=True, on_step=True, on_epoch=True)

        self.log('val_loss/Refiner', refiner_loss, prog_bar=True,
                 logger=True, on_step=True, on_epoch=True)

        # IoU per sample
        sample_iou = []
        for th in self.cfg_tester.voxel_thresh:
            _volume = torch.ge(generated_volumes, th).float()
            intersection = torch.sum(_volume.mul(ground_truth_volumes)).float()
            union = torch.sum(
                torch.ge(_volume.add(ground_truth_volumes), 1)).float()
            sample_iou.append((intersection / union).item())

        # Print sample loss and IoU
        n_samples = -1
        print('\n[INFO] %s Test[%d/%d] Taxonomy = %s Sample = %s EDLoss = %.4f RLoss = %.4f IoU = %s' %
              (dt.now(), batch_idx + 1, n_samples, taxonomy_id, sample_name, encoder_loss.item(),
               refiner_loss.item(), ['%.4f' % si for si in sample_iou]))

        return {
            'taxonomy_id': taxonomy_id,
            'sample_name': sample_name,
            'sample_iou': sample_iou
        }
        
    def _eval_epoch_end(self, outputs):
        # Load taxonomies of dataset
        taxonomies = []
        taxonomy_path = self.trainer.datamodule.get_test_taxonomy_file_path()
        with open(taxonomy_path, encoding='utf-8') as file:
            taxonomies = json.loads(file.read())
        taxonomies = {t['taxonomy_id']: t for t in taxonomies}

        test_iou = {}
        for output in outputs:
            taxonomy_id, sample_name, sample_iou = output[
                'taxonomy_id'], output['sample_name'], output['sample_iou']
            if taxonomy_id not in test_iou:
                test_iou[taxonomy_id] = {'n_samples': 0, 'iou': []}
            test_iou[taxonomy_id]['n_samples'] += 1
            test_iou[taxonomy_id]['iou'].append(sample_iou)

        mean_iou = []
        for taxonomy_id in test_iou:
            test_iou[taxonomy_id]['iou'] = torch.mean(
                torch.tensor(test_iou[taxonomy_id]['iou']), dim=0)
            mean_iou.append(test_iou[taxonomy_id]['iou']
                            * test_iou[taxonomy_id]['n_samples'])
        n_samples = len(outputs)
        mean_iou = torch.stack(mean_iou)
        mean_iou = torch.sum(mean_iou, dim=0) / n_samples

        # Print header
        print('============================ TEST RESULTS ============================')
        print('Taxonomy', end='\t')
        print('#Sample', end='\t')
        print(' Baseline', end='\t')
        for th in self.cfg_tester.voxel_thresh:
            print('t=%.2f' % th, end='\t')
        print()
        # Print body
        for taxonomy_id in test_iou:
            print('%s' % taxonomies[taxonomy_id]
                  ['taxonomy_name'].ljust(8), end='\t')
            print('%d' % test_iou[taxonomy_id]['n_samples'], end='\t')
            if 'baseline' in taxonomies[taxonomy_id]:
                n_views_rendering = self.trainer.datamodule.get_n_views_rendering()
                print('%.4f' % taxonomies[taxonomy_id]['baseline']
                      ['%d-view' % n_views_rendering], end='\t\t')
            else:
                print('N/a', end='\t\t')

            for ti in test_iou[taxonomy_id]['iou']:
                print('%.4f' % ti, end='\t')
            print()
        # Print mean IoU for each threshold
        print('Overall ', end='\t\t\t\t')
        for mi in mean_iou:
            print('%.4f' % mi, end='\t')
        print('\n')

        max_iou = torch.max(mean_iou)
        self.log('Refiner/IoU', max_iou, prog_bar=True, on_epoch=True)
    
    def validation_step(self, batch, batch_idx):
        return self._eval_step(batch, batch_idx)
        
    def validation_epoch_end(self, outputs):
        self._eval_epoch_end(outputs)
        
    def test_step(self, batch, batch_idx):
        return self._eval_step(batch, batch_idx)
        
    def test_epoch_end(self, outputs):
        self._eval_epoch_end(outputs)

    def get_progress_bar_dict(self):
        # don't show the loss as it's None
        items = super().get_progress_bar_dict()
        items.pop("loss", None)
        return items