예제 #1
0
 def network(self):
     model = get_model(self.config)
     self.input_data = model.create_feeds()
     self.metrics = model.net(self.input_data)
     self.inference_target_var = model.inference_target_var
     logger.info("cpu_num: {}".format(os.getenv("CPU_NUM")))
     model.create_optimizer(get_strategy(self.config))
    def init_network(self):
        model = get_model(self.config)
        self.input_data = model.create_feeds()
        self.metrics = model.net(self.input_data)
        self.inference_feed_vars = model.inference_feed_vars
        self.inference_target_var = model.inference_target_var
        if hasattr(model, "all_vars"):
            with open("all_vars.txt", 'w+') as f:
                f.write('\n'.join([var.name for var in model.all_vars]))
        if config.get("runner.need_prune", False):
            # DSSM prune net
            self.inference_feed_vars = self.model.prune_feed_vars
            self.inference_target_var = self.model.prune_target_var
        if config.get("runner.need_train_dump", False):
            self.train_dump_fields = model.train_dump_fields if hasattr(
                self.model, "train_dump_fields") else []
            self.train_dump_params = model.train_dump_params if hasattr(
                self.model, "train_dump_params") else []
        if config.get("runner.need_infer_dump", False):
            self.infer_dump_fields = model.infer_dump_fields if hasattr(
                model, "infer_dump_fields") else []

        self.config['stat_var_names'] = model.thread_stat_var_names
        self.metric_list = model.metric_list
        self.metric_types = model.metric_types

        logger.info("cpu_num: {}".format(os.getenv("CPU_NUM")))
        model.create_optimizer(get_strategy(self.config))
예제 #3
0
    def network(self):
        self.model = get_model(self.config)
        self.input_data = self.model.create_feeds()
        self.metrics = self.model.net(self.input_data)
        logger.info("cpu_num: {}".format(os.getenv("CPU_NUM")))

        thread_stat_var_names = [
            self.model.auc_stat_list[2].name, self.model.auc_stat_list[3].name
        ]
        thread_stat_var_names += [i.name for i in self.model.metric_list]
        thread_stat_var_names = list(set(thread_stat_var_names))
        self.config['stat_var_names'] = thread_stat_var_names

        self.metric_list = list(self.model.auc_stat_list) + list(
            self.model.metric_list)
        self.metric_types = ["int64"] * len(self.model.auc_stat_list) + [
            "float32"
        ] * len(self.model.metric_list)
        self.model.create_optimizer(get_strategy(self.config))