def __init__(self, conv_input, conv_filter, strides, padding, name=None): mtf.Operation.__init__(self, [conv_input, conv_filter], name=name or "conv2d") self._padding = padding self._batch_dims = conv_input.shape.dims[:-3] self._in_h_dim, self._in_w_dim, self._in_dim = conv_input.shape.dims[ -3:] self._fh_dim, self._fw_dim = conv_filter.shape.dims[:2] f_in_dim, self._out_dim = conv_filter.shape.dims[2:] if f_in_dim != self._in_dim: raise ValueError("Dimensions do not match input=%s filter=%s" % (conv_input, conv_filter)) out_h = self._in_h_dim.size out_w = self._in_w_dim.size if padding == "VALID": out_h -= self._fh_dim.size out_w -= self._fw_dim.size self._strides = strides if strides is not None: out_h //= strides[1] out_w //= strides[2] if padding == "VALID": out_h += 1 out_w += 1 self._out_h_dim = mtf.Dimension(self._in_h_dim.name, out_h) self._out_w_dim = mtf.Dimension(self._in_w_dim.name, out_w) output_shape = mtf.Shape( self._batch_dims + [self._out_h_dim, self._out_w_dim, self._out_dim]) self._outputs = [mtf.Tensor(self, output_shape, conv_input.dtype)]
def __init__(self, tensor_in, dims, name=None): super(iFFT3DOperation, self).__init__([tensor_in], name=name or "iFFT3D") self._dims = dims self._output_shape = mtf.Shape(tensor_in.shape[:-3] + dims) self._outputs = [ mtf.Tensor(self, mtf.Shape(self._output_shape), tensor_in.dtype) ]
def __init__(self, tensor_in, k_dims, name=None): super(FFT3DOperation, self).__init__([tensor_in], name=name or "FFT3D") self._k_dims = k_dims self._output_shape = mtf.Shape(tensor_in.shape[:-3] + [k_dims[1], k_dims[2], k_dims[0]]) self._outputs = [ mtf.Tensor(self, mtf.Shape(self._output_shape), tensor_in.dtype) ]
def __init__(self, x, w0, w1, num_units, states, name=None): assert (x.shape[-1].name == w0.shape[0].name == w1.shape[0].name), (x.shape, w0.shape, w1.shape) super().__init__([x, w0, w1] + states, mesh=w1.mesh, name=name or 'rnn') self.num_units = num_units self._outputs = [mtf.Tensor(self, x.shape, x.dtype)]
def __init__(self, mesh, shape, dtype, name=None): super(IndicesOperation, self).__init__([], mesh, name=name or "indices") self._mesh = mesh self._shape = [mtf.convert_to_dimension(dim) for dim in shape] self._dtype = dtype self._outputs = [ mtf.Tensor( self, mtf.Shape(self._shape + [mtf.Dimension("ndim", len(self._shape))]), dtype) ]
def __init__(self, x, mesh, dim_names=None, name=None): assert x.mesh != mesh self.old_mesh = x.mesh if isinstance(dim_names, mtf.Shape): dim_names = dim_names.dimension_names self.new_dim_names = dim_names or x.shape.dimension_names assert len(self.new_dim_names) == len(x.shape) self.new_shape = mtf.Shape([ mtf.Dimension(name or dim.name, dim.size) for name, dim in zip(self.new_dim_names, x.shape.dims) ]) super().__init__([x], mesh=mesh, name=name or self.__class__.__name__) self._outputs = [mtf.Tensor(self, self.new_shape, x.dtype)]