コード例 #1
0
def detect(path, encoder=None, decoder=None):
    torch.backends.cudnn.benchmark = True

    dataset = LoadImages(path, img_size=config.IMAGE_SIZE, used_layers=config.USED_LAYERS)

    if not encoder or not decoder:
        in_channels = num_channels(config.USED_LAYERS)
        encoder = Encoder(in_channels=in_channels)
        decoder = Decoder(num_classes=config.NUM_CLASSES+1)
        encoder = encoder.to(config.DEVICE)
        decoder = decoder.to(config.DEVICE)

        _, encoder, decoder = load_checkpoint(encoder, decoder, config.CHECKPOINT_FILE, config.DEVICE)

    encoder.eval()
    decoder.eval()

    for _, layers, path in dataset:
        with torch.no_grad():
            layers = torch.from_numpy(layers).to(config.DEVICE, non_blocking=True)
            if layers.ndimension() == 3:
                layers = layers.unsqueeze(0)

            features = encoder(layers)
            predictions = decoder(features)
            _, out = predictions, predictions.sigmoid()

            plot_volumes(to_volume(out, config.VOXEL_THRESH).cpu(), [path], config.NAMES)
コード例 #2
0
ファイル: preprocess.py プロジェクト: Alqatf/BUTD_model
def extract_imgs_feat():
    encoder = Encoder(opt.resnet101_file)
    encoder.to(opt.device)
    encoder.eval()

    imgs = os.listdir(opt.imgs_dir)
    imgs.sort()

    if not os.path.exists(opt.out_feats_dir):
        os.makedirs(opt.out_feats_dir)
    with h5py.File(os.path.join(opt.out_feats_dir, '%s_fc.h5' % opt.dataset_name)) as file_fc, \
            h5py.File(os.path.join(opt.out_feats_dir, '%s_att.h5' % opt.dataset_name)) as file_att:
        try:
            for img_nm in tqdm.tqdm(imgs, ncols=100):
                img = skimage.io.imread(os.path.join(opt.imgs_dir, img_nm))
                with torch.no_grad():
                    img = encoder.preprocess(img)
                    img = img.to(opt.device)
                    img_fc, img_att = encoder(img)
                file_fc.create_dataset(img_nm,
                                       data=img_fc.cpu().float().numpy())
                file_att.create_dataset(img_nm,
                                        data=img_att.cpu().float().numpy())
        except BaseException as e:
            file_fc.close()
            file_att.close()
            print(
                '--------------------------------------------------------------------'
            )
            raise e
コード例 #3
0
def test(encoder=None, decoder=None):
    torch.backends.cudnn.benchmark = True

    _, dataloader = create_dataloader(config.IMG_DIR + "/test", config.MESH_DIR + "/test",
                                            batch_size=config.BATCH_SIZE, used_layers=config.USED_LAYERS,
                                            img_size=config.IMAGE_SIZE, map_size=config.MAP_SIZE,
                                            augment=config.AUGMENT, workers=config.NUM_WORKERS,
                                            pin_memory=config.PIN_MEMORY, shuffle=False)
    if not encoder or not decoder:
        in_channels = num_channels(config.USED_LAYERS)
        encoder = Encoder(in_channels=in_channels)
        decoder = Decoder(num_classes=config.NUM_CLASSES+1)
        encoder = encoder.to(config.DEVICE)
        decoder = decoder.to(config.DEVICE)

        _, encoder, decoder = load_checkpoint(encoder, decoder, config.CHECKPOINT_FILE, config.DEVICE)

    loss_fn = LossFunction()

    loop = tqdm(dataloader, leave=True)
    losses = []
    ious = []

    encoder.eval()
    decoder.eval()

    for i, (_, layers, volumes, img_files) in enumerate(loop):
        with torch.no_grad():
            layers = layers.to(config.DEVICE, non_blocking=True)
            volumes = volumes.to(config.DEVICE, non_blocking=True)

            features = encoder(layers)
            predictions = decoder(features)

            loss = loss_fn(predictions, volumes)
            losses.append(loss.item())

            iou = predictions_iou(to_volume(predictions, config.VOXEL_THRESH), volumes)
            ious.append(iou)

            mean_iou = sum(ious) / len(ious)
            mean_loss = sum(losses) / len(losses)
            loop.set_postfix(loss=mean_loss, mean_iou=mean_iou)

            if i == 0 and config.PLOT:
                plot_volumes(to_volume(predictions, config.VOXEL_THRESH).cpu(), img_files, config.NAMES)
                plot_volumes(volumes.cpu(), img_files, config.NAMES)
コード例 #4
0
# coding:utf8
import torch
import skimage.io

from opts import parse_opt
from models.decoder import Decoder
from models.encoder import Encoder

opt = parse_opt()
assert opt.test_model, 'please input test_model'
assert opt.image_file, 'please input image_file'

encoder = Encoder(opt.resnet101_file)
encoder.to(opt.device)
encoder.eval()

img = skimage.io.imread(opt.image_file)
with torch.no_grad():
    img = encoder.preprocess(img)
    img = img.to(opt.device)
    fc_feat, att_feat = encoder(img)

print("====> loading checkpoint '{}'".format(opt.test_model))
chkpoint = torch.load(opt.test_model, map_location=lambda s, l: s)
decoder = Decoder(chkpoint['idx2word'], chkpoint['settings'])
decoder.load_state_dict(chkpoint['model'])
print("====> loaded checkpoint '{}', epoch: {}, train_mode: {}".format(
    opt.test_model, chkpoint['epoch'], chkpoint['train_mode']))
decoder.to(opt.device)
decoder.eval()
コード例 #5
0
def test_single_img(cfg):
    encoder = Encoder(cfg)
    decoder = Decoder(cfg)
    refiner = Refiner(cfg)
    merger = Merger(cfg)

    cfg.CONST.WEIGHTS = 'D:/Pix2Vox/Pix2Vox/pretrained/Pix2Vox-A-ShapeNet.pth'
    checkpoint = torch.load(cfg.CONST.WEIGHTS, map_location=torch.device('cpu'))

    fix_checkpoint = {}
    fix_checkpoint['encoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['encoder_state_dict'].items())
    fix_checkpoint['decoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['decoder_state_dict'].items())
    fix_checkpoint['refiner_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['refiner_state_dict'].items())
    fix_checkpoint['merger_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['merger_state_dict'].items())

    epoch_idx = checkpoint['epoch_idx']
    encoder.load_state_dict(fix_checkpoint['encoder_state_dict'])
    decoder.load_state_dict(fix_checkpoint['decoder_state_dict'])

    if cfg.NETWORK.USE_REFINER:
        print('Use refiner')
        refiner.load_state_dict(fix_checkpoint['refiner_state_dict'])
    if cfg.NETWORK.USE_MERGER:
        print('Use merger')
        merger.load_state_dict(fix_checkpoint['merger_state_dict'])


    encoder.eval()
    decoder.eval()
    refiner.eval()
    merger.eval()

    img1_path = 'D:/Pix2Vox/Pix2Vox/rand/minecraft.png'
    img1_np = cv2.imread(img1_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.

    sample = np.array([img1_np])

    IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
    CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W

    test_transforms = utils.data_transforms.Compose([
        utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD),
        utils.data_transforms.ToTensor(),
    ])

    rendering_images = test_transforms(rendering_images=sample)
    rendering_images = rendering_images.unsqueeze(0)

    with torch.no_grad():
        image_features = encoder(rendering_images)
        raw_features, generated_volume = decoder(image_features)

        if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
            generated_volume = merger(raw_features, generated_volume)
        else:
            generated_volume = torch.mean(generated_volume, dim=1)

        if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
            generated_volume = refiner(generated_volume)

    generated_volume = generated_volume.squeeze(0)

    img_dir = 'D:/Pix2Vox/Pix2Vox/output'
    gv = generated_volume.cpu().numpy()
    gv_new = np.swapaxes(gv, 2, 1)
    print(gv_new)
    rendering_views = utils.binvox_visualization.get_volume_views(gv_new, os.path.join(img_dir),
                                                                                        epoch_idx)
def main(config):
    print('Starting')

    checkpoints = config.checkpoint.parent.glob(config.checkpoint.name +
                                                '_*.pth')
    checkpoints = [c for c in checkpoints if extract_id(c) in config.decoders]
    assert len(checkpoints) >= 1, "No checkpoints found."

    model_config = torch.load(config.checkpoint.parent / 'args.pth')[0]
    encoder = Encoder(model_config.encoder)
    encoder.load_state_dict(torch.load(checkpoints[0])['encoder_state'])
    encoder.eval()
    encoder = encoder.cuda()

    generators = []
    generator_ids = []
    for checkpoint in checkpoints:
        decoder = Decoder(model_config.decoder)
        decoder.load_state_dict(torch.load(checkpoint)['decoder_state'])
        decoder.eval()
        decoder = decoder.cuda()

        generator = SampleGenerator(decoder,
                                    config.batch_size,
                                    wav_freq=config.rate)

        generators.append(generator)
        generator_ids.append(extract_id(checkpoint))

    xs = []
    assert config.out_dir is not None

    if len(config.sample_dir) == 1 and config.sample_dir[0].is_dir():
        top = config.sample_dir[0]
        file_paths = list(top.glob('**/*.wav')) + list(top.glob('**/*.h5'))
    else:
        file_paths = config.sample_dir

    print("File paths to be used:", file_paths)
    for file_path in file_paths:
        if file_path.suffix == '.wav':
            data, rate = librosa.load(file_path, sr=config.rate)
            data = helper_functions.mu_law(data)
        elif file_path.suffix == '.h5':
            data = helper_functions.mu_law(
                h5py.File(file_path, 'r')['wav'][:] / (2**15))
            if data.shape[-1] % config.rate != 0:
                data = data[:-(data.shape[-1] % config.rate)]
            assert data.shape[-1] % config.rate == 0
            print(data.shape)
        else:
            raise Exception(f'Unsupported filetype {file_path}')

        if config.sample_len:
            data = data[:config.sample_len]
        else:
            config.sample_len = len(data)
        xs.append(torch.tensor(data).unsqueeze(0).float().cuda())

    xs = torch.stack(xs).contiguous()
    print(f'xs size: {xs.size()}')

    def save(x, decoder_idx, filepath):
        wav = helper_functions.inv_mu_law(x.cpu().numpy())
        print(f'X size: {x.shape}')
        print(f'X min: {x.min()}, max: {x.max()}')

        save_audio(wav.squeeze(),
                   config.out_dir / str(decoder_idx) /
                   filepath.with_suffix('.wav').name,
                   rate=config.rate)

    yy = {}
    with torch.no_grad():
        zz = []
        for xs_batch in torch.split(xs, config.batch_size):
            zz += [encoder(xs_batch)]
        zz = torch.cat(zz, dim=0)

        for i, generator_id in enumerate(generator_ids):
            yy[generator_id] = []
            generator = generators[i]
            for zz_batch in torch.split(zz, config.batch_size):
                print("Batch shape:", zz_batch.shape)
                splits = torch.split(zz_batch, config.split_size, -1)
                audio_data = []
                generator.reset()
                for cond in tqdm.tqdm(splits):
                    audio_data += [generator.generate(cond).cpu()]
                audio_data = torch.cat(audio_data, -1)
                yy[generator_id] += [audio_data]
            yy[generator_id] = torch.cat(yy[generator_id], dim=0)

            for sample_result, filepath in zip(yy[generator_id], file_paths):
                save(sample_result, generator_id, filepath)

            del generator
コード例 #7
0
def test_net(cfg, epoch_idx=-1, output_dir=None, test_data_loader=None, \
        test_writer=None, encoder=None, decoder=None, refiner=None, merger=None):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    # Load taxonomies of dataset
    taxonomies = []
    with open(cfg.DATASETS[cfg.DATASET.TEST_DATASET.upper()].TAXONOMY_FILE_PATH, encoding='utf-8') as file:
        taxonomies = json.loads(file.read())
    taxonomies = {t['taxonomy_id']: t for t in taxonomies}

    # Set up data loader
    if test_data_loader is None:
        # Set up data augmentation
        IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
        CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
        test_transforms = utils.data_transforms.Compose([
            utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
            utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
            utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD),
            utils.data_transforms.ToTensor(),
        ])

        dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg)
        test_data_loader = torch.utils.data.DataLoader(
            dataset=dataset_loader.get_dataset(utils.data_loaders.DatasetType.TEST,
                                               cfg.CONST.N_VIEWS_RENDERING, test_transforms),
            batch_size=1,
            num_workers=1,
            pin_memory=True,
            shuffle=False)

    # Set up networks
    if decoder is None or encoder is None:
        encoder = Encoder(cfg)
        decoder = Decoder(cfg)
        refiner = Refiner(cfg)
        merger = Merger(cfg)

        if torch.cuda.is_available():
            encoder = torch.nn.DataParallel(encoder).cuda()
            decoder = torch.nn.DataParallel(decoder).cuda()
            refiner = torch.nn.DataParallel(refiner).cuda()
            merger = torch.nn.DataParallel(merger).cuda()

        print('[INFO] %s Loading weights from %s ...' % (dt.now(), cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)
        epoch_idx = checkpoint['epoch_idx']
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder.load_state_dict(checkpoint['decoder_state_dict'])

        if cfg.NETWORK.USE_REFINER:
            refiner.load_state_dict(checkpoint['refiner_state_dict'])
        if cfg.NETWORK.USE_MERGER:
            merger.load_state_dict(checkpoint['merger_state_dict'])

    # Set up loss functions
    bce_loss = torch.nn.BCELoss()

    # Testing loop
    n_samples = len(test_data_loader)
    test_iou = dict()
    encoder_losses = utils.network_utils.AverageMeter()
    refiner_losses = utils.network_utils.AverageMeter()

    # Switch models to evaluation mode
    encoder.eval()
    decoder.eval()
    refiner.eval()
    merger.eval()

    for sample_idx, (taxonomy_id, sample_name, rendering_images, ground_truth_volume) in enumerate(test_data_loader):
        taxonomy_id = taxonomy_id[0] if isinstance(taxonomy_id[0], str) else taxonomy_id[0].item()
        sample_name = sample_name[0]

        with torch.no_grad():
            # Get data from data loader
            rendering_images = utils.network_utils.var_or_cuda(rendering_images)
            ground_truth_volume = utils.network_utils.var_or_cuda(ground_truth_volume)

            # Test the encoder, decoder, refiner and merger
            image_features = encoder(rendering_images)
            raw_features, generated_volume = decoder(image_features)

            if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
                generated_volume = merger(raw_features, generated_volume)
            else:
                generated_volume = torch.mean(generated_volume, dim=1)
            encoder_loss = bce_loss(generated_volume, ground_truth_volume) * 10

            if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
                generated_volume = refiner(generated_volume)
                refiner_loss = bce_loss(generated_volume, ground_truth_volume) * 10
            else:
                refiner_loss = encoder_loss

            print("vox shape {}".format(generated_volume.shape))

            # Append loss and accuracy to average metrics
            encoder_losses.update(encoder_loss.item())
            refiner_losses.update(refiner_loss.item())

            # IoU per sample
            sample_iou = []
            for th in cfg.TEST.VOXEL_THRESH:
                _volume = torch.ge(generated_volume, th).float()
                intersection = torch.sum(_volume.mul(ground_truth_volume)).float()
                union = torch.sum(torch.ge(_volume.add(ground_truth_volume), 1)).float()
                sample_iou.append((intersection / union).item())

            # IoU per taxonomy
            if not taxonomy_id in test_iou:
                test_iou[taxonomy_id] = {'n_samples': 0, 'iou': []}
            test_iou[taxonomy_id]['n_samples'] += 1
            test_iou[taxonomy_id]['iou'].append(sample_iou)

            # Append generated volumes to TensorBoard
            if output_dir and sample_idx < 3:
                img_dir = output_dir % 'images'
                # Volume Visualization
                gv = generated_volume.cpu().numpy()
                rendering_views = utils.binvox_visualization.get_volume_views(gv, os.path.join(img_dir, 'test'),
                                                                              epoch_idx)
                if not test_writer is None:
                    test_writer.add_image('Test Sample#%02d/Volume Reconstructed' % sample_idx, rendering_views, epoch_idx)
                gtv = ground_truth_volume.cpu().numpy()
                rendering_views = utils.binvox_visualization.get_volume_views(gtv, os.path.join(img_dir, 'test'),
                                                                              epoch_idx)
                if not test_writer is None:
                    test_writer.add_image('Test Sample#%02d/Volume GroundTruth' % sample_idx, rendering_views, epoch_idx)

            # Print sample loss and IoU
            print('[INFO] %s Test[%d/%d] Taxonomy = %s Sample = %s EDLoss = %.4f RLoss = %.4f IoU = %s' % \
                (dt.now(), sample_idx + 1, n_samples, taxonomy_id, sample_name, encoder_loss.item(), \
                    refiner_loss.item(), ['%.4f' % si for si in sample_iou]))

    # Output testing results
    mean_iou = []
    for taxonomy_id in test_iou:
        test_iou[taxonomy_id]['iou'] = np.mean(test_iou[taxonomy_id]['iou'], axis=0)
        mean_iou.append(test_iou[taxonomy_id]['iou'] * test_iou[taxonomy_id]['n_samples'])
    mean_iou = np.sum(mean_iou, axis=0) / n_samples

    # Print header
    print('============================ TEST RESULTS ============================')
    print('Taxonomy', end='\t')
    print('#Sample', end='\t')
    print('Baseline', end='\t')
    for th in cfg.TEST.VOXEL_THRESH:
        print('t=%.2f' % th, end='\t')
    print()
    # Print body
    for taxonomy_id in test_iou:
        print('%s' % taxonomies[taxonomy_id]['taxonomy_name'].ljust(8), end='\t')
        print('%d' % test_iou[taxonomy_id]['n_samples'], end='\t')
        if 'baseline' in taxonomies[taxonomy_id]:
            print('%.4f' % taxonomies[taxonomy_id]['baseline']['%d-view' % cfg.CONST.N_VIEWS_RENDERING], end='\t\t')
        else:
            print('N/a', end='\t\t')

        for ti in test_iou[taxonomy_id]['iou']:
            print('%.4f' % ti, end='\t')
        print()
    # Print mean IoU for each threshold
    print('Overall ', end='\t\t\t\t')
    for mi in mean_iou:
        print('%.4f' % mi, end='\t')
    print('\n')

    # Add testing results to TensorBoard
    max_iou = np.max(mean_iou)
    if not test_writer is None:
        test_writer.add_scalar('EncoderDecoder/EpochLoss', encoder_losses.avg, epoch_idx)
        test_writer.add_scalar('Refiner/EpochLoss', refiner_losses.avg, epoch_idx)
        test_writer.add_scalar('Refiner/IoU', max_iou, epoch_idx)

    return max_iou
コード例 #8
0
def test_single_img_net(cfg):

    encoder = Encoder(cfg)
    decoder = Decoder(cfg)
    refiner = Refiner(cfg)
    merger = Merger(cfg)

    print('[INFO] %s Loading weights from %s ...' %
          (dt.now(), cfg.CONST.WEIGHTS))
    checkpoint = torch.load(cfg.CONST.WEIGHTS,
                            map_location=torch.device('cpu'))

    fix_checkpoint = {}
    fix_checkpoint['encoder_state_dict'] = OrderedDict(
        (k.split('module.')[1:][0], v)
        for k, v in checkpoint['encoder_state_dict'].items())
    fix_checkpoint['decoder_state_dict'] = OrderedDict(
        (k.split('module.')[1:][0], v)
        for k, v in checkpoint['decoder_state_dict'].items())
    fix_checkpoint['refiner_state_dict'] = OrderedDict(
        (k.split('module.')[1:][0], v)
        for k, v in checkpoint['refiner_state_dict'].items())
    fix_checkpoint['merger_state_dict'] = OrderedDict(
        (k.split('module.')[1:][0], v)
        for k, v in checkpoint['merger_state_dict'].items())

    epoch_idx = checkpoint['epoch_idx']
    encoder.load_state_dict(fix_checkpoint['encoder_state_dict'])
    decoder.load_state_dict(fix_checkpoint['decoder_state_dict'])

    if cfg.NETWORK.USE_REFINER:
        print('Use refiner')
        refiner.load_state_dict(fix_checkpoint['refiner_state_dict'])
    if cfg.NETWORK.USE_MERGER:
        print('Use merger')
        merger.load_state_dict(fix_checkpoint['merger_state_dict'])

    encoder.eval()
    decoder.eval()
    refiner.eval()
    merger.eval()

    img1_path = '/media/caig/FECA2C89CA2C406F/dataset/ShapeNetRendering_copy/03001627/1a74a83fa6d24b3cacd67ce2c72c02e/rendering/00.png'
    img1_np = cv2.imread(img1_path, cv2.IMREAD_UNCHANGED).astype(
        np.float32) / 255.

    sample = np.array([img1_np])

    IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
    CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W

    test_transforms = utils.data_transforms.Compose([
        utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                        std=cfg.DATASET.STD),
        utils.data_transforms.ToTensor(),
    ])

    rendering_images = test_transforms(rendering_images=sample)
    rendering_images = rendering_images.unsqueeze(0)

    with torch.no_grad():
        image_features = encoder(rendering_images)
        raw_features, generated_volume = decoder(image_features)

        if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
            generated_volume = merger(raw_features, generated_volume)
        else:
            generated_volume = torch.mean(generated_volume, dim=1)

        if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
            generated_volume = refiner(generated_volume)

    generated_volume = generated_volume.squeeze(0)

    img_dir = '/media/caig/FECA2C89CA2C406F/sketch3D/sketch3D/test_output'
    gv = generated_volume.cpu().numpy()
    gv_new = np.swapaxes(gv, 2, 1)
    rendering_views = utils.binvox_visualization.get_volume_views(
        gv_new, os.path.join(img_dir), epoch_idx)
コード例 #9
0
def test_img(cfg):

    encoder = Encoder(cfg)
    decoder = Decoder(cfg)
    refiner = Refiner(cfg)
    merger = Merger(cfg)

    cfg.CONST.WEIGHTS = '/Users/pranavpomalapally/Downloads/new-Pix2Vox-A-ShapeNet.pth'
    checkpoint = torch.load(cfg.CONST.WEIGHTS,
                            map_location=torch.device('cpu'))

    print()
    # fix_checkpoint = {}
    # fix_checkpoint['encoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['encoder_state_dict'].items())
    # fix_checkpoint['decoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['decoder_state_dict'].items())
    # fix_checkpoint['refiner_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['refiner_state_dict'].items())
    # fix_checkpoint['merger_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['merger_state_dict'].items())

    # fix_checkpoint['encoder_state_dict'] = OrderedDict((k.split('module.')[0], v) for k, v in checkpoint['encoder_state_dict'].items())
    # fix_checkpoint['decoder_state_dict'] = OrderedDict((k.split('module.')[0], v) for k, v in checkpoint['decoder_state_dict'].items())
    # fix_checkpoint['refiner_state_dict'] = OrderedDict((k.split('module.')[0], v) for k, v in checkpoint['refiner_state_dict'].items())
    # fix_checkpoint['merger_state_dict'] = OrderedDict((k.split('module.')[0], v) for k, v in checkpoint['merger_state_dict'].items())

    epoch_idx = checkpoint['epoch_idx']
    # encoder.load_state_dict(fix_checkpoint['encoder_state_dict'])
    # decoder.load_state_dict(fix_checkpoint['decoder_state_dict'])
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_state_dict'])

    # if cfg.NETWORK.USE_REFINER:
    #  print('Use refiner')
    #  refiner.load_state_dict(fix_checkpoint['refiner_state_dict'])

    print('Use refiner')
    refiner.load_state_dict(checkpoint['refiner_state_dict'])
    if cfg.NETWORK.USE_MERGER:
        print('Use merger')
        # merger.load_state_dict(fix_checkpoint['merger_state_dict'])
        merger.load_state_dict(checkpoint['merger_state_dict'])

    encoder.eval()
    decoder.eval()
    refiner.eval()
    merger.eval()

    #img1_path = '/Users/pranavpomalapally/Downloads/ShapeNetRendering/02691156/1a04e3eab45ca15dd86060f189eb133/rendering/00.png'
    img1_path = '/Users/pranavpomalapally/Downloads/09 copy.png'
    img1_np = cv2.imread(img1_path, cv2.IMREAD_UNCHANGED).astype(
        np.float32) / 255.

    sample = np.array([img1_np])

    IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
    CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W

    test_transforms = utils.data_transforms.Compose([
        utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                        std=cfg.DATASET.STD),
        utils.data_transforms.ToTensor(),
    ])

    rendering_images = test_transforms(rendering_images=sample)
    rendering_images = rendering_images.unsqueeze(0)

    with torch.no_grad():
        image_features = encoder(rendering_images)
        raw_features, generated_volume = decoder(image_features)

        if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
            generated_volume = merger(raw_features, generated_volume)
        else:
            generated_volume = torch.mean(generated_volume, dim=1)

        # if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
        #     generated_volume = refiner(generated_volume)
    generated_volume = refiner(generated_volume)
    generated_volume = generated_volume.squeeze(0)

    img_dir = '/Users/pranavpomalapally/Downloads/outputs'
    # gv = generated_volume.cpu().numpy()
    gv = generated_volume.cpu().detach().numpy()
    gv_new = np.swapaxes(gv, 2, 1)

    os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
    rendering_views = utils.binvox_visualization.get_volume_views(
        gv_new, img_dir, epoch_idx)
コード例 #10
0
ファイル: test.py プロジェクト: chicleee/Pix2Vox
def test_net(cfg,
             epoch_idx=-1,
             output_dir=None,
             test_data_loader=None,
             test_writer=None,
             encoder=None,
             decoder=None,
             merger=None):
   
    # Load taxonomies of dataset
    taxonomies = []
    with open(cfg.DATASETS[cfg.DATASET.TEST_DATASET.upper()].TAXONOMY_FILE_PATH, encoding='utf-8') as file:
        taxonomies = json.loads(file.read())
    taxonomies = {t['taxonomy_id']: t for t in taxonomies}

    # # Set up data loader
    if test_data_loader is None:
        # Set up data augmentation
        IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
        CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
        test_transforms = utils.data_transforms.Compose([
            utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
            utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
            utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD),
            utils.data_transforms.ToTensor(),
        ])

        dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg)
        test_data_loader = paddle.io.DataLoader(dataset=dataset_loader.get_dataset(
            utils.data_loaders.DatasetType.TEST, cfg.CONST.N_VIEWS_RENDERING, test_transforms),
                                                       batch_size=1,
                                                    #    num_workers=1,
                                                       shuffle=False)
        mode = 'test'
    else:
        mode = 'val'

    
    # paddle.io.Dataset not support 'str' input
    dataset_taxonomy = None
    rendering_image_path_template = cfg.DATASETS.SHAPENET.RENDERING_PATH
    volume_path_template = cfg.DATASETS.SHAPENET.VOXEL_PATH

    # Load all taxonomies of the dataset
    with open('./datasets/ShapeNet.json', encoding='utf-8') as file:
        dataset_taxonomy = json.loads(file.read())
        # print("[INFO]TEST-- open TAXONOMY_FILE_PATH succeess")

    all_test_taxonomy_id_and_sample_name = []
    # Load data for each category
    for taxonomy in dataset_taxonomy:
        taxonomy_folder_name = taxonomy['taxonomy_id']
        # print('[INFO] %set -- Collecting files of Taxonomy[ID=%s, Name=%s]' %
        #         (mode, taxonomy['taxonomy_id'], taxonomy['taxonomy_name']))
        samples = taxonomy[mode]
        for sample in samples:
            all_test_taxonomy_id_and_sample_name.append([taxonomy_folder_name, sample])
    # print(len(all_test_taxonomy_id_and_sample_name))
    # print(all_test_taxonomy_id_and_sample_name)
    print('[INFO] Collected files of %set' % (mode))   
    # Set up networks
    if decoder is None or encoder is None:
        encoder = Encoder(cfg)
        decoder = Decoder(cfg)
        merger = Merger(cfg)

        # if torch.cuda.is_available():
        #     encoder = paddle.DataParallel(encoder)
        #     decoder = paddle.DataParallel(decoder)
        #     merger = paddle.DataParallel(merger)

        print('[INFO] %s Loading weights from %s ...' % (dt.now(), cfg.CONST.WEIGHTS))
        encoder_state_dict = paddle.load(os.path.join(cfg.CONST.WEIGHTS, "encoder.pdparams"))
        # encoder_solver_state_dict = paddle.load(os.path.join(cfg.CONST.WEIGHTS, "encoder_solver.pdopt"))
        encoder.set_state_dict(encoder_state_dict)
        # encoder_solver.set_state_dict(encoder_solver_state_dict)
        decoder_state_dict = paddle.load(os.path.join(cfg.CONST.WEIGHTS, "decoder.pdparams"))
        # decoder_solver_state_dict = paddle.load(os.path.join(cfg.CONST.WEIGHTS, "decoder_solver.pdopt"))
        decoder.set_state_dict(decoder_state_dict)
        # decoder_solver.set_state_dict(decoder_solver_state_dict)

        if cfg.NETWORK.USE_MERGER:
            merger_state_dict = paddle.load(os.path.join(cfg.CONST.WEIGHTS, "merger.pdparams"))
            # merger_solver_state_dict = paddle.load(os.path.join(cfg.CONST.WEIGHTS, "merger_solver.pdopt"))
            merger.set_state_dict(merger_state_dict)
            # merger_solver.set_state_dict(merger_solver_state_dict)

    # Set up loss functions
    bce_loss = paddle.nn.BCELoss()

    # Testing loop
    n_samples = len(test_data_loader)
    test_iou = dict()
    encoder_losses = utils.network_utils.AverageMeter()

    # Switch models to evaluation mode
    encoder.eval()
    decoder.eval()
    merger.eval()

    for sample_idx, (rendering_images, ground_truth_volume) in enumerate(test_data_loader):
        taxonomy_id = all_test_taxonomy_id_and_sample_name[sample_idx][0]
        sample_name = all_test_taxonomy_id_and_sample_name[sample_idx][1]
        # print("all_test_taxonomy_id_and_sample_name")
        # print(taxonomy_id)
        # print(sample_name)

        with paddle.no_grad():
            # Get data from data loader
            # rendering_images = utils.network_utils.var_or_cuda(rendering_images)
            # ground_truth_volume = utils.network_utils.var_or_cuda(ground_truth_volume)

            # Test the encoder, decoder and merger
            image_features = encoder(rendering_images)
            raw_features, generated_volume = decoder(image_features)

            if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
                generated_volume = merger(raw_features, generated_volume)
            else:
                generated_volume = paddle.mean(generated_volume, axis=1)

            encoder_loss = bce_loss(generated_volume, ground_truth_volume) * 10

            # Append loss and accuracy to average metrics
            encoder_losses.update(encoder_loss)

            # IoU per sample
            sample_iou = []
            for th in cfg.TEST.VOXEL_THRESH:
                # _volume = torch.ge(generated_volume, th).float()
                # intersection = torch.sum(_volume.mul(ground_truth_volume)).float()
                # union = torch.sum(torch.ge(_volume.add(ground_truth_volume), 1)).float()
                # print("#################")
                _volume = paddle.greater_equal(generated_volume, paddle.to_tensor(th)).astype("float32")
                # print(_volume)
                # print("@@@@@@@")
                # print(ground_truth_volume)
                intersection = paddle.sum(paddle.multiply(_volume, ground_truth_volume))
                # print(paddle.greater_equal(paddle.add(_volume, ground_truth_volume).astype("float32"), paddle.to_tensor(1., dtype='float32')).astype("float32"))
                union = paddle.sum(paddle.greater_equal(paddle.add(_volume, ground_truth_volume).astype("float32"), paddle.to_tensor(1., dtype='float32')).astype("float32"))
                # print(union)
                sample_iou.append((intersection / union))

            # IoU per taxonomy
            if taxonomy_id not in test_iou:
                test_iou[taxonomy_id] = {'n_samples': 0, 'iou': []}
            test_iou[taxonomy_id]['n_samples'] += 1
            test_iou[taxonomy_id]['iou'].append(sample_iou)

            # Append generated volumes to TensorBoard
            if output_dir and sample_idx < 1:
                img_dir = output_dir % 'images'
                # Volume Visualization
                gv = generated_volume.cpu().numpy()
                rendering_views = utils.binvox_visualization.get_volume_views(gv, os.path.join(img_dir, 'Reconstructed'),
                                                                              epoch_idx)
                test_writer.add_image(tag='Reconstructed', img=rendering_views, step=epoch_idx)
                gtv = ground_truth_volume.cpu().numpy()
                rendering_views = utils.binvox_visualization.get_volume_views(gtv, os.path.join(img_dir, 'GroundTruth'),
                                                                              epoch_idx)
                test_writer.add_image(tag='GroundTruth', img=rendering_views, step=epoch_idx)

            # Print sample loss and IoU
            print('[INFO] %s Test[%d/%d] Taxonomy = %s Sample = %s EDLoss = %.4f IoU = %s' %
                  (dt.now(), sample_idx + 1, n_samples, taxonomy_id, sample_name, encoder_loss,
                   ['%.4f' % si for si in sample_iou]))

    # Output testing results
    mean_iou = []
    for taxonomy_id in test_iou:
        test_iou[taxonomy_id]['iou'] = np.mean(test_iou[taxonomy_id]['iou'], axis=0)
        mean_iou.append(test_iou[taxonomy_id]['iou'] * test_iou[taxonomy_id]['n_samples'])
    mean_iou = np.sum(mean_iou, axis=0) / n_samples

    # Print header
    print('============================ TEST RESULTS ============================')
    print('Taxonomy', end='\t')
    print('#Sample', end='\t')
    print('Baseline', end='\t')
    for th in cfg.TEST.VOXEL_THRESH:
        print('t=%.2f' % th, end='\t')
    print()
    # Print body
    for taxonomy_id in test_iou:
        print('%s' % taxonomies[taxonomy_id]['taxonomy_name'].ljust(8), end='\t')
        print('%d' % test_iou[taxonomy_id]['n_samples'], end='\t')
        if 'baseline' in taxonomies[taxonomy_id]:
            print('%.4f' % taxonomies[taxonomy_id]['baseline']['%d-view' % cfg.CONST.N_VIEWS_RENDERING], end='\t\t')
        else:
            print('N/a', end='\t\t')

        for ti in test_iou[taxonomy_id]['iou']:
            print('%.4f' % ti, end='\t')
        print()
    # Print mean IoU for each threshold
    print('Overall ', end='\t\t\t\t')
    for mi in mean_iou:
        print('%.4f' % mi, end='\t')
    print('\n')

    # Add testing results to TensorBoard
    max_iou = np.max(mean_iou)
    if test_writer is not None:
        test_writer.add_scalar(tag='EncoderDecoder/EpochLoss', value=encoder_losses.avg, step=epoch_idx)
        test_writer.add_scalar(tag='EncoderDecoder/IoU', value=max_iou, step=epoch_idx)

    return max_iou
コード例 #11
0
class Trainer:
    def __init__(self, config):
        self.config = config
        self.config.data.n_datasets = len(config.data.datasets)

        print("No of datasets used:", self.config.data.n_datasets)

        torch.manual_seed(config.env.seed)
        torch.cuda.manual_seed(config.env.seed)
        self.expPath = self.config.env.expPath

        self.logger = Logger("Training", "logs/training.log")
        self.data = [
            DatasetSet(data_path, config.data.seq_len, config.data)
            for data_path in config.data.datasets
        ]

        self.losses_recon = [
            LossMeter(f'recon {i}') for i in range(self.config.data.n_datasets)
        ]
        self.loss_d_right = LossMeter('d')
        self.loss_total = [
            LossMeter(f'total {i}') for i in range(self.config.data.n_datasets)
        ]

        self.evals_recon = [
            LossMeter(f'recon {i}') for i in range(self.config.data.n_datasets)
        ]
        self.eval_d_right = LossMeter('eval d')
        self.eval_total = [
            LossMeter(f'eval total {i}')
            for i in range(self.config.data.n_datasets)
        ]

        self.encoder = Encoder(config.encoder)
        self.decoders = torch.nn.ModuleList([
            Decoder(config.decoder) for _ in range(self.config.data.n_datasets)
        ])
        self.classifier = DomainClassifier(
            config.domain_classifier, num_classes=self.config.data.n_datasets)

        states = None
        if config.env.checkpoint:

            checkpoint_args_path = os.path.dirname(
                config.env.checkpoint) + '/args.pth'
            checkpoint_args = torch.load(checkpoint_args_path)

            self.start_epoch = checkpoint_args[-1] + 1
            states = [
                torch.load(self.config.env.checkpoint + f'_{i}.pth')
                for i in range(self.config.data.n_datasets)
            ]

            self.encoder.load_state_dict(states[0]['encoder_state'])
            for i in range(self.config.data.n_datasets):
                self.decoders[i].load_state_dict(states[i]['decoder_state'])
            self.classifier.load_state_dict(states[0]['discriminator_state'])
            self.logger.info('Loaded checkpoint parameters')

            raise NotImplementedError
        else:
            self.start_epoch = 0

        self.encoder = torch.nn.DataParallel(self.encoder).cuda()
        self.classifier = torch.nn.DataParallel(self.classifier).cuda()
        for i, decoder in enumerate(self.decoders):
            self.decoders[i] = torch.nn.DataParallel(decoder).cuda()

        self.model_optimizers = [
            optim.Adam(chain(self.encoder.parameters(), decoder.parameters()),
                       lr=config.data.lr) for decoder in self.decoders
        ]

        self.classifier_optimizer = optim.Adam(self.classifier.parameters(),
                                               lr=config.data.lr)

        if config.env.checkpoint and config.env.load_optimizer:
            for i in range(self.config.data.n_datasets):
                self.model_optimizers[i].load_state_dict(
                    states[i]['model_optimizer_state'])

            self.classifier_optimizer.load_state_dict(
                states[0]['d_optimizer_state'])

        self.lr_managers = []
        for i in range(self.config.data.n_datasets):
            self.lr_managers.append(
                torch.optim.lr_scheduler.ExponentialLR(
                    self.model_optimizers[i], config.data.lr_decay))
            self.lr_managers[i].last_epoch = self.start_epoch
            self.lr_managers[i].step()

    def eval_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        z = self.encoder(x)
        y = self.decoders[dset_num](x, z)
        z_logits = self.classifier(z)

        z_classification = torch.max(z_logits, dim=1)[1]

        z_accuracy = (z_classification == dset_num).float().mean()

        self.eval_d_right.add(z_accuracy.data.item())

        # discriminator_right = F.cross_entropy(z_logits, dset_num).mean()
        discriminator_right = F.cross_entropy(
            z_logits,
            torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()
        recon_loss = cross_entropy_loss(y, x)

        self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        total_loss = discriminator_right.data.item() * self.config.domain_classifier.d_lambda + \
                     recon_loss.mean().data.item()

        self.eval_total[dset_num].add(total_loss)

        return total_loss

    def train_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        # Optimize D - classifier right
        z = self.encoder(x)
        z_logits = self.classifier(z)
        discriminator_right = F.cross_entropy(
            z_logits,
            torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()
        loss = discriminator_right * self.config.domain_classifier.d_lambda
        self.loss_d_right.add(loss.data.item())

        self.classifier_optimizer.zero_grad()
        loss.backward()
        if self.config.domain_classifier.grad_clip is not None:
            clip_grad_value_(self.classifier.parameters(),
                             self.config.domain_classifier.grad_clip)

        self.classifier_optimizer.step()

        # optimize G - reconstructs well, classifier wrong
        z = self.encoder(x_aug)
        y = self.decoders[dset_num](x, z)
        z_logits = self.classifier(z)

        discriminator_wrong = -F.cross_entropy(
            z_logits,
            torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()

        if not (-100 < discriminator_right.data.item() < 100):
            self.logger.debug(f'z_logits: {z_logits.detach().cpu().numpy()}')
            self.logger.debug(f'dset_num: {dset_num}')

        recon_loss = cross_entropy_loss(y, x)
        self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        loss = (recon_loss.mean() +
                self.config.domain_classifier.d_lambda * discriminator_wrong)

        self.model_optimizers[dset_num].zero_grad()
        loss.backward()
        if self.config.domain_classifier.grad_clip is not None:
            clip_grad_value_(self.encoder.parameters(),
                             self.config.domain_classifier.grad_clip)
            clip_grad_value_(self.decoders[dset_num].parameters(),
                             self.config.domain_classifier.grad_clip)

        self.model_optimizers[dset_num].step()

        self.loss_total[dset_num].add(loss.data.item())

        return loss.data.item()

    def train_epoch(self, epoch):
        for meter in self.losses_recon:
            meter.reset()
        self.loss_d_right.reset()
        for i in range(len(self.loss_total)):
            self.loss_total[i].reset()

        self.encoder.train()
        self.classifier.train()
        for decoder in self.decoders:
            decoder.train()

        n_batches = self.config.data.epoch_len

        with tqdm(total=n_batches,
                  desc='Train epoch %d' % epoch) as train_enum:
            for batch_num in range(n_batches):
                if self.config.data.short and batch_num == 3:
                    break

                dset_num = batch_num % self.config.data.n_datasets

                x, x_aug = next(self.data[dset_num].train_iter)

                x = wrap_cuda(x)
                x_aug = wrap_cuda(x_aug)
                batch_loss = self.train_batch(x, x_aug, dset_num)

                train_enum.set_description(
                    f'Train (loss: {batch_loss:.2f}) epoch {epoch}')
                train_enum.update()

    def evaluate_epoch(self, epoch):
        for meter in self.evals_recon:
            meter.reset()
        self.eval_d_right.reset()
        for i in range(len(self.eval_total)):
            self.eval_total[i].reset()

        self.encoder.eval()
        self.classifier.eval()
        for decoder in self.decoders:
            decoder.eval()

        n_batches = int(np.ceil(self.config.data.epoch_len / 10))

        with tqdm(total=n_batches) as valid_enum, \
                torch.no_grad():
            for batch_num in range(n_batches):
                if self.config.data.short and batch_num == 10:
                    break

                dset_num = batch_num % self.config.data.n_datasets

                x, x_aug = next(self.data[dset_num].valid_iter)

                x = wrap_cuda(x)
                x_aug = wrap_cuda(x_aug)
                batch_loss = self.eval_batch(x, x_aug, dset_num)

                valid_enum.set_description(
                    f'Test (loss: {batch_loss:.2f}) epoch {epoch}')
                valid_enum.update()

    @staticmethod
    def format_losses(meters):
        losses = [meter.summarize_epoch() for meter in meters]
        return ', '.join('{:.4f}'.format(x) for x in losses)

    def train_losses(self):
        meters = [*self.losses_recon, self.loss_d_right]
        return self.format_losses(meters)

    def eval_losses(self):
        meters = [*self.evals_recon, self.eval_d_right]
        return self.format_losses(meters)

    def train(self):
        best_eval = [float('inf') for _ in range(self.config.data.n_datasets)]

        # Begin!
        for epoch in range(self.start_epoch,
                           self.start_epoch + self.config.env.epochs):
            self.train_epoch(epoch)
            self.evaluate_epoch(epoch)

            self.logger.info(f'Epoch %s - Train loss: (%s), Test loss (%s)',
                             epoch, self.train_losses(), self.eval_losses())
            for i in range(len(self.lr_managers)):
                self.lr_managers[i].step()

            for dataset_id in range(self.config.data.n_datasets):
                val_loss = self.eval_total[dataset_id].summarize_epoch()

                if val_loss < best_eval[dataset_id]:
                    self.save_model(f'bestmodel_{dataset_id}.pth', dataset_id)
                    best_eval[dataset_id] = val_loss

                if not self.config.env.save_per_epoch:
                    self.save_model(f'lastmodel_{dataset_id}.pth', dataset_id)
                else:
                    self.save_model(f'lastmodel_{epoch}_rank_{dataset_id}.pth',
                                    dataset_id)

                torch.save([self.config, epoch], '%s/args.pth' % self.expPath)

                self.logger.debug('Ended epoch')

    def save_model(self, filename, decoder_id):
        save_path = self.expPath / filename

        torch.save(
            {
                'encoder_state':
                self.encoder.module.state_dict(),
                'decoder_state':
                self.decoders[decoder_id].module.state_dict(),
                'discriminator_state':
                self.classifier.module.state_dict(),
                'model_optimizer_state':
                self.model_optimizers[decoder_id].state_dict(),
                'dataset':
                decoder_id,
                'd_optimizer_state':
                self.classifier_optimizer.state_dict()
            }, save_path)

        self.logger.debug(f'Saved model to {save_path}')
コード例 #12
0
ファイル: demo.py プロジェクト: chicleee/Pix2Vox
def demo_net(cfg, imgs_path):
    encoder = Encoder(cfg)
    decoder = Decoder(cfg)
    merger = Merger(cfg)

    print('[INFO] %s Loading weights from %s ...' %
          (dt.now(), cfg.CONST.WEIGHTS))
    encoder_state_dict = paddle.load(
        os.path.join(cfg.CONST.WEIGHTS, "encoder.pdparams"))
    encoder.set_state_dict(encoder_state_dict)
    decoder_state_dict = paddle.load(
        os.path.join(cfg.CONST.WEIGHTS, "decoder.pdparams"))
    decoder.set_state_dict(decoder_state_dict)

    if cfg.NETWORK.USE_MERGER:
        merger_state_dict = paddle.load(
            os.path.join(cfg.CONST.WEIGHTS, "merger.pdparams"))
        merger.set_state_dict(merger_state_dict)

    # Switch models to evaluation mode
    encoder.eval()
    decoder.eval()
    merger.eval()

    rendering_images = []
    if os.path.isfile(imgs_path):
        print("demo img")
        rendering_image = cv2.imread(imgs_path, cv2.IMREAD_UNCHANGED).astype(
            np.float32) / 255.
        rendering_image = np.asarray(rendering_image)[np.newaxis, :, :, :]
        # print(rendering_image.shape)
        IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
        CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
        test_transforms = utils.data_transforms.Compose([
            utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
            utils.data_transforms.RandomBackground(
                cfg.TEST.RANDOM_BG_COLOR_RANGE),
            utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                            std=cfg.DATASET.STD),
            utils.data_transforms.ToTensor(),
        ])

        rendering_image = test_transforms(rendering_image)
        # print(rendering_image)
        rendering_image = paddle.reshape(rendering_image, [1, 1, 3, 224, 224])
        with paddle.no_grad():
            # Get data from data loader
            rendering_image = utils.network_utils.var_or_cuda(rendering_image)

            # Test the encoder, decoder and merger
            image_features = encoder(rendering_image)
            raw_features, generated_volume = decoder(image_features)

            if cfg.NETWORK.USE_MERGER:
                generated_volume = merger(raw_features, generated_volume)
            else:
                generated_volume = paddle.mean(generated_volume, axis=1)

            for th in cfg.TEST.DEMO_VOXEL_THRESH:
                _volume = paddle.greater_equal(
                    generated_volume, paddle.to_tensor(th)).astype("float32")
                _volume = paddle.reshape(_volume, [32, 32, 32])
                # print(_volume.shape)
                # print(_volume)
                # Append generated volumes to TensorBoard
                if cfg.DIR.OUT_PATH:
                    # Volume Visualization
                    pred_file_name = os.path.join(
                        cfg.DIR.OUT_PATH,
                        imgs_path.split('/')[-1].split('.')[0] + '.obj')
                    print("save ", pred_file_name)
                    utils.voxel.voxel2obj(pred_file_name,
                                          _volume.cpu().numpy())

    elif os.path.isdir(imgs_path):
        print("demo dir")
        rendering_files_path = os.listdir(imgs_path)
        for rendering_file_path in rendering_files_path:
            if '.png' not in rendering_file_path:
                continue
            print(os.path.join(imgs_path, rendering_file_path))
            rendering_image = cv2.imread(
                os.path.join(imgs_path, rendering_file_path),
                cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
            rendering_image = np.asarray(rendering_image)[np.newaxis, :, :, :]
            # print(rendering_image.shape)
            IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
            CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
            test_transforms = utils.data_transforms.Compose([
                utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
                utils.data_transforms.RandomBackground(
                    cfg.TEST.RANDOM_BG_COLOR_RANGE),
                utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                                std=cfg.DATASET.STD),
                utils.data_transforms.ToTensor(),
            ])

            rendering_image = test_transforms(rendering_image)
            # print(rendering_image)
            rendering_image = paddle.reshape(rendering_image,
                                             [1, 1, 3, 224, 224])
            with paddle.no_grad():
                # Get data from data loader
                rendering_image = utils.network_utils.var_or_cuda(
                    rendering_image)

                # Test the encoder, decoder and merger
                image_features = encoder(rendering_image)
                raw_features, generated_volume = decoder(image_features)

                if cfg.NETWORK.USE_MERGER:
                    generated_volume = merger(raw_features, generated_volume)
                else:
                    generated_volume = paddle.mean(generated_volume, axis=1)

                # for th in cfg.TEST.VOXEL_THRESH:
                #     _volume = paddle.greater_equal(generated_volume, paddle.to_tensor(th)).astype("float32")
                #     print(_volume.shape)

                # Append generated volumes to TensorBoard
                if cfg.DIR.OUT_PATH:
                    # Volume Visualization
                    gv = generated_volume.detach().cpu().numpy()
                    pred_file_name = os.path.join(
                        cfg.DIR.OUT_PATH, imgs_path,
                        rendering_file_path.split('.')[0] + '.obj')
                    utils.voxel.voxel2obj(
                        pred_file_name, gv[0, 1] > cfg.TEST.DEMO_VOXEL_THRESH)
    else:
        raise Exception("error input path")
コード例 #13
0
def test_net(cfg, epoch_idx=-1, output_dir=None, test_data_loader=None, \
        test_writer=None, encoder=None, decoder=None, refiner=None, merger=None):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    # Load taxonomies of dataset
    taxonomies = []
    with open(
            cfg.DATASETS[cfg.DATASET.TEST_DATASET.upper()].TAXONOMY_FILE_PATH,
            encoding='utf-8') as file:
        taxonomies = json.loads(file.read())
    taxonomies = {t['taxonomy_id']: t for t in taxonomies}

    # Set up data loader
    if test_data_loader is None:
        # Set up data augmentation
        IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
        CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W

        test_transforms = utils.data_transforms.Compose([
            utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
            utils.data_transforms.RandomBackground(
                cfg.TEST.RANDOM_BG_COLOR_RANGE),
            utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                            std=cfg.DATASET.STD),
            utils.data_transforms.ToTensor(),
        ])

        dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
            cfg.DATASET.TEST_DATASET](cfg)
        test_data_loader = torch.utils.data.DataLoader(
            dataset=dataset_loader.get_dataset(
                utils.data_loaders.DatasetType.TEST,
                cfg.CONST.N_VIEWS_RENDERING, test_transforms),
            batch_size=1,
            num_workers=1,
            pin_memory=True,
            shuffle=False)

    # Set up networks
    if decoder is None or encoder is None:
        encoder = Encoder(cfg)
        decoder = Decoder(cfg)
        refiner = Refiner(cfg)
        merger = Merger(cfg)

        if torch.cuda.is_available():
            encoder = torch.nn.DataParallel(encoder).cuda()
            decoder = torch.nn.DataParallel(decoder).cuda()
            refiner = torch.nn.DataParallel(refiner).cuda()
            merger = torch.nn.DataParallel(merger).cuda()

        print('[INFO] %s Loading weights from %s ...' %
              (dt.now(), cfg.CONST.WEIGHTS))

        if torch.cuda.is_available():
            checkpoint = torch.load(cfg.CONST.WEIGHTS)
        else:
            map_location = torch.device('cpu')
            checkpoint = torch.load(cfg.CONST.WEIGHTS,
                                    map_location=map_location)

        epoch_idx = checkpoint['epoch_idx']
        print('Epoch ID of the current model is {}'.format(epoch_idx))
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder.load_state_dict(checkpoint['decoder_state_dict'])

        if cfg.NETWORK.USE_REFINER:
            refiner.load_state_dict(checkpoint['refiner_state_dict'])
        if cfg.NETWORK.USE_MERGER:
            merger.load_state_dict(checkpoint['merger_state_dict'])

    # Set up loss functions
    bce_loss = torch.nn.BCELoss()

    # Testing loop
    n_samples = len(test_data_loader)
    test_iou = dict()
    encoder_losses = utils.network_utils.AverageMeter()
    refiner_losses = utils.network_utils.AverageMeter()

    # Switch models to evaluation mode
    encoder.eval()
    decoder.eval()
    refiner.eval()
    merger.eval()

    print("test data loader type is {}".format(type(test_data_loader)))
    for sample_idx, (taxonomy_id, sample_name,
                     rendering_images) in enumerate(test_data_loader):
        taxonomy_id = taxonomy_id[0] if isinstance(
            taxonomy_id[0], str) else taxonomy_id[0].item()
        sample_name = sample_name[0]
        print("sample IDx {}".format(sample_idx))
        print("taxonomy id {}".format(taxonomy_id))
        with torch.no_grad():
            # Get data from data loader
            rendering_images = utils.network_utils.var_or_cuda(
                rendering_images)

            print("Shape of the loaded images {}".format(
                rendering_images.shape))

            # Test the encoder, decoder, refiner and merger
            image_features = encoder(rendering_images)
            raw_features, generated_volume = decoder(image_features)

            if cfg.NETWORK.USE_MERGER:
                generated_volume = merger(raw_features, generated_volume)
            else:
                generated_volume = torch.mean(generated_volume, dim=1)

            if cfg.NETWORK.USE_REFINER:
                generated_volume = refiner(generated_volume)

            print("vox shape {}".format(generated_volume.shape))

            gv = generated_volume.cpu().numpy()

            rendering_views = utils.binvox_visualization.get_volume_views(
                gv,
                os.path.join('./LargeDatasets/inference_images/', 'inference'),
                sample_idx)
    print("gv shape is {}".format(gv.shape))
    return gv, rendering_images
コード例 #14
0
def test_net(cfg,
             model_type,
             dataset_type,
             results_file_name,
             epoch_idx=-1,
             test_data_loader=None,
             test_writer=None,
             encoder=None,
             decoder=None,
             refiner=None,
             merger=None,
             save_results_to_file=False,
             show_voxels=False,
             path_to_times_csv=None):
    if model_type == Pix2VoxTypes.Pix2Vox_A or model_type == Pix2VoxTypes.Pix2Vox_Plus_Plus_A:
        use_refiner = True
    else:
        use_refiner = False

    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    # Set up data loader
    if test_data_loader is None:
        # Set up data augmentation
        IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
        CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
        test_transforms = utils.data_transforms.Compose([
            utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
            utils.data_transforms.RandomBackground(
                cfg.TEST.RANDOM_BG_COLOR_RANGE),
            utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                            std=cfg.DATASET.STD),
            utils.data_transforms.ToTensor(),
        ])

        dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
            cfg.DATASET.TEST_DATASET](cfg)
        test_data_loader = torch.utils.data.DataLoader(
            dataset=dataset_loader.get_dataset(dataset_type,
                                               cfg.CONST.N_VIEWS_RENDERING,
                                               test_transforms),
            batch_size=1,
            num_workers=cfg.CONST.NUM_WORKER,
            pin_memory=True,
            shuffle=False)

    # Set up networks
    if decoder is None or encoder is None:
        encoder = Encoder(cfg, model_type)
        decoder = Decoder(cfg, model_type)
        if use_refiner:
            refiner = Refiner(cfg)
        merger = Merger(cfg, model_type)

        if torch.cuda.is_available():
            encoder = torch.nn.DataParallel(encoder).cuda()
            decoder = torch.nn.DataParallel(decoder).cuda()
            if use_refiner:
                refiner = torch.nn.DataParallel(refiner).cuda()
            merger = torch.nn.DataParallel(merger).cuda()

        logging.info('Loading weights from %s ...' % (cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)
        epoch_idx = checkpoint['epoch_idx']
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder.load_state_dict(checkpoint['decoder_state_dict'])

        if use_refiner:
            refiner.load_state_dict(checkpoint['refiner_state_dict'])
        if cfg.NETWORK.USE_MERGER:
            merger.load_state_dict(checkpoint['merger_state_dict'])

    # Set up loss functions
    bce_loss = torch.nn.BCELoss()

    # Testing loop
    n_samples = len(test_data_loader)
    test_iou = dict()
    encoder_losses = AverageMeter()
    if use_refiner:
        refiner_losses = AverageMeter()

    # Switch models to evaluation mode
    encoder.eval()
    decoder.eval()
    if use_refiner:
        refiner.eval()
    merger.eval()

    samples_names = []
    edlosses = []
    rlosses = []
    ious_dict = {}
    for iou_threshold in cfg.TEST.VOXEL_THRESH:
        ious_dict[iou_threshold] = []

    if path_to_times_csv is not None:
        n_view_list = []
        times_list = []

    for sample_idx, (taxonomy_id, sample_name, rendering_images,
                     ground_truth_volume) in enumerate(test_data_loader):
        taxonomy_id = taxonomy_id[0] if isinstance(
            taxonomy_id[0], str) else taxonomy_id[0].item()
        sample_name = sample_name[0]
        with torch.no_grad():
            # Get data from data loader
            rendering_images = utils.helpers.var_or_cuda(rendering_images)
            ground_truth_volume = utils.helpers.var_or_cuda(
                ground_truth_volume)

            if path_to_times_csv is not None:
                start_time = time.time()

            # Test the encoder, decoder, refiner and merger
            image_features = encoder(rendering_images)
            raw_features, generated_volume = decoder(image_features)

            if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
                generated_volume = merger(raw_features, generated_volume)
            else:
                generated_volume = torch.mean(generated_volume, dim=1)
            encoder_loss = bce_loss(generated_volume, ground_truth_volume) * 10

            if use_refiner and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
                generated_volume = refiner(generated_volume)
                refiner_loss = bce_loss(generated_volume,
                                        ground_truth_volume) * 10
            else:
                refiner_loss = encoder_loss

            if path_to_times_csv is not None:
                end_time = time.time()
                n_view_list.append(rendering_images.size()[1])
                times_list.append(end_time - start_time)

            # Append loss and accuracy to average metrics
            encoder_losses.update(encoder_loss.item())
            if use_refiner:
                refiner_losses.update(refiner_loss.item())

            # IoU per sample
            sample_iou = []
            for th in cfg.TEST.VOXEL_THRESH:
                _volume = torch.ge(generated_volume, th).float()
                intersection = torch.sum(
                    _volume.mul(ground_truth_volume)).float()
                union = torch.sum(torch.ge(_volume.add(ground_truth_volume),
                                           1)).float()
                sample_iou.append((intersection / union).item())

                ious_dict[th].append((intersection / union).item())

            # IoU per taxonomy
            if taxonomy_id not in test_iou:
                test_iou[taxonomy_id] = {'n_samples': 0, 'iou': []}
            test_iou[taxonomy_id]['n_samples'] += 1
            test_iou[taxonomy_id]['iou'].append(sample_iou)

            # Append generated volumes to TensorBoard
            if show_voxels:
                with open("model.binvox", "wb") as f:
                    v = br.Voxels(
                        torch.ge(generated_volume,
                                 0.2).float().cpu().numpy()[0], (32, 32, 32),
                        (0, 0, 0), 1, "xyz")
                    v.write(f)

                subprocess.run([VIEWVOX_EXE, "model.binvox"])

                with open("model.binvox", "wb") as f:
                    v = br.Voxels(ground_truth_volume.cpu().numpy()[0],
                                  (32, 32, 32), (0, 0, 0), 1, "xyz")
                    v.write(f)

                subprocess.run([VIEWVOX_EXE, "model.binvox"])

            # Print sample loss and IoU
            logging.info(
                'Test[%d/%d] Taxonomy = %s Sample = %s EDLoss = %.4f RLoss = %.4f IoU = %s'
                % (sample_idx + 1, n_samples, taxonomy_id, sample_name,
                   encoder_loss.item(), refiner_loss.item(),
                   ['%.4f' % si for si in sample_iou]))

            samples_names.append(sample_name)
            edlosses.append(encoder_loss.item())
            if use_refiner:
                rlosses.append(refiner_loss.item())

    if save_results_to_file:
        save_test_results_to_csv(samples_names,
                                 edlosses,
                                 rlosses,
                                 ious_dict,
                                 path_to_csv=results_file_name)

    if path_to_times_csv is not None:
        save_times_to_csv(times_list,
                          n_view_list,
                          path_to_csv=path_to_times_csv)

    # Output testing results
    mean_iou = []
    for taxonomy_id in test_iou:
        test_iou[taxonomy_id]['iou'] = np.mean(test_iou[taxonomy_id]['iou'],
                                               axis=0)
        mean_iou.append(test_iou[taxonomy_id]['iou'] *
                        test_iou[taxonomy_id]['n_samples'])
    mean_iou = np.sum(mean_iou, axis=0) / n_samples

    # Print header
    print(
        '============================ TEST RESULTS ============================'
    )
    print('Taxonomy', end='\t')
    print('#Sample', end='\t')
    print('Baseline', end='\t')
    for th in cfg.TEST.VOXEL_THRESH:
        print('t=%.2f' % th, end='\t')
    print()
    # Print mean IoU for each threshold
    print('Overall ', end='\t\t\t\t')
    for mi in mean_iou:
        print('%.4f' % mi, end='\t')
    print('\n')

    # Add testing results to TensorBoard
    max_iou = np.max(mean_iou)
    if test_writer is not None:
        test_writer.add_scalar('EncoderDecoder/EpochLoss', encoder_losses.avg,
                               epoch_idx)
        if use_refiner:
            test_writer.add_scalar('Refiner/EpochLoss', refiner_losses.avg,
                                   epoch_idx)
            test_writer.add_scalar('Refiner/IoU', max_iou, epoch_idx)

    return max_iou