示例#1
0
 def initialize_model(train_x, train_obj, state_dict=None):
     # define models for objective
     model_obj = FixedNoiseGP(train_x, train_obj,
                              train_yvar.expand_as(train_obj)).to(train_x)
     # combine into a multi-output GP model
     model = ModelListGP(model_obj)
     mll = SumMarginalLogLikelihood(model.likelihood, model)
     # load state dict if it is passed
     if state_dict is not None:
         model.load_state_dict(state_dict)
     return mll, model
示例#2
0
文件: main.py 项目: stys/albo
def initialize_model(x, z, state_dict=None):
    n = z.shape[-1]
    gp_models = []
    for i in range(n):
        y = z[..., i].unsqueeze(-1)
        gp_model = SingleTaskGP(train_X=x, train_Y=y)
        gp_model.likelihood.noise_covar.register_constraint(
            "raw_noise", GreaterThan(1e-5))
        gp_models.append(gp_model)
    model_list = ModelListGP(*gp_models)
    mll = SumMarginalLogLikelihood(model_list.likelihood, model_list)
    if state_dict is not None:
        model_list.load_state_dict(state_dict)
    return mll, model_list
def initialize_model(train_x, train_obj, train_con, state_dict=None):
    if problem.num_constraints == 1:
        # define models for objective and constraint
        # model_obj = SingleTaskGP(train_x, train_obj, outcome_transform=Standardize(m=train_obj.shape[-1]))
        # model_con = SingleTaskGP(train_x, train_con, outcome_transform=Standardize(m=train_con.shape[-1]))
        model_obj = SingleTaskGP(train_x, train_obj)
        model_con = SingleTaskGP(train_x, train_con)
        # combine into a multi-output GP model
        model = ModelListGP(model_obj, model_con)
        mll = SumMarginalLogLikelihood(model.likelihood, model)
    else:
        train_y = torch.cat([train_obj, train_con], dim=-1)
        model = SingleTaskGP(
            train_x,
            train_y,
            outcome_transform=Standardize(m=train_y.shape[-1]))
        mll = ExactMarginalLogLikelihood(model.likelihood, model)
    # load state dict if it is passed
    if state_dict is not None:
        model.load_state_dict(state_dict)
    return mll, model