Exemple #1
0
    def init_output_module(self):
        self.emb_out = LookupTable(self.voc_sz, self.out_dim)
        p = Parallel()
        p.add(self.emb_out)
        p.add(Identity())

        self.mod_out = Sequential()
        self.mod_out.add(p)
        self.mod_out.add(MatVecProd(False))
Exemple #2
0
    def init_query_module(self):
        self.emb_query = LookupTable(self.voc_sz, self.in_dim)
        p = Parallel()
        p.add(self.emb_query)
        p.add(Identity())

        self.mod_query = Sequential()
        self.mod_query.add(p)
        self.mod_query.add(MatVecProd(True))
        self.mod_query.add(Softmax())
Exemple #3
0
    def init_output_module(self):
        self.emb_out = LookupTable(self.voc_sz, self.out_dim)
        s = Sequential()
        s.add(self.emb_out)
        s.add(ElemMult(self.config["weight"]))
        s.add(Sum(dim=1))

        p = Parallel()
        p.add(s)
        p.add(Identity())

        self.mod_out = Sequential()
        self.mod_out.add(p)
        self.mod_out.add(MatVecProd(False))
Exemple #4
0
    def init_query_module(self):
        self.emb_query = LookupTable(self.voc_sz, self.in_dim)
        s = Sequential()
        s.add(self.emb_query)
        s.add(ElemMult(self.config["weight"]))
        s.add(Sum(dim=1))

        p = Parallel()
        p.add(s)
        p.add(Identity())

        self.mod_query = Sequential()
        self.mod_query.add(p)
        self.mod_query.add(MatVecProd(True))
        self.mod_query.add(Softmax())
        except AssertionError:
            tests[1] = False

if TEST2:
    tests[2] = True
    for i in range(10):
        M = np.random.rand(*matrix_batch_dim)
        V = np.random.rand(*vect_batch_dim)

        input_data = [M, V]
        input_data_torch = [
            torch.from_numpy(M).type(FloatTensor),
            torch.from_numpy(V).type(FloatTensor)
        ]
        transpose = i % 2 == 0
        mvp = MatVecProd(transpose)
        mvp_pt = MatVecProdPytorch(transpose)
        result_1 = mvp.fprop(input_data)
        result_2 = mvp_pt.forward(input_data_torch)
        try:
            result_2_np = result_2.data.numpy()
            assert np.allclose(result_1, result_2_np)
        except AssertionError:
            tests[2] = False

if TEST3:
    tests[3] = True
    input = [
        np.random.rand(50, 32).astype('f'),
        np.random.rand(50, 32).astype('f')
    ]