示例#1
0
    def test_conversion(self):
        a = np.arange(10, dtype=np.float32).reshape((5, 2))
        b = csr_matrix(a)
        c = torch.from_numpy(a).to(utils.device)
        for arr in (a, b, c):
            for t in utils.OutputType:
                for dtype in (
                        torch.float32,
                        torch.float64) if t is utils.OutputType.tensor else (
                            np.float32, np.float64):
                    utils.convert(arr, t, dtype)  # These should all pass

        with self.assertRaises(ValueError):
            utils.convert("a", utils.OutputType.tensor)  # This shouldn't work
示例#2
0
 def forward(self, x):
     embeds = [embedding(x) for embedding in self.embeddings]
     is_list = isinstance(embeds[0], list)  # happens for word embeddings before pooling
     types = set([OutputType.from_object(embed[0]) if is_list else OutputType.from_object(embed) for embed in embeds])
     if OutputType.tensor in types:  # need to convert everything to torch
         embeds = convert(embeds, OutputType.tensor)
         cat_fn = _Concat._tensor_concat
     elif len(types) == 2:  # both numpy and sparse
         embeds = convert(embeds, OutputType.array)
         cat_fn = _Concat._array_concat
     elif OutputType.array in types:  # only numpy
         cat_fn = _Concat._array_concat
     else:  # only sparse
         cat_fn = _Concat._sparse_concat
     if is_list:
         return [cat_fn(embed) for embed in zip(*embeds)]
     return cat_fn(embeds)
示例#3
0
 def _convert_to_dtype(self, x, detach=False):
     typ = OutputType.tensor if isinstance(
         self.dtype, torch.dtype) else OutputType.array
     return convert(x, typ, self.dtype, detach=detach)
示例#4
0
 def fit_transform(self, x, y=None):
     needs_fit = self._init_vectorizer()
     if needs_fit:
         return self.vectorizer.fit_transform(x, convert(y, OutputType.array))
     return self.vectorizer.transform(x)
示例#5
0
 def fit(self, x, y=None):
     needs_fit = self._init_vectorizer()
     if needs_fit:
         self.vectorizer.fit(x, convert(y, OutputType.array))
示例#6
0
 def _fit_transform(self, x, y=None):
     self.model = self.Model(**self.init_args)
     return self.model.fit_transform(x, convert(
         y, OutputType.array)).astype(np.float32)
示例#7
0
 def _fit(self, x, y=None):
     self.model = self.Model(**self.init_args)
     self.model.fit(x, convert(y, OutputType.array))
示例#8
0
 def _check_input(self, x):
     if not isinstance(x, tuple(t.value for t in self.input_types)):
         return convert(x, self.input_types[0])
     return x