def retrieve_params(mlp_full, mlp_sub, dropout, exo_dropout=0.):
    # get the list of parameters
    assert exo_dropout >= 0 and exo_dropout < 1
    assert dropout >= 0 and dropout < 1
    if isinstance(mlp_sub, CompositeSequence):
        L_W_full, L_B_full = mlp_full.get_Params()
        L_W_sub, L_B_sub = mlp_sub.get_Params()
    else:
        L_W_full, L_B_full = mlp_full.get_Params(option="")
        x = T.matrix()
        cost = mlp_sub.apply(x).sum()
        cg = ComputationGraph(cost)
        W = VariableFilter(roles=[WEIGHT])(cg.variables)
        B = VariableFilter(roles=[BIAS])(cg.variables)
        if W[0].shape.eval()[1] != W[1].shape.eval()[0]:
            W.reverse()
            B.reverse()
        for w, b, index in zip(W, B, range(len(W))):
            w.name = "layer_" + str(index) + "_W"
            b.name = "layer_" + str(index) + "_B"
        L_W_sub = W
        L_B_sub = B

    # sort the lists here
    L_W_full.sort(key=lambda e: analyze_param_name(e.name)[1])
    L_W_sub.sort(key=lambda e: analyze_param_name(e.name)[1])
    L_B_full.sort(key=lambda e: analyze_param_name(e.name)[1])
    L_B_sub.sort(key=lambda e: analyze_param_name(e.name)[1])

    # get the splits
    splits_W, splits_B, rough_kinds = dropout_indices(L_W_full, L_B_full,
                                                      dropout)
    # WHEN APPLYING DROPOUT ONE MUST APPLY SCALING !
    proba = 1. / (1 - dropout)
    for w_full, w_sub, b_full, \
        b_sub in zip(L_W_full, L_W_sub,
                                L_B_full, L_B_sub):
        if w_sub.name != w_full.name:
            raise Exception("names %s and %s do not match", w_sub.name,
                            w_full.name)
        if b_sub.name != b_full.name:
            raise Exception("names %s and %s do not match", b_sub.name,
                            b_full.name)
        #get the indices :
        (layer_name, layer_number, _) = analyze_param_name(w_full.name)
        layer_name = layer_name + str(layer_number)
        w_split = splits_W[layer_name]
        b_split = splits_B[layer_name]
        rough_kind = rough_kinds[layer_name]
        if rough_kind == "FULLY_CONNECTED":
            w_sub.set_value(w_full.get_value()[w_split[0]][:, w_split[1]] *
                            proba)
            b_sub.set_value(b_full.get_value()[b_split] * proba)

        elif rough_kind == "CONV_FILTER":
            w_sub.set_value(w_full.get_value()[w_split[1]][:, w_split[0]] *
                            proba)
            b_sub.set_value(b_full.get_value()[b_split] * proba)

        elif rough_kind == "BATCH_NORM":
            w_sub.set_value(w_full.get_value()[w_split[0]] * proba)
            b_sub.set_value(b_full.get_value()[b_split] * proba)

        else:
            raise Exception("unknown type :%s", rough_kind)