Exemplo n.º 1
0
    def make_graph(self, data, include_losses=True):
        pred_config = nd.PredConfig()
        pred_config.add(
            nd.PredConfigId(
                type='disp',
                perspective='L',
                channels=1,
                scale=self._scale,
            ))

        pred_dispL_t_1 = data.disp.L
        pred_flow_fwd = data.flow[0].fwd
        pred_occ_fwd = data.occ[0].fwd

        pred_dispL_t1_warped = nd.ops.warp(pred_dispL_t_1, pred_flow_fwd)

        pred_config[0].mod_func = lambda x: self.interpolator(
            pred=x, prev_disp=pred_dispL_t1_warped)
        inp = nd.ops.concat(data.img.L, nd.ops.scale(pred_dispL_t1_warped,
                                                     0.05), pred_occ_fwd)

        with nd.Scope('refine_disp', learn=True, **self.scope_args()):
            arch = Architecture_S(
                num_outputs=pred_config.total_channels(),
                disassembling_function=pred_config.disassemble,
                loss_function=None,
                conv_upsample=self._conv_upsample,
                exit_after=0)
            out = arch.make_graph(inp, edge_features=data.img.L)
        return out
Exemplo n.º 2
0
    def make_graph(self, data, include_losses=False):
        
        # hypNet
        pred_config = nd.PredConfig()
        pred_config.add(nd.PredConfigId(type='flow_hyp', dir='fwd', offset=0, channels=2, scale=self._scale, array_length=self._num_hypotheses))
        pred_config.add(nd.PredConfigId(type='iul_b_hyp_log', dir='fwd', offset=0, channels=2, scale=self._scale, array_length=self._num_hypotheses, mod_func=self._log_sigmoid))
            
        nd.log('pred_config:')
        nd.log(pred_config)

        with nd.Scope('hypNet', shared_batchnorm=False, correlation_leaky_relu=True, **self.scope_args()):
            arch = Architecture_C(
                num_outputs=pred_config.total_channels(),
                disassembling_function=pred_config.disassemble,
                conv_upsample=True,
                loss_function= None,
                channel_factor=self._channel_factor,
                feature_channels=self._feature_channels
            )

            out_hyp = arch.make_graph(data.img[0], data.img[1])
        
        # mergeNet
        pred_config = nd.PredConfig()
        pred_config.add(nd.PredConfigId(type='flow', dir='fwd', offset=0, channels=2, scale=self._scale, dist=1))
        pred_config.add(nd.PredConfigId(type='iul_b_log', dir='fwd', offset=0, channels=2, scale=self._scale, dist=1, mod_func=self.iul_b_log_sigmoid))
        nd.log('pred_config:')
        nd.log(pred_config)
        hyps = [nd.ops.resample(hyp, reference=data.img[0], antialias=False, type='LINEAR') for hyp in [out_hyp.final.flow_hyp[0].fwd[i] for i in range(self._num_hypotheses)]]
        uncertainties = [nd.ops.resample(unc, reference=data.img[0], antialias=False, type='LINEAR') for unc in [out_hyp.final.iul_b_hyp_log[0].fwd[i] for i in range(self._num_hypotheses)]]
        img_warped = [nd.ops.warp(data.img[1], hyp) for hyp in hyps]
        with nd.Scope('mergeNet', shared_batchnorm=False, **self.scope_args()):            
            input = nd.ops.concat([data.img[0]] + [data.img[1]] + hyps + uncertainties + img_warped)
            arch = Architecture_S(
                num_outputs=pred_config.total_channels(),
                disassembling_function=pred_config.disassemble,
                conv_upsample=True,
                loss_function= None,
                channel_factor=self._channel_factor
                )
            out_merge = arch.make_graph(input)
        
        return out_merge
Exemplo n.º 3
0
    def make_graph(self, data, include_losses=True):

        pred_config = nd.PredConfig()

        pred_config.add(
            nd.PredConfigId(type='flow',
                            dir='fwd',
                            offset=0,
                            channels=2,
                            scale=self._scale))
        pred_config.add(
            nd.PredConfigId(type='occ',
                            dir='fwd',
                            offset=0,
                            channels=2,
                            scale=self._scale))

        nd.log('pred_config:')
        nd.log(pred_config)

        #### Net 1 ####
        with nd.Scope('net1', learn=False, **self.scope_args()):
            arch1 = Architecture_C(
                num_outputs=pred_config.total_channels(),
                disassembling_function=pred_config.disassemble,
                loss_function=None,
                conv_upsample=self._conv_upsample)

            out1 = arch1.make_graph(data.img[0], data.img[1])

        #### Net 2 ####
        flow_fwd = out1.final.flow[0].fwd
        upsampled_flow_fwd = nd.ops.differentiable_resample(
            flow_fwd, reference=data.img[0])
        warped = nd.ops.warp(data.img[1], upsampled_flow_fwd)

        # prepare data for second net
        occ_fwd = self.resample_occ(out1.final.occ[0].fwd, data.img[0])

        input2 = nd.ops.concat(data.img[0], data.img[1],
                               nd.ops.scale(upsampled_flow_fwd, 0.05), warped,
                               occ_fwd)

        pred_config[0].mod_func = lambda x: nd.ops.add(
            x,
            nd.ops.resample(
                flow_fwd, reference=x, type='LINEAR', antialias=False))
        pred_config[1].mod_func = lambda x: nd.ops.add(
            x,
            nd.ops.resample(
                occ_fwd, reference=x, type='LINEAR', antialias=False))

        with nd.Scope('net2', learn=False, **self.scope_args()):

            arch2 = Architecture_S(
                num_outputs=pred_config.total_channels(),
                disassembling_function=pred_config.disassemble,
                loss_function=None,
                conv_upsample=self._conv_upsample)
            out2 = arch2.make_graph(input2)

        #### Net 3 ####

        flow_fwd = out2.final.flow[0].fwd
        upsampled_flow_fwd = nd.ops.differentiable_resample(
            flow_fwd, reference=data.img[0])
        warped = nd.ops.warp(data.img[1], upsampled_flow_fwd)

        # prepare data for third net
        occ_fwd = self.resample_occ(out2.final.occ[0].fwd, data.img[0])

        input3 = nd.ops.concat(data.img[0], data.img[1],
                               nd.ops.scale(upsampled_flow_fwd, 0.05), warped,
                               occ_fwd)

        pred_config.add(
            nd.PredConfigId(type='mb',
                            dir='fwd',
                            offset=0,
                            channels=2,
                            scale=self._scale))

        pred_config[0].mod_func = lambda x: nd.ops.add(
            x,
            nd.ops.resample(
                flow_fwd, reference=x, type='LINEAR', antialias=False))
        pred_config[1].mod_func = lambda x: nd.ops.add(
            x,
            nd.ops.resample(
                occ_fwd, reference=x, type='LINEAR', antialias=False))

        with nd.Scope('net3', learn=True, **self.scope_args()):

            arch3 = Architecture_S(
                num_outputs=pred_config.total_channels(),
                disassembling_function=pred_config.disassemble,
                loss_function=None,
                conv_upsample=self._conv_upsample,
                exit_after=0,
            )
            out3 = arch3.make_graph(input3, edge_features=data.img[0])

        return out3
Exemplo n.º 4
0
    def make_graph(self, data, include_losses=True):
        pred_config = nd.PredConfig()
        pred_config.add(
            nd.PredConfigId(type='disp',
                            perspective='L',
                            channels=1,
                            scale=self._scale,
                            mod_func=lambda b: nd.ops.neg_relu(b)))

        pred_config.add(
            nd.PredConfigId(
                type='occ',
                perspective='L',
                channels=2,
                scale=self._scale,
            ))

        with nd.Scope('net1', learn=False, **self.scope_args()):
            arch1 = Architecture_C(
                num_outputs=pred_config.total_channels(),
                disassembling_function=pred_config.disassemble,
                loss_function=None,
                conv_upsample=self._conv_upsample,
                channel_factor=0.375)
            out1 = arch1.make_graph(data.img.L,
                                    data.img.R,
                                    use_1D_corr=True,
                                    single_direction=0)
        disp1 = out1.final.disp.L
        occ1 = self.resample_occ(out1.final.occ.L, data.img.L)

        upsampled_disp1 = nd.ops.differentiable_resample(disp1,
                                                         reference=data.img.L)
        pred_config[0].mod_func = lambda x: nd.ops.add(
            x,
            nd.ops.resample(disp1, reference=x, type='LINEAR', antialias=False)
        )
        pred_config[1].mod_func = lambda x: nd.ops.add(
            x,
            nd.ops.resample(occ1, reference=x, type='LINEAR', antialias=False))

        warped = nd.ops.warp(data.img.R, nd.ops.disp_to_flow(upsampled_disp1))

        input2 = nd.ops.concat(data.img.L, data.img.R,
                               nd.ops.scale(upsampled_disp1, 0.05), warped,
                               occ1)

        with nd.Scope('net2', learn=False, **self.scope_args()):
            arch2 = Architecture_S(
                num_outputs=pred_config.total_channels(),
                disassembling_function=pred_config.disassemble,
                loss_function=None,
                conv_upsample=self._conv_upsample,
                channel_factor=0.375)
            out2 = arch2.make_graph(input2)
        ## Net 3

        disp2 = out2.final.disp.L
        occ2 = self.resample_occ(out2.final.occ.L, data.img.L)

        upsampled_disp2 = nd.ops.differentiable_resample(disp2,
                                                         reference=data.img.L)
        pred_config.add(
            nd.PredConfigId(type='db',
                            perspective='L',
                            channels=2,
                            scale=self._scale))

        pred_config[0].mod_func = lambda x: nd.ops.add(
            x,
            nd.ops.resample(disp2, reference=x, type='LINEAR', antialias=False)
        )
        pred_config[1].mod_func = lambda x: nd.ops.add(
            x,
            nd.ops.resample(occ2, reference=x, type='LINEAR', antialias=False))

        warped = nd.ops.warp(data.img.R, nd.ops.disp_to_flow(upsampled_disp2))

        input3 = nd.ops.concat(data.img.L, data.img.R,
                               nd.ops.scale(upsampled_disp2, 0.05), warped,
                               occ2)

        with nd.Scope('net3', learn=True, **self.scope_args()):
            arch3 = Architecture_S(
                num_outputs=pred_config.total_channels(),
                disassembling_function=pred_config.disassemble,
                loss_function=None,
                conv_upsample=self._conv_upsample,
                exit_after=0,
                channel_factor=0.375)
            out3 = arch3.make_graph(input3, edge_features=data.img.L)
        return out3