def log_wishart_prior(p, wishart_gamma, wishart_m, sum_qs, Qdiags, icf): n = p + wishart_m + 1 k = shape(icf)[0] out = tf.reduce_sum( 0.5 * wishart_gamma * wishart_gamma * (tf.reduce_sum(Qdiags**2, 1) + tf.reduce_sum(icf[:, p:]**2, 1)) - wishart_m * sum_qs) C = n * p * (math.log(wishart_gamma / math.sqrt(2))) return out - k * (C - log_gamma_distrib(0.5 * n, p))
def gmm_objective(alphas, means, icf, x, wishart_gamma, wishart_m): xshape = shape(x) n = xshape[0] d = xshape[1] Qdiags = tf.exp(icf[:, :d]) sum_qs = tf.reduce_sum(icf[:, :d], 1) icf_sz = shape(icf)[0] Ls = tf.stack(tuple(constructL(d, icf[i]) for i in range(icf_sz))) xcentered = tf.stack(tuple(x[i] - means for i in range(n))) Lxcentered = Qtimesx(Qdiags, Ls, xcentered) sqsum_Lxcentered = tf.reduce_sum(Lxcentered**2, 2) inner_term = alphas + sum_qs - 0.5 * sqsum_Lxcentered lse = logsumexpvec(inner_term) slse = tf.reduce_sum(lse) const = tf.constant(-n * d * 0.5 * math.log(2 * math.pi), dtype=tf.float64) return const + slse - n * logsumexp(alphas) + \ log_wishart_prior(d, wishart_gamma, wishart_m, sum_qs, Qdiags, icf)
def lstm_objective(main_params, extra_params, state, sequence): '''Gets the average loss for the LSTM across a sequence of inputs.''' total = 0.0 count = 0 inp = sequence[0] all_states = [state] for t in range(shape(sequence)[0] - 1): ypred, new_state = predict(main_params, extra_params, all_states[t], inp) all_states.append(new_state) ynorm = ypred - tf.math.log(tf.reduce_sum(tf.exp(ypred)) + 2) ygold = sequence[t + 1] total += tf.reduce_sum(ygold * ynorm) count += shape(ygold)[0] inp = ygold return -total / count
def predict(w, w2, state, x): '''Predicts output given an input.''' new_state = [] x = x * w2[0] for i in range(0, shape(state)[0], 2): hidden, cell = lstm_model(w[i], w[i + 1], state[i], state[i + 1], x) x = hidden new_state.append(hidden) new_state.append(cell) new_state = tf.stack(new_state, 0) return (x * w2[1] + w2[2], new_state)
def lstm_model(weight, bias, hidden, cell, inp): '''The LSTM model.''' gates = tf.concat(( inp, hidden, inp, hidden ), 0) * weight + bias hidden_size = shape(hidden)[0] forget = tf.math.sigmoid(gates[0: hidden_size]) ingate = tf.math.sigmoid(gates[hidden_size: 2 * hidden_size]) outgate = tf.math.sigmoid(gates[2 * hidden_size: 3 * hidden_size]) change = tf.math.tanh(gates[3 * hidden_size:]) cell = cell * forget + ingate * change hidden = outgate * tf.math.tanh(cell) return (hidden, cell)
def body(t, state, inp, total, count): ypred, state = predict( main_params, extra_params, state, inp ) ynorm = ypred - tf.math.log(tf.reduce_sum(tf.exp(ypred)) + 2) ygold = sequence[t + 1] total += tf.reduce_sum(ygold * ynorm) count += shape(ygold)[0] return (t + 1, state, ygold, total, count)
def get_skinned_vertex_positions( pose_params, base_relatives, parents, inverse_base_absolutes, base_positions, weights, mirror_factor ): relatives = get_posed_relatives(pose_params, base_relatives) absolutes = relatives_to_absolutes(relatives, parents) transforms = absolutes @ inverse_base_absolutes positions = base_positions @ tf.transpose(transforms, perm = [ 0, 2, 1 ]) positions = tf.reduce_sum(positions * tf.reshape(weights, (shape(weights) + [1])), 0) positions = apply_global_transform(pose_params, positions[:, :3]) return positions