def get_model(prm, model_type='Stochastic'):

    model_name = prm.model_name

    # Get task info:
    task_info = data_gen.get_info(prm)

    # Define default layers functions
    def linear_layer(in_dim, out_dim, use_bias=True):
        if model_type == 'Standard':
            return nn.Linear(in_dim, out_dim, use_bias)
        elif model_type == 'Stochastic':
            return StochasticLinear(in_dim, out_dim, prm, use_bias)

    def conv2d_layer(in_channels,
                     out_channels,
                     kernel_size,
                     use_bias=True,
                     stride=1,
                     padding=0,
                     dilation=1):
        if model_type == 'Standard':
            return nn.Conv2d(in_channels,
                             out_channels,
                             kernel_size=kernel_size)
        elif model_type == 'Stochastic':
            return StochasticConv2d(in_channels, out_channels, kernel_size,
                                    prm, use_bias, stride, padding, dilation)

    #  Return selected model:
    if model_name == 'FcNet3':
        model = FcNet3(model_type, model_name, linear_layer, conv2d_layer,
                       task_info)

    elif model_name == 'ConvNet3':
        model = ConvNet3(model_type, model_name, linear_layer, conv2d_layer,
                         task_info)

    elif model_name == 'BayesDenseNet':
        from Models.densenetBayes import get_bayes_densenet_model_class
        densenet_model = get_bayes_densenet_model_class(prm, task_info)
        model = densenet_model(depth=20)

    elif model_name == 'OmConvNet':
        model = OmConvNet(model_type, model_name, linear_layer, conv2d_layer,
                          task_info)

    else:
        raise ValueError('Invalid model_name')

    from Utils.config import USE_GPU
    if USE_GPU:
        model.cuda()  # always use GPU
    init_layers(model, prm.log_var_init)  # init model

    # For debug: set the STD of epsilon variable for re-parametrization trick (default=1.0)
    if hasattr(prm, 'override_eps_std'):
        model.set_eps_std(prm.override_eps_std)  # debug
    return model
def get_model(device, log_var_init, input_size, output_size):
    # Define default layers functions
    def linear_layer(input_size, output_size, use_bias=True):
        return StochasticLinear(input_size, output_size, log_var_init,
                                use_bias)

    model = FcNet3(linear_layer, input_size, output_size)

    # Move model to device (GPU\CPU):
    model.to(device)
    # DEBUG check: [(x[0], x[1].device) for x in model.named_parameters()]

    # init model:
    init_layers(model, log_var_init)

    model.weights_count = count_weights(model)

    # # For debug: set the STD of epsilon variable for re-parametrization trick (default=1.0)
    # if hasattr(prm, 'override_eps_std'):
    #     model.set_eps_std(prm.override_eps_std)  # debug
    return model
示例#3
0
 def _init_weights(self):
     init_layers(self)
示例#4
0
 def _init_weights(self, log_var_init):
     init_layers(self, log_var_init)
示例#5
0
def get_model(prm, model_type='Stochastic'):

    model_name = prm.model_name

    # Get task info:
    task_info = data_gen.get_info(prm)

    # Define default layers functions
    def linear_layer(in_dim, out_dim, use_bias=True):
        if model_type == 'Standard':
            return nn.Linear(in_dim, out_dim, use_bias)
        elif model_type == 'Stochastic':
            return StochasticLinear(in_dim, out_dim, prm, use_bias)

    def conv2d_layer(in_channels,
                     out_channels,
                     kernel_size,
                     use_bias=True,
                     stride=1,
                     padding=0,
                     dilation=1):
        if model_type == 'Standard':
            return nn.Conv2d(in_channels,
                             out_channels,
                             kernel_size=kernel_size)
        elif model_type == 'Stochastic':
            return StochasticConv2d(in_channels, out_channels, kernel_size,
                                    prm, use_bias, stride, padding, dilation)

    #  Return selected model:
    if model_name == 'FcNet3':
        model = FcNet3(model_type, model_name, linear_layer, conv2d_layer,
                       task_info)

    elif model_name == 'ConvNet3':
        model = ConvNet3(model_type, model_name, linear_layer, conv2d_layer,
                         task_info)

    # elif model_name == 'BayesDenseNet':
    #     from Models.densenetBayes import get_bayes_densenet_model_class
    #     densenet_model = get_bayes_densenet_model_class(prm, task_info)
    #     model = densenet_model(depth=20)

    elif model_name == 'OmConvNet':
        model = OmConvNet(model_type, model_name, linear_layer, conv2d_layer,
                          task_info)

    elif model_name == 'OmConvNet_NoBN':
        model = OmConvNet_NoBN(model_type, model_name, linear_layer,
                               conv2d_layer, task_info)

    elif model_name == 'OmConvNet_NoBN_32':
        model = OmConvNet_NoBN(model_type,
                               model_name,
                               linear_layer,
                               conv2d_layer,
                               task_info,
                               filt_size=32)
    elif model_name == 'OmConvNet_NoBN_16':
        model = OmConvNet_NoBN(model_type,
                               model_name,
                               linear_layer,
                               conv2d_layer,
                               task_info,
                               filt_size=16)
    elif model_name == 'OmConvNet_NoBN_elu':
        model = OmConvNet_NoBN_elu(model_type, model_name, linear_layer,
                                   conv2d_layer, task_info)

    else:
        raise ValueError('Invalid model_name')

    # Move model to device (GPU\CPU):
    model.to(prm.device)
    # DEBUG check: [(x[0], x[1].device) for x in model.named_parameters()]

    # init model:
    init_layers(model, prm.log_var_init)

    model.weights_count = count_weights(model)

    # # For debug: set the STD of epsilon variable for re-parametrization trick (default=1.0)
    # if hasattr(prm, 'override_eps_std'):
    #     model.set_eps_std(prm.override_eps_std)  # debug

    return model