예제 #1
0
def load_ac_agent(h5file):
    model = load_model_from_hdf5_group(h5file['model'])
    encoder_name = h5file['encoder'].attrs['name']
    if not isinstance(encoder_name, str):
        encoder_name = encoder_name.decode('ascii')
    board_sz = h5file['encoder'].attrs['board_sz']
    encoder = get_encoder_by_name(encoder_name, board_sz)
    return ACAgent(model, encoder)
예제 #2
0
파일: zero.py 프로젝트: armandli/rlgames
def load_zero_agent(h5file, eval_rounds=1600, exploration_factor=2.):
    model = load_model_from_hdf5_group(h5file['model'])
    encoder_name = h5file['encoder'].attrs['name']
    if not isinstance(encoder_name, str):
        encoder_name = encoder_name.decode('ascii')
    board_sz = h5file['encoder'].attrs['board_sz']
    encoder = get_encoder_by_name(encoder_name, board_sz)
    return ZeroAgent(model, encoder, eval_rounds, exploration_factor)
예제 #3
0
def main():
    args = parse_args()
    if not os.path.isfile(args.path):
        raise ValueError('File {} does not exist!'.format(args.path))
    with h5py.File(args.path, 'r') as h5file:
        model = load_model_from_hdf5_group(h5file['model'])
        encoder_name = h5file['encoder'].attrs['name']
        print(encoder_name)
        board_width = h5file['encoder'].attrs['board_width']
        board_height = h5file['encoder'].attrs['board_height']
        print('board width {} and height {}'.format(board_width, board_height))
        with h5py.File(args.out, 'w') as outfile:
            outfile.create_group('encoder')
            outfile['encoder'].attrs['name'] = encoder_name
            outfile['encoder'].attrs['board_sz'] = board_width
            outfile.create_group('model')
            save_model_to_hdf5_group(model, outfile['model'])
    print('Model saved to {}'.format(args.out))
예제 #4
0
def load_policy_agent(h5file):
    model = load_model_from_hdf5_group(h5file['model'])
    encoder_name = h5file['encoder'].attrs['name']
    board_sz = h5file['encoder'].attrs['board_sz']
    encoder = get_encoder_by_name(encoder_name, board_sz)
    return PolicyAgent(model, encoder)