예제 #1
0
    def __init__(self,
                 model,
                 model_params,
                 load_data=True,
                 debug=False,
                 batch_size=64):
        super().__init__(model,
                         model_params,
                         load_data=load_data,
                         debug=debug,
                         batch_size=batch_size,
                         name="TaskSetModeling")

        prior_dist_params = get_param_val(
            self.model_params,
            "prior_distribution",
            allow_default=False,
            error_location="TaskSetModeling - init")
        self.prior_distribution = create_prior_distribution(prior_dist_params)

        self.beta_scheduler = create_scheduler(self.model_params["beta"],
                                               "beta")

        self.summary_dict = {
            "log_prob": list(),
            "ldj": list(),
            "z": list(),
            "beta": 0
        }
예제 #2
0
    def __init__(self,
                 model,
                 model_params,
                 load_data=True,
                 debug=False,
                 batch_size=64):
        super().__init__(model,
                         model_params,
                         load_data=load_data,
                         debug=debug,
                         batch_size=batch_size,
                         name="TaskGraphColoring")

        prior_dist_params = get_param_val(self.model_params,
                                          "prior_distribution", dict())
        self.prior_distribution = create_prior_distribution(prior_dist_params)

        self.beta_scheduler = create_scheduler(self.model_params["beta"],
                                               "beta")
        self.gamma_scheduler = create_scheduler(self.model_params["gamma"],
                                                "gamma")

        self.summary_dict = {
            "log_prob": list(),
            "ldj": list(),
            "z": list(),
            "beta": 0,
            "gamma": 0
        }
        self.checkpoint_path = None
예제 #3
0
 def _create_layers(self):
     # Load global model params
     self.max_num_nodes = self.dataset_class.max_num_nodes()
     self.num_node_types = self.dataset_class.num_node_types()
     self.num_edge_types = self.dataset_class.num_edge_types()
     self.num_max_neighbours = self.dataset_class.num_max_neighbours()
     # Prior distribution is needed here for edges
     prior_config = get_param_val(self.model_params,
                                  "prior_distribution",
                                  default_val=dict())
     self.prior_distribution = create_prior_distribution(prior_config)
     # Create encoding and flow layers
     self._create_encoding_layers()
     self._create_step_flows()