def _train_by_parallel_executor(self, num_epochs, event_handler, reader, feed_order): with self._prog_and_scope_guard(): pe = self._get_or_create_parallel_executor() feed_var_list = build_feed_var_list(self.train_program, feed_order) feeder = data_feeder.DataFeeder( feed_list=feed_var_list, place=self.place) reader = feeder.decorate_reader(reader, multi_devices=True) self._train_by_any_executor(event_handler, pe, num_epochs, reader)
def _test_by_executor(self, reader, feed_order, fetch_list): with executor.scope_guard(self.scope): feed_var_list = build_feed_var_list(self.test_program, feed_order) feeder = data_feeder.DataFeeder( feed_list=feed_var_list, place=self.place) exe = executor.Executor(self.place) accumulated = len(fetch_list) * [0] count = 0 for data in reader(): outs = exe.run(program=self.test_program, feed=feeder.feed(data), fetch_list=fetch_list) accumulated = [x[0] + x[1][0] for x in zip(accumulated, outs)] count += 1 return [x / count for x in accumulated]
def _train_by_executor(self, num_epochs, event_handler, reader, feed_order): """ Train by Executor and single device. Args: num_epochs: event_handler: reader: feed_order: Returns: """ with self._prog_and_scope_guard(): feed_var_list = build_feed_var_list(self.train_program, feed_order) feeder = data_feeder.DataFeeder( feed_list=feed_var_list, place=self.place) exe = executor.Executor(self.place) reader = feeder.decorate_reader(reader, multi_devices=False) self._train_by_any_executor(event_handler, exe, num_epochs, reader)
def _train_by_executor(self, num_epochs, event_handler, reader, feed_order): """ Train by Executor and single device. Args: num_epochs: event_handler: reader: feed_order: Returns: """ with self._prog_and_scope_guard(): exe = executor.Executor(self.place) if feed_order is None: feed_var_list = [ var for var in self.train_program.global_block( ).vars.itervalues() if hasattr(var, 'is_data') and var.is_data ] else: feed_var_list = [ self.train_program.global_block().var(var_name) for var_name in feed_order ] feeder = data_feeder.DataFeeder( feed_list=feed_var_list, place=self.place) for epoch_id in range(num_epochs): event_handler(BeginEpochEvent(epoch_id)) for step_id, data in enumerate(reader()): event_handler(BeginStepEvent(epoch_id, step_id)) exe.run(feed=feeder.feed(data), fetch_list=[]) event_handler(EndStepEvent(epoch_id, step_id)) event_handler(EndEpochEvent(epoch_id))