def gradient_check_ref(seq_len, input_size, hidden_size, batch_size, epsilon=1.0e-5, dtypeu=np.float64, threshold=1e-4): # this is a check of the reference code itself # estimates the gradients by adding perturbations # to the input and the weights and compares to # the values calculated in bprop # generate sparse random input matrix NervanaObject.be.bsz = NervanaObject.be.batch_size = batch_size input_shape = (seq_len, input_size, batch_size) # hidden_shape = (seq_len, hidden_size, batch_size) (inp_bl, nz_inds) = sparse_rand(input_shape, frac=1.0 / input_shape[1]) inp_bl = np.random.randn(*input_shape) # convert input matrix from neon to ref code format inp_bl = inp_bl.swapaxes(1, 2).astype(dtypeu) # generate reference LSTM lstm_ref = RefLSTM() WLSTM = lstm_ref.init(input_size, hidden_size).astype(dtypeu) # init parameters as done for neon WLSTM = np.random.randn(*WLSTM.shape) (Hout, cprev, hprev, cache) = lstm_ref.forward(inp_bl, WLSTM) # scale Hout by random matrix... rand_scale = np.random.random(Hout.shape) * 2.0 - 1.0 rand_scale = dtypeu(rand_scale) # line below would be the loss function # loss_bl = np.sum(rand_scale * Hout) # run bprop, input deltas is rand_scale (dX_bl, dWLSTM_bl, dc0, dh0) = lstm_ref.backward(rand_scale, cache) grads_est = np.zeros(dX_bl.shape) inp_pert = inp_bl.copy() for pert_ind in range(inp_bl.size): save_val = inp_pert.flat[pert_ind] # add/subtract perturbations to input inp_pert.flat[pert_ind] = save_val + epsilon # and run fprop on perturbed input (Hout_pos, cprev, hprev, cache) = lstm_ref.forward(inp_pert, WLSTM) inp_pert.flat[pert_ind] = save_val - epsilon (Hout_neg, cprev, hprev, cache) = lstm_ref.forward(inp_pert, WLSTM) # calculate the loss on outputs loss_pos = np.sum(rand_scale * Hout_pos) loss_neg = np.sum(rand_scale * Hout_neg) grads_est.flat[pert_ind] = 0.5 * (loss_pos - loss_neg) / epsilon # reset input inp_pert.flat[pert_ind] = save_val # assert that gradient estimates within rel threshold of # bprop calculated deltas assert allclose_with_out(grads_est, dX_bl, rtol=threshold, atol=0.0) return
def gradient_check_ref(seq_len, input_size, hidden_size, batch_size, epsilon=1.0e-5, dtypeu=np.float64, threshold=1e-4): # this is a check of the reference code itself # estimates the gradients by adding perturbations # to the input and the weights and compares to # the values calculated in bprop # generate sparse random input matrix NervanaObject.be.bsz = NervanaObject.be.batch_size = batch_size input_shape = (seq_len, input_size, batch_size) # hidden_shape = (seq_len, hidden_size, batch_size) (inp_bl, nz_inds) = sparse_rand(input_shape, frac=1.0/input_shape[1]) inp_bl = np.random.randn(*input_shape) # convert input matrix from neon to ref code format inp_bl = inp_bl.swapaxes(1, 2).astype(dtypeu) # generate reference LSTM lstm_ref = RefLSTM() WLSTM = lstm_ref.init(input_size, hidden_size).astype(dtypeu) # init parameters as done for neon WLSTM = np.random.randn(*WLSTM.shape) (Hout, cprev, hprev, cache) = lstm_ref.forward(inp_bl, WLSTM) # scale Hout by random matrix... rand_scale = np.random.random(Hout.shape)*2.0 - 1.0 rand_scale = dtypeu(rand_scale) # line below would be the loss function # loss_bl = np.sum(rand_scale * Hout) # run bprop, input deltas is rand_scale (dX_bl, dWLSTM_bl, dc0, dh0) = lstm_ref.backward(rand_scale, cache) grads_est = np.zeros(dX_bl.shape) inp_pert = inp_bl.copy() for pert_ind in range(inp_bl.size): save_val = inp_pert.flat[pert_ind] # add/subtract perturbations to input inp_pert.flat[pert_ind] = save_val + epsilon # and run fprop on perturbed input (Hout_pos, cprev, hprev, cache) = lstm_ref.forward(inp_pert, WLSTM) inp_pert.flat[pert_ind] = save_val - epsilon (Hout_neg, cprev, hprev, cache) = lstm_ref.forward(inp_pert, WLSTM) # calculate the loss on outputs loss_pos = np.sum(rand_scale*Hout_pos) loss_neg = np.sum(rand_scale*Hout_neg) grads_est.flat[pert_ind] = 0.5*(loss_pos-loss_neg)/epsilon # reset input inp_pert.flat[pert_ind] = save_val # assert that gradient estimates within rel threshold of # bprop calculated deltas assert allclose_with_out(grads_est, dX_bl, rtol=threshold, atol=0.0) return