コード例 #1
0
def convert_var_shape(x, idx=None, in_control_flow=False):
    """
    A function representation of the shape of variable.
    """
    def has_negative(list_shape, idx=None):
        if idx is not None:
            return list_shape[idx] < 0

        num_negative = sum([1 if i < 0 else 0 for i in list_shape])
        return num_negative > 0

    # When `x` is Variable, call nn.shape(x) in following cases:
    #  (1) The shape of `x` is used in control flow condition.
    #      ```
    #      if x.shape[0] == 1:
    #          y = XX
    #      ```
    #  (2) The dim to be used is negative
    #      ```
    #      # Assume x.shape=[3, -1] in static mode
    #      y = paddle.reshape(x, shape=[1, x.shape[1]])
    #      ```
    if isinstance(x, Variable) and (in_control_flow
                                    or has_negative(x.shape, idx)):
        return nn.shape(x) if idx is None else nn.shape(x)[idx]
    else:
        return x.shape if idx is None else x.shape[idx]
コード例 #2
0
def convert_len(var):
    """
    Returns variable(length) from shape ops based on var.type

    Note: In addition to some ast transformations, some block-related
          operations are added in `len` transformation, such as appending
          `shape_op` in var.block.
    """
    if isinstance(var, Variable):
        if var.type in [
                core.VarDesc.VarType.LOD_TENSOR,
                core.VarDesc.VarType.SELECTED_ROWS
        ]:
            # Note: Length of var may be known ahead of time in dygraph,
            # but it probably represents batch size which can be variant.
            # so we return a variable dynamically inferred from var.shape.
            return nn.shape(var)[0]
        elif var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
            return control_flow.array_length(var)
        else:
            raise TypeError(
                'len(var) only supports LoDTensor/LoDTensorArray/SelectedRows, but received %s.'
                % type(var))
    else:
        return len(var)
コード例 #3
0
def convert_var_shape_simple(x):
    """
    A function representation of the shape of variable.
    """
    if isinstance(x, Variable):
        return nn.shape(x)
    else:
        return x.shape
コード例 #4
0
ファイル: normal.py プロジェクト: sandyhouse/Paddle
    def sample(self, shape, seed=0):
        """Generate samples of the specified shape.

        Args:
          shape (list): 1D `int32`. Shape of the generated samples.
          seed (int): Python integer number.

        Returns:
          Tensor: A tensor with prepended dimensions shape.The data type is float32.

        """
        if not _non_static_mode():
            check_type(shape, 'shape', (list), 'sample')
            check_type(seed, 'seed', (int), 'sample')

        batch_shape = list((self.loc + self.scale).shape)
        name = self.name + '_sample'

        if self.batch_size_unknown:
            output_shape = shape + batch_shape
            zero_tmp = tensor.fill_constant_batch_size_like(
                self.loc + self.scale, batch_shape + shape, self.dtype, 0.)
            zero_tmp_reshape = nn.reshape(zero_tmp, output_shape)
            zero_tmp_shape = nn.shape(zero_tmp_reshape)
            normal_random_tmp = nn.gaussian_random(zero_tmp_shape,
                                                   mean=0.,
                                                   std=1.,
                                                   seed=seed,
                                                   dtype=self.dtype)
            output = normal_random_tmp * (zero_tmp_reshape + self.scale)
            output = elementwise_add(output, self.loc, name=name)
            return output
        else:
            output_shape = shape + batch_shape
            output = nn.gaussian_random(output_shape, mean=0., std=1., seed=seed, dtype=self.dtype) * \
                     (tensor.zeros(output_shape, dtype=self.dtype) + self.scale)
            output = elementwise_add(output, self.loc, name=name)
            if self.all_arg_is_float:
                return nn.reshape(output, shape, name=name)
            else:
                return output