def run_experiment_IRM(args): if args["seed"] >= 0: torch.manual_seed(args["seed"]) numpy.random.seed(args["seed"]) torch.set_num_threads(1) if args["setup_sem"] == "chain": setup_str = "chain_hidden={}_hetero={}_scramble={}".format( args["setup_hidden"], args["setup_hetero"], args["setup_scramble"]) elif args["setup_sem"] == "icp": setup_str = "sem_icp" else: raise NotImplementedError all_methods = { "ERM": EmpiricalRiskMinimizer, "ICP": InvariantCausalPrediction, "IRM": InvariantRiskMinimization } if args["methods"] == "all": methods = all_methods else: methods = {m: all_methods[m] for m in args["methods"].split(',')} all_sems = [] all_solutions = [] all_environments = [] train_size = np.int(args["n_samples"] * 0.8) ## split data into 80:20 train:val val_size = np.int(args["n_samples"] * 0.2) for rep_i in range(args["n_reps"]): if args["setup_sem"] == "chain": sem = ChainEquationModel(args["dim"], hidden=args["setup_hidden"], scramble=args["setup_scramble"], hetero=args["setup_hetero"]) env_list = args["env_list"] m = len(env_list) environments = [] for o in range(2 * m): if (o < m): environments.append(sem(train_size, env_list[o % m])) environments.append(sem(train_size, env_list[o % m])) else: environments.append(sem(val_size, env_list[o % m])) environments.append(sem(val_size, env_list[o % m])) # environments = [sem(args["n_samples"], e) for e in env_list] # environments = [sem(args_ns1, 0.2), # sem(args_ns1, 2.0), # sem(args_ns2, 0.2), # sem(args_ns2, 2.0) # ] # environments = [sem(train_size, 0.2), # sem(train_size, 2.0), # sem(val_size, 0.2), # sem(val_size, 2.0) # ] else: raise NotImplementedError all_sems.append(sem) all_environments.append(environments) for sem, environments in zip(all_sems, all_environments): sem_solution, sem_scramble = sem.solution() solutions = [ "{} SEM {} {:.5f} {:.5f}".format(setup_str, pretty(sem_solution), 0, 0) ] for method_name, method_constructor in methods.items(): method = method_constructor(environments, args) msolution = sem_scramble @ method.solution() err_causal, err_noncausal = errors(sem_solution, msolution) solutions.append("{} {} {} {:.5f} {:.5f}".format( setup_str, method_name, pretty(msolution), err_causal, err_noncausal)) all_solutions += solutions return all_solutions, all_environments, msolution, sem_solution
def run_experiment(args): if args["seed"] >= 0: torch.manual_seed(args["seed"]) numpy.random.seed(args["seed"]) torch.set_num_threads(1) if args["setup_sem"] == "chain": setup_str = "chain_hidden={}_hetero={}_scramble={}".format( args["setup_hidden"], args["setup_hetero"], args["setup_scramble"]) elif args["setup_sem"] == "icp": setup_str = "sem_icp" else: raise NotImplementedError all_methods = { "ERM": EmpiricalRiskMinimizer, "ICP": InvariantCausalPrediction, "IRM": InvariantRiskMinimization } if args["methods"] == "all": methods = all_methods else: methods = {m: all_methods[m] for m in args["methods"].split(',')} all_sems = [] all_solutions = [] all_environments = [] for rep_i in range(args["n_reps"]): if args["setup_sem"] == "chain": sem = ChainEquationModel(args["dim"], hidden=args["setup_hidden"], scramble=args["setup_scramble"], hetero=args["setup_hetero"]) environments = [ sem(args["n_samples"], .2), sem(args["n_samples"], 2.), sem(args["n_samples"], 5.) ] else: raise NotImplementedError all_sems.append(sem) all_environments.append(environments) for sem, environments in zip(all_sems, all_environments): solutions = [ "{} SEM {} {:.5f} {:.5f}".format(setup_str, pretty(sem.solution()), 0, 0) ] for method_name, method_constructor in methods.items(): method = method_constructor(environments, args) msolution = method.solution() err_causal, err_noncausal = errors(sem.solution(), msolution) solutions.append("{} {} {} {:.5f} {:.5f}".format( setup_str, method_name, pretty(msolution), err_causal, err_noncausal)) all_solutions += solutions return all_solutions
def run_experiment(args): if args["seed"] >= 0: torch.manual_seed(args["seed"]) numpy.random.seed(args["seed"]) torch.set_num_threads(1) random.seed(1) if args["setup_sem"] == "chain": setup_str = "chain_ones={}_hidden={}_hetero={}_scramble={}".format( args["setup_ones"], args["setup_hidden"], args["setup_hetero"], args["setup_scramble"]) elif args["setup_sem"] == "simple": setup_str = "" elif args["setup_sem"] == "icp": setup_str = "sem_icp" else: raise NotImplementedError all_methods = { "ERM": EmpiricalRiskMinimizer, "ICP": InvariantCausalPrediction, "IRM": InvariantRiskMinimization } if int(args["env_list"]) > 1: all_methods["IRM"] = InvariantRiskMinimization if args["methods"] == "all": methods = all_methods else: methods = {m: all_methods[m] for m in args["methods"].split(',')} all_sems = [] all_solutions = [] all_environments = [] for rep_i in range(args["n_reps"]): if args["setup_sem"] == "chain": sem = ChainEquationModel(args["dim"], ones=args["setup_ones"], hidden=args["setup_hidden"], scramble=args["setup_scramble"], hetero=args["setup_hetero"]) env_list = [float(e) for e in args["env_list"].split(",")] environments = [sem(args["n_samples"], e) for e in env_list] elif args["setup_sem"] == "simple": sem = SEM_X1YX2X3(args["dim"], args["k"], args["env_shuffle"]) env_list = range(int(args["env_list"])) ratios = list(map(int, args["env_rat"].split(':'))) n = args["n_samples"] n_samples = [ math.ceil(ni * 1.0 / sum(ratios) * n) for ni in ratios ] print("sample in envs ", n_samples) environments = [] env_orders = [] for e in env_list: res = sem(n_samples[e], e) environments += [res[:2]] env_orders += [res[2]] else: raise NotImplementedError all_sems.append(sem) all_environments.append(environments) sol_dict = {} for sem, environments in zip(all_sems, all_environments): sem_solution, sem_scramble = sem.solution() solutions = [ "{} SEM {} {:.5f} {:.5f}".format(setup_str, pretty(sem_solution), 0, 0) ] #sol_dict["SEM"] = [sem_solution.numpy(), 0, 0] for method_name, method_constructor in methods.items(): method = method_constructor(environments, args, env_orders) method_solution = sem_scramble @ method.solution() err_causal, err_noncausal, ecaus, enoncaus = errors( sem_solution, method_solution) solutions.append("{} {} {} {:.5f} {:.5f}".format( setup_str, method_name, pretty(method_solution), err_causal, err_noncausal)) sol_dict[method_name] = [method_solution.detach().view(-1).numpy(), \ ecaus.detach().view(-1).numpy(), enoncaus.detach().view(-1).numpy()] all_solutions += solutions return all_solutions, sol_dict