def get_C8_Decoder(name, dim_cov_est, context_rep_ids): N = 8 flip = False #Family of decoders using purely regular fiber representations: if name == "regular_little": hidden_reps_ids = 2 * [3 * [-1]] kernel_sizes = [j for j in range(7, 20, 6)] non_linearity = ['ReLU'] elif name == "regular_small": hidden_reps_ids = 4 * [4 * [-1]] kernel_sizes = [j for j in range(3, 20, 4)] non_linearity = ['ReLU'] elif name == "regular_middle": hidden_reps_ids = 6 * [12 * [-1]] kernel_sizes = [3, 3, 5, 7, 7, 11, 11] non_linearity = ['ReLU'] elif name == "regular_big": hidden_reps_ids = 6 * [24 * [-1]] kernel_sizes = [5, 5, 5, 7, 7, 11, 11] non_linearity = ['ReLU'] elif name == "regular_huge": hidden_reps_ids = 8 * [24 * [-1]] kernel_sizes = [5, 5, 5, 5, 7, 7, 9, 15, 21] non_linearity = ['ReLU'] else: sys.exit("Unkown architecture name.") return (AC.SteerDecoder(hidden_reps_ids=hidden_reps_ids, kernel_sizes=kernel_sizes, dim_cov_est=dim_cov_est, context_rep_ids=context_rep_ids, N=N, flip=flip, non_linearity=non_linearity))
def get_Flip_Decoder(name, dim_cov_est, context_rep_ids): N = 2 flip = True #Family of decoders using purely regular fiber representations: if name == "regular_little": hidden_reps_ids = 2 * [3 * [-1]] kernel_sizes = [j for j in range(7, 20, 6)] non_linearity = ['ReLU'] elif name == "regular_small": hidden_reps_ids = 4 * [4 * [-1]] kernel_sizes = [j for j in range(3, 20, 4)] non_linearity = ['ReLU'] elif name == "regular_middle": hidden_reps_ids = 6 * [12 * [-1]] kernel_sizes = [3, 3, 5, 7, 7, 11, 11] non_linearity = ['ReLU'] elif name == "regular_big": hidden_reps_ids = 6 * [24 * [-1]] kernel_sizes = [5, 5, 5, 7, 7, 11, 11] non_linearity = ['ReLU'] elif name == "regular_huge": hidden_reps_ids = 8 * [24 * [-1]] kernel_sizes = [5, 5, 5, 5, 7, 7, 9, 15, 21] non_linearity = ['ReLU'] #Family of decoders using irreps and regular representations: elif name == "irrep_little": if flip: hidden_reps_ids = 2 * [4 * [[1, 1]]] else: hidden_reps_ids = 2 * [4 * [1]] kernel_sizes = [j for j in range(7, 20, 6)] non_linearity = ['NormReLU'] elif name == "irrep_small": if flip: hidden_reps_ids = 5 * [6 * [[1, 1]]] else: hidden_reps_ids = 5 * [6 * [1]] kernel_sizes = [j for j in range(3, 24, 4)] non_linearity = ['NormReLU'] elif name == "irrep_middle": if flip: hidden_reps_ids = 7 * [18 * [[1, 1]]] else: hidden_reps_ids = 7 * [18 * [1]] kernel_sizes = [3, 3, 5, 5, 11, 11, 13, 13] non_linearity = ['NormReLU'] elif name == "irrep_big": if flip: hidden_reps_ids = 8 * [32 * [[1, 1]]] else: hidden_reps_ids = 8 * [32 * [1]] kernel_sizes = [5, 5, 7, 7, 11, 13, 15, 17, 19] non_linearity = ['NormReLU'] elif name == "irrep_huge": if flip: hidden_reps_ids = 10 * [40 * [[1, 1]]] else: hidden_reps_ids = 10 * [40 * [1]] kernel_sizes = [5, 5, 7, 7, 11, 11, 11, 13, 17, 19, 21] non_linearity = ['NormReLU'] else: sys.exit("Unkown architecture name.") return (AC.SteerDecoder(hidden_reps_ids=hidden_reps_ids, kernel_sizes=kernel_sizes, dim_cov_est=dim_cov_est, context_rep_ids=context_rep_ids, N=N, flip=flip, non_linearity=non_linearity))