コード例 #1
0
def evaluate_hand_draw_net(cfg):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
    CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W

    eval_transforms = utils.data_transforms.Compose([
        utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                        std=cfg.DATASET.STD),
        utils.data_transforms.ToTensor(),
    ])

    # Set up networks
    encoder = Encoder(cfg)
    decoder = Decoder(cfg)
    azi_classes, ele_classes = int(360 / cfg.CONST.BIN_SIZE), int(
        180 / cfg.CONST.BIN_SIZE)
    view_estimater = ViewEstimater(cfg,
                                   azi_classes=azi_classes,
                                   ele_classes=ele_classes)

    if torch.cuda.is_available():
        encoder = torch.nn.DataParallel(encoder).cuda()
        decoder = torch.nn.DataParallel(decoder).cuda()
        view_estimater = torch.nn.DataParallel(view_estimater).cuda()

    # Load weight
    # Load weight for encoder, decoder
    print('[INFO] %s Loading reconstruction weights from %s ...' %
          (dt.now(), cfg.EVALUATE_HAND_DRAW.RECONSTRUCTION_WEIGHTS))
    rec_checkpoint = torch.load(cfg.EVALUATE_HAND_DRAW.RECONSTRUCTION_WEIGHTS)
    encoder.load_state_dict(rec_checkpoint['encoder_state_dict'])
    decoder.load_state_dict(rec_checkpoint['decoder_state_dict'])
    print('[INFO] Best reconstruction result at epoch %d ...' %
          rec_checkpoint['epoch_idx'])

    # Load weight for view estimater
    print('[INFO] %s Loading view estimation weights from %s ...' %
          (dt.now(), cfg.EVALUATE_HAND_DRAW.VIEW_ESTIMATION_WEIGHTS))
    view_checkpoint = torch.load(
        cfg.EVALUATE_HAND_DRAW.VIEW_ESTIMATION_WEIGHTS)
    view_estimater.load_state_dict(
        view_checkpoint['view_estimator_state_dict'])
    print('[INFO] Best view estimation result at epoch %d ...' %
          view_checkpoint['epoch_idx'])

    for img_path in os.listdir(cfg.EVALUATE_HAND_DRAW.INPUT_IMAGE_FOLDER):
        eval_id = int(img_path[:-4])
        input_img_path = os.path.join(
            cfg.EVALUATE_HAND_DRAW.INPUT_IMAGE_FOLDER, img_path)
        print(input_img_path)
        evaluate_hand_draw_img(cfg, encoder, decoder, view_estimater,
                               input_img_path, eval_transforms, eval_id)
コード例 #2
0
def run(ckpt_fpath):
    checkpoint = torch.load(ckpt_fpath)
    """ Load Config """
    config = dict_to_cls(checkpoint['config'])
    """ Build Data Loader """
    if config.corpus == "MSVD":
        corpus = MSVD(config)
    elif config.corpus == "MSR-VTT":
        corpus = MSRVTT(config)
    train_iter, val_iter, test_iter, vocab = \
        corpus.train_data_loader, corpus.val_data_loader, corpus.test_data_loader, corpus.vocab
    print(
        '#vocabs: {} ({}), #words: {} ({}). Trim words which appear less than {} times.'
        .format(vocab.n_vocabs, vocab.n_vocabs_untrimmed, vocab.n_words,
                vocab.n_words_untrimmed, config.loader.min_count))
    """ Build Models """
    decoder = Decoder(rnn_type=config.decoder.rnn_type,
                      num_layers=config.decoder.rnn_num_layers,
                      num_directions=config.decoder.rnn_num_directions,
                      feat_size=config.feat.size,
                      feat_len=config.loader.frame_sample_len,
                      embedding_size=config.vocab.embedding_size,
                      hidden_size=config.decoder.rnn_hidden_size,
                      attn_size=config.decoder.rnn_attn_size,
                      output_size=vocab.n_vocabs,
                      rnn_dropout=config.decoder.rnn_dropout)
    decoder.load_state_dict(checkpoint['decoder'])
    model = CaptionGenerator(decoder, config.loader.max_caption_len, vocab)
    model = model.cuda()
    """ Train Set """
    """
    train_vid2pred = get_predicted_captions(train_iter, model, model.vocab, beam_width=5, beam_alpha=0.)
    train_vid2GTs = get_groundtruth_captions(train_iter, model.vocab)
    train_scores = score(train_vid2pred, train_vid2GTs)
    print("[TRAIN] {}".format(train_scores))
    """
    """ Validation Set """
    """
    val_vid2pred = get_predicted_captions(val_iter, model, model.vocab, beam_width=5, beam_alpha=0.)
    val_vid2GTs = get_groundtruth_captions(val_iter, model.vocab)
    val_scores = score(val_vid2pred, val_vid2GTs)
    print("[VAL] scores: {}".format(val_scores))
    """
    """ Test Set """
    test_vid2pred = get_predicted_captions(test_iter,
                                           model,
                                           model.vocab,
                                           beam_width=5,
                                           beam_alpha=0.)
    test_vid2GTs = get_groundtruth_captions(test_iter, model.vocab)
    test_scores = score(test_vid2pred, test_vid2GTs)
    print("[TEST] {}".format(test_scores))

    test_save_fpath = os.path.join(C.result_dpath,
                                   "{}_{}.csv".format(config.corpus, 'test'))
    save_result(test_vid2pred, test_vid2GTs, test_save_fpath)
コード例 #3
0
class Visualization_demo():
    def __init__(self, cfg, output_dir):
        self.encoder = Encoder(cfg)
        self.decoder = Decoder(cfg)
        self.refiner = Refiner(cfg)
        self.merger = Merger(cfg)

        checkpoint = torch.load(cfg.CHECKPOINT)
        encoder_state_dict = clean_state_dict(checkpoint['encoder_state_dict'])
        self.encoder.load_state_dict(encoder_state_dict)
        decoder_state_dict = clean_state_dict(checkpoint['decoder_state_dict'])
        self.decoder.load_state_dict(decoder_state_dict)
        if cfg.NETWORK.USE_REFINER:
            refiner_state_dict = clean_state_dict(
                checkpoint['refiner_state_dict'])
            self.refiner.load_state_dict(refiner_state_dict)
        if cfg.NETWORK.USE_MERGER:
            merger_state_dict = clean_state_dict(
                checkpoint['merger_state_dict'])
            self.merger.load_state_dict(merger_state_dict)

        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        self.output_dir = output_dir

    def run_on_images(self, imgs, sid, mid, iid, sampled_idx):
        dir1 = os.path.join(output_dir, str(sid), str(mid))
        if not os.path.exists(dir1):
            os.makedirs(dir1)

        deprocess = imagenet_deprocess(rescale_image=False)
        image_features = self.encoder(imgs)
        raw_features, generated_volume = self.decoder(image_features)
        generated_volume = self.merger(raw_features, generated_volume)
        generated_volume = self.refiner(generated_volume)

        mesh = cubify(generated_volume, 0.3)
        #         mesh = voxel_to_world(meshes)
        save_mesh = os.path.join(dir1, "%s_%s.obj" % (iid, sampled_idx))
        verts, faces = mesh.get_mesh_verts_faces(0)
        save_obj(save_mesh, verts, faces)

        generated_volume = generated_volume.squeeze()
        img = image_to_numpy(deprocess(imgs[0][0]))
        save_img = os.path.join(dir1, "%02d.png" % (iid))
        #         cv2.imwrite(save_img, img[:, :, ::-1])
        cv2.imwrite(save_img, img)
        img1 = image_to_numpy(deprocess(imgs[0][1]))
        save_img1 = os.path.join(dir1, "%02d.png" % (sampled_idx))
        cv2.imwrite(save_img1, img1)
        #         cv2.imwrite(save_img1, img1[:, :, ::-1])
        get_volume_views(generated_volume, dir1, iid, sampled_idx)
コード例 #4
0
def build_model(C, vocab):

    decoder = Decoder(rnn_type=C.decoder.rnn_type,
                      num_layers=C.decoder.rnn_num_layers,
                      num_directions=C.decoder.rnn_num_directions,
                      feat_size=C.feat.size,
                      feat_len=C.loader.frame_sample_len,
                      embedding_size=C.vocab.embedding_size,
                      hidden_size=C.decoder.rnn_hidden_size,
                      attn_size=C.decoder.rnn_attn_size,
                      output_size=vocab.n_vocabs,
                      rnn_dropout=C.decoder.rnn_dropout)
    if C.pretrained_decoder_fpath is not None:
        decoder.load_state_dict(
            torch.load(C.pretrained_decoder_fpath)['decoder'])
        print("Pretrained decoder is loaded from {}".format(
            C.pretrained_decoder_fpath))
    #全局和局部重构器
    if C.reconstructor is None:
        reconstructor = None
    elif C.reconstructor.type == 'global':
        reconstructor = GlobalReconstructor(
            rnn_type=C.reconstructor.rnn_type,
            num_layers=C.reconstructor.rnn_num_layers,
            num_directions=C.reconstructor.rnn_num_directions,
            decoder_size=C.decoder.rnn_hidden_size,
            hidden_size=C.reconstructor.rnn_hidden_size,
            rnn_dropout=C.reconstructor.rnn_dropout)
    else:
        reconstructor = LocalReconstructor(
            rnn_type=C.reconstructor.rnn_type,
            num_layers=C.reconstructor.rnn_num_layers,
            num_directions=C.reconstructor.rnn_num_directions,
            decoder_size=C.decoder.rnn_hidden_size,
            hidden_size=C.reconstructor.rnn_hidden_size,
            attn_size=C.reconstructor.rnn_attn_size,
            rnn_dropout=C.reconstructor.rnn_dropout)
    if C.pretrained_reconstructor_fpath is not None:
        reconstructor.load_state_dict(
            torch.load(C.pretrained_reconstructor_fpath)['reconstructor'])
        print("Pretrained reconstructor is loaded from {}".format(
            C.pretrained_reconstructor_fpath))

    model = CaptionGenerator(decoder, reconstructor, C.loader.max_caption_len,
                             vocab)
    model.cuda()
    return model
class Quantitative_analysis_demo():
    def __init__(self, cfg, output_dir):
        self.encoder = Encoder(cfg)
        self.decoder = Decoder(cfg)
        self.refiner = Refiner(cfg)
        self.merger = Merger(cfg)
        #         self.thresh = cfg.VOXEL_THRESH
        self.th = cfg.TEST.VOXEL_THRESH

        checkpoint = torch.load(cfg.CHECKPOINT)
        encoder_state_dict = clean_state_dict(checkpoint['encoder_state_dict'])
        self.encoder.load_state_dict(encoder_state_dict)
        decoder_state_dict = clean_state_dict(checkpoint['decoder_state_dict'])
        self.decoder.load_state_dict(decoder_state_dict)
        if cfg.NETWORK.USE_REFINER:
            refiner_state_dict = clean_state_dict(
                checkpoint['refiner_state_dict'])
            self.refiner.load_state_dict(refiner_state_dict)
        if cfg.NETWORK.USE_MERGER:
            merger_state_dict = clean_state_dict(
                checkpoint['merger_state_dict'])
            self.merger.load_state_dict(merger_state_dict)

        self.output_dir = output_dir

    def calculate_iou(self, imgs, GT_voxels, sid, mid, iid):
        dir1 = os.path.join(self.output_dir, str(sid), str(mid))
        if not os.path.exists(dir1):
            os.makedirs(dir1)

        image_features = self.encoder(imgs)
        raw_features, generated_volume = self.decoder(image_features)
        generated_volume = self.merger(raw_features, generated_volume)
        generated_volume = self.refiner(generated_volume)
        generated_volume = generated_volume.squeeze()

        sample_iou = []
        for th in self.th:
            _volume = torch.ge(generated_volume, th).float()
            intersection = torch.sum(_volume.mul(GT_voxels)).float()
            union = torch.sum(torch.ge(_volume.add(GT_voxels), 1)).float()
            sample_iou.append((intersection / union).item())
        return sample_iou
コード例 #6
0
ファイル: train.py プロジェクト: ChangZhou94/SA-LSTM
def build_model(vocab):
    decoder = Decoder(rnn_type=C.decoder.rnn_type,
                      num_layers=C.decoder.rnn_num_layers,
                      num_directions=C.decoder.rnn_num_directions,
                      feat_size=C.feat.size,
                      feat_len=C.loader.frame_sample_len,
                      embedding_size=C.vocab.embedding_size,
                      hidden_size=C.decoder.rnn_hidden_size,
                      attn_size=C.decoder.rnn_attn_size,
                      output_size=vocab.n_vocabs,
                      rnn_dropout=C.decoder.rnn_dropout)
    if C.pretrained_decoder_fpath is not None:
        decoder.load_state_dict(
            torch.load(C.pretrained_decoder_fpath)['decoder'])
        print("Pretrained decoder is loaded from {}".format(
            C.pretrained_decoder_fpath))

    model = CaptionGenerator(decoder, C.loader.max_caption_len, vocab)
    model.cuda()
    return model
コード例 #7
0
ファイル: eval.py プロジェクト: AcodeC/video
def main():
    checkpoint = torch.load(C.model_fpath)
    TC = MockConfig()
    TC_dict = dict(checkpoint['config'].__dict__)
    for key, val in TC_dict.items():
        setattr(TC, key, val)
    TC.build_train_data_loader = False
    TC.build_val_data_loader = False
    TC.build_test_data_loader = True
    TC.build_score_data_loader = True
    TC.test_video_fpath = C.test_video_fpath
    TC.test_caption_fpath = C.test_caption_fpath

    MSVD = _MSVD(TC)
    vocab = MSVD.vocab
    score_data_loader = MSVD.score_data_loader

    decoder = Decoder(
        model_name=TC.decoder_model,
        n_layers=TC.decoder_n_layers,
        encoder_size=TC.encoder_output_size,
        embedding_size=TC.embedding_size,
        embedding_scale=TC.embedding_scale,
        hidden_size=TC.decoder_hidden_size,
        attn_size=TC.decoder_attn_size,
        output_size=vocab.n_vocabs,
        embedding_dropout=TC.embedding_dropout,
        dropout=TC.decoder_dropout,
        out_dropout=TC.decoder_out_dropout,
    )
    decoder = decoder.to(C.device)

    decoder.load_state_dict(checkpoint['dec'])
    decoder.eval()

    scores = evaluate(TC, MSVD, score_data_loader, decoder, ("beam", 5))
    print(scores)
コード例 #8
0
def test_net(cfg, epoch_idx=-1, output_dir=None, test_data_loader=None, \
        test_writer=None, encoder=None, decoder=None, refiner=None, merger=None):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    # Load taxonomies of dataset
    taxonomies = []
    with open(cfg.DATASETS[cfg.DATASET.TEST_DATASET.upper()].TAXONOMY_FILE_PATH, encoding='utf-8') as file:
        taxonomies = json.loads(file.read())
    taxonomies = {t['taxonomy_id']: t for t in taxonomies}

    # Set up data loader
    if test_data_loader is None:
        # Set up data augmentation
        IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
        CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
        test_transforms = utils.data_transforms.Compose([
            utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
            utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
            utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD),
            utils.data_transforms.ToTensor(),
        ])

        dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg)
        test_data_loader = torch.utils.data.DataLoader(
            dataset=dataset_loader.get_dataset(utils.data_loaders.DatasetType.TEST,
                                               cfg.CONST.N_VIEWS_RENDERING, test_transforms),
            batch_size=1,
            num_workers=1,
            pin_memory=True,
            shuffle=False)

    # Set up networks
    if decoder is None or encoder is None:
        encoder = Encoder(cfg)
        decoder = Decoder(cfg)
        refiner = Refiner(cfg)
        merger = Merger(cfg)

        if torch.cuda.is_available():
            encoder = torch.nn.DataParallel(encoder).cuda()
            decoder = torch.nn.DataParallel(decoder).cuda()
            refiner = torch.nn.DataParallel(refiner).cuda()
            merger = torch.nn.DataParallel(merger).cuda()

        print('[INFO] %s Loading weights from %s ...' % (dt.now(), cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)
        epoch_idx = checkpoint['epoch_idx']
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder.load_state_dict(checkpoint['decoder_state_dict'])

        if cfg.NETWORK.USE_REFINER:
            refiner.load_state_dict(checkpoint['refiner_state_dict'])
        if cfg.NETWORK.USE_MERGER:
            merger.load_state_dict(checkpoint['merger_state_dict'])

    # Set up loss functions
    bce_loss = torch.nn.BCELoss()

    # Testing loop
    n_samples = len(test_data_loader)
    test_iou = dict()
    encoder_losses = utils.network_utils.AverageMeter()
    refiner_losses = utils.network_utils.AverageMeter()

    # Switch models to evaluation mode
    encoder.eval()
    decoder.eval()
    refiner.eval()
    merger.eval()

    for sample_idx, (taxonomy_id, sample_name, rendering_images, ground_truth_volume) in enumerate(test_data_loader):
        taxonomy_id = taxonomy_id[0] if isinstance(taxonomy_id[0], str) else taxonomy_id[0].item()
        sample_name = sample_name[0]

        with torch.no_grad():
            # Get data from data loader
            rendering_images = utils.network_utils.var_or_cuda(rendering_images)
            ground_truth_volume = utils.network_utils.var_or_cuda(ground_truth_volume)

            # Test the encoder, decoder, refiner and merger
            image_features = encoder(rendering_images)
            raw_features, generated_volume = decoder(image_features)

            if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
                generated_volume = merger(raw_features, generated_volume)
            else:
                generated_volume = torch.mean(generated_volume, dim=1)
            encoder_loss = bce_loss(generated_volume, ground_truth_volume) * 10

            if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
                generated_volume = refiner(generated_volume)
                refiner_loss = bce_loss(generated_volume, ground_truth_volume) * 10
            else:
                refiner_loss = encoder_loss

            print("vox shape {}".format(generated_volume.shape))

            # Append loss and accuracy to average metrics
            encoder_losses.update(encoder_loss.item())
            refiner_losses.update(refiner_loss.item())

            # IoU per sample
            sample_iou = []
            for th in cfg.TEST.VOXEL_THRESH:
                _volume = torch.ge(generated_volume, th).float()
                intersection = torch.sum(_volume.mul(ground_truth_volume)).float()
                union = torch.sum(torch.ge(_volume.add(ground_truth_volume), 1)).float()
                sample_iou.append((intersection / union).item())

            # IoU per taxonomy
            if not taxonomy_id in test_iou:
                test_iou[taxonomy_id] = {'n_samples': 0, 'iou': []}
            test_iou[taxonomy_id]['n_samples'] += 1
            test_iou[taxonomy_id]['iou'].append(sample_iou)

            # Append generated volumes to TensorBoard
            if output_dir and sample_idx < 3:
                img_dir = output_dir % 'images'
                # Volume Visualization
                gv = generated_volume.cpu().numpy()
                rendering_views = utils.binvox_visualization.get_volume_views(gv, os.path.join(img_dir, 'test'),
                                                                              epoch_idx)
                if not test_writer is None:
                    test_writer.add_image('Test Sample#%02d/Volume Reconstructed' % sample_idx, rendering_views, epoch_idx)
                gtv = ground_truth_volume.cpu().numpy()
                rendering_views = utils.binvox_visualization.get_volume_views(gtv, os.path.join(img_dir, 'test'),
                                                                              epoch_idx)
                if not test_writer is None:
                    test_writer.add_image('Test Sample#%02d/Volume GroundTruth' % sample_idx, rendering_views, epoch_idx)

            # Print sample loss and IoU
            print('[INFO] %s Test[%d/%d] Taxonomy = %s Sample = %s EDLoss = %.4f RLoss = %.4f IoU = %s' % \
                (dt.now(), sample_idx + 1, n_samples, taxonomy_id, sample_name, encoder_loss.item(), \
                    refiner_loss.item(), ['%.4f' % si for si in sample_iou]))

    # Output testing results
    mean_iou = []
    for taxonomy_id in test_iou:
        test_iou[taxonomy_id]['iou'] = np.mean(test_iou[taxonomy_id]['iou'], axis=0)
        mean_iou.append(test_iou[taxonomy_id]['iou'] * test_iou[taxonomy_id]['n_samples'])
    mean_iou = np.sum(mean_iou, axis=0) / n_samples

    # Print header
    print('============================ TEST RESULTS ============================')
    print('Taxonomy', end='\t')
    print('#Sample', end='\t')
    print('Baseline', end='\t')
    for th in cfg.TEST.VOXEL_THRESH:
        print('t=%.2f' % th, end='\t')
    print()
    # Print body
    for taxonomy_id in test_iou:
        print('%s' % taxonomies[taxonomy_id]['taxonomy_name'].ljust(8), end='\t')
        print('%d' % test_iou[taxonomy_id]['n_samples'], end='\t')
        if 'baseline' in taxonomies[taxonomy_id]:
            print('%.4f' % taxonomies[taxonomy_id]['baseline']['%d-view' % cfg.CONST.N_VIEWS_RENDERING], end='\t\t')
        else:
            print('N/a', end='\t\t')

        for ti in test_iou[taxonomy_id]['iou']:
            print('%.4f' % ti, end='\t')
        print()
    # Print mean IoU for each threshold
    print('Overall ', end='\t\t\t\t')
    for mi in mean_iou:
        print('%.4f' % mi, end='\t')
    print('\n')

    # Add testing results to TensorBoard
    max_iou = np.max(mean_iou)
    if not test_writer is None:
        test_writer.add_scalar('EncoderDecoder/EpochLoss', encoder_losses.avg, epoch_idx)
        test_writer.add_scalar('Refiner/EpochLoss', refiner_losses.avg, epoch_idx)
        test_writer.add_scalar('Refiner/IoU', max_iou, epoch_idx)

    return max_iou
コード例 #9
0
def test_single_img_net(cfg):

    encoder = Encoder(cfg)
    decoder = Decoder(cfg)
    refiner = Refiner(cfg)
    merger = Merger(cfg)

    print('[INFO] %s Loading weights from %s ...' %
          (dt.now(), cfg.CONST.WEIGHTS))
    checkpoint = torch.load(cfg.CONST.WEIGHTS,
                            map_location=torch.device('cpu'))

    fix_checkpoint = {}
    fix_checkpoint['encoder_state_dict'] = OrderedDict(
        (k.split('module.')[1:][0], v)
        for k, v in checkpoint['encoder_state_dict'].items())
    fix_checkpoint['decoder_state_dict'] = OrderedDict(
        (k.split('module.')[1:][0], v)
        for k, v in checkpoint['decoder_state_dict'].items())
    fix_checkpoint['refiner_state_dict'] = OrderedDict(
        (k.split('module.')[1:][0], v)
        for k, v in checkpoint['refiner_state_dict'].items())
    fix_checkpoint['merger_state_dict'] = OrderedDict(
        (k.split('module.')[1:][0], v)
        for k, v in checkpoint['merger_state_dict'].items())

    epoch_idx = checkpoint['epoch_idx']
    encoder.load_state_dict(fix_checkpoint['encoder_state_dict'])
    decoder.load_state_dict(fix_checkpoint['decoder_state_dict'])

    if cfg.NETWORK.USE_REFINER:
        print('Use refiner')
        refiner.load_state_dict(fix_checkpoint['refiner_state_dict'])
    if cfg.NETWORK.USE_MERGER:
        print('Use merger')
        merger.load_state_dict(fix_checkpoint['merger_state_dict'])

    encoder.eval()
    decoder.eval()
    refiner.eval()
    merger.eval()

    img1_path = '/media/caig/FECA2C89CA2C406F/dataset/ShapeNetRendering_copy/03001627/1a74a83fa6d24b3cacd67ce2c72c02e/rendering/00.png'
    img1_np = cv2.imread(img1_path, cv2.IMREAD_UNCHANGED).astype(
        np.float32) / 255.

    sample = np.array([img1_np])

    IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
    CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W

    test_transforms = utils.data_transforms.Compose([
        utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                        std=cfg.DATASET.STD),
        utils.data_transforms.ToTensor(),
    ])

    rendering_images = test_transforms(rendering_images=sample)
    rendering_images = rendering_images.unsqueeze(0)

    with torch.no_grad():
        image_features = encoder(rendering_images)
        raw_features, generated_volume = decoder(image_features)

        if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
            generated_volume = merger(raw_features, generated_volume)
        else:
            generated_volume = torch.mean(generated_volume, dim=1)

        if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
            generated_volume = refiner(generated_volume)

    generated_volume = generated_volume.squeeze(0)

    img_dir = '/media/caig/FECA2C89CA2C406F/sketch3D/sketch3D/test_output'
    gv = generated_volume.cpu().numpy()
    gv_new = np.swapaxes(gv, 2, 1)
    rendering_views = utils.binvox_visualization.get_volume_views(
        gv_new, os.path.join(img_dir), epoch_idx)
コード例 #10
0
def main():

    #parsing the arguments
    args, _ = parse_arguments()

    #setup logging
    #output_dir = Path('/content/drive/My Drive/image-captioning/output')
    output_dir = Path(args.output_directory)
    output_dir.mkdir(parents=True, exist_ok=True)
    logfile_path = Path(output_dir / "output.log")
    setup_logging(logfile=logfile_path)

    #setup and read config.ini
    #config_file = Path('/content/drive/My Drive/image-captioning/config.ini')
    config_file = Path('../config.ini')
    reading_config(config_file)

    #tensorboard
    tensorboard_logfile = Path(output_dir / 'tensorboard')
    tensorboard_writer = SummaryWriter(tensorboard_logfile)

    #load dataset
    #dataset_dir = Path('/content/drive/My Drive/Flickr8k_Dataset')
    dataset_dir = Path(args.dataset)
    images_path = Path(dataset_dir / Config.get("images_dir"))
    captions_path = Path(dataset_dir / Config.get("captions_dir"))
    training_loader, validation_loader, testing_loader = data_loaders(
        images_path, captions_path)

    #load the model (encoder, decoder, optimizer)
    embed_size = Config.get("encoder_embed_size")
    hidden_size = Config.get("decoder_hidden_size")
    batch_size = Config.get("training_batch_size")
    epochs = Config.get("epochs")
    feature_extraction = Config.get("feature_extraction")
    raw_captions = read_captions(captions_path)
    id_to_word, word_to_id = dictionary(raw_captions, threshold=5)
    vocab_size = len(id_to_word)
    encoder = Encoder(embed_size, feature_extraction)
    decoder = Decoder(embed_size, hidden_size, vocab_size, batch_size)

    #load pretrained embeddings
    #pretrained_emb_dir = Path('/content/drive/My Drive/word2vec')
    pretrained_emb_dir = Path(args.pretrained_embeddings)
    pretrained_emb_file = Path(pretrained_emb_dir /
                               Config.get("pretrained_emb_path"))
    pretrained_embeddings = load_pretrained_embeddings(pretrained_emb_file,
                                                       id_to_word)

    #load the optimizer
    learning_rate = Config.get("learning_rate")
    optimizer = adam_optimizer(encoder, decoder, learning_rate)

    #loss funciton
    criterion = cross_entropy

    #load checkpoint
    checkpoint_file = Path(output_dir / Config.get("checkpoint_file"))
    checkpoint_captioning = load_checkpoint(checkpoint_file)

    #using available device(gpu/cpu)
    encoder = encoder.to(Config.get("device"))
    decoder = decoder.to(Config.get("device"))
    pretrained_embeddings = pretrained_embeddings.to(Config.get("device"))

    start_epoch = 1
    if checkpoint_captioning is not None:
        start_epoch = checkpoint_captioning['epoch'] + 1
        encoder.load_state_dict(checkpoint_captioning['encoder'])
        decoder.load_state_dict(checkpoint_captioning['decoder'])
        optimizer.load_state_dict(checkpoint_captioning['optimizer'])
        logger.info(
            'Initialized encoder, decoder and optimizer from loaded checkpoint'
        )

    del checkpoint_captioning

    #image captioning model
    model = ImageCaptioning(encoder, decoder, optimizer, criterion,
                            training_loader, validation_loader, testing_loader,
                            pretrained_embeddings, output_dir,
                            tensorboard_writer)

    #training and testing the model
    if args.training:
        validate_every = Config.get("validate_every")
        model.train(epochs, validate_every, start_epoch)
    elif args.testing:
        images_path = Path(images_path / Config.get("images_dir"))
        model.testing(id_to_word, images_path)
コード例 #11
0
def test_single_img(cfg):
    encoder = Encoder(cfg)
    decoder = Decoder(cfg)
    refiner = Refiner(cfg)
    merger = Merger(cfg)

    cfg.CONST.WEIGHTS = 'D:/Pix2Vox/Pix2Vox/pretrained/Pix2Vox-A-ShapeNet.pth'
    checkpoint = torch.load(cfg.CONST.WEIGHTS, map_location=torch.device('cpu'))

    fix_checkpoint = {}
    fix_checkpoint['encoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['encoder_state_dict'].items())
    fix_checkpoint['decoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['decoder_state_dict'].items())
    fix_checkpoint['refiner_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['refiner_state_dict'].items())
    fix_checkpoint['merger_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['merger_state_dict'].items())

    epoch_idx = checkpoint['epoch_idx']
    encoder.load_state_dict(fix_checkpoint['encoder_state_dict'])
    decoder.load_state_dict(fix_checkpoint['decoder_state_dict'])

    if cfg.NETWORK.USE_REFINER:
        print('Use refiner')
        refiner.load_state_dict(fix_checkpoint['refiner_state_dict'])
    if cfg.NETWORK.USE_MERGER:
        print('Use merger')
        merger.load_state_dict(fix_checkpoint['merger_state_dict'])


    encoder.eval()
    decoder.eval()
    refiner.eval()
    merger.eval()

    img1_path = 'D:/Pix2Vox/Pix2Vox/rand/minecraft.png'
    img1_np = cv2.imread(img1_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.

    sample = np.array([img1_np])

    IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
    CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W

    test_transforms = utils.data_transforms.Compose([
        utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD),
        utils.data_transforms.ToTensor(),
    ])

    rendering_images = test_transforms(rendering_images=sample)
    rendering_images = rendering_images.unsqueeze(0)

    with torch.no_grad():
        image_features = encoder(rendering_images)
        raw_features, generated_volume = decoder(image_features)

        if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
            generated_volume = merger(raw_features, generated_volume)
        else:
            generated_volume = torch.mean(generated_volume, dim=1)

        if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
            generated_volume = refiner(generated_volume)

    generated_volume = generated_volume.squeeze(0)

    img_dir = 'D:/Pix2Vox/Pix2Vox/output'
    gv = generated_volume.cpu().numpy()
    gv_new = np.swapaxes(gv, 2, 1)
    print(gv_new)
    rendering_views = utils.binvox_visualization.get_volume_views(gv_new, os.path.join(img_dir),
                                                                                        epoch_idx)
コード例 #12
0
    vocab = Vocabulary("./captions.json",
                       args.NUM_CAPTIONS,
                       num_fre=args.NUM_FRE)
    VOCAB_SIZE = vocab.num_words
    SEQ_LEN = vocab.max_sentence_len

    encoder = Encoder(args.ENCODER_OUTPUT_SIZE)
    decoder = Decoder(embed_size=args.EMBED_SIZE,
                      hidden_size=args.HIDDEN_SIZE,
                      attention_size=args.ATTENTION_SIZE,
                      vocab_size=VOCAB_SIZE,
                      encoder_size=2048,
                      device=device,
                      seq_len=SEQ_LEN + 2)
    encoder.load_state_dict(torch.load(args.ENCODER_MODEL_LOAD_PATH))
    decoder.load_state_dict(torch.load(args.DECODER_MODEL_LOAD_PATH))

    encoder.to(device)
    decoder.to(device)
    encoder.eval()
    decoder.eval()
    result_json = {"images": []}
    for path in IMG_PATH:
        img_name = path.split("/")[-1]
        img = Image.open(path)
        img = transform(img).unsqueeze(0).to(
            device)  # [BATCH_SIZE(1) * CHANNEL * INPUT_SIZE * INPUT_SIZE]

        num_sentence = args.NUM_TOP_PROB
        top_prev_prob = torch.zeros((num_sentence, 1)).to(device)
        words = torch.Tensor([vocab.SOS_token]).long().expand(
コード例 #13
0
def test_net(cfg,
             model_type,
             dataset_type,
             results_file_name,
             epoch_idx=-1,
             test_data_loader=None,
             test_writer=None,
             encoder=None,
             decoder=None,
             refiner=None,
             merger=None,
             save_results_to_file=False,
             show_voxels=False,
             path_to_times_csv=None):
    if model_type == Pix2VoxTypes.Pix2Vox_A or model_type == Pix2VoxTypes.Pix2Vox_Plus_Plus_A:
        use_refiner = True
    else:
        use_refiner = False

    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    # Set up data loader
    if test_data_loader is None:
        # Set up data augmentation
        IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
        CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
        test_transforms = utils.data_transforms.Compose([
            utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
            utils.data_transforms.RandomBackground(
                cfg.TEST.RANDOM_BG_COLOR_RANGE),
            utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                            std=cfg.DATASET.STD),
            utils.data_transforms.ToTensor(),
        ])

        dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
            cfg.DATASET.TEST_DATASET](cfg)
        test_data_loader = torch.utils.data.DataLoader(
            dataset=dataset_loader.get_dataset(dataset_type,
                                               cfg.CONST.N_VIEWS_RENDERING,
                                               test_transforms),
            batch_size=1,
            num_workers=cfg.CONST.NUM_WORKER,
            pin_memory=True,
            shuffle=False)

    # Set up networks
    if decoder is None or encoder is None:
        encoder = Encoder(cfg, model_type)
        decoder = Decoder(cfg, model_type)
        if use_refiner:
            refiner = Refiner(cfg)
        merger = Merger(cfg, model_type)

        if torch.cuda.is_available():
            encoder = torch.nn.DataParallel(encoder).cuda()
            decoder = torch.nn.DataParallel(decoder).cuda()
            if use_refiner:
                refiner = torch.nn.DataParallel(refiner).cuda()
            merger = torch.nn.DataParallel(merger).cuda()

        logging.info('Loading weights from %s ...' % (cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)
        epoch_idx = checkpoint['epoch_idx']
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder.load_state_dict(checkpoint['decoder_state_dict'])

        if use_refiner:
            refiner.load_state_dict(checkpoint['refiner_state_dict'])
        if cfg.NETWORK.USE_MERGER:
            merger.load_state_dict(checkpoint['merger_state_dict'])

    # Set up loss functions
    bce_loss = torch.nn.BCELoss()

    # Testing loop
    n_samples = len(test_data_loader)
    test_iou = dict()
    encoder_losses = AverageMeter()
    if use_refiner:
        refiner_losses = AverageMeter()

    # Switch models to evaluation mode
    encoder.eval()
    decoder.eval()
    if use_refiner:
        refiner.eval()
    merger.eval()

    samples_names = []
    edlosses = []
    rlosses = []
    ious_dict = {}
    for iou_threshold in cfg.TEST.VOXEL_THRESH:
        ious_dict[iou_threshold] = []

    if path_to_times_csv is not None:
        n_view_list = []
        times_list = []

    for sample_idx, (taxonomy_id, sample_name, rendering_images,
                     ground_truth_volume) in enumerate(test_data_loader):
        taxonomy_id = taxonomy_id[0] if isinstance(
            taxonomy_id[0], str) else taxonomy_id[0].item()
        sample_name = sample_name[0]
        with torch.no_grad():
            # Get data from data loader
            rendering_images = utils.helpers.var_or_cuda(rendering_images)
            ground_truth_volume = utils.helpers.var_or_cuda(
                ground_truth_volume)

            if path_to_times_csv is not None:
                start_time = time.time()

            # Test the encoder, decoder, refiner and merger
            image_features = encoder(rendering_images)
            raw_features, generated_volume = decoder(image_features)

            if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
                generated_volume = merger(raw_features, generated_volume)
            else:
                generated_volume = torch.mean(generated_volume, dim=1)
            encoder_loss = bce_loss(generated_volume, ground_truth_volume) * 10

            if use_refiner and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
                generated_volume = refiner(generated_volume)
                refiner_loss = bce_loss(generated_volume,
                                        ground_truth_volume) * 10
            else:
                refiner_loss = encoder_loss

            if path_to_times_csv is not None:
                end_time = time.time()
                n_view_list.append(rendering_images.size()[1])
                times_list.append(end_time - start_time)

            # Append loss and accuracy to average metrics
            encoder_losses.update(encoder_loss.item())
            if use_refiner:
                refiner_losses.update(refiner_loss.item())

            # IoU per sample
            sample_iou = []
            for th in cfg.TEST.VOXEL_THRESH:
                _volume = torch.ge(generated_volume, th).float()
                intersection = torch.sum(
                    _volume.mul(ground_truth_volume)).float()
                union = torch.sum(torch.ge(_volume.add(ground_truth_volume),
                                           1)).float()
                sample_iou.append((intersection / union).item())

                ious_dict[th].append((intersection / union).item())

            # IoU per taxonomy
            if taxonomy_id not in test_iou:
                test_iou[taxonomy_id] = {'n_samples': 0, 'iou': []}
            test_iou[taxonomy_id]['n_samples'] += 1
            test_iou[taxonomy_id]['iou'].append(sample_iou)

            # Append generated volumes to TensorBoard
            if show_voxels:
                with open("model.binvox", "wb") as f:
                    v = br.Voxels(
                        torch.ge(generated_volume,
                                 0.2).float().cpu().numpy()[0], (32, 32, 32),
                        (0, 0, 0), 1, "xyz")
                    v.write(f)

                subprocess.run([VIEWVOX_EXE, "model.binvox"])

                with open("model.binvox", "wb") as f:
                    v = br.Voxels(ground_truth_volume.cpu().numpy()[0],
                                  (32, 32, 32), (0, 0, 0), 1, "xyz")
                    v.write(f)

                subprocess.run([VIEWVOX_EXE, "model.binvox"])

            # Print sample loss and IoU
            logging.info(
                'Test[%d/%d] Taxonomy = %s Sample = %s EDLoss = %.4f RLoss = %.4f IoU = %s'
                % (sample_idx + 1, n_samples, taxonomy_id, sample_name,
                   encoder_loss.item(), refiner_loss.item(),
                   ['%.4f' % si for si in sample_iou]))

            samples_names.append(sample_name)
            edlosses.append(encoder_loss.item())
            if use_refiner:
                rlosses.append(refiner_loss.item())

    if save_results_to_file:
        save_test_results_to_csv(samples_names,
                                 edlosses,
                                 rlosses,
                                 ious_dict,
                                 path_to_csv=results_file_name)

    if path_to_times_csv is not None:
        save_times_to_csv(times_list,
                          n_view_list,
                          path_to_csv=path_to_times_csv)

    # Output testing results
    mean_iou = []
    for taxonomy_id in test_iou:
        test_iou[taxonomy_id]['iou'] = np.mean(test_iou[taxonomy_id]['iou'],
                                               axis=0)
        mean_iou.append(test_iou[taxonomy_id]['iou'] *
                        test_iou[taxonomy_id]['n_samples'])
    mean_iou = np.sum(mean_iou, axis=0) / n_samples

    # Print header
    print(
        '============================ TEST RESULTS ============================'
    )
    print('Taxonomy', end='\t')
    print('#Sample', end='\t')
    print('Baseline', end='\t')
    for th in cfg.TEST.VOXEL_THRESH:
        print('t=%.2f' % th, end='\t')
    print()
    # Print mean IoU for each threshold
    print('Overall ', end='\t\t\t\t')
    for mi in mean_iou:
        print('%.4f' % mi, end='\t')
    print('\n')

    # Add testing results to TensorBoard
    max_iou = np.max(mean_iou)
    if test_writer is not None:
        test_writer.add_scalar('EncoderDecoder/EpochLoss', encoder_losses.avg,
                               epoch_idx)
        if use_refiner:
            test_writer.add_scalar('Refiner/EpochLoss', refiner_losses.avg,
                                   epoch_idx)
            test_writer.add_scalar('Refiner/IoU', max_iou, epoch_idx)

    return max_iou
コード例 #14
0
ファイル: train.py プロジェクト: Doragd/NIC_model
def train():
    opt = parse_opt()
    train_mode = opt.train_mode
    idx2word = json.load(open(opt.idx2word, 'r'))
    captions = json.load(open(opt.captions, 'r'))

    # 模型
    decoder = Decoder(idx2word, opt.settings)
    decoder.to(opt.device)
    lr = opt.learning_rate
    optimizer, xe_criterion = decoder.get_optim_and_crit(lr)
    if opt.resume:
        print("====> loading checkpoint '{}'".format(opt.resume))
        chkpoint = torch.load(opt.resume, map_location=lambda s, l: s)
        assert opt.settings == chkpoint['settings'], \
            'opt.settings and resume model settings are different'
        assert idx2word == chkpoint['idx2word'], \
            'idx2word and resume model idx2word are different'
        decoder.load_state_dict(chkpoint['model'])
        if chkpoint['train_mode'] == train_mode:
            optimizer.load_state_dict(chkpoint['optimizer'])
            lr = optimizer.param_groups[0]['lr']
        print("====> loaded checkpoint '{}', epoch: {}, train_mode: {}".format(
            opt.resume, chkpoint['epoch'], chkpoint['train_mode']))
    elif train_mode == 'rl':
        raise Exception('"rl" mode need resume model')

    print('====> process image captions begin')
    word2idx = {}
    for i, w in enumerate(idx2word):
        word2idx[w] = i
    captions_id = {}
    for split, caps in captions.items():
        print('convert %s captions to index' % split)
        captions_id[split] = {}
        for fn, seqs in tqdm.tqdm(caps.items(), ncols=100):
            tmp = []
            for seq in seqs:
                tmp.append(
                    [decoder.sos_id] +
                    [word2idx.get(w, None) or word2idx['<UNK>']
                     for w in seq] + [decoder.eos_id])
            captions_id[split][fn] = tmp
    captions = captions_id
    print('====> process image captions end')

    train_data = get_dataloader(opt.img_feats, captions['train'],
                                decoder.pad_id, opt.max_seq_len,
                                opt.batch_size, opt.num_workers)
    val_data = get_dataloader(opt.img_feats,
                              captions['val'],
                              decoder.pad_id,
                              opt.max_seq_len,
                              opt.batch_size,
                              opt.num_workers,
                              shuffle=False)
    test_captions = {}
    for fn in captions['test']:
        test_captions[fn] = [[]]
    test_data = get_dataloader(opt.img_feats,
                               test_captions,
                               decoder.pad_id,
                               opt.max_seq_len,
                               opt.batch_size,
                               opt.num_workers,
                               shuffle=False)

    if train_mode == 'rl':
        rl_criterion = RewardCriterion()
        ciderd_scorer = get_ciderd_scorer(captions, decoder.sos_id,
                                          decoder.eos_id)

    def forward(data, training=True, ss_prob=0.0):
        decoder.train(training)
        loss_val = 0.0
        reward_val = 0.0
        for fns, fc_feats, (caps_tensor,
                            lengths), ground_truth in tqdm.tqdm(data,
                                                                ncols=100):
            fc_feats = fc_feats.to(opt.device)
            caps_tensor = caps_tensor.to(opt.device)

            if training and train_mode == 'rl':
                sample_captions, sample_logprobs, seq_masks = decoder(
                    fc_feats,
                    sample_max=0,
                    max_seq_len=opt.max_seq_len,
                    mode=train_mode)
                decoder.eval()
                with torch.no_grad():
                    greedy_captions, _, _ = decoder(
                        fc_feats,
                        sample_max=1,
                        max_seq_len=opt.max_seq_len,
                        mode=train_mode)
                decoder.train(training)
                reward = get_self_critical_reward(sample_captions,
                                                  greedy_captions, fns,
                                                  ground_truth, decoder.sos_id,
                                                  decoder.eos_id,
                                                  ciderd_scorer)
                loss = rl_criterion(
                    sample_logprobs, seq_masks,
                    torch.from_numpy(reward).float().to(opt.device))
                reward_val += float(np.mean(reward[:, 0]))
            else:
                pred = decoder(fc_feats, caps_tensor, ss_prob=ss_prob)
                loss = xe_criterion(pred, caps_tensor[:, 1:], lengths)

            loss_val += float(loss)
            if training:
                optimizer.zero_grad()
                loss.backward()
                clip_gradient(optimizer, opt.grad_clip)
                optimizer.step()

        return loss_val / len(data), reward_val / len(data)

    checkpoint_dir = os.path.join(opt.checkpoint, train_mode)
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    result_dir = os.path.join(opt.result, train_mode)
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)
    previous_loss = None
    for epoch in range(opt.max_epochs):
        print('--------------------epoch: %d' % epoch)
        ss_prob = 0.0
        if epoch > opt.scheduled_sampling_start >= 0:
            frac = (epoch - opt.scheduled_sampling_start
                    ) // opt.scheduled_sampling_increase_every
            ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                          opt.scheduled_sampling_max_prob)
        train_loss, train_reward = forward(train_data, ss_prob=ss_prob)
        with torch.no_grad():
            val_loss, _ = forward(val_data, training=False)

        if train_mode == 'xe' and previous_loss is not None and val_loss >= previous_loss:
            lr = lr * 0.5
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        previous_loss = val_loss

        if epoch in [0, 5, 10, 15, 20, 25, 29, 30, 35, 39, 40, 45, 49]:
            # test
            results = []
            for fns, fc_feats, _, _ in tqdm.tqdm(test_data, ncols=100):
                fc_feats = fc_feats.to(opt.device)
                for i, fn in enumerate(fns):
                    fc_feat = fc_feats[i]
                    with torch.no_grad():
                        rest, _ = decoder.sample(fc_feat,
                                                 beam_size=opt.beam_size,
                                                 max_seq_len=opt.max_seq_len)
                    results.append({'image_id': fn, 'caption': rest[0]})
            json.dump(
                results,
                open(os.path.join(result_dir, 'result_%d.json' % epoch), 'w'))

            chkpoint = {
                'epoch': epoch,
                'model': decoder.state_dict(),
                'optimizer': optimizer.state_dict(),
                'settings': opt.settings,
                'idx2word': idx2word,
                'train_mode': train_mode,
            }
            checkpoint_path = os.path.join(
                checkpoint_dir, 'model_%d_%.4f_%s.pth' %
                (epoch, val_loss, time.strftime('%m%d-%H%M')))
            torch.save(chkpoint, checkpoint_path)

        print('train_loss: %.4f, train_reward: %.4f, val_loss: %.4f' %
              (train_loss, train_reward, val_loss))
コード例 #15
0
ファイル: run.py プロジェクト: AlineAlly/RecNet
def run(ckpt_fpath):
    checkpoint = torch.load(ckpt_fpath)
    """ Load Config """
    config = dict_to_cls(checkpoint['config'])
    """ Build Data Loader """
    if config.corpus == "MSVD":
        corpus = MSVD(config)
    elif config.corpus == "MSR-VTT":
        corpus = MSRVTT(config)
    train_iter, val_iter, test_iter, vocab = \
        corpus.train_data_loader, corpus.val_data_loader, corpus.test_data_loader, corpus.vocab
    print(
        '#vocabs: {} ({}), #words: {} ({}). Trim words which appear less than {} times.'
        .format(vocab.n_vocabs, vocab.n_vocabs_untrimmed, vocab.n_words,
                vocab.n_words_untrimmed, config.loader.min_count))
    """ Build Models """
    decoder = Decoder(rnn_type=config.decoder.rnn_type,
                      num_layers=config.decoder.rnn_num_layers,
                      num_directions=config.decoder.rnn_num_directions,
                      feat_size=config.feat.size,
                      feat_len=config.loader.frame_sample_len,
                      embedding_size=config.vocab.embedding_size,
                      hidden_size=config.decoder.rnn_hidden_size,
                      attn_size=config.decoder.rnn_attn_size,
                      output_size=vocab.n_vocabs,
                      rnn_dropout=config.decoder.rnn_dropout)
    decoder.load_state_dict(checkpoint['decoder'])

    if config.reconstructor.type == 'global':
        reconstructor = GlobalReconstructor(
            rnn_type=config.reconstructor.rnn_type,
            num_layers=config.reconstructor.rnn_num_layers,
            num_directions=config.reconstructor.rnn_num_directions,
            decoder_size=config.decoder.rnn_hidden_size,
            hidden_size=config.reconstructor.rnn_hidden_size,
            rnn_dropout=config.reconstructor.rnn_dropout)
    else:
        reconstructor = LocalReconstructor(
            rnn_type=config.reconstructor.rnn_type,
            num_layers=config.reconstructor.rnn_num_layers,
            num_directions=config.reconstructor.rnn_num_directions,
            decoder_size=config.decoder.rnn_hidden_size,
            hidden_size=config.reconstructor.rnn_hidden_size,
            attn_size=config.reconstructor.rnn_attn_size,
            rnn_dropout=config.reconstructor.rnn_dropout)
    reconstructor.load_state_dict(checkpoint['reconstructor'])

    model = CaptionGenerator(decoder, reconstructor,
                             config.loader.max_caption_len, vocab)
    model = model.cuda()
    '''
    """ Train Set """
    train_scores, train_refs, train_hypos = score(model, train_iter, vocab)
    save_result(train_refs, train_hypos, C.result_dpath, config.corpus, 'train')
    print("[TRAIN] {}".format(train_scores))

    """ Validation Set """
    val_scores, val_refs, val_hypos = score(model, val_iter, vocab)
    save_result(val_refs, val_hypos, C.result_dpath, config.corpus, 'val')
    print("[VAL] scores: {}".format(val_scores))
    '''
    """ Test Set """
    test_scores, test_refs, test_hypos = score(model, test_iter, vocab)
    save_result(test_refs, test_hypos, C.result_dpath, config.corpus, 'test')
    print("[TEST] {}".format(test_scores))
コード例 #16
0
ファイル: server.py プロジェクト: peternara/speech2face-1
encoder.load_weights()
encoder.eval()

for p in encoder.parameters():
    p.requires_grad = False
"""

Load Facial Decoder 

"""

batchSize = 110
net = Decoder(batchSize)
checkpoint = torch.load("./weights/decoder-iter-4449.pt",
                        map_location=torch.device('cpu'))
net.load_state_dict(checkpoint['net_state_dict'])
net.eval()
"""

Load Voice Encoder 

"""

x = Speaker()

net2 = VoiceEncoder(1)
checkpoint = torch.load("./weights/voice-encoder-epoch-16.pt",
                        map_location=torch.device('cpu'))
net2.load_state_dict(checkpoint['net_state_dict'])
net2.eval()
"""
コード例 #17
0
def test_net(cfg, epoch_idx=-1, output_dir=None, test_data_loader=None, \
        test_writer=None, encoder=None, decoder=None, refiner=None, merger=None):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    # Load taxonomies of dataset
    taxonomies = []
    with open(
            cfg.DATASETS[cfg.DATASET.TEST_DATASET.upper()].TAXONOMY_FILE_PATH,
            encoding='utf-8') as file:
        taxonomies = json.loads(file.read())
    taxonomies = {t['taxonomy_id']: t for t in taxonomies}

    # Set up data loader
    if test_data_loader is None:
        # Set up data augmentation
        IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
        CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W

        test_transforms = utils.data_transforms.Compose([
            utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
            utils.data_transforms.RandomBackground(
                cfg.TEST.RANDOM_BG_COLOR_RANGE),
            utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                            std=cfg.DATASET.STD),
            utils.data_transforms.ToTensor(),
        ])

        dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
            cfg.DATASET.TEST_DATASET](cfg)
        test_data_loader = torch.utils.data.DataLoader(
            dataset=dataset_loader.get_dataset(
                utils.data_loaders.DatasetType.TEST,
                cfg.CONST.N_VIEWS_RENDERING, test_transforms),
            batch_size=1,
            num_workers=1,
            pin_memory=True,
            shuffle=False)

    # Set up networks
    if decoder is None or encoder is None:
        encoder = Encoder(cfg)
        decoder = Decoder(cfg)
        refiner = Refiner(cfg)
        merger = Merger(cfg)

        if torch.cuda.is_available():
            encoder = torch.nn.DataParallel(encoder).cuda()
            decoder = torch.nn.DataParallel(decoder).cuda()
            refiner = torch.nn.DataParallel(refiner).cuda()
            merger = torch.nn.DataParallel(merger).cuda()

        print('[INFO] %s Loading weights from %s ...' %
              (dt.now(), cfg.CONST.WEIGHTS))

        if torch.cuda.is_available():
            checkpoint = torch.load(cfg.CONST.WEIGHTS)
        else:
            map_location = torch.device('cpu')
            checkpoint = torch.load(cfg.CONST.WEIGHTS,
                                    map_location=map_location)

        epoch_idx = checkpoint['epoch_idx']
        print('Epoch ID of the current model is {}'.format(epoch_idx))
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder.load_state_dict(checkpoint['decoder_state_dict'])

        if cfg.NETWORK.USE_REFINER:
            refiner.load_state_dict(checkpoint['refiner_state_dict'])
        if cfg.NETWORK.USE_MERGER:
            merger.load_state_dict(checkpoint['merger_state_dict'])

    # Set up loss functions
    bce_loss = torch.nn.BCELoss()

    # Testing loop
    n_samples = len(test_data_loader)
    test_iou = dict()
    encoder_losses = utils.network_utils.AverageMeter()
    refiner_losses = utils.network_utils.AverageMeter()

    # Switch models to evaluation mode
    encoder.eval()
    decoder.eval()
    refiner.eval()
    merger.eval()

    print("test data loader type is {}".format(type(test_data_loader)))
    for sample_idx, (taxonomy_id, sample_name,
                     rendering_images) in enumerate(test_data_loader):
        taxonomy_id = taxonomy_id[0] if isinstance(
            taxonomy_id[0], str) else taxonomy_id[0].item()
        sample_name = sample_name[0]
        print("sample IDx {}".format(sample_idx))
        print("taxonomy id {}".format(taxonomy_id))
        with torch.no_grad():
            # Get data from data loader
            rendering_images = utils.network_utils.var_or_cuda(
                rendering_images)

            print("Shape of the loaded images {}".format(
                rendering_images.shape))

            # Test the encoder, decoder, refiner and merger
            image_features = encoder(rendering_images)
            raw_features, generated_volume = decoder(image_features)

            if cfg.NETWORK.USE_MERGER:
                generated_volume = merger(raw_features, generated_volume)
            else:
                generated_volume = torch.mean(generated_volume, dim=1)

            if cfg.NETWORK.USE_REFINER:
                generated_volume = refiner(generated_volume)

            print("vox shape {}".format(generated_volume.shape))

            gv = generated_volume.cpu().numpy()

            rendering_views = utils.binvox_visualization.get_volume_views(
                gv,
                os.path.join('./LargeDatasets/inference_images/', 'inference'),
                sample_idx)
    print("gv shape is {}".format(gv.shape))
    return gv, rendering_images
def main(config):
    print('Starting')

    checkpoints = config.checkpoint.parent.glob(config.checkpoint.name +
                                                '_*.pth')
    checkpoints = [c for c in checkpoints if extract_id(c) in config.decoders]
    assert len(checkpoints) >= 1, "No checkpoints found."

    model_config = torch.load(config.checkpoint.parent / 'args.pth')[0]
    encoder = Encoder(model_config.encoder)
    encoder.load_state_dict(torch.load(checkpoints[0])['encoder_state'])
    encoder.eval()
    encoder = encoder.cuda()

    generators = []
    generator_ids = []
    for checkpoint in checkpoints:
        decoder = Decoder(model_config.decoder)
        decoder.load_state_dict(torch.load(checkpoint)['decoder_state'])
        decoder.eval()
        decoder = decoder.cuda()

        generator = SampleGenerator(decoder,
                                    config.batch_size,
                                    wav_freq=config.rate)

        generators.append(generator)
        generator_ids.append(extract_id(checkpoint))

    xs = []
    assert config.out_dir is not None

    if len(config.sample_dir) == 1 and config.sample_dir[0].is_dir():
        top = config.sample_dir[0]
        file_paths = list(top.glob('**/*.wav')) + list(top.glob('**/*.h5'))
    else:
        file_paths = config.sample_dir

    print("File paths to be used:", file_paths)
    for file_path in file_paths:
        if file_path.suffix == '.wav':
            data, rate = librosa.load(file_path, sr=config.rate)
            data = helper_functions.mu_law(data)
        elif file_path.suffix == '.h5':
            data = helper_functions.mu_law(
                h5py.File(file_path, 'r')['wav'][:] / (2**15))
            if data.shape[-1] % config.rate != 0:
                data = data[:-(data.shape[-1] % config.rate)]
            assert data.shape[-1] % config.rate == 0
            print(data.shape)
        else:
            raise Exception(f'Unsupported filetype {file_path}')

        if config.sample_len:
            data = data[:config.sample_len]
        else:
            config.sample_len = len(data)
        xs.append(torch.tensor(data).unsqueeze(0).float().cuda())

    xs = torch.stack(xs).contiguous()
    print(f'xs size: {xs.size()}')

    def save(x, decoder_idx, filepath):
        wav = helper_functions.inv_mu_law(x.cpu().numpy())
        print(f'X size: {x.shape}')
        print(f'X min: {x.min()}, max: {x.max()}')

        save_audio(wav.squeeze(),
                   config.out_dir / str(decoder_idx) /
                   filepath.with_suffix('.wav').name,
                   rate=config.rate)

    yy = {}
    with torch.no_grad():
        zz = []
        for xs_batch in torch.split(xs, config.batch_size):
            zz += [encoder(xs_batch)]
        zz = torch.cat(zz, dim=0)

        for i, generator_id in enumerate(generator_ids):
            yy[generator_id] = []
            generator = generators[i]
            for zz_batch in torch.split(zz, config.batch_size):
                print("Batch shape:", zz_batch.shape)
                splits = torch.split(zz_batch, config.split_size, -1)
                audio_data = []
                generator.reset()
                for cond in tqdm.tqdm(splits):
                    audio_data += [generator.generate(cond).cpu()]
                audio_data = torch.cat(audio_data, -1)
                yy[generator_id] += [audio_data]
            yy[generator_id] = torch.cat(yy[generator_id], dim=0)

            for sample_result, filepath in zip(yy[generator_id], file_paths):
                save(sample_result, generator_id, filepath)

            del generator
コード例 #19
0
def main(_run, _config, _log):
    for source_file, _ in _run.experiment_info['sources']:
        os.makedirs(
            os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'),
            exist_ok=True)
        _run.observers[0].save_file(source_file, f'source/{source_file}')
    shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)
    device = torch.device(f"cuda:{_config['gpu_id']}")

    _log.info('###### Create model ######')
    resize_dim = _config['input_size']
    encoded_h = int(resize_dim[0] / 2**_config['n_pool'])
    encoded_w = int(resize_dim[1] / 2**_config['n_pool'])

    s_encoder = SupportEncoder(_config['path']['init_path'],
                               device)  #.to(device)
    q_encoder = QueryEncoder(_config['path']['init_path'],
                             device)  #.to(device)
    decoder = Decoder(input_res=(encoded_h, encoded_w),
                      output_res=resize_dim).to(device)

    checkpoint = torch.load(_config['snapshot'], map_location='cpu')
    s_encoder.load_state_dict(checkpoint['s_encoder'])
    q_encoder.load_state_dict(checkpoint['q_encoder'])
    decoder.load_state_dict(checkpoint['decoder'])

    # initializer.eval()
    # encoder.eval()
    # convlstmcell.eval()
    # decoder.eval()

    _log.info('###### Load data ######')
    data_name = _config['dataset']
    make_data = meta_data
    max_label = 1

    tr_dataset, val_dataset, ts_dataset = make_data(_config)
    testloader = DataLoader(
        dataset=ts_dataset,
        batch_size=1,
        shuffle=False,
        # num_workers=_config['n_work'],
        pin_memory=False,  # True
        drop_last=False)
    # all_samples = test_loader_Spleen()
    # all_samples = test_loader_Prostate()
    if _config['record']:
        _log.info('###### define tensorboard writer #####')
        board_name = f'board/test_{_config["board"]}_{date()}'
        writer = SummaryWriter(board_name)

    _log.info('###### Testing begins ######')
    # metric = Metric(max_label=max_label, n_runs=_config['n_runs'])
    img_cnt = 0
    # length = len(all_samples)
    length = len(testloader)
    img_lists = []
    pred_lists = []
    label_lists = []

    saves = {}
    for subj_idx in range(len(ts_dataset.get_cnts())):
        saves[subj_idx] = []

    with torch.no_grad():
        loss_valid = 0
        batch_i = 0  # use only 1 batch size for testing

        for i, sample_test in enumerate(
                testloader):  # even for upward, down for downward
            subj_idx, idx = ts_dataset.get_test_subj_idx(i)
            img_list = []
            pred_list = []
            label_list = []
            preds = []

            fnames = sample_test['q_fname']
            s_x = sample_test['s_x'].to(device)  # [B, slice_num, 1, 256, 256]
            s_y = sample_test['s_y'].to(device)  # [B, slice_num, 1, 256, 256]
            q_x = sample_test['q_x'].to(device)  # [B, slice_num, 1, 256, 256]
            q_y = sample_test['q_y'].to(device)  # [B, slice_num, 1, 256, 256]
            s_xi = s_x[:, :, 0, :, :, :]  # [B, Support, 1, 256, 256]
            s_yi = s_y[:, :, 0, :, :, :]

            for s_idx in range(_config["n_shot"]):
                s_x_merge = s_xi.view(s_xi.size(0) * s_xi.size(1), 1, 256, 256)
                s_y_merge = s_yi.view(s_yi.size(0) * s_yi.size(1), 1, 256, 256)
                s_xi_encode_merge, _ = s_encoder(s_x_merge,
                                                 s_y_merge)  # [B*S, 512, w, h]

            s_xi_encode = s_xi_encode_merge.view(s_yi.size(0), s_yi.size(1),
                                                 512, encoded_w, encoded_h)
            s_xi_encode_avg = torch.mean(s_xi_encode, dim=1)
            # s_xi_encode, _ = s_encoder(s_xi, s_yi)  # [B, 512, w, h]
            q_xi = q_x[:, 0, :, :, :]
            q_yi = q_y[:, 0, :, :, :]
            q_xi_encode, q_ft_list = q_encoder(q_xi)
            sq_xi = torch.cat((s_xi_encode_avg, q_xi_encode), dim=1)
            yhati = decoder(sq_xi, q_ft_list)  # [B, 1, 256, 256]

            preds.append(yhati.round())
            img_list.append(q_xi[batch_i].cpu().numpy())
            pred_list.append(yhati[batch_i].round().cpu().numpy())
            label_list.append(q_yi[batch_i].cpu().numpy())

            saves[subj_idx].append(
                [subj_idx, idx, img_list, pred_list, label_list, fnames])
            print(f"test, iter:{i}/{length} - {subj_idx}/{idx} \t\t", end='\r')
            img_lists.append(img_list)
            pred_lists.append(pred_list)
            label_lists.append(label_list)

    print("start computing dice similarities ... total ", len(saves))
    dice_similarities = []
    for subj_idx in range(len(saves)):
        imgs, preds, labels = [], [], []
        save_subj = saves[subj_idx]
        for i in range(len(save_subj)):
            # print(len(save_subj), len(save_subj)-q_slice_n+1, q_slice_n, i)
            subj_idx, idx, img_list, pred_list, label_list, fnames = save_subj[
                i]
            # print(subj_idx, idx, is_reverse, len(img_list))
            # print(i, is_reverse, is_reverse_next, is_flip)

            for j in range(len(img_list)):
                imgs.append(img_list[j])
                preds.append(pred_list[j])
                labels.append(label_list[j])

        # pdb.set_trace()
        img_arr = np.concatenate(imgs, axis=0)
        pred_arr = np.concatenate(preds, axis=0)
        label_arr = np.concatenate(labels, axis=0)
        # print(ts_dataset.slice_cnts[subj_idx] , len(imgs))
        # pdb.set_trace()
        dice = np.sum([label_arr * pred_arr
                       ]) * 2.0 / (np.sum(pred_arr) + np.sum(label_arr))
        dice_similarities.append(dice)
        print(f"computing dice scores {subj_idx}/{10}", end='\n')

        if _config['record']:
            frames = []
            for frame_id in range(0, len(save_subj)):
                frames += overlay_color(torch.tensor(imgs[frame_id]),
                                        torch.tensor(preds[frame_id]),
                                        torch.tensor(labels[frame_id]),
                                        scale=_config['scale'])
            visual = make_grid(frames, normalize=True, nrow=5)
            writer.add_image(f"test/{subj_idx}", visual, i)
            writer.add_scalar(f'dice_score/{i}', dice)

        if _config['save_sample']:
            ## only for internal test (BCV - MICCAI2015)
            sup_idx = _config['s_idx']
            target = _config['target']
            save_name = _config['save_name']
            dirs = ["gt", "pred", "input"]
            save_dir = f"/user/home2/soopil/tmp/PANet/MICCAI2015/sample/fss1000_organ{target}_sup{sup_idx}_{save_name}"

            for dir in dirs:
                try:
                    os.makedirs(os.path.join(save_dir, dir))
                except:
                    pass

            subj_name = fnames[0][0].split("/")[-2]
            if target == 14:
                src_dir = "/user/home2/soopil/Datasets/MICCAI2015challenge/Cervix/RawData/Training/img"
                orig_fname = f"{src_dir}/{subj_name}-Image.nii.gz"
                pass
            else:
                src_dir = "/user/home2/soopil/Datasets/MICCAI2015challenge/Abdomen/RawData/Training/img"
                orig_fname = f"{src_dir}/img{subj_name}.nii.gz"

            itk = sitk.ReadImage(orig_fname)
            orig_spacing = itk.GetSpacing()

            label_arr = label_arr * 2.0
            # label_arr = np.concatenate([np.zeros([1,256,256]), label_arr,np.zeros([1,256,256])])
            # pred_arr = np.concatenate([np.zeros([1,256,256]), pred_arr,np.zeros([1,256,256])])
            # img_arr = np.concatenate([np.zeros([1,256,256]), img_arr,np.zeros([1,256,256])])
            # pdb.set_trace()
            itk = sitk.GetImageFromArray(label_arr)
            itk.SetSpacing(orig_spacing)
            sitk.WriteImage(itk, f"{save_dir}/gt/{subj_idx}.nii.gz")
            itk = sitk.GetImageFromArray(pred_arr.astype(float))
            itk.SetSpacing(orig_spacing)
            sitk.WriteImage(itk, f"{save_dir}/pred/{subj_idx}.nii.gz")
            itk = sitk.GetImageFromArray(img_arr.astype(float))
            itk.SetSpacing(orig_spacing)
            sitk.WriteImage(itk, f"{save_dir}/input/{subj_idx}.nii.gz")

    print(f"test result \n n : {len(dice_similarities)}, mean dice score : \
    {np.mean(dice_similarities)} \n dice similarities : {dice_similarities}")

    if _config['record']:
        writer.add_scalar(f'dice_score/mean', np.mean(dice_similarities))
コード例 #20
0
ファイル: train.py プロジェクト: sushantmakadia/Pix2Vox-1
def train_net(cfg):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    # Set up data augmentation
    IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
    CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
    train_transforms = utils.data_transforms.Compose([
        utils.data_transforms.RandomCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(
            cfg.TRAIN.RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.ColorJitter(cfg.TRAIN.BRIGHTNESS,
                                          cfg.TRAIN.CONTRAST,
                                          cfg.TRAIN.SATURATION),
        utils.data_transforms.RandomNoise(cfg.TRAIN.NOISE_STD),
        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                        std=cfg.DATASET.STD),
        utils.data_transforms.RandomFlip(),
        utils.data_transforms.RandomPermuteRGB(),
        utils.data_transforms.ToTensor(),
    ])
    val_transforms = utils.data_transforms.Compose([
        utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                        std=cfg.DATASET.STD),
        utils.data_transforms.ToTensor(),
    ])

    # Set up data loader
    train_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
        cfg.DATASET.TRAIN_DATASET](cfg)
    val_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
        cfg.DATASET.TEST_DATASET](cfg)
    train_data_loader = torch.utils.data.DataLoader(
        dataset=train_dataset_loader.get_dataset(
            utils.data_loaders.DatasetType.TRAIN, cfg.CONST.N_VIEWS_RENDERING,
            train_transforms),
        batch_size=cfg.CONST.BATCH_SIZE,
        num_workers=cfg.TRAIN.NUM_WORKER,
        pin_memory=True,
        shuffle=True,
        drop_last=True)
    val_data_loader = torch.utils.data.DataLoader(
        dataset=val_dataset_loader.get_dataset(
            utils.data_loaders.DatasetType.VAL, cfg.CONST.N_VIEWS_RENDERING,
            val_transforms),
        batch_size=1,
        num_workers=1,
        pin_memory=True,
        shuffle=False)

    # Set up networks
    encoder = Encoder(cfg)
    decoder = Decoder(cfg)
    refiner = Refiner(cfg)
    merger = Merger(cfg)
    print('[DEBUG] %s Parameters in Encoder: %d.' %
          (dt.now(), utils.network_utils.count_parameters(encoder)))
    print('[DEBUG] %s Parameters in Decoder: %d.' %
          (dt.now(), utils.network_utils.count_parameters(decoder)))
    print('[DEBUG] %s Parameters in Refiner: %d.' %
          (dt.now(), utils.network_utils.count_parameters(refiner)))
    print('[DEBUG] %s Parameters in Merger: %d.' %
          (dt.now(), utils.network_utils.count_parameters(merger)))

    # Initialize weights of networks
    encoder.apply(utils.network_utils.init_weights)
    decoder.apply(utils.network_utils.init_weights)
    refiner.apply(utils.network_utils.init_weights)
    merger.apply(utils.network_utils.init_weights)

    # Set up solver
    if cfg.TRAIN.POLICY == 'adam':
        encoder_solver = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                 encoder.parameters()),
                                          lr=cfg.TRAIN.ENCODER_LEARNING_RATE,
                                          betas=cfg.TRAIN.BETAS)
        decoder_solver = torch.optim.Adam(decoder.parameters(),
                                          lr=cfg.TRAIN.DECODER_LEARNING_RATE,
                                          betas=cfg.TRAIN.BETAS)
        refiner_solver = torch.optim.Adam(refiner.parameters(),
                                          lr=cfg.TRAIN.REFINER_LEARNING_RATE,
                                          betas=cfg.TRAIN.BETAS)
        merger_solver = torch.optim.Adam(merger.parameters(),
                                         lr=cfg.TRAIN.MERGER_LEARNING_RATE,
                                         betas=cfg.TRAIN.BETAS)
    elif cfg.TRAIN.POLICY == 'sgd':
        encoder_solver = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                                encoder.parameters()),
                                         lr=cfg.TRAIN.ENCODER_LEARNING_RATE,
                                         momentum=cfg.TRAIN.MOMENTUM)
        decoder_solver = torch.optim.SGD(decoder.parameters(),
                                         lr=cfg.TRAIN.DECODER_LEARNING_RATE,
                                         momentum=cfg.TRAIN.MOMENTUM)
        refiner_solver = torch.optim.SGD(refiner.parameters(),
                                         lr=cfg.TRAIN.REFINER_LEARNING_RATE,
                                         momentum=cfg.TRAIN.MOMENTUM)
        merger_solver = torch.optim.SGD(merger.parameters(),
                                        lr=cfg.TRAIN.MERGER_LEARNING_RATE,
                                        momentum=cfg.TRAIN.MOMENTUM)
    else:
        raise Exception('[FATAL] %s Unknown optimizer %s.' %
                        (dt.now(), cfg.TRAIN.POLICY))

    # Set up learning rate scheduler to decay learning rates dynamically
    encoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        encoder_solver,
        milestones=cfg.TRAIN.ENCODER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA)
    decoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        decoder_solver,
        milestones=cfg.TRAIN.DECODER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA)
    refiner_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        refiner_solver,
        milestones=cfg.TRAIN.REFINER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA)
    merger_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        merger_solver,
        milestones=cfg.TRAIN.MERGER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA)

    if torch.cuda.is_available():
        encoder = torch.nn.DataParallel(encoder).cuda()
        decoder = torch.nn.DataParallel(decoder).cuda()
        refiner = torch.nn.DataParallel(refiner).cuda()
        merger = torch.nn.DataParallel(merger).cuda()

    # Set up loss functions
    bce_loss = torch.nn.BCELoss()

    # Load pretrained model if exists
    init_epoch = 0
    best_iou = -1
    best_epoch = -1
    if 'WEIGHTS' in cfg.CONST and cfg.TRAIN.RESUME_TRAIN:
        print('[INFO] %s Recovering from %s ...' %
              (dt.now(), cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)
        init_epoch = checkpoint['epoch_idx']
        best_iou = checkpoint['best_iou']
        best_epoch = checkpoint['best_epoch']

        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder.load_state_dict(checkpoint['decoder_state_dict'])
        if cfg.NETWORK.USE_REFINER:
            refiner.load_state_dict(checkpoint['refiner_state_dict'])
        if cfg.NETWORK.USE_MERGER:
            merger.load_state_dict(checkpoint['merger_state_dict'])

        print('[INFO] %s Recover complete. Current epoch #%d, Best IoU = %.4f at epoch #%d.' \
                 % (dt.now(), init_epoch, best_iou, best_epoch))

    # Summary writer for TensorBoard
    output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s', dt.now().isoformat())
    log_dir = output_dir % 'logs'
    ckpt_dir = output_dir % 'checkpoints'
    train_writer = SummaryWriter(os.path.join(log_dir, 'train'))
    val_writer = SummaryWriter(os.path.join(log_dir, 'test'))

    # Training loop
    for epoch_idx in range(init_epoch, cfg.TRAIN.NUM_EPOCHES):
        # Tick / tock
        epoch_start_time = time()

        # Batch average meterics
        batch_time = utils.network_utils.AverageMeter()
        data_time = utils.network_utils.AverageMeter()
        encoder_losses = utils.network_utils.AverageMeter()
        refiner_losses = utils.network_utils.AverageMeter()

        # Adjust learning rate
        encoder_lr_scheduler.step()
        decoder_lr_scheduler.step()
        refiner_lr_scheduler.step()
        merger_lr_scheduler.step()

        # switch models to training mode
        encoder.train()
        decoder.train()
        merger.train()
        refiner.train()

        batch_end_time = time()
        n_batches = len(train_data_loader)
        for batch_idx, (taxonomy_names, sample_names, rendering_images,
                        ground_truth_volumes) in enumerate(train_data_loader):
            # Measure data time
            data_time.update(time() - batch_end_time)

            # Get data from data loader
            rendering_images = utils.network_utils.var_or_cuda(
                rendering_images)
            ground_truth_volumes = utils.network_utils.var_or_cuda(
                ground_truth_volumes)

            # Train the encoder, decoder, refiner, and merger
            image_features = encoder(rendering_images)
            raw_features, generated_volumes = decoder(image_features)

            if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
                generated_volumes = merger(raw_features, generated_volumes)
            else:
                generated_volumes = torch.mean(generated_volumes, dim=1)
            encoder_loss = bce_loss(generated_volumes,
                                    ground_truth_volumes) * 10

            if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
                generated_volumes = refiner(generated_volumes)
                refiner_loss = bce_loss(generated_volumes,
                                        ground_truth_volumes) * 10
            else:
                refiner_loss = encoder_loss

            # Gradient decent
            encoder.zero_grad()
            decoder.zero_grad()
            refiner.zero_grad()
            merger.zero_grad()

            if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
                encoder_loss.backward(retain_graph=True)
                refiner_loss.backward()
            else:
                encoder_loss.backward()

            encoder_solver.step()
            decoder_solver.step()
            refiner_solver.step()
            merger_solver.step()

            # Append loss to average metrics
            encoder_losses.update(encoder_loss.item())
            refiner_losses.update(refiner_loss.item())
            # Append loss to TensorBoard
            n_itr = epoch_idx * n_batches + batch_idx
            train_writer.add_scalar('EncoderDecoder/BatchLoss',
                                    encoder_loss.item(), n_itr)
            train_writer.add_scalar('Refiner/BatchLoss', refiner_loss.item(),
                                    n_itr)

            # Tick / tock
            batch_time.update(time() - batch_end_time)
            batch_end_time = time()
            print('[INFO] %s [Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) EDLoss = %.4f RLoss = %.4f' % \
                (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, batch_idx + 1, n_batches, \
                    batch_time.val, data_time.val, encoder_loss.item(), refiner_loss.item()))

        # Append epoch loss to TensorBoard
        train_writer.add_scalar('EncoderDecoder/EpochLoss', encoder_losses.avg,
                                epoch_idx + 1)
        train_writer.add_scalar('Refiner/EpochLoss', refiner_losses.avg,
                                epoch_idx + 1)

        # Tick / tock
        epoch_end_time = time()
        print('[INFO] %s Epoch [%d/%d] EpochTime = %.3f (s) EDLoss = %.4f RLoss = %.4f' %
            (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, epoch_end_time - epoch_start_time, \
                encoder_losses.avg, refiner_losses.avg))

        # Update Rendering Views
        if cfg.TRAIN.UPDATE_N_VIEWS_RENDERING:
            n_views_rendering = random.randint(1, cfg.CONST.N_VIEWS_RENDERING)
            train_data_loader.dataset.set_n_views_rendering(n_views_rendering)
            print('[INFO] %s Epoch [%d/%d] Update #RenderingViews to %d' % \
                (dt.now(), epoch_idx + 2, cfg.TRAIN.NUM_EPOCHES, n_views_rendering))

        # Validate the training models
        iou = test_net(cfg, epoch_idx + 1, output_dir, val_data_loader,
                       val_writer, encoder, decoder, refiner, merger)

        # Save weights to file
        if (epoch_idx + 1) % cfg.TRAIN.SAVE_FREQ == 0:
            if not os.path.exists(ckpt_dir):
                os.makedirs(ckpt_dir)

            utils.network_utils.save_checkpoints(cfg, \
                    os.path.join(ckpt_dir, 'ckpt-epoch-%04d.pth' % (epoch_idx + 1)), \
                    epoch_idx + 1, encoder, encoder_solver, decoder, decoder_solver, \
                    refiner, refiner_solver, merger, merger_solver, best_iou, best_epoch)
        if iou > best_iou:
            if not os.path.exists(ckpt_dir):
                os.makedirs(ckpt_dir)

            best_iou = iou
            best_epoch = epoch_idx + 1
            utils.network_utils.save_checkpoints(cfg, \
                    os.path.join(ckpt_dir, 'best-ckpt.pth'), \
                    epoch_idx + 1, encoder, encoder_solver, decoder, decoder_solver, \
                    refiner, refiner_solver, merger, merger_solver, best_iou, best_epoch)

    # Close SummaryWriter for TensorBoard
    train_writer.close()
    val_writer.close()
コード例 #21
0
def main(_run, _config, _log):
    for source_file, _ in _run.experiment_info['sources']:
        os.makedirs(
            os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'),
            exist_ok=True)
        _run.observers[0].save_file(source_file, f'source/{source_file}')
    shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)
    device = torch.device(f"cuda:{_config['gpu_id']}")

    _log.info('###### Create model ######')
    resize_dim = _config['input_size']
    encoded_h = int(resize_dim[0] / 2**_config['n_pool'])
    encoded_w = int(resize_dim[1] / 2**_config['n_pool'])

    s_encoder = SupportEncoder(_config['path']['init_path'],
                               device)  #.to(device)
    q_encoder = QueryEncoder(_config['path']['init_path'],
                             device)  #.to(device)
    decoder = Decoder(input_res=(encoded_h, encoded_w),
                      output_res=resize_dim).to(device)

    checkpoint = torch.load(_config['snapshot'], map_location='cpu')
    s_encoder.load_state_dict(checkpoint['s_encoder'])
    q_encoder.load_state_dict(checkpoint['q_encoder'])
    decoder.load_state_dict(checkpoint['decoder'])

    # initializer.eval()
    # encoder.eval()
    # convlstmcell.eval()
    # decoder.eval()

    _log.info('###### Load data ######')
    data_name = _config['dataset']
    make_data = meta_data
    max_label = 1

    tr_dataset, val_dataset, ts_dataset = make_data(_config)
    testloader = DataLoader(
        dataset=ts_dataset,
        batch_size=1,
        shuffle=False,
        # num_workers=_config['n_work'],
        pin_memory=False,  # True
        drop_last=False)

    _log.info('###### Testing begins ######')
    # metric = Metric(max_label=max_label, n_runs=_config['n_runs'])
    img_cnt = 0
    # length = len(all_samples)
    length = len(testloader)
    img_lists = []
    pred_lists = []
    label_lists = []

    saves = {}
    for subj_idx in range(len(ts_dataset.get_cnts())):
        saves[subj_idx] = []

    with torch.no_grad():
        loss_valid = 0
        batch_i = 0  # use only 1 batch size for testing

        for i, sample_test in enumerate(
                testloader):  # even for upward, down for downward
            subj_idx, idx = ts_dataset.get_test_subj_idx(i)
            img_list = []
            pred_list = []
            label_list = []
            preds = []

            s_x = sample_test['s_x'].to(device)  # [B, slice_num, 1, 256, 256]
            s_y = sample_test['s_y'].to(device)  # [B, slice_num, 1, 256, 256]
            q_x = sample_test['q_x'].to(device)  # [B, slice_num, 1, 256, 256]
            q_y = sample_test['q_y'].to(device)  # [B, slice_num, 1, 256, 256]
            s_fname = sample_test['s_fname']
            q_fname = sample_test['q_fname']

            s_xi = s_x[:, 0, :, :, :]  #[B, 1, 256, 256]
            s_yi = s_y[:, 0, :, :, :]
            s_xi_encode, _ = s_encoder(s_xi, s_yi)  #[B, 512, w, h]
            q_xi = q_x[:, 0, :, :, :]
            q_yi = q_y[:, 0, :, :, :]
            q_xi_encode, q_ft_list = q_encoder(q_xi)
            sq_xi = torch.cat((s_xi_encode, q_xi_encode), dim=1)
            yhati = decoder(sq_xi, q_ft_list)  # [B, 1, 256, 256]

            preds.append(yhati.round())
            img_list.append(q_xi[batch_i].cpu().numpy())
            pred_list.append(yhati[batch_i].round().cpu().numpy())
            label_list.append(q_yi[batch_i].cpu().numpy())

            saves[subj_idx].append(
                [subj_idx, idx, img_list, pred_list, label_list])
            print(f"test, iter:{i}/{length} - {subj_idx}/{idx} \t\t", end='\r')
            img_lists.append(img_list)
            pred_lists.append(pred_list)
            label_lists.append(label_list)

            q_fname_split = q_fname[0][0].split("/")
            q_fname_split[-6] = "Training_2d_2_pred"
            try_mkdirs("/".join(q_fname_split[:-1]))
            o_q_fname = "/".join(q_fname_split)
            np.save(o_q_fname, yhati.round().cpu().numpy())
            # print(q_fname[0][0])
            # print(o_q_fname)

    try_mkdirs("figure")
    print("start computing dice similarities ... total ", len(saves))
    for subj_idx in range(len(saves)):
        save_subj = saves[subj_idx]
        dices = []

        for slice_idx in range(len(save_subj)):
            subj_idx, idx, img_list, pred_list, label_list = save_subj[
                slice_idx]

            for j in range(len(img_list)):
                dice = np.sum([label_list[j] * pred_list[j]]) * 2.0 / (
                    np.sum(pred_list[j]) + np.sum(label_list[j]))
                dices.append(dice)

        plt.clf()
        plt.bar([k for k in range(len(dices))], dices)
        plt.savefig(f"figure/bar_{_config['target']}_{subj_idx}.png")
コード例 #22
0
from models.encoder import Encoder

opt = parse_opt()
assert opt.test_model, 'please input test_model'
assert opt.image_file, 'please input image_file'

encoder = Encoder(opt.resnet101_file)
encoder.to(opt.device)
encoder.eval()

img = skimage.io.imread(opt.image_file)
with torch.no_grad():
    img = encoder.preprocess(img)
    img = img.to(opt.device)
    fc_feat, att_feat = encoder(img)

print("====> loading checkpoint '{}'".format(opt.test_model))
chkpoint = torch.load(opt.test_model, map_location=lambda s, l: s)
decoder = Decoder(chkpoint['idx2word'], chkpoint['settings'])
decoder.load_state_dict(chkpoint['model'])
print("====> loaded checkpoint '{}', epoch: {}, train_mode: {}".format(
    opt.test_model, chkpoint['epoch'], chkpoint['train_mode']))
decoder.to(opt.device)
decoder.eval()

rest, _ = decoder.sample(fc_feat,
                         att_feat,
                         beam_size=opt.beam_size,
                         max_seq_len=opt.max_seq_len)
print('generate captions:\n' + '\n'.join(rest))
コード例 #23
0
def test_img(cfg):

    encoder = Encoder(cfg)
    decoder = Decoder(cfg)
    refiner = Refiner(cfg)
    merger = Merger(cfg)

    cfg.CONST.WEIGHTS = '/Users/pranavpomalapally/Downloads/new-Pix2Vox-A-ShapeNet.pth'
    checkpoint = torch.load(cfg.CONST.WEIGHTS,
                            map_location=torch.device('cpu'))

    print()
    # fix_checkpoint = {}
    # fix_checkpoint['encoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['encoder_state_dict'].items())
    # fix_checkpoint['decoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['decoder_state_dict'].items())
    # fix_checkpoint['refiner_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['refiner_state_dict'].items())
    # fix_checkpoint['merger_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['merger_state_dict'].items())

    # fix_checkpoint['encoder_state_dict'] = OrderedDict((k.split('module.')[0], v) for k, v in checkpoint['encoder_state_dict'].items())
    # fix_checkpoint['decoder_state_dict'] = OrderedDict((k.split('module.')[0], v) for k, v in checkpoint['decoder_state_dict'].items())
    # fix_checkpoint['refiner_state_dict'] = OrderedDict((k.split('module.')[0], v) for k, v in checkpoint['refiner_state_dict'].items())
    # fix_checkpoint['merger_state_dict'] = OrderedDict((k.split('module.')[0], v) for k, v in checkpoint['merger_state_dict'].items())

    epoch_idx = checkpoint['epoch_idx']
    # encoder.load_state_dict(fix_checkpoint['encoder_state_dict'])
    # decoder.load_state_dict(fix_checkpoint['decoder_state_dict'])
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_state_dict'])

    # if cfg.NETWORK.USE_REFINER:
    #  print('Use refiner')
    #  refiner.load_state_dict(fix_checkpoint['refiner_state_dict'])

    print('Use refiner')
    refiner.load_state_dict(checkpoint['refiner_state_dict'])
    if cfg.NETWORK.USE_MERGER:
        print('Use merger')
        # merger.load_state_dict(fix_checkpoint['merger_state_dict'])
        merger.load_state_dict(checkpoint['merger_state_dict'])

    encoder.eval()
    decoder.eval()
    refiner.eval()
    merger.eval()

    #img1_path = '/Users/pranavpomalapally/Downloads/ShapeNetRendering/02691156/1a04e3eab45ca15dd86060f189eb133/rendering/00.png'
    img1_path = '/Users/pranavpomalapally/Downloads/09 copy.png'
    img1_np = cv2.imread(img1_path, cv2.IMREAD_UNCHANGED).astype(
        np.float32) / 255.

    sample = np.array([img1_np])

    IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
    CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W

    test_transforms = utils.data_transforms.Compose([
        utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                        std=cfg.DATASET.STD),
        utils.data_transforms.ToTensor(),
    ])

    rendering_images = test_transforms(rendering_images=sample)
    rendering_images = rendering_images.unsqueeze(0)

    with torch.no_grad():
        image_features = encoder(rendering_images)
        raw_features, generated_volume = decoder(image_features)

        if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
            generated_volume = merger(raw_features, generated_volume)
        else:
            generated_volume = torch.mean(generated_volume, dim=1)

        # if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
        #     generated_volume = refiner(generated_volume)
    generated_volume = refiner(generated_volume)
    generated_volume = generated_volume.squeeze(0)

    img_dir = '/Users/pranavpomalapally/Downloads/outputs'
    # gv = generated_volume.cpu().numpy()
    gv = generated_volume.cpu().detach().numpy()
    gv_new = np.swapaxes(gv, 2, 1)

    os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
    rendering_views = utils.binvox_visualization.get_volume_views(
        gv_new, img_dir, epoch_idx)