Exemplo n.º 1
0
Arquivo: flags.py Projeto: liziniu/RLX
    def finalize(cls):
        log_dir = cls.log_dir
        if log_dir is None:
            run_id = cls.run_id
            if run_id is None:
                run_id = '{}-{}-{}-{}'.format(
                    cls.algorithm, cls.env.id, cls.seed,
                    time.strftime('%Y-%m-%d-%H-%M-%S'))

            log_dir = os.path.join("logs", run_id)
            cls.log_dir = log_dir

        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

        assert cls.TRPO.rollout_samples % cls.env.num_env == 0

        if os.path.exists('.git'):
            for t in range(10):
                try:
                    if sys.platform == 'linux':
                        cls.commit = check_output(['git', 'rev-parse', 'HEAD'
                                                   ]).decode('utf-8').strip()
                        check_output(['git', 'add', '.'])
                        check_output([
                            'git', 'checkout-index', '-a', '-f',
                            '--prefix={}/src/'.format(cls.log_dir)
                        ])
                        open(os.path.join(log_dir, 'diff.patch'), 'w').write(
                            check_output(['git', '--no-pager', 'diff',
                                          'HEAD']).decode('utf-8'))
                    else:
                        check_output([
                            'git', 'checkout-index', '-a',
                            '--prefix={}/src/'.format(cls.log_dir)
                        ])
                    break
                except Exception as e:
                    print(e)
                    print('Try again...')
                time.sleep(1)
            else:
                raise RuntimeError('Failed after 10 trials.')

        yaml.dump(cls.as_dict(),
                  open(os.path.join(log_dir, 'config.yml'), 'w'),
                  default_flow_style=False)
        # logger.add_sink(FileSink(os.path.join(log_dir, 'log.json')))
        logger.add_sink(FileSink(os.path.join(log_dir, 'log.txt')))
        logger.add_csvwriter(CSVWriter(os.path.join(log_dir, 'progress.csv')))
        logger.info("log_dir = %s", log_dir)

        cls.set_frozen()
Exemplo n.º 2
0
    def finalize(cls):
        log_dir = cls.log_dir
        if log_dir is None:
            run_id = cls.run_id
            if run_id is None:
                run_id = time.strftime('%Y-%m-%d_%H-%M-%S')

            log_dir = os.path.join(cls.ckpt.base, run_id)
            cls.log_dir = log_dir

        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

        for t in range(60):
            try:
                cls.commit = check_output(['git', 'rev-parse',
                                           'HEAD']).decode('utf-8').strip()
                check_output(['git', 'add', '.'])
                check_output([
                    'git', 'checkout-index', '-a', '-f',
                    f'--prefix={log_dir}/src/'
                ])
                break
            except CalledProcessError:
                pass
            time.sleep(1)
        else:
            raise RuntimeError('Failed after 60 trials.')

        yaml.dump(cls.as_dict(),
                  open(os.path.join(log_dir, 'config.yml'), 'w'),
                  default_flow_style=False)
        open(os.path.join(log_dir, 'diff.patch'), 'w').write(
            check_output(['git', '--no-pager', 'diff',
                          'HEAD']).decode('utf-8'))

        logger.add_sink(FileSink(os.path.join(log_dir, 'log.json')))
        logger.info("log_dir = %s", log_dir)

        cls.set_frozen()