def init_models(train_obs_mean, architecture, device): """Args: train_obs_mean: tensor of shape [obs_dim] architecture: linear_1, linear_2, linear_3 or non_linear device: torch.device Returns: generative_model, inference_network """ if architecture[:len('linear')] == 'linear': num_stochastic_layers = int(architecture[-1]) generative_model = models.GenerativeModel( num_stochastic_layers=num_stochastic_layers, num_deterministic_layers=0, device=device, train_obs_mean=train_obs_mean) inference_network = models.InferenceNetwork( num_stochastic_layers=num_stochastic_layers, num_deterministic_layers=0, device=device, train_obs_mean=train_obs_mean) elif architecture == 'non_linear': generative_model = models.GenerativeModel( num_stochastic_layers=1, num_deterministic_layers=2, device=device, train_obs_mean=train_obs_mean) inference_network = models.InferenceNetwork( num_stochastic_layers=1, num_deterministic_layers=2, device=device, train_obs_mean=train_obs_mean) if device.type == 'cuda': generative_model.cuda() inference_network.cuda() return generative_model, inference_network
def init_models(pcfg_path): """Returns: generative_model, inference_network, true_generative_model""" grammar, true_production_probs = read_pcfg(pcfg_path) generative_model = models.GenerativeModel(grammar) inference_network = models.InferenceNetwork(grammar) true_generative_model = models.GenerativeModel(grammar, true_production_probs) return generative_model, inference_network, true_generative_model
def init(num_data, num_dim, true_cluster_cov, device): prior_loc = torch.zeros(num_dim, device=device) prior_cov = torch.eye(num_dim, device=device) generative_model = models.GenerativeModel(num_data, prior_loc, prior_cov, device).to(device) inference_network = models.InferenceNetwork(num_data, num_dim).to(device) true_generative_model = models.GenerativeModel(num_data, prior_loc, prior_cov, device, true_cluster_cov).to(device) return (generative_model, inference_network, true_generative_model)
def init_models(args): """Returns: generative_model, inference_network, true_generative_model""" generative_model = models.GenerativeModel( args.init_mixture_logits, softmax_multiplier=args.softmax_multiplier, device=args.device).to(device=args.device) inference_network = models.InferenceNetwork( args.num_mixtures, args.relaxed_one_hot, args.temperature, args.device).to(device=args.device) true_generative_model = models.GenerativeModel( args.true_mixture_logits, softmax_multiplier=args.softmax_multiplier, device=args.device).to(device=args.device) return generative_model, inference_network, true_generative_model
def load_models(model_folder='.', iteration=None, load_mws_memory=False): """Returns: generative_model, inference network """ if iteration is None: suffix = '' else: suffix = iteration generative_model_path = os.path.join(model_folder, 'gen{}.pt'.format(suffix)) inference_network_path = os.path.join(model_folder, 'inf{}.pt'.format(suffix)) if os.path.exists(generative_model_path): args = load_object(get_args_path(model_folder)) generative_model = models.GenerativeModel( args.init_mixture_logits, softmax_multiplier=args.softmax_multiplier, device=args.device ).to(device=args.device) inference_network = models.InferenceNetwork( args.num_mixtures, args.relaxed_one_hot, args.temperature, args.device).to(device=args.device) generative_model.load_state_dict(torch.load(generative_model_path)) print_with_time('Loaded from {}'.format(generative_model_path)) inference_network.load_state_dict(torch.load(inference_network_path)) print_with_time('Loaded from {}'.format(inference_network_path)) if load_mws_memory: mws_memory_path = os.path.join(model_folder, 'mws_mem{}.pkl'.format(suffix)) mws_memory = load_object(mws_memory_path) return generative_model, inference_network, mws_memory else: return generative_model, inference_network else: return None, None
def load_models(model_folder='.'): """Returns: generative_model, inference network """ generative_model_path = os.path.join(model_folder, 'gen.pt') inference_network_path = os.path.join(model_folder, 'inf.pt') pcfg_path_path = os.path.join(model_folder, 'pcfg_path.txt') with open(pcfg_path_path) as f: pcfg_path = f.read() grammar, _ = read_pcfg(pcfg_path) generative_model = models.GenerativeModel(grammar) inference_network = models.InferenceNetwork(grammar) generative_model.load_state_dict(torch.load(generative_model_path)) print_with_time('Loaded from {}'.format(generative_model_path)) inference_network.load_state_dict(torch.load(inference_network_path)) print_with_time('Loaded from {}'.format(inference_network_path)) return generative_model, inference_network
def init(run_args, device): generative_model = models.GenerativeModel( run_args.num_primitives, run_args.initial_max_curve, run_args.big_arcs, run_args.p_lstm_hidden_size, run_args.num_rows, run_args.num_cols, run_args.num_arcs, run_args.likelihood, run_args.p_uniform_mixture, use_alphabet=run_args.condition_on_alphabet, ).to(device) inference_network = models.InferenceNetwork( run_args.num_primitives, run_args.q_lstm_hidden_size, run_args.num_rows, run_args.num_cols, run_args.num_arcs, run_args.obs_embedding_dim, run_args.q_uniform_mixture, use_alphabet=run_args.condition_on_alphabet, ).to(device) optimizer = init_optimizer( generative_model, inference_network, 1, ) stats = Stats([], [], [], [], [], [], [], []) if "mws" in run_args.algorithm: memory = init_memory( run_args.num_train_data, run_args.memory_size, generative_model.num_arcs, generative_model.num_primitives, device, ) else: memory = None return generative_model, inference_network, optimizer, memory, stats