示例#1
0
文件: layer.py 项目: kkasravi/ngraph
    def train_outputs(self, in_obj):
        """
        Arguments:
            in_obj (Tensor): object that provides the lookup indices
        """
        in_obj.axes.find_by_short_name('time')[0].add_role(ar.time)
        in_obj.axes.find_by_short_name('time')[0].is_recurrent = True
        in_obj = ng.axes_with_role_order(in_obj, self.role_order)
        in_obj = ng.flatten(in_obj)
        in_axes = in_obj.axes

        self.lut_v_axis = ng.make_axis(self.vocab_size).named('V')
        self.lut_f_axis = ng.make_axis(self.embed_dim).named('F')

        self.w_axes = ng.make_axes([self.lut_v_axis, self.lut_f_axis])
        self.lut_o_axes = in_axes + ng.make_axes([self.lut_f_axis])
        self.o_axes = ng.make_axes([self.lut_f_axis]) + in_axes[0].axes

        self.W = ng.variable(axes=self.w_axes,
                             initial_value=self.lut_init(
                                 self.w_axes, self.lut_v_axis,
                                 self.pad_idx)).named('W')

        lut_result = ng.lookuptable(self.W,
                                    in_obj,
                                    self.lut_o_axes,
                                    update=self.update,
                                    pad_idx=self.pad_idx)
        return ng.axes_with_order(ng.unflatten(lut_result), self.o_axes)
示例#2
0
def test_lut(lut_args):
    """
    test lut fprop and bprop
    """
    pad_idx = 0
    with ExecutorFactory() as ex:

        vocab_size, embed_dim, bsz, seq_len, mem_size = lut_args

        V = ng.make_axis(vocab_size)
        F = ng.make_axis(embed_dim)
        M = ng.make_axis(mem_size)

        ax.N.length = bsz
        ax.REC.length = seq_len

        # Multi-axis input to LUT
        ax_idx = ng.make_axes([M, ax.REC, ax.N])
        ax_lut = ng.make_axes([V, F])

        lut = ng.placeholder(ax_lut)
        idx = ng.placeholder(ax_idx)
        idx_flat = ng.flatten(idx)
        ax_out = idx_flat.axes | ng.make_axes([F])

        # fprop
        lut_out_ng = ng.lookuptable(lut, idx_flat, ax_out, pad_idx=pad_idx)
        fprop_fun = ex.executor(lut_out_ng, lut, idx)

        # bprop
        update_error = ng.placeholder(ax_out)
        update_out_ng = lookuptable_update(update_error, lut, idx, lut_out_ng)
        update_fun = ex.executor(update_out_ng, update_error, lut, idx)

        # provide actual inputs and execute the graph
        lut_value = rng.uniform(-1, 1, lut.axes)
        idx_value = rng.random_integers(0, vocab_size - 1, idx.axes)
        fprop_lut = fprop_fun(lut_value, idx_value).copy()

        # compare fprop
        fprop_ref = lut_fprop_ref(lut_value, idx_value)
        ng.testing.assert_allclose(fprop_lut, fprop_ref, rtol=0.0, atol=1.0e-5)

        # provide actual delta and execute the update op
        update_value = rng.uniform(-1, 1, update_error.axes)
        update_lut = update_fun(update_value, lut_value, idx_value).copy()

        # compare bprop (udpate)
        update_ref = lut_update_ref(update_value,
                                    lut_value,
                                    idx_value,
                                    pad_idx=pad_idx)
        ng.testing.assert_allclose(update_lut,
                                   update_ref,
                                   rtol=0.0,
                                   atol=1.0e-5)
def test_lut(lut_args):
    """
    test lut fprop and bprop
    """
    pad_idx = 0
    with ExecutorFactory() as ex:

        vocab_size, embed_dim, bsz, seq_len, mem_size = lut_args

        V = ng.make_axis(vocab_size)
        F = ng.make_axis(embed_dim)
        M = ng.make_axis(mem_size)

        ax.N.length = bsz
        ax.REC.length = seq_len

        # Multi-axis input to LUT
        ax_idx = ng.make_axes([M, ax.REC, ax.N])
        ax_lut = ng.make_axes([V, F])

        lut = ng.placeholder(ax_lut)
        idx = ng.placeholder(ax_idx)
        idx_flat = ng.flatten(idx)
        ax_out = idx_flat.axes | ng.make_axes([F])

        # fprop
        lut_out_ng = ng.lookuptable(lut, idx_flat, ax_out, pad_idx=pad_idx)
        fprop_fun = ex.executor(lut_out_ng, lut, idx)

        # bprop
        update_error = ng.placeholder(ax_out)
        update_out_ng = lookuptable_update(update_error, lut, idx, lut_out_ng)
        update_fun = ex.executor(update_out_ng, update_error, lut, idx)

        # provide actual inputs and execute the graph
        lut_value = rng.uniform(-1, 1, lut.axes)
        idx_value = rng.random_integers(0, vocab_size - 1, idx.axes)
        fprop_lut = fprop_fun(lut_value, idx_value).copy()

        # compare fprop
        fprop_ref = lut_fprop_ref(lut_value, idx_value)
        ng.testing.assert_allclose(fprop_lut, fprop_ref, rtol=0.0, atol=1.0e-5)

        # provide actual delta and execute the update op
        update_value = rng.uniform(-1, 1, update_error.axes)
        update_lut = update_fun(update_value, lut_value, idx_value).copy()

        # compare bprop (udpate)
        update_ref = lut_update_ref(
            update_value,
            lut_value,
            idx_value,
            pad_idx=pad_idx)
        ng.testing.assert_allclose(
            update_lut, update_ref, rtol=0.0, atol=1.0e-5)
示例#4
0
    def __call__(self, in_obj, **kwargs):
        """
        Arguments:
            in_obj (Tensor): object that provides the lookup indices
        """
        LABELS = {"weight": "weight", "bias": "bias"}

        in_obj = ng.axes_with_order(
            in_obj,
            ng.make_axes(
                [in_obj.axes.recurrent_axis(),
                 in_obj.axes.batch_axis()]))
        in_obj = ng.flatten(in_obj)
        in_axes = in_obj.axes

        # label lut_v_axis as shadow axis for initializers ... once #1158 is
        # in, shadow axis will do more than just determine fan in/out for
        # initializers.
        self.lut_v_axis = ng.make_axis(self.vocab_size).named('V')
        self.axes_map = shadow_axes_map([self.lut_v_axis])
        self.lut_v_axis = list(self.axes_map.values())[0]

        self.lut_f_axis = ng.make_axis(self.embed_dim).named('F')

        self.w_axes = ng.make_axes([self.lut_v_axis, self.lut_f_axis])
        self.lut_o_axes = in_axes | ng.make_axes([self.lut_f_axis])
        self.o_axes = ng.make_axes([self.lut_f_axis]) | in_axes[0].axes

        if not self.initialized:
            self.W = ng.variable(
                axes=self.w_axes,
                initial_value=self.lut_init(self.w_axes, self.lut_v_axis,
                                            self.pad_idx),
                metadata={
                    "label": LABELS["weight"]
                },
            ).named('LutW')

        lut_result = ng.lookuptable(self.W,
                                    in_obj,
                                    self.lut_o_axes,
                                    update=self.update,
                                    pad_idx=self.pad_idx)
        return ng.axes_with_order(
            ng.map_roles(ng.unflatten(lut_result), self.axes_map), self.o_axes)
    def __call__(self, in_obj, **kwargs):
        """
        Arguments:
            in_obj (Tensor): object that provides the lookup indices
        """
        in_obj = ng.flatten(in_obj)
        in_axes = in_obj.axes

        # label lut_v_axis as shadow axis for initializers ... once #1158 is
        # in, shadow axis will do more than just determine fan in/out for
        # initializers.
        self.lut_v_axis = ng.make_axis(self.vocab_size).named('V')
        self.axes_map = shadow_axes_map([self.lut_v_axis])
        self.lut_v_axis = list(self.axes_map.values())[0]

        self.lut_f_axis = ng.make_axis(self.embed_dim).named('F')

        self.w_axes = ng.make_axes([self.lut_v_axis, self.lut_f_axis])
        self.lut_o_axes = in_axes | ng.make_axes([self.lut_f_axis])
        self.o_axes = ng.make_axes([self.lut_f_axis]) | in_axes[0].axes

        if not self.initialized:
            self.W = ng.variable(
                axes=self.w_axes,
                initial_value=self.lut_init(
                    self.w_axes,
                    self.lut_v_axis,
                    self.pad_idx),
                metadata={
                    "label": LABELS["weight"]},
            ).named('LutW')

        lut_result = ng.lookuptable(
            self.W,
            in_obj,
            self.lut_o_axes,
            update=self.update,
            pad_idx=self.pad_idx)
        return ng.map_roles(ng.unflatten(lut_result), self.axes_map)