示例#1
0
    def gather_results(self, workers, log_queue, test_split=False):
        '''
        check for logs while waiting for workers
        '''
        results = {}
        eval_type = 'subgoal' if self.args.subgoals else 'task'
        lock = filelock.FileLock(self.results_path + '.lock')
        eta = ETA(self.num_trials, scope=32)
        while True:
            if log_queue.qsize() > 0:
                # there is a new log entry available, process it
                log_entry, trial_uid, model_path = log_queue.get()
                # load old results (if available)
                with lock:
                    if os.path.exists(self.results_path):
                        with open(self.results_path, 'r') as results_file:
                            results = json.load(results_file)

                eval_epoch = os.path.basename(model_path)
                # update the old results with the new log entry
                if eval_epoch not in results:
                    results[eval_epoch] = {}
                if eval_type not in results[eval_epoch]:
                    results[eval_epoch][eval_type] = {}
                if trial_uid in results[eval_epoch][eval_type] and not test_split:
                    success_prev = results[eval_epoch][eval_type][trial_uid]['success']
                    success_curr = log_entry['success']
                    if success_prev != success_curr:
                        print(colored(
                            'WARNING: trial {} result has changed from {} to {}'.format(
                                trial_uid,
                                'success' if success_prev else 'fail',
                                'success' if success_curr else 'fail'), 'yellow'))
                results[eval_epoch][eval_type][trial_uid] = log_entry

                # print updated results
                self.num_trials_done += 1
                eta.numerator = self.num_trials_done
                if not test_split:
                    successes = [
                        log['success'] for log in results[eval_epoch][eval_type].values()]
                    print(colored(
                        '{:4d}/{} trials are done (current SR = {:.1f}), ETA = {}, elapsed = {}'.format(
                            self.num_trials_done,
                            self.num_trials,
                            100 * sum(successes) / len(successes),
                            time.strftime('%H:%M:%S', time.gmtime(eta.eta_seconds)),
                            time.strftime('%H:%M:%S', time.gmtime(eta.elapsed))),
                        'green'))
                # make a backup copy of results file before writing
                eval_util.save_with_backup(results, self.results_path, lock)
                # update info.json file
                model_util.update_log(
                    self.args.dout, stage='eval',
                    update='increase', progress=1)

            # check whether all workers have exited (exitcode == None means they are still running)
            all_terminated = all([worker.exitcode is not None for worker in workers])
            if all_terminated and log_queue.qsize() == 0:
                if self.num_trials_left > 0:
                    print(colored('WARNING: only {}/{} trials were evaluated'.format(
                        self.num_trials_done, self.num_trials), 'red'))
                # our mission is over
                break
            time.sleep(1)
        print(colored('Evaluation is complete', 'green'))