def gradient_calc(seq_len, input_size, hidden_size, batch_size, epsilon=None, rand_scale=None, inp_bl=None): NervanaObject.be.bsz = NervanaObject.be.batch_size = batch_size input_shape = (input_size, seq_len * batch_size) # generate input if one is not given if inp_bl is None: inp_bl = np.random.randn(*input_shape) # neon lstm instance lstm = LSTM(hidden_size, Gaussian(), Tanh(), Logistic()) inpa = lstm.be.array(np.copy(inp_bl)) # run fprop on the baseline input out_bl = lstm.fprop(inpa).get() # random scaling/hash to generate fake loss if rand_scale is None: rand_scale = np.random.random(out_bl.shape) * 2.0 - 1.0 # loss function would be: # loss_bl = np.sum(rand_scale * out_bl) # run back prop with rand_scale as the errors # use copy to avoid any interactions deltas_neon = lstm.bprop(lstm.be.array(np.copy(rand_scale))).get() # add a perturbation to each input element grads_est = np.zeros(inpa.shape) inp_pert = inp_bl.copy() for pert_ind in range(inpa.size): save_val = inp_pert.flat[pert_ind] inp_pert.flat[pert_ind] = save_val + epsilon reset_lstm(lstm) out_pos = lstm.fprop(lstm.be.array(inp_pert)).get() inp_pert.flat[pert_ind] = save_val - epsilon reset_lstm(lstm) out_neg = lstm.fprop(lstm.be.array(inp_pert)).get() # calculate the loss with perturbations loss_pos = np.sum(rand_scale*out_pos) loss_neg = np.sum(rand_scale*out_neg) # compute the gradient estimate grad = 0.5*(loss_pos-loss_neg)/epsilon grads_est.flat[pert_ind] = grad # reset the perturbed input element inp_pert.flat[pert_ind] = save_val del lstm return (grads_est, deltas_neon)
def gradient_calc(seq_len, input_size, hidden_size, batch_size, epsilon=None, rand_scale=None, inp_bl=None): NervanaObject.be.bsz = NervanaObject.be.batch_size = batch_size input_shape = (input_size, seq_len * batch_size) # generate input if one is not given if inp_bl is None: inp_bl = np.random.randn(*input_shape) # neon lstm instance lstm = LSTM(hidden_size, Gaussian(), Tanh(), Logistic()) inpa = lstm.be.array(np.copy(inp_bl)) # run fprop on the baseline input out_bl = lstm.fprop(inpa).get() # random scaling/hash to generate fake loss if rand_scale is None: rand_scale = np.random.random(out_bl.shape) * 2.0 - 1.0 # loss function would be: # loss_bl = np.sum(rand_scale * out_bl) # run back prop with rand_scale as the errors # use copy to avoid any interactions deltas_neon = lstm.bprop(lstm.be.array(np.copy(rand_scale))).get() # add a perturbation to each input element grads_est = np.zeros(inpa.shape) inp_pert = inp_bl.copy() for pert_ind in range(inpa.size): save_val = inp_pert.flat[pert_ind] inp_pert.flat[pert_ind] = save_val + epsilon reset_lstm(lstm) out_pos = lstm.fprop(lstm.be.array(inp_pert)).get() inp_pert.flat[pert_ind] = save_val - epsilon reset_lstm(lstm) out_neg = lstm.fprop(lstm.be.array(inp_pert)).get() # calculate the loss with perturbations loss_pos = np.sum(rand_scale * out_pos) loss_neg = np.sum(rand_scale * out_neg) # compute the gradient estimate grad = 0.5 * (loss_pos - loss_neg) / epsilon grads_est.flat[pert_ind] = grad # reset the perturbed input element inp_pert.flat[pert_ind] = save_val del lstm return (grads_est, deltas_neon)
def test_biLSTM_fprop_rnn(backend_default, fargs): # basic sanity check with 0 weights random inputs seq_len, input_size, hidden_size, batch_size = fargs in_shape = (input_size, seq_len) out_shape = (hidden_size, seq_len) NervanaObject.be.bsz = batch_size # setup the bi-directional rnn init_glorot = GlorotUniform() bilstm = BiLSTM(hidden_size, gate_activation=Logistic(), activation=Tanh(), init=init_glorot, reset_cells=True) bilstm.configure(in_shape) bilstm.prev_layer = True bilstm.allocate() # setup the bi-directional rnn init_glorot = GlorotUniform() rnn = LSTM(hidden_size, gate_activation=Logistic(), activation=Tanh(), init=init_glorot, reset_cells=True) rnn.configure(in_shape) rnn.prev_layer = True rnn.allocate() # same weight for bi-rnn backward and rnn weights nout = hidden_size bilstm.W_input_b[:] = bilstm.W_input_f bilstm.W_recur_b[:] = bilstm.W_recur_f bilstm.b_b[:] = bilstm.b_f bilstm.dW[:] = 0 rnn.W_input[:] = bilstm.W_input_f rnn.W_recur[:] = bilstm.W_recur_f rnn.b[:] = bilstm.b_f rnn.dW[:] = 0 # inputs - random and flipped left-to-right inputs lr = np.random.random((input_size, seq_len * batch_size)) lr_rev = list(reversed(get_steps(lr.copy(), in_shape))) rl = con(lr_rev, axis=1) inp_lr = bilstm.be.array(lr) inp_rl = bilstm.be.array(rl) inp_rnn = rnn.be.array(lr) # outputs out_lr = bilstm.fprop(inp_lr).get().copy() bilstm.h_buffer[:] = 0 out_rl = bilstm.fprop(inp_rl).get() out_rnn = rnn.fprop(inp_rnn).get().copy() # views out_lr_f_s = get_steps(out_lr[:nout], out_shape) out_lr_b_s = get_steps(out_lr[nout:], out_shape) out_rl_f_s = get_steps(out_rl[:nout], out_shape) out_rl_b_s = get_steps(out_rl[nout:], out_shape) out_rnn_s = get_steps(out_rnn, out_shape) # asserts for fprop for x_rnn, x_f, x_b, y_f, y_b in zip(out_rnn_s, out_lr_f_s, out_lr_b_s, reversed(out_rl_f_s), reversed(out_rl_b_s)): assert allclose_with_out(x_f, y_b, rtol=0.0, atol=1.0e-5) assert allclose_with_out(x_b, y_f, rtol=0.0, atol=1.0e-5) assert allclose_with_out(x_rnn, x_f, rtol=0.0, atol=1.0e-5) assert allclose_with_out(x_rnn, y_b, rtol=0.0, atol=1.0e-5)
def test_beamsearch(backend_default): """ Simlulated beam search on a minibatch of 2, for 4 time steps. The LSTM states are real but the "softmax outputs" z are hardcoded and not taken from the network. There are 6 tokens the network outputs, and they have probabilities like exp(1), exp(5), exp(7) The test asserts that the score_lists assigned by _beamsearch_step(z_list) are equal to the probabilities computed manually adding probabilities to z_list. """ be = backend_default batch_size = 2 be.bsz = batch_size time_steps = 4 nout = 6 num_beams = 3 # create unused layers activation = Tanh() gate_activation = Logistic() init_ary = np.eye(nout) init = Array(init_ary) encoder = LSTM(nout, init, activation=activation, gate_activation=gate_activation, name="Enc") decoder = LSTM(nout, init, activation=activation, gate_activation=gate_activation, name="Dec") class DummyFProp(): """ Constructs an artificial beam search example with known correct outputs. This is called inside a nested loop over steps, num_life. In the first time step there is one life beam, after that, 3 life beams per step. There are 4 time steps total. Each beamsearch_step builds one list over num_life beams. At t=0, the winners for ex0 are 1, 4, 5 (indexed by their position) and winners for ex1 are 2,4,5. From there we continue the beam for ex0: 12, 13, 14 6+2=8 6+3=9 6+2=8 40, 43, 45 with scores 5+4=9 5+3=8 5+7=12 three new winners 45, 52, 55 50, 52, 55 5+4=9 5+6=11 5+5=10 for ex2 1 4 5 with scores 5 4 7 we get the three winners 1, 4, 5 and continue (just taking the 3 in order, no sorting) 10 12 13 14 (not unique!) 5+2=7 5+2=7 5+3=8 41 42 43 with scores 4+6=10 4+5=9 4+7=11 winners 43 51 52 51 52 53 7+4=11 7+6=13 7+3=10 scores 11 11 13 continue from the three winners 43 51 52 431 433 434 11+10=21 11+3=14 11+9=20 511 512 513 with scores 11+6=17 11+5=16 11+7=18 winners 431 434 520 520 521 522 13+8=21 13+4=17 13+6=19 scores 21 20 21 continue from three winners 431 511 513 (going along beams, the matches in a beam) 4310 4312 4313 4314 21+2=23 21+2=23 21+3=24 21+10=31 (not unique!) 4341 4342 4343 with scores 20+10=30 20+5=25 20+7=27 winners 4314 4341 5204 5200 5202 5204 21+8=29 21+6=27 21+10=31 scores 31 30 31 overall winners are 4314 4341 5204 """ def __init__(self): self.i = -1 # t=0 # X x x <-- winners: 1, 4, 5 (for example 0) z = be.array(np.exp(np.array([[1, 6, 2, 1, 5, 5], [1, 5, 2, 2, 4, 7]]))).T # t=1 # x x x <-- give we picked 4: new winners 2,3,4 z1 = be.array(np.exp(np.array([[1, 1, 2, 3, 2, 1], [2, 1, 2, 3, 2, 1]]))).T # x x x <-- give we picked 5: # new winners 0,3,[5] # score 12 z2 = be.array(np.exp(np.array([[4, 1, 2, 3, 1, 7], [2, 6, 5, 7, 2, 4]]))).T # x X X <-- give we picked 1: # new winners 0,[2],[5] # scores 12, 11 z3 = be.array(np.exp(np.array([[4, 1, 6, 3, 1, 5], [1, 4, 6, 3, 2, 1]]))).T # t=2 # example 0: given constructed (1, 5), score 11: 3, 4; scores 21, 20 z4 = be.array(np.exp(np.array([[1, 1, 2, 10, 9, 1], [2, 10, 2, 3, 9, 1]]))).T # example 0: given constructed (5, 5), score 12: none selected from this beam z5 = be.array(np.exp(np.array([[4, 1, 2, 3, 1, 7], [2, 6, 5, 7, 2, 4]]))).T # example 0: given constructed (1, 2), score 12: 1; score 20 z6 = be.array(np.exp(np.array([[4, 8, 6, 3, 1, 5], [8, 4, 6, 3, 1, 1]]))).T # t=3 # example 0: given constructed (1, 5, 4), score 20: 1, score 30 z7 = be.array(np.exp(np.array([[1, 10, 2, 1, 1, 1], [2, 1, 2, 3, 10, 1]]))).T # example 0: given constructed (1, 2, 1), score 20: 5, score 30 z8 = be.array(np.exp(np.array([[4, 1, 2, 3, 1, 10], [2, 10, 5, 7, 2, 4]]))).T # example 0: given constructed (1, 5, 3), score 21: 4, score 31 z9 = be.array(np.exp(np.array([[4, 8, 6, 3, 10, 5], [8, 4, 6, 3, 10, 1]]))).T self.z_list = [z, z1, z2, z3, z4, z5, z6, z7, z8, z9] def fprop(self, z, inference=True, init_state=None): self.i += 1 return self.z_list[self.i] def final_state(): return be.zeros_like(decoder.h[-1]) class InObj(NervanaObject): def __init__(self): self.shape = (nout, time_steps) self.decoder_shape = (nout, time_steps) decoder.fprop = DummyFProp().fprop layers = Seq2Seq([encoder, decoder], decoder_connections=[0]) layers.decoder._recurrent[0].final_state = final_state in_obj = InObj() layers.configure(in_obj) # made zeros because zeros have shape layers.allocate() layers.allocate_deltas(None) beamsearch = BeamSearch(layers) inputs = be.iobuf(in_obj.shape) beamsearch.beamsearch(inputs, num_beams=num_beams) ex0 = np.array([[1, 5, 4, 1], [1, 2, 1, 5], [1, 5, 3, 4]]) ex1 = np.array([[5, 1, 4, 4], [5, 1, 1, 1], [5, 2, 0, 4]]) # extract all candidates examples = reformat_samples(beamsearch, num_beams, batch_size) assert allclose_with_out(examples[0], ex0) assert allclose_with_out(examples[1], ex1)
def check_lstm(seq_len, input_size, hidden_size, batch_size, init_func, inp_moms=[0.0, 1.0]): # init_func is the initializer for the model params # inp_moms is the [ mean, std dev] of the random input input_shape = (input_size, seq_len * batch_size) hidden_shape = (hidden_size, seq_len * batch_size) NervanaObject.be.bsz = NervanaObject.be.batch_size = batch_size # neon LSTM lstm = LSTM(hidden_size, init_func, activation=Tanh(), gate_activation=Logistic()) inp = np.random.rand(*input_shape) * inp_moms[1] + inp_moms[0] inpa = lstm.be.array(inp) # import pdb; pdb.set_trace() # run neon fprop lstm.fprop(inpa) # reference numpy LSTM lstm_ref = RefLSTM() WLSTM = lstm_ref.init(input_size, hidden_size) # make ref weights and biases with neon model WLSTM[0, :] = lstm.b.get().T WLSTM[1:input_size + 1, :] = lstm.W_input.get().T WLSTM[input_size + 1:] = lstm.W_recur.get().T # transpose input X and do fprop inp_ref = inp.copy().T.reshape(seq_len, batch_size, input_size) (Hout_ref, cprev, hprev, batch_cache) = lstm_ref.forward(inp_ref, WLSTM) # the output needs transpose as well Hout_ref = Hout_ref.reshape(seq_len * batch_size, hidden_size).T IFOGf_ref = batch_cache['IFOGf'].reshape(seq_len * batch_size, hidden_size * 4).T Ct_ref = batch_cache['Ct'].reshape(seq_len * batch_size, hidden_size).T # compare results print '====Verifying IFOG====' allclose_with_out(lstm.ifog_buffer.get(), IFOGf_ref, rtol=0.0, atol=1.0e-5) print '====Verifying cell states====' allclose_with_out(lstm.c_act_buffer.get(), Ct_ref, rtol=0.0, atol=1.0e-5) print '====Verifying hidden states====' allclose_with_out(lstm.h_buffer.get(), Hout_ref, rtol=0.0, atol=1.0e-5) print 'fprop is verified' # now test the bprop # generate random deltas tensor deltas = np.random.randn(*hidden_shape) lstm.bprop(lstm.be.array(deltas)) # grab the delta W from gradient buffer dWinput_neon = lstm.dW_input.get() dWrecur_neon = lstm.dW_recur.get() db_neon = lstm.db.get() # import pdb; pdb.set_trace() deltas_ref = deltas.copy().T.reshape(seq_len, batch_size, hidden_size) (dX_ref, dWLSTM_ref, dc0_ref, dh0_ref) = lstm_ref.backward(deltas_ref, batch_cache) dWrecur_ref = dWLSTM_ref[-hidden_size:, :] dWinput_ref = dWLSTM_ref[1:input_size + 1, :] db_ref = dWLSTM_ref[0, :] dX_ref = dX_ref.reshape(seq_len * batch_size, input_size).T # compare results print 'Making sure neon LSTM match numpy LSTM in bprop' print '====Verifying update on W_recur====' assert allclose_with_out(dWrecur_neon, dWrecur_ref.T, rtol=0.0, atol=1.0e-5) print '====Verifying update on W_input====' assert allclose_with_out(dWinput_neon, dWinput_ref.T, rtol=0.0, atol=1.0e-5) print '====Verifying update on bias====' assert allclose_with_out(db_neon.flatten(), db_ref, rtol=0.0, atol=1.0e-5) print '====Verifying output delta====' assert allclose_with_out(lstm.out_deltas_buffer.get(), dX_ref, rtol=0.0, atol=1.0e-5) print 'bprop is verified' return
def check_lstm(seq_len, input_size, hidden_size, batch_size, init_func, inp_moms=[0.0, 1.0]): # init_func is the initializer for the model params # inp_moms is the [ mean, std dev] of the random input input_shape = (input_size, seq_len * batch_size) hidden_shape = (hidden_size, seq_len * batch_size) NervanaObject.be.bsz = NervanaObject.be.batch_size = batch_size # neon LSTM lstm = LSTM(hidden_size, init_func, activation=Tanh(), gate_activation=Logistic()) inp = np.random.rand(*input_shape)*inp_moms[1] + inp_moms[0] inpa = lstm.be.array(inp) # run neon fprop lstm.configure((input_size, seq_len)) lstm.prev_layer = True # Hack to force allocating a delta buffer lstm.allocate() lstm.set_deltas([lstm.be.iobuf(lstm.in_shape)]) lstm.fprop(inpa) # reference numpy LSTM lstm_ref = RefLSTM() WLSTM = lstm_ref.init(input_size, hidden_size) # make ref weights and biases with neon model WLSTM[0, :] = lstm.b.get().T WLSTM[1:input_size+1, :] = lstm.W_input.get().T WLSTM[input_size+1:] = lstm.W_recur.get().T # transpose input X and do fprop inp_ref = inp.copy().T.reshape(seq_len, batch_size, input_size) (Hout_ref, cprev, hprev, batch_cache) = lstm_ref.forward(inp_ref, WLSTM) # the output needs transpose as well Hout_ref = Hout_ref.reshape(seq_len * batch_size, hidden_size).T IFOGf_ref = batch_cache['IFOGf'].reshape(seq_len * batch_size, hidden_size * 4).T Ct_ref = batch_cache['Ct'].reshape(seq_len * batch_size, hidden_size).T # compare results print '====Verifying IFOG====' allclose_with_out(lstm.ifog_buffer.get(), IFOGf_ref, rtol=0.0, atol=1.0e-5) print '====Verifying cell states====' allclose_with_out(lstm.c_act_buffer.get(), Ct_ref, rtol=0.0, atol=1.0e-5) print '====Verifying hidden states====' allclose_with_out(lstm.outputs.get(), Hout_ref, rtol=0.0, atol=1.0e-5) print 'fprop is verified' # now test the bprop # generate random deltas tensor deltas = np.random.randn(*hidden_shape) lstm.bprop(lstm.be.array(deltas)) # grab the delta W from gradient buffer dWinput_neon = lstm.dW_input.get() dWrecur_neon = lstm.dW_recur.get() db_neon = lstm.db.get() deltas_ref = deltas.copy().T.reshape(seq_len, batch_size, hidden_size) (dX_ref, dWLSTM_ref, dc0_ref, dh0_ref) = lstm_ref.backward(deltas_ref, batch_cache) dWrecur_ref = dWLSTM_ref[-hidden_size:, :] dWinput_ref = dWLSTM_ref[1:input_size+1, :] db_ref = dWLSTM_ref[0, :] dX_ref = dX_ref.reshape(seq_len * batch_size, input_size).T # compare results print 'Making sure neon LSTM match numpy LSTM in bprop' print '====Verifying update on W_recur====' assert allclose_with_out(dWrecur_neon, dWrecur_ref.T, rtol=0.0, atol=1.0e-5) print '====Verifying update on W_input====' assert allclose_with_out(dWinput_neon, dWinput_ref.T, rtol=0.0, atol=1.0e-5) print '====Verifying update on bias====' assert allclose_with_out(db_neon.flatten(), db_ref, rtol=0.0, atol=1.0e-5) print '====Verifying output delta====' assert allclose_with_out(lstm.out_deltas_buffer.get(), dX_ref, rtol=0.0, atol=1.0e-5) print 'bprop is verified' return
def test_beamsearch(backend_default): """ Simlulated beam search on a minibatch of 2, for 4 time steps. The LSTM states are real but the "softmax outputs" z are hardcoded and not taken from the network. There are 6 tokens the network outputs, and they have probabilities like exp(1), exp(5), exp(7) The test asserts that the score_lists assigned by _beamsearch_step(z_list) are equal to the probabilities computed manually adding probabilities to z_list. """ be = backend_default batch_size = 2 be.bsz = batch_size time_steps = 4 nout = 6 num_beams = 3 # create unused layers activation = Tanh() gate_activation = Logistic() init_ary = np.eye(nout) init = Array(init_ary) encoder = LSTM(nout, init, activation=activation, gate_activation=gate_activation, name="Enc") decoder = LSTM(nout, init, activation=activation, gate_activation=gate_activation, name="Dec") class DummyFProp(): """ Constructs an artificial beam search example with known correct outputs. This is called inside a nested loop over steps, num_life. In the first time step there is one life beam, after that, 3 life beams per step. There are 4 time steps total. Each beamsearch_step builds one list over num_life beams. At t=0, the winners for ex0 are 1, 4, 5 (indexed by their position) and winners for ex1 are 2,4,5. From there we continue the beam for ex0: 12, 13, 14 6+2=8 6+3=9 6+2=8 40, 43, 45 with scores 5+4=9 5+3=8 5+7=12 three new winners 45, 52, 55 50, 52, 55 5+4=9 5+6=11 5+5=10 for ex2 1 4 5 with scores 5 4 7 we get the three winners 1, 4, 5 and continue (just taking the 3 in order, no sorting) 10 12 13 14 (not unique!) 5+2=7 5+2=7 5+3=8 41 42 43 with scores 4+6=10 4+5=9 4+7=11 winners 43 51 52 51 52 53 7+4=11 7+6=13 7+3=10 scores 11 11 13 continue from the three winners 43 51 52 431 433 434 11+10=21 11+3=14 11+9=20 511 512 513 with scores 11+6=17 11+5=16 11+7=18 winners 431 434 520 520 521 522 13+8=21 13+4=17 13+6=19 scores 21 20 21 continue from three winners 431 511 513 (going along beams, the matches in a beam) 4310 4312 4313 4314 21+2=23 21+2=23 21+3=24 21+10=31 (not unique!) 4341 4342 4343 with scores 20+10=30 20+5=25 20+7=27 winners 4314 4341 5204 5200 5202 5204 21+8=29 21+6=27 21+10=31 scores 31 30 31 overall winners are 4314 4341 5204 """ def __init__(self): self.i = -1 # t=0 # X x x <-- winners: 1, 4, 5 (for example 0) z = be.array( np.exp(np.array([[1, 6, 2, 1, 5, 5], [1, 5, 2, 2, 4, 7]]))).T # t=1 # x x x <-- give we picked 4: new winners 2,3,4 z1 = be.array( np.exp(np.array([[1, 1, 2, 3, 2, 1], [2, 1, 2, 3, 2, 1]]))).T # x x x <-- give we picked 5: # new winners 0,3,[5] # score 12 z2 = be.array( np.exp(np.array([[4, 1, 2, 3, 1, 7], [2, 6, 5, 7, 2, 4]]))).T # x X X <-- give we picked 1: # new winners 0,[2],[5] # scores 12, 11 z3 = be.array( np.exp(np.array([[4, 1, 6, 3, 1, 5], [1, 4, 6, 3, 2, 1]]))).T # t=2 # example 0: given constructed (1, 5), score 11: 3, 4; scores 21, 20 z4 = be.array( np.exp(np.array([[1, 1, 2, 10, 9, 1], [2, 10, 2, 3, 9, 1]]))).T # example 0: given constructed (5, 5), score 12: none selected from this beam z5 = be.array( np.exp(np.array([[4, 1, 2, 3, 1, 7], [2, 6, 5, 7, 2, 4]]))).T # example 0: given constructed (1, 2), score 12: 1; score 20 z6 = be.array( np.exp(np.array([[4, 8, 6, 3, 1, 5], [8, 4, 6, 3, 1, 1]]))).T # t=3 # example 0: given constructed (1, 5, 4), score 20: 1, score 30 z7 = be.array( np.exp(np.array([[1, 10, 2, 1, 1, 1], [2, 1, 2, 3, 10, 1]]))).T # example 0: given constructed (1, 2, 1), score 20: 5, score 30 z8 = be.array( np.exp(np.array([[4, 1, 2, 3, 1, 10], [2, 10, 5, 7, 2, 4]]))).T # example 0: given constructed (1, 5, 3), score 21: 4, score 31 z9 = be.array( np.exp(np.array([[4, 8, 6, 3, 10, 5], [8, 4, 6, 3, 10, 1]]))).T self.z_list = [z, z1, z2, z3, z4, z5, z6, z7, z8, z9] def fprop(self, z, inference=True, init_state=None): self.i += 1 return self.z_list[self.i] def final_state(): return be.zeros_like(decoder.h[-1]) class InObj(NervanaObject): def __init__(self): self.shape = (nout, time_steps) self.decoder_shape = (nout, time_steps) decoder.fprop = DummyFProp().fprop layers = Seq2Seq([encoder, decoder], decoder_connections=[0]) layers.decoder._recurrent[0].final_state = final_state in_obj = InObj() layers.configure(in_obj) # made zeros because zeros have shape layers.allocate() layers.allocate_deltas(None) beamsearch = BeamSearch(layers) inputs = be.iobuf(in_obj.shape) beamsearch.beamsearch(inputs, num_beams=num_beams) ex0 = np.array([[1, 5, 4, 1], [1, 2, 1, 5], [1, 5, 3, 4]]) ex1 = np.array([[5, 1, 4, 4], [5, 1, 1, 1], [5, 2, 0, 4]]) # extract all candidates examples = reformat_samples(beamsearch, num_beams, batch_size) assert allclose_with_out(examples[0], ex0) assert allclose_with_out(examples[1], ex1)