Exemplo n.º 1
0
    def __init__(self, conv_input, conv_filter, strides, padding, name=None):
        mtf.Operation.__init__(self, [conv_input, conv_filter],
                               name=name or "conv2d")
        self._padding = padding
        self._batch_dims = conv_input.shape.dims[:-3]
        self._in_h_dim, self._in_w_dim, self._in_dim = conv_input.shape.dims[
            -3:]
        self._fh_dim, self._fw_dim = conv_filter.shape.dims[:2]
        f_in_dim, self._out_dim = conv_filter.shape.dims[2:]
        if f_in_dim != self._in_dim:
            raise ValueError("Dimensions do not match input=%s filter=%s" %
                             (conv_input, conv_filter))
        out_h = self._in_h_dim.size
        out_w = self._in_w_dim.size
        if padding == "VALID":
            out_h -= self._fh_dim.size
            out_w -= self._fw_dim.size

        self._strides = strides
        if strides is not None:
            out_h //= strides[1]
            out_w //= strides[2]

        if padding == "VALID":
            out_h += 1
            out_w += 1

        self._out_h_dim = mtf.Dimension(self._in_h_dim.name, out_h)
        self._out_w_dim = mtf.Dimension(self._in_w_dim.name, out_w)
        output_shape = mtf.Shape(
            self._batch_dims +
            [self._out_h_dim, self._out_w_dim, self._out_dim])
        self._outputs = [mtf.Tensor(self, output_shape, conv_input.dtype)]
Exemplo n.º 2
0
 def __init__(self, tensor_in, dims, name=None):
   super(iFFT3DOperation, self).__init__([tensor_in], name=name or "iFFT3D")
   self._dims = dims
   self._output_shape = mtf.Shape(tensor_in.shape[:-3] + dims)
   self._outputs = [
       mtf.Tensor(self, mtf.Shape(self._output_shape), tensor_in.dtype)
   ]
Exemplo n.º 3
0
 def __init__(self, tensor_in, k_dims, name=None):
     super(FFT3DOperation, self).__init__([tensor_in], name=name or "FFT3D")
     self._k_dims = k_dims
     self._output_shape = mtf.Shape(tensor_in.shape[:-3] +
                                    [k_dims[1], k_dims[2], k_dims[0]])
     self._outputs = [
         mtf.Tensor(self, mtf.Shape(self._output_shape), tensor_in.dtype)
     ]
Exemplo n.º 4
0
 def __init__(self, x, w0, w1, num_units, states, name=None):
     assert (x.shape[-1].name == w0.shape[0].name ==
             w1.shape[0].name), (x.shape, w0.shape, w1.shape)
     super().__init__([x, w0, w1] + states,
                      mesh=w1.mesh,
                      name=name or 'rnn')
     self.num_units = num_units
     self._outputs = [mtf.Tensor(self, x.shape, x.dtype)]
Exemplo n.º 5
0
 def __init__(self, mesh, shape, dtype, name=None):
   super(IndicesOperation, self).__init__([], mesh, name=name or "indices")
   self._mesh = mesh
   self._shape = [mtf.convert_to_dimension(dim) for dim in shape]
   self._dtype = dtype
   self._outputs = [
       mtf.Tensor(
           self,
           mtf.Shape(self._shape + [mtf.Dimension("ndim", len(self._shape))]),
           dtype)
   ]
Exemplo n.º 6
0
    def __init__(self, x, mesh, dim_names=None, name=None):
        assert x.mesh != mesh

        self.old_mesh = x.mesh
        if isinstance(dim_names, mtf.Shape):
            dim_names = dim_names.dimension_names
        self.new_dim_names = dim_names or x.shape.dimension_names

        assert len(self.new_dim_names) == len(x.shape)
        self.new_shape = mtf.Shape([
            mtf.Dimension(name or dim.name, dim.size)
            for name, dim in zip(self.new_dim_names, x.shape.dims)
        ])

        super().__init__([x], mesh=mesh, name=name or self.__class__.__name__)
        self._outputs = [mtf.Tensor(self, self.new_shape, x.dtype)]