def compute_shader_multiply(index=("input", "GlobalInvocationId", ps.ivec3), data1=("buffer", 0, ps.Array(ps.f32)), data2=("buffer", 1, ps.Array(ps.f32)), data3=("buffer", 2, ps.Array(ps.f32))): i = index.x data3[i] = data1[i] * data2[i]
def compute_shader_wg(gl_idx=("input", "GlobalInvocationId", ps.ivec3), gl_wg_id=("input", "WorkgroupId", ps.ivec3), gl_wg_num=("input", "NumWorkgroups", ps.ivec3), data1=("buffer", 0, ps.Array(ps.f32)), data2=("buffer", 1, ps.Array(ps.f32))): i = gl_wg_id.x * gl_wg_num.y + gl_wg_id.y data1[i] = f32(gl_idx.x) data2[i] = f32(gl_idx.y)
def compute_shader( index = ("input", "GlobalInvocationId", ps.ivec3), x_i = ("buffer", 0, ps.Array(ps.f32)), x_j = ("buffer", 1, ps.Array(ps.f32)), y = ("buffer", 2, ps.Array(ps.f32)), w_in = ("buffer", 3, ps.Array(ps.f32)), w_out_i = ("buffer", 4, ps.Array(ps.f32)), w_out_j = ("buffer", 5, ps.Array(ps.f32)), b_in = ("buffer", 6, ps.Array(ps.f32)), b_out = ("buffer", 7, ps.Array(ps.f32)), l_out = ("buffer", 8, ps.Array(ps.f32)), M = ("buffer", 9, ps.Array(ps.f32))): i = index.x m = M[0] w_curr = vec2(w_in[0], w_in[1]) b_curr = b_in[0] x_curr = vec2(x_i[i], x_j[i]) y_curr = y[i] z_dot = w_curr @ x_curr z = z_dot + b_curr y_hat = 1.0 / (1.0 + exp(-z)) d_z = y_hat - y_curr d_w = (1.0 / m) * x_curr * d_z d_b = (1.0 / m) * d_z loss = -((y_curr * log(y_hat)) + ((1.0 + y_curr) * log(1.0 - y_hat))) w_out_i[i] = d_w.x w_out_j[i] = d_w.y b_out[i] = d_b l_out[i] = loss