Beispiel #1
0
    # import corresponding configuration , neural networks and envs
    if args.case == 'classic_control':
        from config.classic_control import run_config
    elif args.case == 'box2d':
        from config.box2d import run_config
    elif args.case == 'mujoco':
        from config.mujoco import run_config
    elif args.case == 'dm_control':
        from config.dm_control import run_config
    else:
        raise Exception('Invalid --case option.')

    # set config as per arguments
    run_config.set_config(args)
    log_base_path = make_results_dir(run_config.exp_path, args)
    if args.use_wandb:
        os.makedirs(args.wandb_dir, exist_ok=True)

    # set-up logger
    init_logger(log_base_path)
    logging.getLogger('root').info('cmd args:{}'.format(' '.join(
        sys.argv[1:])))  # log command line arguments.

    try:
        if args.opr == 'train':
            if args.use_wandb:
                os.makedirs(args.wandb_dir, exist_ok=True)
                os.environ['WANDB_DIR'] = str(args.wandb_dir)
                import wandb
Beispiel #2
0
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # import corresponding configuration , neural networks and envs
    if args.case == 'atari':
        from config.atari import muzero_config
    elif args.case == 'box2d':
        from config.classic_control import muzero_config  # just using same config as classic_control for now
    elif args.case == 'classic_control':
        from config.classic_control import muzero_config
    else:
        raise Exception('Invalid --case option')

    # set config as per arguments
    exp_path = muzero_config.set_config(args)
    exp_path, log_base_path = make_results_dir(exp_path, args)

    # set-up logger
    init_logger(log_base_path)

    try:
        if args.opr == 'train':
            summary_writer = SummaryWriter(exp_path, flush_secs=10)
            train(muzero_config, summary_writer)

        elif args.opr == 'test':
            assert os.path.exists(muzero_config.model_path), 'model not found at {}'.format(muzero_config.model_path)
            model = muzero_config.get_uniform_network().to('cpu')
            model.load_state_dict(torch.load(muzero_config.model_path, map_location=torch.device('cpu')))
            test_score = test(muzero_config, model, args.test_episodes, device='cpu', render=args.render,
                              save_video=True)
Beispiel #3
0
    parser.add_argument(
        '--log_suffix',
        type=str,
        default='',
        help=
        'Log Suffix Attached to the resulting directory (default: %(default)s)'
    )

    # Process arguments
    args = parser.parse_args()
    args.device = 'cuda' if (
        not args.no_cuda) and torch.cuda.is_available() else 'cpu'

    # create relative paths to store results
    exp_path = make_results_dir(
        os.path.join(args.result_dir, args.env, 'seed_{}'.format(args.seed)),
        args)

    # set-up logger and tensorboard
    init_logger(os.path.join(exp_path, args.opr + '.log'))
    logger = logging.getLogger()
    summary_writer = SummaryWriter(exp_path, flush_secs=10)

    # seeding random iterators
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    try:
        # import corresponding configuration , neural networks and envs
        if args.case == 'atari':
            from config.atari import muzero_config