示例#1
0
    def generate_random_feeder(self,
                               config,
                               use_feed_fetch=True,
                               feeder_adapter=None):
        if config is None or not isinstance(config, api_param.APIConfig):
            raise ValueError(
                "Argument \"config\" must be set to an instance of APIConfig.")

        if feeder_adapter is not None and feeder_adapter.framework != "tensorflow":
            assert use_feed_fetch, "Argument use_feed_fetch must be True when feeder_adapter is initialized by paddle."

        if feeder_adapter is None or feeder_adapter.framework != "tensorflow":
            self._need_feed = config.name == "feed"
            self._need_fetch = use_feed_fetch or config.name == "fetch"
            self._feed_spec = feeder.copy_feed_spec(config.feed_spec)
            self._feed_dict = {}

            self.__backward = False
            self.build_graph(config=config)

        if feeder_adapter is None:
            feed_list = []
            assert len(self._feed_dict) == len(self.feed_list)
            for var in self.feed_list:
                feed_list.append(self._feed_dict[var])
            return feeder.FeederAdapter("tensorflow", config.feed_spec,
                                        feed_list)
        else:
            return feeder_adapter
示例#2
0
    def generate_random_feeder(self,
                               config,
                               use_feed_fetch=True,
                               feeder_adapter=None):
        if config is None or not isinstance(config, api_param.APIConfig):
            raise ValueError(
                "Argument \"config\" must be set to an instance of APIConfig.")

        if feeder_adapter is None or feeder_adapter.framework != "paddle":
            self._need_feed = config.name == "feed"
            self._need_fetch = use_feed_fetch or config.name == "fetch"
            self._feed_spec = feeder.copy_feed_spec(config.feed_spec)
            self._feed_dict = {}

            self.__backward = False
            self.main_program = fluid.Program()
            self.startup_program = fluid.Program()
            with fluid.program_guard(self.main_program, self.startup_program):
                self.build_program(config=config)

        if feeder_adapter is None:
            feed_list = []
            for var in self.feed_vars:
                feed_list.append(self._feed_dict[var])
            return feeder.FeederAdapter("paddle", config.feed_spec, feed_list)
        else:
            return feeder_adapter
示例#3
0
    def generate_random_feeder(self,
                               config,
                               use_feed_fetch=True,
                               feeder_adapter=None):
        if config is None or not isinstance(config, api_param.APIConfig):
            raise ValueError(
                "Argument \"config\" must be set to an instance of APIConfig.")

        if feeder_adapter is None or feeder_adapter.framework != "paddle":
            self._need_feed = config.name == "feed"
            self._need_fetch = use_feed_fetch or config.name == "fetch"
            self._feed_spec = feeder.copy_feed_spec(config.feed_spec)
            self._feed_dict = {}

            self.__backward = False
            self.main_program = fluid.Program()
            self.startup_program = fluid.Program()
            with fluid.program_guard(self.main_program, self.startup_program):
                self.build_program(config=config)

            # For backward benchmark, the program is composed of:
            #   xxx -> shape -> fill_constant -> xxx_grad
            # The extra CUDA kernel of fill_constant will make the traced times
            # larger than the actual, but tf can automatic optimize the execution
            # of fill_constant. We call self._prune() to move the fill_constant op
            # from main_program to startup_program for current benchmark and will
            # optimize the execution strategy in the future.
            self._prune(config)

        if feeder_adapter is None:
            feed_list = []
            for var in self.feed_vars:
                feed_list.append(self._feed_dict[var])
            return feeder.FeederAdapter("paddle", config.feed_spec, feed_list)
        else:
            return feeder_adapter
示例#4
0
 def generate_random_feeder(self, config):
     return feeder.FeederAdapter("pytorch", config.feed_spec,
                                 self._generated_feed_values)