Exemplo n.º 1
0
 def __init__(self):
     if(isinstance(self.__class__.KeysetDict, set)):
         self.KeysetDict = {i: None for i in
                            self.__class__.KeysetDict}
     elif(isinstance(self.__class__.KeysetDict, dict)):
         self.KeysetDict = ccopy(self.__class__.KeysetDict)
     else:
         try:
             # Attempt to coerce to iterable.
             self.KeysetDict = {i: None for i in
                                list(self.__class__.KeysetDict)}
         except TypeError:
             raise IllegalArgumentError("KeysetDict not iterable...")
Exemplo n.º 2
0
def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
    """Concatenate InferenceData objects.

    Concatenates over `group`, `chain` or `draw`.
    By default concatenates over unique groups.
    To concatenate over `chain` or `draw` function
    needs identical groups and variables.

    The `variables` in the `data` -group are merged if `dim` are not found.


    Parameters
    ----------
    *args : InferenceData
        Variable length InferenceData list or
        Sequence of InferenceData.
    dim : str, optional
        If defined, concatenated over the defined dimension.
        Dimension which is concatenated. If None, concatenates over
        unique groups.
    copy : bool
        If True, groups are copied to the new InferenceData object.
        Used only if `dim` is None.
    inplace : bool
        If True, merge args to first object.
    reset_dim : bool
        Valid only if dim is not None.

    Returns
    -------
    InferenceData
        A new InferenceData object by default.
        When `inplace==True` merge args to first arg and return `None`

    Examples
    --------
    Use ``concat`` method to concatenate InferenceData objects. This will concatenates over
    unique groups by default. We first create an ``InferenceData`` object:

    .. ipython::

        In [1]: import arviz as az
           ...: import numpy as np
           ...: data = {
           ...:     "a": np.random.normal(size=(4, 100, 3)),
           ...:     "b": np.random.normal(size=(4, 100)),
           ...: }
           ...: coords = {"a_dim": ["x", "y", "z"]}
           ...: dataA = az.from_dict(data, coords=coords, dims={"a": ["a_dim"]})
           ...: dataA

    We have created an ``InferenceData`` object with default group 'posterior'. Now, we will
    create another ``InferenceData`` object:

    .. ipython::

        In [1]: dataB = az.from_dict(prior=data, coords=coords, dims={"a": ["a_dim"]})
           ...: dataB

    We have created another ``InferenceData`` object with group 'prior'. Now, we will concatenate
    these two ``InferenceData`` objects:

    .. ipython::

        In [1]: az.concat(dataA, dataB)

    Now, we will concatenate over chain (or draw). It requires identical groups and variables.
    Here we are concatenating two identical ``InferenceData`` objects over dimension chain:

    .. ipython::

        In [1]: az.concat(dataA, dataA, dim="chain")

    It will create an ``InferenceData`` with the original group 'posterior'. In similar way,
    we can also concatenate over draws.

    """
    # pylint: disable=undefined-loop-variable, too-many-nested-blocks
    if len(args) == 0:
        if inplace:
            return
        return InferenceData()

    if len(args) == 1 and isinstance(args[0], Sequence):
        args = args[0]

    # assert that all args are InferenceData
    for i, arg in enumerate(args):
        if not isinstance(arg, InferenceData):
            raise TypeError(
                "Concatenating is supported only"
                "between InferenceData objects. Input arg {} is {}".format(
                    i, type(arg)))

    if dim is not None and dim.lower() not in {"group", "chain", "draw"}:
        msg = "Invalid `dim`: {}. Valid `dim` are {}".format(
            dim, '{"group", "chain", "draw"}')
        raise TypeError(msg)
    dim = dim.lower() if dim is not None else dim

    if len(args) == 1 and isinstance(args[0], InferenceData):
        if inplace:
            return None
        else:
            if copy:
                return deepcopy(args[0])
            else:
                return args[0]

    current_time = str(datetime.now())

    if not inplace:
        # Keep order for python 3.5
        inference_data_dict = OrderedDict()

    if dim is None:
        arg0 = args[0]
        arg0_groups = ccopy(arg0._groups_all)
        args_groups = dict()
        # check if groups are independent
        # Concat over unique groups
        for arg in args[1:]:
            for group in arg._groups_all:
                if group in args_groups or group in arg0_groups:
                    msg = (
                        "Concatenating overlapping groups is not supported unless `dim` is defined."
                    )
                    msg += " Valid dimensions are `chain` and `draw`."
                    raise TypeError(msg)
                group_data = getattr(arg, group)
                args_groups[group] = deepcopy(
                    group_data) if copy else group_data
        # add arg0 to args_groups if inplace is False
        # otherwise it will merge args_groups to arg0
        # inference data object
        if not inplace:
            for group in arg0_groups:
                group_data = getattr(arg0, group)
                args_groups[group] = deepcopy(
                    group_data) if copy else group_data

        other_groups = [
            group for group in args_groups if group not in SUPPORTED_GROUPS_ALL
        ]

        for group in SUPPORTED_GROUPS_ALL + other_groups:
            if group not in args_groups:
                continue
            if inplace:
                if group.startswith(WARMUP_TAG):
                    arg0._groups_warmup.append(group)
                else:
                    arg0._groups.append(group)
                setattr(arg0, group, args_groups[group])
            else:
                inference_data_dict[group] = args_groups[group]
        if inplace:
            other_groups = [
                group
                for group in arg0_groups if group not in SUPPORTED_GROUPS_ALL
            ] + other_groups
            sorted_groups = [
                group for group in SUPPORTED_GROUPS + other_groups
                if group in arg0._groups
            ]
            setattr(arg0, "_groups", sorted_groups)
            sorted_groups_warmup = [
                group for group in SUPPORTED_GROUPS_WARMUP + other_groups
                if group in arg0._groups_warmup
            ]
            setattr(arg0, "_groups_warmup", sorted_groups_warmup)
    else:
        arg0 = args[0]
        arg0_groups = arg0._groups_all
        for arg in args[1:]:
            for group0 in arg0_groups:
                if group0 not in arg._groups_all:
                    if group0 == "observed_data":
                        continue
                    msg = "Mismatch between the groups."
                    raise TypeError(msg)
            for group in arg._groups_all:
                # handle data groups seperately
                if group not in [
                        "observed_data", "constant_data",
                        "predictions_constant_data"
                ]:
                    # assert that groups are equal
                    if group not in arg0_groups:
                        msg = "Mismatch between the groups."
                        raise TypeError(msg)

                    # assert that variables are equal
                    group_data = getattr(arg, group)
                    group_vars = group_data.data_vars

                    if not inplace and group in inference_data_dict:
                        group0_data = inference_data_dict[group]
                    else:
                        group0_data = getattr(arg0, group)
                    group0_vars = group0_data.data_vars

                    for var in group0_vars:
                        if var not in group_vars:
                            msg = "Mismatch between the variables."
                            raise TypeError(msg)

                    for var in group_vars:
                        if var not in group0_vars:
                            msg = "Mismatch between the variables."
                            raise TypeError(msg)
                        var_dims = getattr(group_data, var).dims
                        var0_dims = getattr(group0_data, var).dims
                        if var_dims != var0_dims:
                            msg = "Mismatch between the dimensions."
                            raise TypeError(msg)

                        if dim not in var_dims or dim not in var0_dims:
                            msg = "Dimension {} missing.".format(dim)
                            raise TypeError(msg)

                    # xr.concat
                    concatenated_group = xr.concat((group_data, group0_data),
                                                   dim=dim)
                    if reset_dim:
                        concatenated_group[dim] = range(
                            concatenated_group[dim].size)

                    # handle attrs
                    if hasattr(group0_data, "attrs"):
                        group0_attrs = deepcopy(getattr(group0_data, "attrs"))
                    else:
                        group0_attrs = OrderedDict()

                    if hasattr(group_data, "attrs"):
                        group_attrs = getattr(group_data, "attrs")
                    else:
                        group_attrs = dict()

                    # gather attrs results to group0_attrs
                    for attr_key, attr_values in group_attrs.items():
                        group0_attr_values = group0_attrs.get(attr_key, None)
                        equality = attr_values == group0_attr_values
                        if hasattr(equality, "__iter__"):
                            equality = np.all(equality)
                        if equality:
                            continue
                        # handle special cases:
                        if attr_key in ("created_at", "previous_created_at"):
                            # check the defaults
                            if not hasattr(group0_attrs,
                                           "previous_created_at"):
                                group0_attrs["previous_created_at"] = []
                                if group0_attr_values is not None:
                                    group0_attrs["previous_created_at"].append(
                                        group0_attr_values)
                            # check previous values
                            if attr_key == "previous_created_at":
                                if not isinstance(attr_values, list):
                                    attr_values = [attr_values]
                                group0_attrs["previous_created_at"].extend(
                                    attr_values)
                                continue
                            # update "created_at"
                            if group0_attr_values != current_time:
                                group0_attrs[attr_key] = current_time
                            group0_attrs["previous_created_at"].append(
                                attr_values)

                        elif attr_key in group0_attrs:
                            combined_key = "combined_{}".format(attr_key)
                            if combined_key not in group0_attrs:
                                group0_attrs[combined_key] = [
                                    group0_attr_values
                                ]
                            group0_attrs[combined_key].append(attr_values)
                        else:
                            group0_attrs[attr_key] = attr_values
                    # update attrs
                    setattr(concatenated_group, "attrs", group0_attrs)

                    if inplace:
                        setattr(arg0, group, concatenated_group)
                    else:
                        inference_data_dict[group] = concatenated_group
                else:
                    # observed_data, "constant_data", "predictions_constant_data",
                    if group not in arg0_groups:
                        setattr(arg0, group,
                                deepcopy(group_data) if copy else group_data)
                        arg0._groups.append(group)
                        continue

                    # assert that variables are equal
                    group_data = getattr(arg, group)
                    group_vars = group_data.data_vars

                    group0_data = getattr(arg0, group)
                    if not inplace:
                        group0_data = deepcopy(group0_data)
                    group0_vars = group0_data.data_vars

                    for var in group_vars:
                        if var not in group0_vars:
                            var_data = getattr(group_data, var)
                            getattr(arg0, group)[var] = var_data
                        else:
                            var_data = getattr(group_data, var)
                            var0_data = getattr(group0_data, var)
                            if dim in var_data.dims and dim in var0_data.dims:
                                concatenated_var = xr.concat(
                                    (group_data, group0_data), dim=dim)
                                group0_data[var] = concatenated_var

                    # handle attrs
                    if hasattr(group0_data, "attrs"):
                        group0_attrs = getattr(group0_data, "attrs")
                    else:
                        group0_attrs = OrderedDict()

                    if hasattr(group_data, "attrs"):
                        group_attrs = getattr(group_data, "attrs")
                    else:
                        group_attrs = dict()

                    # gather attrs results to group0_attrs
                    for attr_key, attr_values in group_attrs.items():
                        group0_attr_values = group0_attrs.get(attr_key, None)
                        equality = attr_values == group0_attr_values
                        if hasattr(equality, "__iter__"):
                            equality = np.all(equality)
                        if equality:
                            continue
                        # handle special cases:
                        if attr_key in ("created_at", "previous_created_at"):
                            # check the defaults
                            if not hasattr(group0_attrs,
                                           "previous_created_at"):
                                group0_attrs["previous_created_at"] = []
                                if group0_attr_values is not None:
                                    group0_attrs["previous_created_at"].append(
                                        group0_attr_values)
                            # check previous values
                            if attr_key == "previous_created_at":
                                if not isinstance(attr_values, list):
                                    attr_values = [attr_values]
                                group0_attrs["previous_created_at"].extend(
                                    attr_values)
                                continue
                            # update "created_at"
                            if group0_attr_values != current_time:
                                group0_attrs[attr_key] = current_time
                            group0_attrs["previous_created_at"].append(
                                attr_values)

                        elif attr_key in group0_attrs:
                            combined_key = "combined_{}".format(attr_key)
                            if combined_key not in group0_attrs:
                                group0_attrs[combined_key] = [
                                    group0_attr_values
                                ]
                            group0_attrs[combined_key].append(attr_values)

                        else:
                            group0_attrs[attr_key] = attr_values
                    # update attrs
                    setattr(group0_data, "attrs", group0_attrs)

                    if inplace:
                        setattr(arg0, group, group0_data)
                    else:
                        inference_data_dict[group] = group0_data

    return None if inplace else InferenceData(**inference_data_dict)
Exemplo n.º 3
0
def concat(*args, copy=True, inplace=False):
    """Concatenate InferenceData objects on a group level.

    Supports only concatenating with independent unique groups.

    Parameters
    ----------
    *args : InferenceData
        Variable length InferenceData list or
        Sequence of InferenceData.
    copy : bool
        If True, groups are copied to the new InferenceData object.
    inplace : bool
        If True, merge args to first object.

    Returns
    -------
    InferenceData
        A new InferenceData object by default.
        When `inplace==True` merge args to first arg and return `None`
    """
    if len(args) == 0:
        return InferenceData()
    if len(args) == 1 and isinstance(args[0], Sequence):
        args = args[0]
    elif len(args) == 1:
        if isinstance(args[0], InferenceData):
            if inplace:
                return None
            else:
                if copy:
                    return deepcopy(args[0])
                else:
                    return args[0]

    # assert that all args are InferenceData
    for i, arg in enumerate(args):
        if not isinstance(arg, InferenceData):
            raise TypeError(
                "Concatenating is supported only"
                "between InferenceData objects. Input arg {} is {}".format(
                    i, type(arg)))
    # assert that groups are independent
    first_arg = args[0]
    first_arg_groups = ccopy(first_arg._groups)
    args_groups = dict()
    for arg in args[1:]:
        for group in arg._groups:
            if group in args_groups or group in first_arg_groups:
                raise NotImplementedError(
                    "Concatenating with overlapping groups is not supported.")
            group_data = getattr(arg, group)
            args_groups[group] = deepcopy(group_data) if copy else group_data

    # add first_arg to args_groups if inplace is False
    if not inplace:
        for group in first_arg_groups:
            group_data = getattr(first_arg, group)
            args_groups[group] = deepcopy(group_data) if copy else group_data

    basic_order = [
        "posterior",
        "posterior_predictive",
        "sample_stats",
        "prior",
        "prior_predictive",
        "sample_stats_prior",
        "observed_data",
    ]
    other_groups = [group for group in args_groups if group not in basic_order]

    if not inplace:
        # Keep order for python 3.5
        inference_data_dict = OrderedDict()
    for group in basic_order + other_groups:
        if group not in args_groups:
            continue
        if inplace:
            first_arg._groups.append(group)
            setattr(first_arg, group, args_groups[group])
        else:
            inference_data_dict[group] = args_groups[group]
    if inplace:
        other_groups = [
            group for group in first_arg_groups if group not in basic_order
        ] + other_groups
        sorted_groups = [
            group for group in basic_order + other_groups
            if group in first_arg._groups
        ]
        setattr(first_arg, "_groups", sorted_groups)
        return None
    return InferenceData(**inference_data_dict)
Exemplo n.º 4
0
def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
    """Concatenate InferenceData objects.

    Concatenates over `group`, `chain` or `draw`.
    By default concatenates over unique groups.
    To concatenate over `chain` or `draw` function
    needs identical groups and variables.

    The `variables` in the `data` -group are merged if `dim` are not found.


    Parameters
    ----------
    *args : InferenceData
        Variable length InferenceData list or
        Sequence of InferenceData.
    dim : str, optional
        If defined, concatenated over the defined dimension.
        Dimension which is concatenated. If None, concatenates over
        unique groups.
    copy : bool
        If True, groups are copied to the new InferenceData object.
        Used only if `dim` is None.
    inplace : bool
        If True, merge args to first object.
    reset_dim : bool
        Valid only if dim is not None.

    Returns
    -------
    InferenceData
        A new InferenceData object by default.
        When `inplace==True` merge args to first arg and return `None`
    """
    # pylint: disable=undefined-loop-variable, too-many-nested-blocks
    if len(args) == 0:
        if inplace:
            return
        return InferenceData()

    if len(args) == 1 and isinstance(args[0], Sequence):
        args = args[0]

    # assert that all args are InferenceData
    for i, arg in enumerate(args):
        if not isinstance(arg, InferenceData):
            raise TypeError(
                "Concatenating is supported only"
                "between InferenceData objects. Input arg {} is {}".format(
                    i, type(arg)))

    if dim is not None and dim.lower() not in {"group", "chain", "draw"}:
        msg = "Invalid `dim`: {}. Valid `dim` are {}".format(
            dim, '{"group", "chain", "draw"}')
        raise TypeError(msg)
    dim = dim.lower() if dim is not None else dim

    if len(args) == 1 and isinstance(args[0], InferenceData):
        if inplace:
            return None
        else:
            if copy:
                return deepcopy(args[0])
            else:
                return args[0]

    current_time = str(datetime.now())

    if not inplace:
        # Keep order for python 3.5
        inference_data_dict = OrderedDict()

    if dim is None:
        arg0 = args[0]
        arg0_groups = ccopy(arg0._groups)
        args_groups = dict()
        # check if groups are independent
        # Concat over unique groups
        for arg in args[1:]:
            for group in arg._groups:
                if group in args_groups or group in arg0_groups:
                    msg = (
                        "Concatenating overlapping groups is not supported unless `dim` is defined."
                    )
                    msg += " Valid dimensions are `chain` and `draw`."
                    raise TypeError(msg)
            group_data = getattr(arg, group)
            args_groups[group] = deepcopy(group_data) if copy else group_data
        # add arg0 to args_groups if inplace is False
        if not inplace:
            for group in arg0_groups:
                group_data = getattr(arg0, group)
                args_groups[group] = deepcopy(
                    group_data) if copy else group_data

        basic_order = [
            "posterior",
            "posterior_predictive",
            "sample_stats",
            "prior",
            "prior_predictive",
            "sample_stats_prior",
            "observed_data",
        ]
        other_groups = [
            group for group in args_groups if group not in basic_order
        ]

        for group in basic_order + other_groups:
            if group not in args_groups:
                continue
            if inplace:
                arg0._groups.append(group)
                setattr(arg0, group, args_groups[group])
            else:
                inference_data_dict[group] = args_groups[group]
        if inplace:
            other_groups = [
                group for group in arg0_groups if group not in basic_order
            ] + other_groups
            sorted_groups = [
                group for group in basic_order + other_groups
                if group in arg0._groups
            ]
            setattr(arg0, "_groups", sorted_groups)
    else:
        arg0 = args[0]
        arg0_groups = arg0._groups
        for arg in args[1:]:
            for group0 in arg0_groups:
                if group0 not in arg._groups:
                    if group0 == "observed_data":
                        continue
                    msg = "Mismatch between the groups."
                    raise TypeError(msg)
            for group in arg._groups:
                if group != "observed_data":
                    # assert that groups are equal
                    if group not in arg0_groups:
                        msg = "Mismatch between the groups."
                        raise TypeError(msg)

                    # assert that variables are equal
                    group_data = getattr(arg, group)
                    group_vars = group_data.data_vars

                    if not inplace and group in inference_data_dict:
                        group0_data = inference_data_dict[group]
                    else:
                        group0_data = getattr(arg0, group)
                    group0_vars = group0_data.data_vars

                    for var in group0_vars:
                        if var not in group_vars:
                            msg = "Mismatch between the variables."
                            raise TypeError(msg)

                    for var in group_vars:
                        if var not in group0_vars:
                            msg = "Mismatch between the variables."
                            raise TypeError(msg)
                        var_dims = getattr(group_data, var).dims
                        var0_dims = getattr(group0_data, var).dims
                        if var_dims != var0_dims:
                            msg = "Mismatch between the dimensions."
                            raise TypeError(msg)

                        if dim not in var_dims or dim not in var0_dims:
                            msg = "Dimension {} missing.".format(dim)
                            raise TypeError(msg)

                    # xr.concat
                    concatenated_group = xr.concat((group_data, group0_data),
                                                   dim=dim)
                    if reset_dim:
                        concatenated_group[dim] = range(
                            concatenated_group[dim].size)

                    # handle attrs
                    if hasattr(group0_data, "attrs"):
                        group0_attrs = deepcopy(getattr(group0_data, "attrs"))
                    else:
                        group0_attrs = OrderedDict()

                    if hasattr(group_data, "attrs"):
                        group_attrs = getattr(group_data, "attrs")
                    else:
                        group_attrs = dict()

                    # gather attrs results to group0_attrs
                    for attr_key, attr_values in group_attrs.items():
                        group0_attr_values = group0_attrs.get(attr_key, None)
                        equality = attr_values == group0_attr_values
                        if hasattr(equality, "__iter__"):
                            equality = np.all(equality)
                        if equality:
                            continue
                        # handle special cases:
                        if attr_key in ("created_at", "previous_created_at"):
                            # check the defaults
                            if not hasattr(group0_attrs,
                                           "previous_created_at"):
                                group0_attrs["previous_created_at"] = []
                                if group0_attr_values is not None:
                                    group0_attrs["previous_created_at"].append(
                                        group0_attr_values)
                            # check previous values
                            if attr_key == "previous_created_at":
                                if not isinstance(attr_values, list):
                                    attr_values = [attr_values]
                                group0_attrs["previous_created_at"].extend(
                                    attr_values)
                                continue
                            # update "created_at"
                            if group0_attr_values != current_time:
                                group0_attrs[attr_key] = current_time
                            group0_attrs["previous_created_at"].append(
                                attr_values)

                        elif attr_key in group0_attrs:
                            combined_key = "combined_{}".format(attr_key)
                            if combined_key not in group0_attrs:
                                group0_attrs[combined_key] = [
                                    group0_attr_values
                                ]
                            group0_attrs[combined_key].append(attr_values)
                        else:
                            group0_attrs[attr_key] = attr_values
                    # update attrs
                    setattr(concatenated_group, "attrs", group0_attrs)

                    if inplace:
                        setattr(arg0, group, concatenated_group)
                    else:
                        inference_data_dict[group] = concatenated_group
                else:
                    # observed_data
                    if group not in arg0_groups:
                        setattr(arg0, group,
                                deepcopy(group_data) if copy else group_data)
                        arg0._groups.append(group)
                        continue

                    # assert that variables are equal
                    group_data = getattr(arg, group)
                    group_vars = group_data.data_vars

                    group0_data = getattr(arg0, group)
                    if not inplace:
                        group0_data = deepcopy(group0_data)
                    group0_vars = group0_data.data_vars

                    for var in group_vars:
                        if var not in group0_vars:
                            var_data = getattr(group_data, var)
                            arg0.observed_data[var] = var_data
                        else:
                            var_data = getattr(group_data, var)
                            var0_data = getattr(group0_data, var)
                            if dim in var_data.dims and dim in var0_data.dims:
                                concatenated_var = xr.concat(
                                    (group_data, group0_data), dim=dim)
                                group0_data[var] = concatenated_var

                    # handle attrs
                    if hasattr(group0_data, "attrs"):
                        group0_attrs = getattr(group0_data, "attrs")
                    else:
                        group0_attrs = OrderedDict()

                    if hasattr(group_data, "attrs"):
                        group_attrs = getattr(group_data, "attrs")
                    else:
                        group_attrs = dict()

                    # gather attrs results to group0_attrs
                    for attr_key, attr_values in group_attrs.items():
                        group0_attr_values = group0_attrs.get(attr_key, None)
                        equality = attr_values == group0_attr_values
                        if hasattr(equality, "__iter__"):
                            equality = np.all(equality)
                        if equality:
                            continue
                        # handle special cases:
                        if attr_key in ("created_at", "previous_created_at"):
                            # check the defaults
                            if not hasattr(group0_attrs,
                                           "previous_created_at"):
                                group0_attrs["previous_created_at"] = []
                                if group0_attr_values is not None:
                                    group0_attrs["previous_created_at"].append(
                                        group0_attr_values)
                            # check previous values
                            if attr_key == "previous_created_at":
                                if not isinstance(attr_values, list):
                                    attr_values = [attr_values]
                                group0_attrs["previous_created_at"].extend(
                                    attr_values)
                                continue
                            # update "created_at"
                            if group0_attr_values != current_time:
                                group0_attrs[attr_key] = current_time
                            group0_attrs["previous_created_at"].append(
                                attr_values)

                        elif attr_key in group0_attrs:
                            combined_key = "combined_{}".format(attr_key)
                            if combined_key not in group0_attrs:
                                group0_attrs[combined_key] = [
                                    group0_attr_values
                                ]
                            group0_attrs[combined_key].append(attr_values)

                        else:
                            group0_attrs[attr_key] = attr_values
                    # update attrs
                    setattr(group0_data, "attrs", group0_attrs)

                    if inplace:
                        setattr(arg0, group, group0_data)
                    else:
                        inference_data_dict[group] = group0_data

    return None if inplace else InferenceData(**inference_data_dict)
Exemplo n.º 5
0
 def copy_header(self):
     ne = ccopy(self)
     ne.matchedtrades = []  #清除trades
     ne.datemap = {}  #清除datemap
     return ne
Exemplo n.º 6
0
 def copy(self):
     return ccopy(self)
Exemplo n.º 7
0
 def flush(self):
     """Return the whole stack and delete it afterwards"""
     old = ccopy(self._stack)
     del self._stack[:]
     return Stack(old)
Exemplo n.º 8
0
 def lookup(self):
     """
     Return the stack without touch the stack itself.
     That way you can use it to lookup the current stack.
     """
     return Stack(ccopy(self._stack))
Exemplo n.º 9
0
 def flush(self):
     """Return the whole stack and delete it afterwards"""
     old = ccopy(self._stack)
     del self._stack[:]
     return Stack(old)
Exemplo n.º 10
0
 def lookup(self):
     """
     Return the stack without touch the stack itself.
     That way you can use it to lookup the current stack.
     """
     return Stack(ccopy(self._stack))
Exemplo n.º 11
0
 def copy_header(self):
     ne = ccopy(self)
     ne.matchedtrades = []   #清除trades
     ne.datemap = {}         #清除datemap   
     return ne
Exemplo n.º 12
0
 def copy(self):
     return ccopy(self)
Exemplo n.º 13
0
def device_selected(device):
    close_device()
    app.gui_panel.btn_chkaddr.Hide()
    app.gui_panel.coins_choice.Disable()
    app.gui_panel.coins_choice.SetSelection(0)
    app.gui_panel.network_choice.Clear()
    app.gui_panel.network_choice.Disable()
    app.gui_panel.wallopt_choice.Clear()
    erase_info(True)
    sel_device = device.GetInt()
    device_sel_name = DEVICES_LIST[sel_device - 1]
    coins_list = ccopy(SUPPORTED_COINS)
    if device_sel_name == "Ledger":
        coins_list = [
            "ETH",
            "BSC",
            "MATIC",
            "FTM",
            "OP",
            "METIS",
            "CELO",
            "GLMR",
            "ARB",
            "AVAX",
        ]
    if device_sel_name == "OpenPGP":
        coins_list.remove("SOL")
    if device_sel_name == "Cryptnox":
        coins_list.remove("SOL")
    app.load_coins_list(coins_list)
    if sel_device == 1:
        # Seed Watcher
        start_seedwatcher(app, cb_open_wallet)
    if sel_device > 1:
        # Real keys device
        the_device = get_device_class(device_sel_name)
        try:
            device_loaded = the_device()
        except Exception as exc:
            app.gui_panel.devices_choice.SetSelection(0)
            logger.error("Error during device loading : %s",
                         str(exc),
                         exc_info=exc,
                         stack_info=True)
            warn_modal(str(exc))
            return
        pin_left = -1
        password_default = device_loaded.default_password
        while True:
            try:
                pwd_pin = the_device.password_name
                if the_device.has_password:
                    if (not the_device.password_retries_inf and pin_left == -1
                            and device_loaded.is_init()):
                        # Goto password exception to ask for user PIN
                        pin_left = device_loaded.get_pw_left()
                        if the_device.is_HD:
                            HDwallet_settings = app.hd_setup("")
                            if HDwallet_settings is None:
                                app.gui_panel.devices_choice.SetSelection(0)
                                return
                        raise pwdException
                    # Can raise notinit
                    device_loaded.open_account(password_default)
                else:
                    device_loaded.open_account()
                break
            except NotinitException:
                if the_device.is_HD:
                    mnemonic = ""
                    if not the_device.internally_gen_keys:
                        # Mnemonic generated by the software wallet
                        mnemonic = device_loaded.generate_mnemonic(
                        )  # mnemonic proposal
                    # Get settings from the user
                    HDwallet_settings = app.hd_setup(mnemonic)
                    if HDwallet_settings is None:
                        app.gui_panel.devices_choice.SetSelection(0)
                        return
                if the_device.has_admin_password:
                    set_admin_message = (
                        f"Choose the {the_device.admin_pass_name} to init "
                        f"the {device_sel_name} device.\n")
                    set_admin_message += "\nFor demo, quick insecure setup, it can be left blank,\n"
                    set_admin_message += (
                        f"and a default {the_device.admin_pass_name} will be used.\n\n"
                    )
                    lenmsg = (
                        str(device_loaded.admin_pwd_minlen)
                        if device_loaded.admin_pwd_minlen
                        == device_loaded.admin_pwd_maxlen else
                        f"between {device_loaded.admin_pwd_minlen} and {device_loaded.admin_pwd_maxlen}"
                    )
                    set_admin_message += (
                        f"The chosen {the_device.admin_pass_name} must be\n{lenmsg} chars long."
                    )
                    while True:
                        admin_password = get_password(device_sel_name,
                                                      set_admin_message)
                        if admin_password is None:
                            app.gui_panel.devices_choice.SetSelection(0)
                            return
                        if admin_password == "":
                            admin_password = device_loaded.default_admin_password
                        if (len(admin_password) >=
                                device_loaded.admin_pwd_minlen
                                and len(admin_password) <=
                                device_loaded.admin_pwd_maxlen):
                            break
                        warn_modal(
                            f"{the_device.admin_pass_name} shall be {lenmsg} chars long.",
                            True,
                        )
                if the_device.has_password:
                    inp_message = f"Choose your {pwd_pin} to setup the {device_sel_name} wallet.\n"
                    inp_message += "\nFor demo, quick insecure setup, it can be left blank,\n"
                    inp_message += f"and a default {pwd_pin} will be used.\n\n"
                    lenmsg = (
                        str(device_loaded.password_min_len)
                        if device_loaded.password_min_len
                        == device_loaded.password_max_len else
                        f"between {device_loaded.password_min_len} and {device_loaded.password_max_len}"
                    )
                    pintype = "chars"
                    if device_loaded.is_pin_numeric:
                        pintype = "digits"
                    inp_message += f"The chosen {pwd_pin} must be\n{lenmsg} {pintype} long."
                    while True:
                        password = get_password(device_sel_name, inp_message)
                        if password is None:
                            app.gui_panel.devices_choice.SetSelection(0)
                            return
                        if password == "":
                            password = device_loaded.default_password
                        if (len(password) >= device_loaded.password_min_len and
                                len(password) <= device_loaded.password_max_len
                                and (not device_loaded.is_pin_numeric
                                     or password.isdigit())):
                            break
                        wmsg = f"Device {pwd_pin} shall be {lenmsg} {pintype} long."
                        if device_loaded.is_pin_numeric:
                            wmsg += f"\n\nThe {pwd_pin} must be {pintype} (0-9) only."
                        warn_modal(
                            wmsg,
                            True,
                        )
                try:
                    if the_device.has_admin_password:
                        device_loaded.set_admin(admin_password)
                    if the_device.is_HD:
                        if the_device.has_password:
                            HDwallet_settings["file_password"] = password
                        device_loaded.initialize_device(HDwallet_settings)
                    elif the_device.has_password:
                        device_loaded.initialize_device(password)
                    else:
                        device_loaded.initialize_device()
                    break
                except Exception as exc:
                    app.gui_panel.devices_choice.SetSelection(0)
                    logger.error(
                        "Error during device initialization : %s",
                        {str(exc)},
                        exc_info=exc,
                        stack_info=True,
                    )
                    warn_modal(str(exc))
                    return
            except pwdException as excp:
                if not device_loaded.password_retries_inf:
                    try:
                        pin_left = device_loaded.get_pw_left()
                    except Exception as exc:
                        device_error(exc)
                        return
                    if pin_left == 0:
                        warn_modal(f"Device {pwd_pin} is locked.")
                        return
                    if (the_device.password_softlock > 0
                            and pin_left == the_device.password_softlock
                            and str(excp) == "0"):
                        warn_modal(
                            f"Device {pwd_pin} is soft locked. Restart it to try again."
                        )
                        return
                while True:
                    inp_message = f"Input your {device_sel_name} wallet {pwd_pin}.\n"
                    if not the_device.password_retries_inf:
                        inp_message += (
                            f"{pin_left} {pwd_pin} tr{'ies' if pin_left >=2 else 'y'}"
                            " left on this device.\n")
                    lenmsg = (
                        str(device_loaded.password_min_len)
                        if device_loaded.password_min_len
                        == device_loaded.password_max_len else
                        f"between {device_loaded.password_min_len} and {device_loaded.password_max_len}"
                    )
                    pintype = "chars"
                    if device_loaded.is_pin_numeric:
                        pintype = "digits"
                    inp_message += f"\nThe {pwd_pin} to provide\nis {lenmsg} {pintype} long."
                    password_default = get_password(device_sel_name,
                                                    inp_message)
                    if password_default is None:
                        app.gui_panel.devices_choice.SetSelection(0)
                        return
                    if (len(password_default) >= device_loaded.password_min_len
                            and len(password_default) <=
                            device_loaded.password_max_len
                            and (not device_loaded.is_pin_numeric
                                 or password_default.isdigit())):
                        break
                    wmsg = f"Device {pwd_pin} shall be {lenmsg} {pintype} long."
                    if device_loaded.is_pin_numeric:
                        wmsg += f"\n\nThe {pwd_pin} must be {pintype} (0-9) only."
                    warn_modal(
                        wmsg,
                        True,
                    )
            except Exception as exc:
                return device_error(exc)
        wx.MilliSleep(100)
        if the_device.has_password and the_device.is_HD and not the_device.password_retries_inf:
            # Kind of special for Cryptnox for now
            device_loaded.set_path(HDwallet_settings)
        app.device = device_loaded
        if app.device.created:
            info_modal(
                "Device created",
                f"A new {device_sel_name} device was successfully created.",
            )
        app.gui_panel.coins_choice.Enable()
        app.gui_panel.coins_choice.Hide()
        app.gui_panel.coins_choice.ShowWithEffect(wx.SHOW_EFFECT_ROLL_TO_RIGHT,
                                                  750)
        app.gui_panel.coins_choice.SetFocus()
    else:
        erase_info(True)