class Rain(scene.visuals.Visual, SlideVisualMixin):
    vertex = """
        #version 120
        
        uniform float u_linewidth;
        uniform float u_antialias;
        
        attribute vec2  a_position;
        attribute vec4  a_fg_color;
        attribute float a_size;
        
        varying vec4  v_fg_color;
        varying float v_size;
        
        void main (void)
        {
            v_size = a_size;
            v_fg_color = a_fg_color;
            if( a_fg_color.a > 0.0)
            {
                gl_Position = $transform(vec4(a_position, 0.0, 1.0));
                gl_PointSize = v_size + u_linewidth + 2*1.5*u_antialias;
            }
            else
            {
                gl_Position = $transform(vec4(-1.0, -1.0, 0.0, 1.0));
                gl_PointSize = 0.0;
            }
        }
        """
        
    fragment = """
        #version 120
        
        uniform float u_linewidth;
        uniform float u_antialias;
        varying vec4  v_fg_color;
        varying vec4  v_bg_color;
        varying float v_size;
        float disc(vec2 P, float size)
        {
            return length((P.xy - vec2(0.5,0.5))*size);
        }
        void main()
        {
            if( v_fg_color.a <= 0.0)
                discard;
            float actual_size = v_size + u_linewidth + 2*1.5*u_antialias;
            float t = u_linewidth/2.0 - u_antialias;
            float r = disc(gl_PointCoord, actual_size);
            float d = abs(r - v_size/2.0) - t;
            if( d < 0.0 )
            {
                gl_FragColor = v_fg_color;
            }
            else if( abs(d) > 2.5*u_antialias )
            {
                discard;
            }
            else
            {
                d /= u_antialias;
                gl_FragColor = vec4(v_fg_color.rgb, exp(-d*d)*v_fg_color.a);
            }
        }
        """
    
    def __init__(self, **kwargs):
        scene.visuals.Visual.__init__(self, **kwargs)
        
        self._n = 250
        self.data = np.zeros(self._n, [('a_position', np.float32, 2),
                                 ('a_fg_color', np.float32, 4),
                                 ('a_size',     np.float32, 1)])
        self.index = 0
        self.program = ModularProgram(self.vertex, self.fragment)
        self.vdata = gloo.VertexBuffer(self.data)
        self._timer = app.Timer(1. / 60., self.on_timer)
    
    def draw(self, event):
        xform = event.render_transform.shader_map()
        self.program.vert['transform'] = xform
        
        self.program.prepare()  
        self.program.bind(self.vdata)
        self.program['u_antialias'] = 1.00
        self.program['u_linewidth'] = 2.00
        
        self.program.draw('points')
    
    def on_timer(self, event):
        self.data['a_fg_color'][..., 3] -= 0.01
        self.data['a_size'] += 1.0
        self.vdata.set_data(self.data)

    def on_mouse_move(self, event):
        x, y = event.pos[:2]
        #h = gloo.get_parameter('viewport')[3]
        self.data['a_position'][self.index] = x, y
        self.data['a_size'][self.index] = 5
        self.data['a_fg_color'][self.index] = 0, 0, 0, 1
        self.index = (self.index + 1) % self._n
class Atom(scene.visuals.Visual, SlideVisualMixin):
        
    vert = """
        #version 120
        uniform float u_size;
        uniform float u_clock;
        
        attribute vec2 a_position;
        attribute vec4 a_color;
        attribute vec4 a_rotation;
        varying vec4 v_color;
        
        mat4 build_rotation(vec3 axis, float angle)
        {
            axis = normalize(axis);
            float s = sin(angle);
            float c = cos(angle);
            float oc = 1.0 - c;
            return mat4(oc * axis.x * axis.x + c,
                        oc * axis.x * axis.y - axis.z * s,
                        oc * axis.z * axis.x + axis.y * s,
                        0.0,
                        oc * axis.x * axis.y + axis.z * s,
                        oc * axis.y * axis.y + c,
                        oc * axis.y * axis.z - axis.x * s,
                        0.0,
                        oc * axis.z * axis.x - axis.y * s,
                        oc * axis.y * axis.z + axis.x * s,
                        oc * axis.z * axis.z + c,
                        0.0,
                        0.0, 0.0, 0.0, 1.0);
        }
        
        
        void main (void) {
            v_color = a_color;
        
            float x0 = 1.5;
            float z0 = 0.0;
        
            float theta = a_position.x + u_clock;
            float x1 = x0*cos(theta) + z0*sin(theta);
            float y1 = 0.0;
            float z1 = (z0*cos(theta) - x0*sin(theta))/2.0;
            
        
            mat4 R = build_rotation(a_rotation.xyz, a_rotation.w);
            vec4 pos = R * vec4(x1,y1,z1,1);
            pos.x = pos.x * 0.13 + 0.5;
            pos.y = pos.y * 0.13 + 0.22;
            gl_Position = $transform(pos);
            gl_PointSize = 12.0 * u_size * sqrt(v_color.a);
        }
        """
        
    frag = """
        #version 120
        varying vec4 v_color;
        varying float v_size;
        void main()
        {
            float d = 2*(length(gl_PointCoord.xy - vec2(0.5,0.5)));
            gl_FragColor = vec4(v_color.rgb, v_color.a*(1-d));
        }
        """

    def __init__(self, **kwargs):
        scene.visuals.Visual.__init__(self, **kwargs)
        
        # Create vertices
        n, p = 150, 32
        data = np.zeros(p * n, [('a_position', np.float32, 2),
                                ('a_color',    np.float32, 4),
                                ('a_rotation', np.float32, 4)])
        trail = .5 * np.pi
        data['a_position'][:, 0] = np.resize(np.linspace(0, trail, n), p * n)
        data['a_position'][:, 0] += np.repeat(np.random.uniform(0, 2 * np.pi, p), n)
        data['a_position'][:, 1] = np.repeat(np.linspace(0, 2 * np.pi, p), n)
        
        data['a_color'] = 1, 1, 1, 1
        data['a_color'] = np.repeat(
            np.random.uniform(0.5, 1.00, (p, 4)).astype(np.float32), n, axis=0)
        data['a_color'][:, 3] = np.resize(np.linspace(0, 1, n), p * n)
        
        data['a_rotation'] = np.repeat(
            np.random.uniform(0, 2 * np.pi, (p, 4)).astype(np.float32), n, axis=0)
            
       
        self.program = ModularProgram(self.vert, self.frag)
        self._vbo = gloo.VertexBuffer(data)
        
        self.theta = 0
        self.phi = 0
        self.clock = 0
        self.stop_rotation = False
        
        self.transform = vispy.scene.transforms.AffineTransform()

        self._timer = app.Timer(1.0 / 30, self.on_timer)

    def on_timer(self, event):
        self.clock += np.pi / 100

    def draw(self, event):
        # Set transform
        xform = event.render_transform.shader_map()
        self.program.vert['transform'] = xform
        self.program.prepare()  
        
        # Bind variables
        self.program.bind(self._vbo)
        self.program['u_size'] = 5 / 6
        self.program['u_clock'] = self.clock

        self.program.draw('points')