def get_C8_Decoder(name, dim_cov_est, context_rep_ids):
    N = 8
    flip = False
    #Family of decoders using purely regular fiber representations:
    if name == "regular_little":
        hidden_reps_ids = 2 * [3 * [-1]]
        kernel_sizes = [j for j in range(7, 20, 6)]
        non_linearity = ['ReLU']

    elif name == "regular_small":
        hidden_reps_ids = 4 * [4 * [-1]]
        kernel_sizes = [j for j in range(3, 20, 4)]
        non_linearity = ['ReLU']

    elif name == "regular_middle":
        hidden_reps_ids = 6 * [12 * [-1]]
        kernel_sizes = [3, 3, 5, 7, 7, 11, 11]
        non_linearity = ['ReLU']

    elif name == "regular_big":
        hidden_reps_ids = 6 * [24 * [-1]]
        kernel_sizes = [5, 5, 5, 7, 7, 11, 11]
        non_linearity = ['ReLU']

    elif name == "regular_huge":
        hidden_reps_ids = 8 * [24 * [-1]]
        kernel_sizes = [5, 5, 5, 5, 7, 7, 9, 15, 21]
        non_linearity = ['ReLU']
    else:
        sys.exit("Unkown architecture name.")

    return (AC.SteerDecoder(hidden_reps_ids=hidden_reps_ids,
                            kernel_sizes=kernel_sizes,
                            dim_cov_est=dim_cov_est,
                            context_rep_ids=context_rep_ids,
                            N=N,
                            flip=flip,
                            non_linearity=non_linearity))
def get_Flip_Decoder(name, dim_cov_est, context_rep_ids):
    N = 2
    flip = True
    #Family of decoders using purely regular fiber representations:
    if name == "regular_little":
        hidden_reps_ids = 2 * [3 * [-1]]
        kernel_sizes = [j for j in range(7, 20, 6)]
        non_linearity = ['ReLU']

    elif name == "regular_small":
        hidden_reps_ids = 4 * [4 * [-1]]
        kernel_sizes = [j for j in range(3, 20, 4)]
        non_linearity = ['ReLU']

    elif name == "regular_middle":
        hidden_reps_ids = 6 * [12 * [-1]]
        kernel_sizes = [3, 3, 5, 7, 7, 11, 11]
        non_linearity = ['ReLU']

    elif name == "regular_big":
        hidden_reps_ids = 6 * [24 * [-1]]
        kernel_sizes = [5, 5, 5, 7, 7, 11, 11]
        non_linearity = ['ReLU']

    elif name == "regular_huge":
        hidden_reps_ids = 8 * [24 * [-1]]
        kernel_sizes = [5, 5, 5, 5, 7, 7, 9, 15, 21]
        non_linearity = ['ReLU']

    #Family of decoders using irreps and regular representations:

    elif name == "irrep_little":
        if flip:
            hidden_reps_ids = 2 * [4 * [[1, 1]]]
        else:
            hidden_reps_ids = 2 * [4 * [1]]
        kernel_sizes = [j for j in range(7, 20, 6)]
        non_linearity = ['NormReLU']

    elif name == "irrep_small":
        if flip:
            hidden_reps_ids = 5 * [6 * [[1, 1]]]
        else:
            hidden_reps_ids = 5 * [6 * [1]]
        kernel_sizes = [j for j in range(3, 24, 4)]
        non_linearity = ['NormReLU']

    elif name == "irrep_middle":
        if flip:
            hidden_reps_ids = 7 * [18 * [[1, 1]]]
        else:
            hidden_reps_ids = 7 * [18 * [1]]
        kernel_sizes = [3, 3, 5, 5, 11, 11, 13, 13]
        non_linearity = ['NormReLU']

    elif name == "irrep_big":
        if flip:
            hidden_reps_ids = 8 * [32 * [[1, 1]]]
        else:
            hidden_reps_ids = 8 * [32 * [1]]
        kernel_sizes = [5, 5, 7, 7, 11, 13, 15, 17, 19]
        non_linearity = ['NormReLU']

    elif name == "irrep_huge":
        if flip:
            hidden_reps_ids = 10 * [40 * [[1, 1]]]
        else:
            hidden_reps_ids = 10 * [40 * [1]]
        kernel_sizes = [5, 5, 7, 7, 11, 11, 11, 13, 17, 19, 21]
        non_linearity = ['NormReLU']
    else:
        sys.exit("Unkown architecture name.")

    return (AC.SteerDecoder(hidden_reps_ids=hidden_reps_ids,
                            kernel_sizes=kernel_sizes,
                            dim_cov_est=dim_cov_est,
                            context_rep_ids=context_rep_ids,
                            N=N,
                            flip=flip,
                            non_linearity=non_linearity))