def finalize(cls): log_dir = cls.log_dir if log_dir is None: run_id = cls.run_id if run_id is None: run_id = '{}-{}-{}-{}'.format( cls.algorithm, cls.env.id, cls.seed, time.strftime('%Y-%m-%d-%H-%M-%S')) log_dir = os.path.join("logs", run_id) cls.log_dir = log_dir if not os.path.exists(log_dir): os.makedirs(log_dir) assert cls.TRPO.rollout_samples % cls.env.num_env == 0 if os.path.exists('.git'): for t in range(10): try: if sys.platform == 'linux': cls.commit = check_output(['git', 'rev-parse', 'HEAD' ]).decode('utf-8').strip() check_output(['git', 'add', '.']) check_output([ 'git', 'checkout-index', '-a', '-f', '--prefix={}/src/'.format(cls.log_dir) ]) open(os.path.join(log_dir, 'diff.patch'), 'w').write( check_output(['git', '--no-pager', 'diff', 'HEAD']).decode('utf-8')) else: check_output([ 'git', 'checkout-index', '-a', '--prefix={}/src/'.format(cls.log_dir) ]) break except Exception as e: print(e) print('Try again...') time.sleep(1) else: raise RuntimeError('Failed after 10 trials.') yaml.dump(cls.as_dict(), open(os.path.join(log_dir, 'config.yml'), 'w'), default_flow_style=False) # logger.add_sink(FileSink(os.path.join(log_dir, 'log.json'))) logger.add_sink(FileSink(os.path.join(log_dir, 'log.txt'))) logger.add_csvwriter(CSVWriter(os.path.join(log_dir, 'progress.csv'))) logger.info("log_dir = %s", log_dir) cls.set_frozen()
def finalize(cls): log_dir = cls.log_dir if log_dir is None: run_id = cls.run_id if run_id is None: run_id = time.strftime('%Y-%m-%d_%H-%M-%S') log_dir = os.path.join(cls.ckpt.base, run_id) cls.log_dir = log_dir if not os.path.exists(log_dir): os.makedirs(log_dir) for t in range(60): try: cls.commit = check_output(['git', 'rev-parse', 'HEAD']).decode('utf-8').strip() check_output(['git', 'add', '.']) check_output([ 'git', 'checkout-index', '-a', '-f', f'--prefix={log_dir}/src/' ]) break except CalledProcessError: pass time.sleep(1) else: raise RuntimeError('Failed after 60 trials.') yaml.dump(cls.as_dict(), open(os.path.join(log_dir, 'config.yml'), 'w'), default_flow_style=False) open(os.path.join(log_dir, 'diff.patch'), 'w').write( check_output(['git', '--no-pager', 'diff', 'HEAD']).decode('utf-8')) logger.add_sink(FileSink(os.path.join(log_dir, 'log.json'))) logger.info("log_dir = %s", log_dir) cls.set_frozen()