def prepare_testing_evolve(self, output, h, w):
     ex = output['ex']
     ex[..., 0] = torch.clamp(ex[..., 0], min=0, max=w - 1)
     ex[..., 1] = torch.clamp(ex[..., 1], min=0, max=h - 1)
     evolve = snake_gcn_utils.prepare_testing_evolve(ex)
     output.update({'it_py': evolve['i_it_py']})
     return evolve
Beispiel #2
0
    def prepare_testing_evolve(self, output, h, w):
        if cfg.evolve_init == 'poly':
            ex = output['init_poly_pred']
        else:
            ex = output['ex']
        ex[..., 0] = torch.clamp(ex[..., 0], min=0, max=w - 1)
        ex[..., 1] = torch.clamp(ex[..., 1], min=0, max=h - 1)

        evolve = snake_gcn_utils.prepare_testing_evolve(ex)
        if 0:
            import matplotlib.pyplot as plt
            import numpy as np
            I = np.zeros((128, 128))
            plt.imshow(I)
            x = ex.cpu().numpy()
            y = evolve['i_it_py'].cpu().numpy()
            for k in range(len(x)):
                xk = x[k]
                yk = y[k]
                yk = np.concatenate((yk, yk[0].reshape(-1, 2)), axis=0)
                plt.plot(xk[:, 0], xk[:, 1], 'go-')
                plt.text(xk[0, 0], xk[0, 1], '1', color='g')
                plt.text(xk[1, 0], xk[1, 1], '2', color='g')
                plt.text(xk[2, 0], xk[2, 1], '3', color='g')
                plt.text(xk[3, 0], xk[3, 1], '4', color='g')

                plt.plot(yk[:, 0], yk[:, 1], 'r')
                plt.plot(yk[0, 0], yk[0, 1], 'yo')
                plt.plot(yk[5, 0], yk[5, 1], 'y+')
            plt.show()

        output.update({'it_py': evolve['i_it_py']})
        if 0:
            import numpy as np
            import matplotlib.pyplot as plt
            print('a:', evolve['i_it_py'].shape)
            x = evolve['i_it_py'].cpu().numpy()
            I = np.zeros((128, 128))
            plt.imshow(I)
            for k in range(len(x)):
                np_ply = x[k]
                plt.plot(np_ply[:, 0], np_ply[:, 1], 'r')
                plt.plot(np_ply[0, 0], np_ply[0, 1], 'go')
                plt.plot(np_ply[10, 0], np_ply[10, 1], 'yo')
            plt.show()
        return evolve