def create_eeg_glow_up(n_chans, hidden_channels, kernel_length, splitter_first, splitter_last, n_blocks): def get_splitter(splitter_name, chunk_chans_first): if splitter_name == 'haar': return Haar1dWavelet(chunk_chans_first=False) else: assert splitter_name == 'subsample' return SubsampleSplitter((2, ), chunk_chans_first=chunk_chans_first) block_a = InvertibleSequential( get_splitter(splitter_first, chunk_chans_first=False), *[ conv_flow_block(n_chans * 2, hidden_channels=hidden_channels, kernel_length=kernel_length) for _ in range(n_blocks) ], get_splitter(splitter_last, chunk_chans_first=True), ) block_b = InvertibleSequential( get_splitter(splitter_first, chunk_chans_first=False), *[ conv_flow_block(n_chans * 4, hidden_channels=hidden_channels, kernel_length=kernel_length) for _ in range(n_blocks) ], get_splitter(splitter_last, chunk_chans_first=True), ) block_c = InvertibleSequential( get_splitter(splitter_first, chunk_chans_first=False), Flatten2d(), *[ dense_flow_block( n_chans * 64 // 4, hidden_channels=hidden_channels, ) for _ in range(n_blocks) ], ) n_a = Node(None, block_a) n_a_split = Node(n_a, ChunkChans(2)) n_a_out = SelectNode(n_a_split, 1) n_b = Node(SelectNode(n_a_split, 0), block_b) n_b_split = Node(n_b, ChunkChans(2)) n_b_out = SelectNode(n_b_split, 1) n_c = Node(SelectNode(n_b_split, 0), block_c) n_c_out = n_c n_merged = CatAsListNode([n_a_out, n_b_out, n_c_out]) return n_merged
def flow_block_down_1(n_chans, hidden_channels): in_to_outs = [[ nn.Sequential( nn.Conv1d(n_chans * 2, hidden_channels, 5, padding=5 // 2)), nn.Sequential( nn.Conv1d(n_chans * 8, hidden_channels, 3, padding=3 // 2), nn.Upsample(scale_factor=2)) ]] m = nn.Sequential( MultipleInputOutput(in_to_outs), Expression(unwrap_single_element), nn.Conv1d(hidden_channels, n_chans * 4 // 2, 5, padding=5 // 2), MultiplyFactors(n_chans * 4 // 2), ) c = CouplingLayer(ChunkChansIn2(swap_dims=False), AdditiveCoefs(m), AffineModifier('sigmoid', add_first=True, eps=0), condition_merger=CatCondAsList(cond_preproc=None)) f = InvertibleSequential( ActNorm( n_chans * 4, 'exp', ), InvPermute(n_chans * 4, fixed=False, use_lu=True), c, ) return f
def dense_flow_block(n_chans, hidden_channels): return InvertibleSequential( ActNorm( n_chans, 'exp', ), InvPermute(n_chans, fixed=False, use_lu=True), CouplingLayer( ChunkChansIn2(swap_dims=False), AdditiveCoefs( nn.Sequential( nn.Linear(n_chans // 2, hidden_channels), nn.ELU(), nn.Linear(hidden_channels, n_chans // 2), MultiplyFactors(n_chans // 2), )), AffineModifier('sigmoid', add_first=True, eps=0)))
def create_eeg_glow_no_dist(n_chans, hidden_channels, kernel_length): block_a_out = InvertibleSequential( AmplitudePhase(), Flatten2d(), ) block_b_out = InvertibleSequential( AmplitudePhase(), Flatten2d(), ) block_c_out = InvertibleSequential( AmplitudePhase(), Flatten2d(), ) n_before_block = create_eeg_glow_before_amp_phase( n_chans=n_chans, hidden_channels=hidden_channels, kernel_length=kernel_length) n_a_out = Node(SelectNode(n_before_block, 0), block_a_out) n_b_out = Node(SelectNode(n_before_block, 1), block_b_out) n_c_out = Node(SelectNode(n_before_block, 2), block_c_out) n_merged = CatAsListNode([n_a_out, n_b_out, n_c_out]) net = n_merged return n_merged
def conv_flow_block(n_chans, hidden_channels, kernel_length): assert kernel_length % 2 == 1 return InvertibleSequential( ActNorm( n_chans, 'exp', ), InvPermute(n_chans, fixed=False, use_lu=True), CouplingLayer( ChunkChansIn2(swap_dims=False), AdditiveCoefs( nn.Sequential( nn.Conv1d(n_chans // 2, hidden_channels, kernel_length, padding=kernel_length // 2), nn.ELU(), nn.Conv1d(hidden_channels, n_chans // 2, kernel_length, padding=kernel_length // 2), MultiplyFactors(n_chans // 2), )), AffineModifier('sigmoid', add_first=True, eps=0)))
def create_glow_model(hidden_channels, K, L, flow_permutation, flow_coupling, LU_decomposed, n_chans, block_type='conv', use_act_norm=True): image_shape = (32, 32, n_chans) H, W, C = image_shape flows_per_scale = [] act_norms_per_scale = [] dists_per_scale = [] for i in range(L): C, H, W = C * 4, H // 2, W // 2 splitter = SubsampleSplitter(2, via_reshape=True, chunk_chans_first=True, checkerboard=False, cat_at_end=True) if block_type == 'dense': pre_flow_layers = [Flatten2d()] in_channels = C * H * W else: assert block_type == 'conv' pre_flow_layers = [] in_channels = C flow_layers = [ flow_block(in_channels=in_channels, hidden_channels=hidden_channels, flow_permutation=flow_permutation, flow_coupling=flow_coupling, LU_decomposed=LU_decomposed, cond_channels=0, cond_merger=None, block_type=block_type, use_act_norm=use_act_norm) for _ in range(K) ] if block_type == 'dense': post_flow_layers = [ViewAs((-1, C * H * W), (-1, C, H, W))] else: assert block_type == 'conv' post_flow_layers = [] flow_layers = pre_flow_layers + flow_layers + post_flow_layers flow_this_scale = InvertibleSequential(splitter, *flow_layers) flows_per_scale.append(flow_this_scale) if i < L - 1: # there will be a chunking here C = C // 2 # act norms for distribution (mean/std as actnorm isntead of integrated # into dist) act_norms_per_scale.append( InvertibleSequential(Flatten2d(), ActNorm((C * H * W), scale_fn='exp'))) dists_per_scale.append( Unlabeled( NClassIndependentDist(1, C * H * W, optimize_mean=False, optimize_std=False))) assert len(flows_per_scale) == 3 nd_1_o = Node(None, flows_per_scale[0], name='m0-flow-0') nd_1_ab = Node(nd_1_o, ChunkChans(2)) nd_1_a = SelectNode(nd_1_ab, 0) nd_1_an = Node(nd_1_a, act_norms_per_scale[0], name='m0-act-0') nd_1_ad = Node(nd_1_an, dists_per_scale[0], name='m0-dist-0') nd_1_b = SelectNode(nd_1_ab, 1, name='m0-in-flow-1') nd_2_o = Node(nd_1_b, flows_per_scale[1], name='m0-flow-1') nd_2_ab = Node( nd_2_o, ChunkChans(2), ) nd_2_a = SelectNode( nd_2_ab, 0, ) nd_2_an = Node(nd_2_a, act_norms_per_scale[1], name='m0-act-1') nd_2_ad = Node(nd_2_an, dists_per_scale[1], name='m0-dist-1') nd_2_b = SelectNode(nd_2_ab, 1, name='m0-in-flow-2') nd_3_o = Node(nd_2_b, flows_per_scale[2], name='m0-flow-2') nd_3_n = Node(nd_3_o, act_norms_per_scale[2], name='m0-act-2') nd_3_d = Node(nd_3_n, dists_per_scale[2], name='m0-dist-2') model = CatAsListNode([nd_1_ad, nd_2_ad, nd_3_d], name='m0-full') # cahnged to pre-full return model
def flow_block(in_channels, hidden_channels, flow_permutation, flow_coupling, LU_decomposed, cond_channels, cond_merger, block_type, use_act_norm, nonlin_name='relu'): if use_act_norm: actnorm = ActNorm(in_channels, scale_fn='exp', eps=0) # 2. permute if flow_permutation == "invconv": flow_permutation = InvPermute(in_channels, fixed=False, use_lu=LU_decomposed) elif flow_permutation == 'invconvfixed': flow_permutation = InvPermute(in_channels, fixed=True, use_lu=LU_decomposed) elif flow_permutation == "identity": flow_permutation = Identity() else: assert flow_permutation == 'shuffle' flow_permutation = Shuffle(in_channels) if flow_coupling == "additive": out_channels = in_channels // 2 else: out_channels = in_channels if type(block_type) is str: if block_type == 'conv': block_fn = get_conv_block else: assert block_type == 'dense' block_fn = get_dense_block else: block_fn = block_type block = block_fn(in_channels // 2 + cond_channels, out_channels, hidden_channels, nonlin_name=nonlin_name) if flow_coupling == "additive": coupling = CouplingLayer(ChunkChansIn2(swap_dims=True), AdditiveCoefs(block, ), AffineModifier( sigmoid_or_exp_scale='sigmoid', eps=0, add_first=True, ), condition_merger=cond_merger) elif flow_coupling == "affine": coupling = CouplingLayer( ChunkChansIn2(swap_dims=True), AffineCoefs(block, EverySecondChan()), AffineModifier(sigmoid_or_exp_scale='sigmoid', eps=0, add_first=True), condition_merger=cond_merger, ) else: assert False, f"unknown flow_coupling {flow_coupling}" if use_act_norm: sequential = InvertibleSequential(actnorm, flow_permutation, coupling) else: sequential = InvertibleSequential(flow_permutation, coupling) return sequential
def create_eeg_glow_down(n_up_block, n_chans, hidden_channels, kernel_length, n_mixes, n_blocks, init_dist_std=1e-1): n_up_0, n_up_1, n_up_2 = [SelectNode(n_up_block, i) for i in (0, 1, 2)] # maybe add n_down_2 as well, postprocessing flow_down_1 = InvertibleSequential( * [flow_block_down_1(n_chans, hidden_channels) for _ in range(n_blocks)]) cond_preproc = nn.Sequential( nn.Linear(n_chans * 16, 128), nn.ELU(), nn.Linear(128, n_chans * 16), NoLogDet(ViewAs((-1, n_chans * 16), (-1, n_chans * 8, 2)))) flow_down_1 = ConditionTransformWrapper(flow_down_1, cond_preproc=cond_preproc) n_down_1 = ConditionalNode(n_up_1, flow_down_1, condition_nodes=n_up_2) flow_down_0 = InvertibleSequential(*[ flow_block_down_0(n_chans, hidden_channels, kernel_length) for _ in range(n_blocks) ]) cond_preproc = ApplyToListNoLogdets( nn.Sequential( nn.Conv1d(n_chans * 4, hidden_channels, 5, padding=5 // 2), nn.ELU(), nn.Conv1d(hidden_channels, n_chans * 4, 5, padding=5 // 2), ), nn.Sequential( nn.Linear(n_chans * 16, 128), nn.ELU(), nn.Linear(128, n_chans * 16), NoLogDet(ViewAs((-1, n_chans * 16), (-1, n_chans * 8, 2))), )) flow_down_0 = ConditionTransformWrapper(flow_down_0, cond_preproc=cond_preproc) n_down_0 = ConditionalNode(n_up_0, flow_down_0, condition_nodes=[n_down_1, n_up_2]) block_a_out = InvertibleSequential(Flatten2d(), ) block_b_out = InvertibleSequential(Flatten2d(), ) block_c_out = InvertibleSequential(Flatten2d(), ) n_0_out = Node(n_down_0, block_a_out) n_1_out = Node(n_down_1, block_b_out) n_2_out = Node(n_up_2, block_c_out) n_merged = CatAsListNode([n_0_out, n_1_out, n_2_out]) dist0 = PerDimWeightedMix(2, n_mixes=n_mixes, n_dims=64 * n_chans // 2, optimize_mean=True, optimize_std=True, init_std=init_dist_std) dist1 = PerDimWeightedMix(2, n_mixes=n_mixes, n_dims=64 * n_chans // 4, optimize_mean=True, optimize_std=True, init_std=init_dist_std) dist2 = PerDimWeightedMix(2, n_mixes=n_mixes, n_dims=64 * n_chans // 4, optimize_mean=True, optimize_std=True, init_std=init_dist_std) # architecture plan: n_dist = Node(n_merged, ApplyToList(dist0, dist1, dist2)) net = n_dist return net