def train_outputs(self, in_obj): """ Arguments: in_obj (Tensor): object that provides the lookup indices """ in_obj.axes.find_by_short_name('time')[0].add_role(ar.time) in_obj.axes.find_by_short_name('time')[0].is_recurrent = True in_obj = ng.axes_with_role_order(in_obj, self.role_order) in_obj = ng.flatten(in_obj) in_axes = in_obj.axes self.lut_v_axis = ng.make_axis(self.vocab_size).named('V') self.lut_f_axis = ng.make_axis(self.embed_dim).named('F') self.w_axes = ng.make_axes([self.lut_v_axis, self.lut_f_axis]) self.lut_o_axes = in_axes + ng.make_axes([self.lut_f_axis]) self.o_axes = ng.make_axes([self.lut_f_axis]) + in_axes[0].axes self.W = ng.variable(axes=self.w_axes, initial_value=self.lut_init( self.w_axes, self.lut_v_axis, self.pad_idx)).named('W') lut_result = ng.lookuptable(self.W, in_obj, self.lut_o_axes, update=self.update, pad_idx=self.pad_idx) return ng.axes_with_order(ng.unflatten(lut_result), self.o_axes)
def test_lut(lut_args): """ test lut fprop and bprop """ pad_idx = 0 with ExecutorFactory() as ex: vocab_size, embed_dim, bsz, seq_len, mem_size = lut_args V = ng.make_axis(vocab_size) F = ng.make_axis(embed_dim) M = ng.make_axis(mem_size) ax.N.length = bsz ax.REC.length = seq_len # Multi-axis input to LUT ax_idx = ng.make_axes([M, ax.REC, ax.N]) ax_lut = ng.make_axes([V, F]) lut = ng.placeholder(ax_lut) idx = ng.placeholder(ax_idx) idx_flat = ng.flatten(idx) ax_out = idx_flat.axes | ng.make_axes([F]) # fprop lut_out_ng = ng.lookuptable(lut, idx_flat, ax_out, pad_idx=pad_idx) fprop_fun = ex.executor(lut_out_ng, lut, idx) # bprop update_error = ng.placeholder(ax_out) update_out_ng = lookuptable_update(update_error, lut, idx, lut_out_ng) update_fun = ex.executor(update_out_ng, update_error, lut, idx) # provide actual inputs and execute the graph lut_value = rng.uniform(-1, 1, lut.axes) idx_value = rng.random_integers(0, vocab_size - 1, idx.axes) fprop_lut = fprop_fun(lut_value, idx_value).copy() # compare fprop fprop_ref = lut_fprop_ref(lut_value, idx_value) ng.testing.assert_allclose(fprop_lut, fprop_ref, rtol=0.0, atol=1.0e-5) # provide actual delta and execute the update op update_value = rng.uniform(-1, 1, update_error.axes) update_lut = update_fun(update_value, lut_value, idx_value).copy() # compare bprop (udpate) update_ref = lut_update_ref(update_value, lut_value, idx_value, pad_idx=pad_idx) ng.testing.assert_allclose(update_lut, update_ref, rtol=0.0, atol=1.0e-5)
def test_lut(lut_args): """ test lut fprop and bprop """ pad_idx = 0 with ExecutorFactory() as ex: vocab_size, embed_dim, bsz, seq_len, mem_size = lut_args V = ng.make_axis(vocab_size) F = ng.make_axis(embed_dim) M = ng.make_axis(mem_size) ax.N.length = bsz ax.REC.length = seq_len # Multi-axis input to LUT ax_idx = ng.make_axes([M, ax.REC, ax.N]) ax_lut = ng.make_axes([V, F]) lut = ng.placeholder(ax_lut) idx = ng.placeholder(ax_idx) idx_flat = ng.flatten(idx) ax_out = idx_flat.axes | ng.make_axes([F]) # fprop lut_out_ng = ng.lookuptable(lut, idx_flat, ax_out, pad_idx=pad_idx) fprop_fun = ex.executor(lut_out_ng, lut, idx) # bprop update_error = ng.placeholder(ax_out) update_out_ng = lookuptable_update(update_error, lut, idx, lut_out_ng) update_fun = ex.executor(update_out_ng, update_error, lut, idx) # provide actual inputs and execute the graph lut_value = rng.uniform(-1, 1, lut.axes) idx_value = rng.random_integers(0, vocab_size - 1, idx.axes) fprop_lut = fprop_fun(lut_value, idx_value).copy() # compare fprop fprop_ref = lut_fprop_ref(lut_value, idx_value) ng.testing.assert_allclose(fprop_lut, fprop_ref, rtol=0.0, atol=1.0e-5) # provide actual delta and execute the update op update_value = rng.uniform(-1, 1, update_error.axes) update_lut = update_fun(update_value, lut_value, idx_value).copy() # compare bprop (udpate) update_ref = lut_update_ref( update_value, lut_value, idx_value, pad_idx=pad_idx) ng.testing.assert_allclose( update_lut, update_ref, rtol=0.0, atol=1.0e-5)
def __call__(self, in_obj, **kwargs): """ Arguments: in_obj (Tensor): object that provides the lookup indices """ LABELS = {"weight": "weight", "bias": "bias"} in_obj = ng.axes_with_order( in_obj, ng.make_axes( [in_obj.axes.recurrent_axis(), in_obj.axes.batch_axis()])) in_obj = ng.flatten(in_obj) in_axes = in_obj.axes # label lut_v_axis as shadow axis for initializers ... once #1158 is # in, shadow axis will do more than just determine fan in/out for # initializers. self.lut_v_axis = ng.make_axis(self.vocab_size).named('V') self.axes_map = shadow_axes_map([self.lut_v_axis]) self.lut_v_axis = list(self.axes_map.values())[0] self.lut_f_axis = ng.make_axis(self.embed_dim).named('F') self.w_axes = ng.make_axes([self.lut_v_axis, self.lut_f_axis]) self.lut_o_axes = in_axes | ng.make_axes([self.lut_f_axis]) self.o_axes = ng.make_axes([self.lut_f_axis]) | in_axes[0].axes if not self.initialized: self.W = ng.variable( axes=self.w_axes, initial_value=self.lut_init(self.w_axes, self.lut_v_axis, self.pad_idx), metadata={ "label": LABELS["weight"] }, ).named('LutW') lut_result = ng.lookuptable(self.W, in_obj, self.lut_o_axes, update=self.update, pad_idx=self.pad_idx) return ng.axes_with_order( ng.map_roles(ng.unflatten(lut_result), self.axes_map), self.o_axes)
def __call__(self, in_obj, **kwargs): """ Arguments: in_obj (Tensor): object that provides the lookup indices """ in_obj = ng.flatten(in_obj) in_axes = in_obj.axes # label lut_v_axis as shadow axis for initializers ... once #1158 is # in, shadow axis will do more than just determine fan in/out for # initializers. self.lut_v_axis = ng.make_axis(self.vocab_size).named('V') self.axes_map = shadow_axes_map([self.lut_v_axis]) self.lut_v_axis = list(self.axes_map.values())[0] self.lut_f_axis = ng.make_axis(self.embed_dim).named('F') self.w_axes = ng.make_axes([self.lut_v_axis, self.lut_f_axis]) self.lut_o_axes = in_axes | ng.make_axes([self.lut_f_axis]) self.o_axes = ng.make_axes([self.lut_f_axis]) | in_axes[0].axes if not self.initialized: self.W = ng.variable( axes=self.w_axes, initial_value=self.lut_init( self.w_axes, self.lut_v_axis, self.pad_idx), metadata={ "label": LABELS["weight"]}, ).named('LutW') lut_result = ng.lookuptable( self.W, in_obj, self.lut_o_axes, update=self.update, pad_idx=self.pad_idx) return ng.map_roles(ng.unflatten(lut_result), self.axes_map)