예제 #1
0
    def build_model():
        modules = []

        mask = torch.arange(0, num_inputs) % 2
        #mask = torch.ones(num_inputs)
        #mask[round(num_inputs/2):] = 0
        mask = mask.to(device).float()

        # build each modules
        for _ in range(args.num_blocks):
            modules += [
                fnn.ActNorm(num_inputs),
                fnn.LUInvertibleMM(num_inputs),
                fnn.CouplingLayer(num_inputs,
                                  num_hidden,
                                  mask,
                                  num_cond_inputs,
                                  s_act='tanh',
                                  t_act='relu')
            ]
            mask = 1 - mask

        # build model
        model = fnn.FlowSequential(*modules)

        # initialize
        for module in model.modules():
            if isinstance(module, nn.Linear):
                nn.init.orthogonal_(module.weight)
                if hasattr(module, 'bias') and module.bias is not None:
                    module.bias.data.fill_(0)

        model.to(device)

        return model
예제 #2
0
}[args.dataset]

act = 'tanh' if args.dataset is 'GAS' else 'relu'

modules = []

assert args.flow in ['maf', 'maf-split', 'maf-split-glow', 'realnvp', 'glow']
if args.flow == 'glow':
    mask = torch.arange(0, num_inputs) % 2
    mask = mask.to(device).float()

    print("Warning: Results for GLOW are not as good as for MAF yet.")
    for _ in range(args.num_blocks):
        modules += [
            fnn.BatchNormFlow(num_inputs),
            fnn.LUInvertibleMM(num_inputs),
            fnn.CouplingLayer(num_inputs,
                              num_hidden,
                              mask,
                              num_cond_inputs,
                              s_act='tanh',
                              t_act='relu')
        ]
        mask = 1 - mask
elif args.flow == 'realnvp':
    mask = torch.arange(0, num_inputs) % 2
    mask = mask.to(device).float()

    for _ in range(args.num_blocks):
        modules += [
            fnn.CouplingLayer(num_inputs,
예제 #3
0
def init_model(args, num_inputs=72):
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.device
        device = torch.device("cuda:" + args.device)
    else:
        device = torch.device("cpu")
    # network structure
    num_hidden = args.num_hidden
    num_cond_inputs = None

    act = 'relu'
    assert act in ['relu', 'sigmoid', 'tanh']

    modules = []

    # normalization flow
    assert args.flow in ['maf', 'realnvp', 'glow']

    if args.flow == 'glow':
        mask = torch.arange(0, num_inputs) % 2
        mask = mask.to(device).float()

        print("Warning: Results for GLOW are not as good as for MAF yet.")
        for _ in range(args.num_blocks):
            modules += [
                fnn.BatchNormFlow(num_inputs),
                fnn.LUInvertibleMM(num_inputs),
                fnn.CouplingLayer(num_inputs,
                                  num_hidden,
                                  mask,
                                  num_cond_inputs,
                                  s_act='tanh',
                                  t_act='relu')
            ]
            mask = 1 - mask

    elif args.flow == 'realnvp':
        mask = torch.arange(0, num_inputs) % 2
        mask = mask.to(device).float()

        for _ in range(args.num_blocks):
            modules += [
                fnn.CouplingLayer(num_inputs,
                                  num_hidden,
                                  mask,
                                  num_cond_inputs,
                                  s_act='tanh',
                                  t_act='relu'),
                fnn.BatchNormFlow(num_inputs)
            ]
            mask = 1 - mask

    elif args.flow == 'maf':
        for _ in range(args.num_blocks):
            modules += [
                fnn.MADE(num_inputs, num_hidden, num_cond_inputs, act=act),
                fnn.BatchNormFlow(num_inputs),
                fnn.Reverse(num_inputs)
            ]

    model = fnn.FlowSequential(*modules)

    for module in model.modules():
        if isinstance(module, nn.Linear):
            nn.init.orthogonal_(module.weight)
            if hasattr(module, 'bias') and module.bias is not None:
                module.bias.data.fill_(0)

    return model