예제 #1
0
def run_experiment(args):
    if args["seed"] >= 0:
        torch.manual_seed(args["seed"])
        numpy.random.seed(args["seed"])
        torch.set_num_threads(1)

    if args["setup_sem"] == "chain":
        setup_str = "chain_hidden={}_hetero={}_scramble={}".format(
            args["setup_hidden"], args["setup_hetero"], args["setup_scramble"])
    elif args["setup_sem"] == "icp":
        setup_str = "sem_icp"
    else:
        raise NotImplementedError

    all_methods = {
        "ERM": EmpiricalRiskMinimizer,
        "ICP": InvariantCausalPrediction,
        "IRM": InvariantRiskMinimization
    }

    if args["methods"] == "all":
        methods = all_methods
    else:
        methods = {m: all_methods[m] for m in args["methods"].split(',')}

    all_sems = []
    all_solutions = []
    all_environments = []

    for rep_i in range(args["n_reps"]):
        if args["setup_sem"] == "chain":
            sem = ChainEquationModel(args["dim"],
                                     hidden=args["setup_hidden"],
                                     scramble=args["setup_scramble"],
                                     hetero=args["setup_hetero"])
            environments = [
                sem(args["n_samples"], .2),
                sem(args["n_samples"], 2.),
                sem(args["n_samples"], 5.)
            ]
        else:
            raise NotImplementedError

        all_sems.append(sem)
        all_environments.append(environments)

    for sem, environments in zip(all_sems, all_environments):
        solutions = [
            "{} SEM {} {:.5f} {:.5f}".format(setup_str, pretty(sem.solution()),
                                             0, 0)
        ]

        for method_name, method_constructor in methods.items():
            method = method_constructor(environments, args)
            msolution = method.solution()
            err_causal, err_noncausal = errors(sem.solution(), msolution)

            solutions.append("{} {} {} {:.5f} {:.5f}".format(
                setup_str, method_name, pretty(msolution), err_causal,
                err_noncausal))

        all_solutions += solutions

    return all_solutions
예제 #2
0
def run_experiment(args):
    if args["seed"] >= 0:
        torch.manual_seed(args["seed"])
        numpy.random.seed(args["seed"])
        torch.set_num_threads(1)
        random.seed(1)

    if args["setup_sem"] == "chain":
        setup_str = "chain_ones={}_hidden={}_hetero={}_scramble={}".format(
            args["setup_ones"], args["setup_hidden"], args["setup_hetero"],
            args["setup_scramble"])
    elif args["setup_sem"] == "simple":
        setup_str = ""
    elif args["setup_sem"] == "icp":
        setup_str = "sem_icp"
    else:
        raise NotImplementedError

    all_methods = {
        "ERM": EmpiricalRiskMinimizer,
        "ICP": InvariantCausalPrediction,
        "IRM": InvariantRiskMinimization
    }
    if int(args["env_list"]) > 1:
        all_methods["IRM"] = InvariantRiskMinimization

    if args["methods"] == "all":
        methods = all_methods
    else:
        methods = {m: all_methods[m] for m in args["methods"].split(',')}

    all_sems = []
    all_solutions = []
    all_environments = []

    for rep_i in range(args["n_reps"]):
        if args["setup_sem"] == "chain":
            sem = ChainEquationModel(args["dim"],
                                     ones=args["setup_ones"],
                                     hidden=args["setup_hidden"],
                                     scramble=args["setup_scramble"],
                                     hetero=args["setup_hetero"])

            env_list = [float(e) for e in args["env_list"].split(",")]
            environments = [sem(args["n_samples"], e) for e in env_list]
        elif args["setup_sem"] == "simple":
            sem = SEM_X1YX2X3(args["dim"], args["k"], args["env_shuffle"])

            env_list = range(int(args["env_list"]))
            ratios = list(map(int, args["env_rat"].split(':')))
            n = args["n_samples"]
            n_samples = [
                math.ceil(ni * 1.0 / sum(ratios) * n) for ni in ratios
            ]
            print("sample in envs ", n_samples)
            environments = []
            env_orders = []
            for e in env_list:
                res = sem(n_samples[e], e)
                environments += [res[:2]]
                env_orders += [res[2]]
        else:
            raise NotImplementedError

        all_sems.append(sem)
        all_environments.append(environments)

    sol_dict = {}
    for sem, environments in zip(all_sems, all_environments):
        sem_solution, sem_scramble = sem.solution()

        solutions = [
            "{} SEM {} {:.5f} {:.5f}".format(setup_str, pretty(sem_solution),
                                             0, 0)
        ]
        #sol_dict["SEM"] = [sem_solution.numpy(), 0, 0]

        for method_name, method_constructor in methods.items():
            method = method_constructor(environments, args, env_orders)

            method_solution = sem_scramble @ method.solution()

            err_causal, err_noncausal, ecaus, enoncaus = errors(
                sem_solution, method_solution)

            solutions.append("{} {} {} {:.5f} {:.5f}".format(
                setup_str, method_name, pretty(method_solution), err_causal,
                err_noncausal))

            sol_dict[method_name] = [method_solution.detach().view(-1).numpy(), \
                                                ecaus.detach().view(-1).numpy(), enoncaus.detach().view(-1).numpy()]

        all_solutions += solutions

    return all_solutions, sol_dict
예제 #3
0
파일: main_v1.py 프로젝트: keerthi166/OoD
def run_experiment_IRM(args):
    if args["seed"] >= 0:
        torch.manual_seed(args["seed"])
        numpy.random.seed(args["seed"])
        torch.set_num_threads(1)

    if args["setup_sem"] == "chain":
        setup_str = "chain_hidden={}_hetero={}_scramble={}".format(
            args["setup_hidden"], args["setup_hetero"], args["setup_scramble"])
    elif args["setup_sem"] == "icp":
        setup_str = "sem_icp"
    else:
        raise NotImplementedError

    all_methods = {
        "ERM": EmpiricalRiskMinimizer,
        "ICP": InvariantCausalPrediction,
        "IRM": InvariantRiskMinimization
    }

    if args["methods"] == "all":
        methods = all_methods
    else:
        methods = {m: all_methods[m] for m in args["methods"].split(',')}

    all_sems = []
    all_solutions = []
    all_environments = []
    train_size = np.int(args["n_samples"] *
                        0.8)  ## split data into 80:20 train:val
    val_size = np.int(args["n_samples"] * 0.2)
    for rep_i in range(args["n_reps"]):
        if args["setup_sem"] == "chain":
            sem = ChainEquationModel(args["dim"],
                                     hidden=args["setup_hidden"],
                                     scramble=args["setup_scramble"],
                                     hetero=args["setup_hetero"])

            env_list = args["env_list"]
            m = len(env_list)
            environments = []
            for o in range(2 * m):
                if (o < m):
                    environments.append(sem(train_size, env_list[o % m]))
                    environments.append(sem(train_size, env_list[o % m]))
                else:
                    environments.append(sem(val_size, env_list[o % m]))
                    environments.append(sem(val_size, env_list[o % m]))


#             environments = [sem(args["n_samples"], e) for e in env_list]

#             environments = [sem(args_ns1, 0.2),
#                             sem(args_ns1, 2.0),
#             sem(args_ns2, 0.2),
#             sem(args_ns2, 2.0)
#             ]

#             environments = [sem(train_size, 0.2),
#                             sem(train_size, 2.0),
#             sem(val_size, 0.2),
#             sem(val_size, 2.0)
#             ]
        else:
            raise NotImplementedError

        all_sems.append(sem)
        all_environments.append(environments)

    for sem, environments in zip(all_sems, all_environments):
        sem_solution, sem_scramble = sem.solution()
        solutions = [
            "{} SEM {} {:.5f} {:.5f}".format(setup_str, pretty(sem_solution),
                                             0, 0)
        ]

        for method_name, method_constructor in methods.items():
            method = method_constructor(environments, args)
            msolution = sem_scramble @ method.solution()
            err_causal, err_noncausal = errors(sem_solution, msolution)

            solutions.append("{} {} {} {:.5f} {:.5f}".format(
                setup_str, method_name, pretty(msolution), err_causal,
                err_noncausal))

        all_solutions += solutions

    return all_solutions, all_environments, msolution, sem_solution
def run_experiment(args):
    if args["seed"] >= 0:
        torch.manual_seed(args["seed"])
        numpy.random.seed(args["seed"])
        random.seed(args["seed"])
        # torch.set_num_threads(1)

    all_methods = {
        "ERM": EmpiricalRiskMinimizer,
        "ICP": InvariantCausalPrediction,
        "IRM": InvariantRiskMinimization,
        "IRMS": InvariantRiskMinimizationSimple,
        "RMG": RiskMinimizationGames,
        "MAML": MAML,
        "MSGD": MetaSGD,
        "SPC": SpecialistRiskGames,
        "SPP": SpecialistPenalty,
        "GVP": GradVarPenalty,
    }

    if args["methods"] == "all":
        methods = all_methods
    else:
        methods = {m: all_methods[m] for m in args["methods"].split(',')}

    all_sems = []
    all_solutions = []
    all_environments = []
    all_setup_strs = []

    if args["setup_sem"] == "chain":
        for rep_i in range(args["n_reps"]):
            for hidden in args["setup_hidden"]:
                for hetero in args["setup_hetero"]:
                    for scramble in args["setup_scramble"]:
                        sem = ChainEquationModel(args["dim"],
                                                 hidden=hidden,
                                                 scramble=scramble,
                                                 hetero=hetero)
                        environments = [
                            sem(args["n_samples"], .2),
                            sem(args["n_samples"], 2.),
                            sem(args["n_samples"], 5.)
                        ]
                        setup_str = "chain_hidden={}_hetero={}_scramble={}".format(
                            hidden, hetero, scramble)
                        all_sems.append(sem)
                        all_environments.append(environments)
                        all_setup_strs.append(setup_str)
    else:
        raise NotImplementedError

    # write lock to avoid overwriting output file from multiple threads
    m = mp.Manager()
    l = m.Lock()

    solutions = Parallel(n_jobs=4)(
        delayed(run_methods)(methods, setup_str, sem, environments, args, l)
        for sem, environments, setup_str in zip(all_sems, all_environments,
                                                all_setup_strs))

    all_solutions += solutions

    return all_solutions