示例#1
0
    def run(self, fetch_list=[], *args, **kwargs):
        """
        wrapper for Executor.run
        """
        #log.debug('Executor running step %d' % self._state.gstep)
        if self._hooks:
            fetch_list = [fetch_list]
            for h in self._hooks:
                #log.debug('calling hook.before_run %s' % h)
                fetch = h.before_run(self._state)
                fetch_list.append(fetch)
            fetch_list_len = map(len, fetch_list)
            fetch_list, schema = util.flatten(fetch_list)
            fetch_list = [
                f.name if not isinstance(f, six.string_types) else f
                for f in fetch_list
            ]
            #if len(set(fetch_list)) != len(fetch_list):
            #    log.error('strange shit happend when fetch list has idetity tensors %s' % fetch_list)
            #log.debug(fetch_list)
            res = self._exe.run(self._program.train_program,
                                fetch_list=fetch_list,
                                *args,
                                **kwargs)
            res = [self._merge_result(r) for r in res]
            #log.debug(res)

            res = util.unflatten(res, schema)
            ret, res = res[0], res[1:]
            for r, h in zip(res, self._hooks):
                #log.debug('calling hook.after_run')
                h.after_run(r, self._state)

            if any(map(lambda i: i.should_stop(self._state), self._hooks)):
                raise StopException('hook call stop')
        else:
            ret = self._exe.run(self._program.train_program,
                                fetch_list=fetch_list,
                                *args,
                                **kwargs)
        self._state = self._state.next()
        return ret
示例#2
0
 def after_run(self, res_list, state):
     res = util.unflatten(res_list, self.schema)
     for r, m in zip(res, self.metrics):
         m.update(r)