Exemple #1
0
def create_eeg_glow_up(n_chans, hidden_channels, kernel_length, splitter_first,
                       splitter_last, n_blocks):
    def get_splitter(splitter_name, chunk_chans_first):
        if splitter_name == 'haar':
            return Haar1dWavelet(chunk_chans_first=False)
        else:
            assert splitter_name == 'subsample'
            return SubsampleSplitter((2, ),
                                     chunk_chans_first=chunk_chans_first)

    block_a = InvertibleSequential(
        get_splitter(splitter_first, chunk_chans_first=False),
        *[
            conv_flow_block(n_chans * 2,
                            hidden_channels=hidden_channels,
                            kernel_length=kernel_length)
            for _ in range(n_blocks)
        ],
        get_splitter(splitter_last, chunk_chans_first=True),
    )

    block_b = InvertibleSequential(
        get_splitter(splitter_first, chunk_chans_first=False),
        *[
            conv_flow_block(n_chans * 4,
                            hidden_channels=hidden_channels,
                            kernel_length=kernel_length)
            for _ in range(n_blocks)
        ],
        get_splitter(splitter_last, chunk_chans_first=True),
    )

    block_c = InvertibleSequential(
        get_splitter(splitter_first, chunk_chans_first=False),
        Flatten2d(),
        *[
            dense_flow_block(
                n_chans * 64 // 4,
                hidden_channels=hidden_channels,
            ) for _ in range(n_blocks)
        ],
    )

    n_a = Node(None, block_a)
    n_a_split = Node(n_a, ChunkChans(2))
    n_a_out = SelectNode(n_a_split, 1)
    n_b = Node(SelectNode(n_a_split, 0), block_b)
    n_b_split = Node(n_b, ChunkChans(2))
    n_b_out = SelectNode(n_b_split, 1)
    n_c = Node(SelectNode(n_b_split, 0), block_c)
    n_c_out = n_c
    n_merged = CatAsListNode([n_a_out, n_b_out, n_c_out])
    return n_merged
Exemple #2
0
def flow_block_down_1(n_chans, hidden_channels):
    in_to_outs = [[
        nn.Sequential(
            nn.Conv1d(n_chans * 2, hidden_channels, 5, padding=5 // 2)),
        nn.Sequential(
            nn.Conv1d(n_chans * 8, hidden_channels, 3, padding=3 // 2),
            nn.Upsample(scale_factor=2))
    ]]
    m = nn.Sequential(
        MultipleInputOutput(in_to_outs),
        Expression(unwrap_single_element),
        nn.Conv1d(hidden_channels, n_chans * 4 // 2, 5, padding=5 // 2),
        MultiplyFactors(n_chans * 4 // 2),
    )
    c = CouplingLayer(ChunkChansIn2(swap_dims=False),
                      AdditiveCoefs(m),
                      AffineModifier('sigmoid', add_first=True, eps=0),
                      condition_merger=CatCondAsList(cond_preproc=None))

    f = InvertibleSequential(
        ActNorm(
            n_chans * 4,
            'exp',
        ),
        InvPermute(n_chans * 4, fixed=False, use_lu=True),
        c,
    )
    return f
Exemple #3
0
def dense_flow_block(n_chans, hidden_channels):
    return InvertibleSequential(
        ActNorm(
            n_chans,
            'exp',
        ), InvPermute(n_chans, fixed=False, use_lu=True),
        CouplingLayer(
            ChunkChansIn2(swap_dims=False),
            AdditiveCoefs(
                nn.Sequential(
                    nn.Linear(n_chans // 2, hidden_channels),
                    nn.ELU(),
                    nn.Linear(hidden_channels, n_chans // 2),
                    MultiplyFactors(n_chans // 2),
                )), AffineModifier('sigmoid', add_first=True, eps=0)))
Exemple #4
0
def create_eeg_glow_no_dist(n_chans, hidden_channels, kernel_length):
    block_a_out = InvertibleSequential(
        AmplitudePhase(),
        Flatten2d(),
    )
    block_b_out = InvertibleSequential(
        AmplitudePhase(),
        Flatten2d(),
    )
    block_c_out = InvertibleSequential(
        AmplitudePhase(),
        Flatten2d(),
    )

    n_before_block = create_eeg_glow_before_amp_phase(
        n_chans=n_chans,
        hidden_channels=hidden_channels,
        kernel_length=kernel_length)
    n_a_out = Node(SelectNode(n_before_block, 0), block_a_out)
    n_b_out = Node(SelectNode(n_before_block, 1), block_b_out)
    n_c_out = Node(SelectNode(n_before_block, 2), block_c_out)
    n_merged = CatAsListNode([n_a_out, n_b_out, n_c_out])
    net = n_merged
    return n_merged
Exemple #5
0
def conv_flow_block(n_chans, hidden_channels, kernel_length):
    assert kernel_length % 2 == 1
    return InvertibleSequential(
        ActNorm(
            n_chans,
            'exp',
        ), InvPermute(n_chans, fixed=False, use_lu=True),
        CouplingLayer(
            ChunkChansIn2(swap_dims=False),
            AdditiveCoefs(
                nn.Sequential(
                    nn.Conv1d(n_chans // 2,
                              hidden_channels,
                              kernel_length,
                              padding=kernel_length // 2),
                    nn.ELU(),
                    nn.Conv1d(hidden_channels,
                              n_chans // 2,
                              kernel_length,
                              padding=kernel_length // 2),
                    MultiplyFactors(n_chans // 2),
                )), AffineModifier('sigmoid', add_first=True, eps=0)))
Exemple #6
0
def create_glow_model(hidden_channels,
                      K,
                      L,
                      flow_permutation,
                      flow_coupling,
                      LU_decomposed,
                      n_chans,
                      block_type='conv',
                      use_act_norm=True):
    image_shape = (32, 32, n_chans)

    H, W, C = image_shape
    flows_per_scale = []
    act_norms_per_scale = []
    dists_per_scale = []
    for i in range(L):

        C, H, W = C * 4, H // 2, W // 2

        splitter = SubsampleSplitter(2,
                                     via_reshape=True,
                                     chunk_chans_first=True,
                                     checkerboard=False,
                                     cat_at_end=True)

        if block_type == 'dense':
            pre_flow_layers = [Flatten2d()]
            in_channels = C * H * W
        else:
            assert block_type == 'conv'
            pre_flow_layers = []
            in_channels = C

        flow_layers = [
            flow_block(in_channels=in_channels,
                       hidden_channels=hidden_channels,
                       flow_permutation=flow_permutation,
                       flow_coupling=flow_coupling,
                       LU_decomposed=LU_decomposed,
                       cond_channels=0,
                       cond_merger=None,
                       block_type=block_type,
                       use_act_norm=use_act_norm) for _ in range(K)
        ]

        if block_type == 'dense':
            post_flow_layers = [ViewAs((-1, C * H * W), (-1, C, H, W))]
        else:
            assert block_type == 'conv'
            post_flow_layers = []
        flow_layers = pre_flow_layers + flow_layers + post_flow_layers
        flow_this_scale = InvertibleSequential(splitter, *flow_layers)
        flows_per_scale.append(flow_this_scale)

        if i < L - 1:
            # there will be a chunking here
            C = C // 2
        # act norms for distribution (mean/std as actnorm isntead of integrated
        # into dist)
        act_norms_per_scale.append(
            InvertibleSequential(Flatten2d(),
                                 ActNorm((C * H * W), scale_fn='exp')))
        dists_per_scale.append(
            Unlabeled(
                NClassIndependentDist(1,
                                      C * H * W,
                                      optimize_mean=False,
                                      optimize_std=False)))

    assert len(flows_per_scale) == 3

    nd_1_o = Node(None, flows_per_scale[0], name='m0-flow-0')
    nd_1_ab = Node(nd_1_o, ChunkChans(2))
    nd_1_a = SelectNode(nd_1_ab, 0)
    nd_1_an = Node(nd_1_a, act_norms_per_scale[0], name='m0-act-0')
    nd_1_ad = Node(nd_1_an, dists_per_scale[0], name='m0-dist-0')

    nd_1_b = SelectNode(nd_1_ab, 1, name='m0-in-flow-1')
    nd_2_o = Node(nd_1_b, flows_per_scale[1], name='m0-flow-1')
    nd_2_ab = Node(
        nd_2_o,
        ChunkChans(2),
    )

    nd_2_a = SelectNode(
        nd_2_ab,
        0,
    )
    nd_2_an = Node(nd_2_a, act_norms_per_scale[1], name='m0-act-1')
    nd_2_ad = Node(nd_2_an, dists_per_scale[1], name='m0-dist-1')

    nd_2_b = SelectNode(nd_2_ab, 1, name='m0-in-flow-2')
    nd_3_o = Node(nd_2_b, flows_per_scale[2], name='m0-flow-2')
    nd_3_n = Node(nd_3_o, act_norms_per_scale[2], name='m0-act-2')
    nd_3_d = Node(nd_3_n, dists_per_scale[2], name='m0-dist-2')

    model = CatAsListNode([nd_1_ad, nd_2_ad, nd_3_d],
                          name='m0-full')  # cahnged to pre-full
    return model
Exemple #7
0
def flow_block(in_channels,
               hidden_channels,
               flow_permutation,
               flow_coupling,
               LU_decomposed,
               cond_channels,
               cond_merger,
               block_type,
               use_act_norm,
               nonlin_name='relu'):
    if use_act_norm:
        actnorm = ActNorm(in_channels, scale_fn='exp', eps=0)
    # 2. permute
    if flow_permutation == "invconv":
        flow_permutation = InvPermute(in_channels,
                                      fixed=False,
                                      use_lu=LU_decomposed)
    elif flow_permutation == 'invconvfixed':
        flow_permutation = InvPermute(in_channels,
                                      fixed=True,
                                      use_lu=LU_decomposed)
    elif flow_permutation == "identity":
        flow_permutation = Identity()
    else:
        assert flow_permutation == 'shuffle'
        flow_permutation = Shuffle(in_channels)

    if flow_coupling == "additive":
        out_channels = in_channels // 2
    else:
        out_channels = in_channels

    if type(block_type) is str:
        if block_type == 'conv':
            block_fn = get_conv_block
        else:
            assert block_type == 'dense'
            block_fn = get_dense_block
    else:
        block_fn = block_type

    block = block_fn(in_channels // 2 + cond_channels,
                     out_channels,
                     hidden_channels,
                     nonlin_name=nonlin_name)

    if flow_coupling == "additive":
        coupling = CouplingLayer(ChunkChansIn2(swap_dims=True),
                                 AdditiveCoefs(block, ),
                                 AffineModifier(
                                     sigmoid_or_exp_scale='sigmoid',
                                     eps=0,
                                     add_first=True,
                                 ),
                                 condition_merger=cond_merger)
    elif flow_coupling == "affine":
        coupling = CouplingLayer(
            ChunkChansIn2(swap_dims=True),
            AffineCoefs(block, EverySecondChan()),
            AffineModifier(sigmoid_or_exp_scale='sigmoid',
                           eps=0,
                           add_first=True),
            condition_merger=cond_merger,
        )
    else:
        assert False, f"unknown flow_coupling {flow_coupling}"
    if use_act_norm:
        sequential = InvertibleSequential(actnorm, flow_permutation, coupling)
    else:
        sequential = InvertibleSequential(flow_permutation, coupling)
    return sequential
Exemple #8
0
def create_eeg_glow_down(n_up_block,
                         n_chans,
                         hidden_channels,
                         kernel_length,
                         n_mixes,
                         n_blocks,
                         init_dist_std=1e-1):
    n_up_0, n_up_1, n_up_2 = [SelectNode(n_up_block, i) for i in (0, 1, 2)]

    # maybe add n_down_2 as well, postprocessing

    flow_down_1 = InvertibleSequential(
        *
        [flow_block_down_1(n_chans, hidden_channels) for _ in range(n_blocks)])
    cond_preproc = nn.Sequential(
        nn.Linear(n_chans * 16, 128), nn.ELU(), nn.Linear(128, n_chans * 16),
        NoLogDet(ViewAs((-1, n_chans * 16), (-1, n_chans * 8, 2))))
    flow_down_1 = ConditionTransformWrapper(flow_down_1,
                                            cond_preproc=cond_preproc)

    n_down_1 = ConditionalNode(n_up_1, flow_down_1, condition_nodes=n_up_2)

    flow_down_0 = InvertibleSequential(*[
        flow_block_down_0(n_chans, hidden_channels, kernel_length)
        for _ in range(n_blocks)
    ])
    cond_preproc = ApplyToListNoLogdets(
        nn.Sequential(
            nn.Conv1d(n_chans * 4, hidden_channels, 5, padding=5 // 2),
            nn.ELU(),
            nn.Conv1d(hidden_channels, n_chans * 4, 5, padding=5 // 2),
        ),
        nn.Sequential(
            nn.Linear(n_chans * 16, 128),
            nn.ELU(),
            nn.Linear(128, n_chans * 16),
            NoLogDet(ViewAs((-1, n_chans * 16), (-1, n_chans * 8, 2))),
        ))

    flow_down_0 = ConditionTransformWrapper(flow_down_0,
                                            cond_preproc=cond_preproc)

    n_down_0 = ConditionalNode(n_up_0,
                               flow_down_0,
                               condition_nodes=[n_down_1, n_up_2])

    block_a_out = InvertibleSequential(Flatten2d(), )
    block_b_out = InvertibleSequential(Flatten2d(), )
    block_c_out = InvertibleSequential(Flatten2d(), )

    n_0_out = Node(n_down_0, block_a_out)
    n_1_out = Node(n_down_1, block_b_out)
    n_2_out = Node(n_up_2, block_c_out)

    n_merged = CatAsListNode([n_0_out, n_1_out, n_2_out])

    dist0 = PerDimWeightedMix(2,
                              n_mixes=n_mixes,
                              n_dims=64 * n_chans // 2,
                              optimize_mean=True,
                              optimize_std=True,
                              init_std=init_dist_std)
    dist1 = PerDimWeightedMix(2,
                              n_mixes=n_mixes,
                              n_dims=64 * n_chans // 4,
                              optimize_mean=True,
                              optimize_std=True,
                              init_std=init_dist_std)
    dist2 = PerDimWeightedMix(2,
                              n_mixes=n_mixes,
                              n_dims=64 * n_chans // 4,
                              optimize_mean=True,
                              optimize_std=True,
                              init_std=init_dist_std)
    # architecture plan:
    n_dist = Node(n_merged, ApplyToList(dist0, dist1, dist2))
    net = n_dist
    return net