Example #1
0
def update(filename, data, validate=True, transform=None):
    """Update the file named filename with data.

    :param filename: the XML filename we should update
    :param data: the result of the submitted data.
    :param validate: validate the updated XML before writing it.
    :type filename: str
    :type data: dict style like: dict, webob.MultiDict, ...
    :type validate: bool
    :param transform: function to transform the XML string just before
        writing it.
    :type transform: function
    :return: the object generated from the data
    :rtype: :class:`Element`
    """
    data = utils.unflatten_params(data)
    encoding = data.pop('_xml_encoding')
    dtd_url = data.pop('_xml_dtd_url')

    if len(data) != 1:
        raise Exception('Bad data')

    root_tag = data.keys()[0]
    dic = dtd_parser.parse(dtd_url=dtd_url, path=os.path.dirname(filename))
    obj = dic[root_tag]()

    obj.load_from_dict(data)
    obj.write(filename,
              encoding,
              dtd_url=dtd_url,
              validate=validate,
              transform=transform)
    return obj
Example #2
0
def permutify(params1, params2):
    """Permute the parameters of params2 to match params1 as closely as possible.
  Returns the permuted version of params2. Only works on sequences of Dense
  layers for now."""
    p1f = flatten_params(params1)
    p2f = flatten_params(params2)

    p2f_new = {**p2f}
    num_layers = max(
        int(kmatch("**/Dense_*/**", k).group(2)) for k in p1f.keys())
    # range is [0, num_layers), so we're safe here since we don't want to be
    # reordering the output of the last layer.
    for layer in range(num_layers):
        # Maximize since we're dealing with similarities, not distances.
        ri, ci = linear_sum_assignment(cosine_similarity(
            p1f[f"params/Dense_{layer}/kernel"].T,
            p2f_new[f"params/Dense_{layer}/kernel"].T),
                                       maximize=True)
        assert (ri == jnp.arange(len(ri))).all()

        p2f_new = {
            **p2f_new, f"params/Dense_{layer}/kernel":
            p2f_new[f"params/Dense_{layer}/kernel"][:, ci],
            f"params/Dense_{layer}/bias":
            p2f_new[f"params/Dense_{layer}/bias"][ci],
            f"params/Dense_{layer+1}/kernel":
            p2f_new[f"params/Dense_{layer+1}/kernel"][ci, :]
        }

    new_params2 = unflatten_params(p2f_new)

    return new_params2
Example #3
0
def update(filename, data, validate=True, transform=None):
    """Update the file named filename with data.

    :param filename: the XML filename we should update
    :param data: the result of the submitted data.
    :param validate: validate the updated XML before writing it.
    :type filename: str
    :type data: dict style like: dict, webob.MultiDict, ...
    :type validate: bool
    :param transform: function to transform the XML string just before
        writing it.
    :type transform: function
    :return: the object generated from the data
    :rtype: :class:`Element`
    """
    data = utils.unflatten_params(data)
    encoding = data.pop('_xml_encoding')
    dtd_url = data.pop('_xml_dtd_url')

    if len(data) != 1:
        raise Exception('Bad data')

    root_tag = data.keys()[0]

    dic = dtd.DTD(dtd_url, path=os.path.dirname(filename)).parse()
    obj = dic[root_tag]()

    obj.load_from_dict(data)
    obj.write(filename, encoding, dtd_url=dtd_url, validate=validate,
              transform=transform)
    return obj
Example #4
0
def getElementData(elt_id, data):
    """Get the dic from data to load last element of elt_id
    """
    data = utils.unflatten_params(data)
    lis = elt_id.split(':')
    tagname = lis[-1]
    for v in lis:
        try:
            if isinstance(data, list):
                v = int(v)
            data = data[v]
        except (KeyError, IndexError):
            data = {}
            break
    return {tagname: data}
Example #5
0
def getElementData(elt_id, data):
    """Get the dic from data to load last element of elt_id
    """
    data = utils.unflatten_params(data)
    lis = elt_id.split(':')
    tagname = lis[-1]
    for v in lis:
        try:
            if isinstance(data, list):
                v = int(v)
            data = data[v]
        except (KeyError, IndexError):
            data = {}
            break
    return {tagname: data}
Example #6
0
        return v
    elif match := kmatch("**/last/kernel", k):
        # Only drop inputs to the last layer.
        prev = max(
            int(kmatch("**/OGDense_*/**", k).group(2))
            for k in only_gains_final_params_flat.keys()
            if kmatch("**/OGDense_*/**", k))
        return v[gain_mask[match.group(1) + f"/OGDense_{prev}/gain"], :]

    else:
        raise ValueError(f"Unknown key: {k}")


# TODO: Shouldn't this be only_gains_init_params_flat instead of the final ones?
lottery_init_params = unflatten_params(
    {k: _lotteryify(k, v)
     for k, v in only_gains_final_params_flat.items()})
print("  lottery ticket param shapes:")
shapes = tree_map(jnp.shape, flatten_params(lottery_init_params))
print(shapes)

print("  structured LTH, train all params...")
# See https://github.com/google/flax/discussions/1555.
net = make_net([shapes["params/first/gain"][0]] + [
    shapes[f"params/OGDense_{i}/gain"][0] for i in range(len(gain_params) - 1)
])
train(net,
      init_params=lottery_init_params,
      trainable_predicate=lambda k: True,
      log_prefix="only_gain_lottery")