コード例 #1
0
ファイル: lenet.py プロジェクト: zizai/deep-weight-prior
    def set_prior(self, prior_list, dwp_samples, vae_list, flow_list=None):
        convs = [self.features.conv1, self.features.conv2]
        for i, m in enumerate(convs):
            if not isinstance(m, bayes._Bayes):
                continue

            if prior_list[i] == 'vae':
                vae = utils.load_vae(vae_list[i], self.device)
                for p in vae.parameters():
                    p.requires_grad = False
                m.kl_function = utils.kl_dwp(vae, n_tries=dwp_samples)
            elif prior_list[i] == 'flow':
                flow = utils.load_flow(flow_list[i], self.device)
                for p in flow.parameters():
                    p.requires_grad = False
                m.kl_function = utils.kl_flow(flow, n_tries=dwp_samples)
            elif prior_list[i] == 'sn':
                m.kl_function = utils.kl_normal
                m.prior = dist.Normal(
                    torch.FloatTensor([0.]).to(self.device),
                    torch.FloatTensor([1.]).to(self.device))
            elif prior_list[i] == 'loguniform':
                if self.cfg == 'bayes-mtrunca':
                    m.kl_function = utils.kl_loguniform_with_trunc_alpha
            else:
                raise NotImplementedError
コード例 #2
0
ファイル: demo_alignments.py プロジェクト: wjgaas/DeepDeform
def main():
    dataset_dir = cfg.DATA_ROOT_DIR
    alignments_path = os.path.join(
        dataset_dir, "{0}_selfsupervised.json".format(cfg.DATA_TYPE))
    flow_normalization = 100.0
    # in pixels

    with open(alignments_path, "r") as f:
        pairs = json.load(f)

    for pair in pairs:
        source_color = cv2.imread(
            os.path.join(dataset_dir, pair["source_color"]))
        target_color = cv2.imread(
            os.path.join(dataset_dir, pair["target_color"]))
        optical_flow = utils.load_flow(
            os.path.join(dataset_dir, pair["optical_flow"]))

        optical_flow = np.moveaxis(optical_flow, 0, -1)  # (h, w, 2)

        invalid_flow = optical_flow == -np.Inf
        optical_flow[invalid_flow] = 0.0

        flow_color = flow_vis.flow_to_color(optical_flow, convert_to_bgr=False)

        cv2.imshow("Source", source_color)
        cv2.imshow("Target", target_color)
        cv2.imshow("Flow", flow_color)

        cv2.waitKey(0)
コード例 #3
0
ファイル: lenet.py プロジェクト: zizai/deep-weight-prior
    def weights_init(self,
                     init_list,
                     vae_list,
                     flow_list=None,
                     pretrained=None,
                     filters_list=None,
                     logvar=-10.):
        self.apply(
            utils.weight_init(module=nn.Conv2d, initf=nn.init.xavier_normal_))
        self.apply(
            utils.weight_init(module=nn.Linear, initf=nn.init.xavier_normal_))
        self.apply(
            utils.weight_init(module=bayes.LogScaleConv2d,
                              initf=utils.const_init(logvar)))
        self.apply(
            utils.weight_init(module=bayes.LogScaleLinear,
                              initf=utils.const_init(logvar)))

        if len(init_list) > 0 and init_list[0] == 'pretrained':
            assert len(init_list) == 1
            w_pretrained = torch.load(pretrained)
            for k, v in w_pretrained.items():
                if k in self.state_dict():
                    self.state_dict()[k].data.copy_(v)
                else:
                    tokens = k.split('.')
                    self.state_dict()['.'.join(tokens[:2] + ['mean'] +
                                               tokens[-1:])].data.copy_(v)
            return

        convs = [self.features.conv1, self.features.conv2]
        for i, m in enumerate(convs):
            init = init_list[i] if i < len(init_list) else 'xavier'
            w = m.mean.weight if isinstance(m, bayes._Bayes) else m.weight
            if init == 'vae':
                vae_path = vae_list[i]
                vae = utils.load_vae(vae_path, device=self.device)
                z = torch.randn(
                    w.size(0) * w.size(1), vae.encoder.z_dim, 1,
                    1).to(vae.device)
                x = vae.decode(z)[0]
                w.data = x.reshape(w.shape)
            elif init == 'flow':
                flow_path = flow_list[i]
                flow = utils.load_flow(flow_path, device=self.device)
                utils.flow_init(flow)(w)
            elif init == 'xavier':
                pass
            elif init == 'filters':
                filters = np.load(filters_list[i])
                N = np.prod(w.shape[:2])
                filters = filters[np.random.permutation(len(filters))[:N]]
                w.data = torch.from_numpy(filters.reshape(*w.shape)).to(
                    self.device)
            else:
                raise NotImplementedError
コード例 #4
0
def evaluate(groundtruth_matches_json_path, predictions_dir):
    print("#"*66)
    print("FLOW EVALUATION")
    print("#"*66)

    # Load groundtruth matches.
    with open(groundtruth_matches_json_path, 'r') as f:
        matches = json.load(f)

    # Compute pixel distances for every annotated keypoint,
    # for every sequence.
    pixel_dist_sum_per_seq = {}
    valid_pixel_num_per_seq = {}
    accurate_pixel_num_per_seq = {}

    pixel_threshold = 20.0

    print()
    print("Groundtruth matches: {}".format(groundtruth_matches_json_path))
    print("Predictions: {}".format(predictions_dir))
    print()

    for frame_pair in tqdm(matches):
        seq_id = frame_pair["seq_id"]
        object_id = frame_pair["object_id"]
        source_id = frame_pair["source_id"]
        target_id = frame_pair["target_id"]

        flow_id = "{0}_{1}_{2}_{3}.oflow".format(seq_id, object_id, source_id, target_id)
        flow_pred_path = os.path.join(predictions_dir, flow_id)
        if not os.path.exists(flow_pred_path):
            print("Flow prediction missing: {}".format(flow_pred_path))
            return { "status": Status.FLOW_NOT_FOUND }

        flow_image_pred = utils.load_flow(flow_pred_path)
        if flow_image_pred.shape[1] != IMAGE_HEIGHT or flow_image_pred.shape[2] != IMAGE_WIDTH:
            print("Invalid flow dimesions:", flow_image_pred.shape)
            return { "status": Status.INVALID_FLOW_DIMENSIONS }

        flow_image_pred = np.moveaxis(flow_image_pred, 0, -1)

        for match in frame_pair["matches"]:
            # Read keypoint and match.
            source_kp = np.array([match["source_x"], match["source_y"]])
            target_kp = np.array([match["target_x"], match["target_y"]])

            source_kp_rounded = np.round(source_kp)
            target_kp_rounded = np.round(target_kp)

            # Make sure it's in bounds.
            assert in_bounds(source_kp_rounded, IMAGE_WIDTH, IMAGE_HEIGHT) and \
                    in_bounds(source_kp_rounded, IMAGE_WIDTH, IMAGE_HEIGHT)

            source_v, source_u = source_kp_rounded[1].astype(np.int64), source_kp_rounded[0].astype(np.int64)

            flow_pred = flow_image_pred[source_v, source_u]
            flow_gt = (target_kp - source_kp_rounded).astype(np.float32)

            diff = flow_pred - flow_gt
            pixel_dist = np.sum(diff * diff)
            if pixel_dist > 0: pixel_dist = np.sqrt(pixel_dist)
            pixel_dist = float(pixel_dist)

            pixel_dist_sum_per_seq[seq_id] = pixel_dist_sum_per_seq.get(seq_id, 0.0) + pixel_dist
            valid_pixel_num_per_seq[seq_id] = valid_pixel_num_per_seq.get(seq_id, 0) + 1
            
            if pixel_dist <= pixel_threshold:
                accurate_pixel_num_per_seq[seq_id] = accurate_pixel_num_per_seq.get(seq_id, 0) + 1

    # Compute total statistics.
    pixel_dist_sum = 0.0
    valid_pixel_num = 0
    accurate_pixel_num = 0

    print()
    print("{0:<20s} | {1:^20s} | {2:^20s}".format("Sequence ID", "Accuracy (<20px)", "EPE (pixel)"))
    print("-"*66)

    pixel_dist_per_seq = {}
    pixel_acc_per_seq = {}
    for seq_id in sorted(pixel_dist_sum_per_seq.keys()):
        pixel_dist_sum_seq = pixel_dist_sum_per_seq[seq_id]
        valid_pixel_num_seq = valid_pixel_num_per_seq[seq_id]
        accurate_pixel_num_seq = accurate_pixel_num_per_seq.get(seq_id, 0)

        pixel_dist_sum += pixel_dist_sum_seq
        valid_pixel_num += valid_pixel_num_seq
        accurate_pixel_num += accurate_pixel_num_seq

        pixel_dist_seq = pixel_dist_sum_seq / valid_pixel_num_seq if valid_pixel_num_seq > 0 else -1.0
        pixel_acc_seq = accurate_pixel_num_seq / valid_pixel_num_seq if valid_pixel_num_seq > 0 else -1.0
        
        pixel_dist_per_seq[seq_id] = pixel_dist_seq
        pixel_acc_per_seq[seq_id] = pixel_acc_seq

        print("{0:<20s} | {1:^20.4f} | {2:^20.3f}".format(seq_id, pixel_acc_seq, pixel_dist_seq))

    pixel_dist = pixel_dist_sum / valid_pixel_num if valid_pixel_num > 0 else -1.0
    pixel_acc = accurate_pixel_num / valid_pixel_num if valid_pixel_num > 0 else -1.0
    print("-"*66)
    print("{0:<20s} | {1:^20.4f} | {2:^20.3f}".format("Total", pixel_acc, pixel_dist))

    return {
        "status": Status.SUCCESS, 
        "pixel_accuracy": pixel_acc, 
        "pixel_distance": pixel_dist, 
        "pixel_accuracy_per_seq": pixel_acc_per_seq, 
        "pixel_distance_per_seq": pixel_dist_per_seq
    }