コード例 #1
0
def from_fname(fname):
    with _open(fname) as f:
        dotmdp = parse_mdp(f.read())

    if dotmdp.values == 'cost':
        raise ValueError('I do not know how to handle `cost` values.')

    # TODO I think this should not be mean but something else..
    if np.any(dotmdp.R.mean(axis=-1, keepdims=True) != dotmdp.R):
        raise ValueError(
            'I cannot handle rewards which depend on observations.')

    sspace = indextools.DomainSpace(dotmdp.states)
    aspace = indextools.DomainSpace(dotmdp.actions)
    gamma = dotmdp.discount
    domain = mdp.Domain(sspace, aspace, gamma=gamma)

    if dotmdp.start is None:
        start = np.ones(sspace.nelems) / sspace.nelems
    else:
        start = dotmdp.start
    T = np.swapaxes(dotmdp.T, 0, 1)
    R = np.einsum('jik', dotmdp.R)

    s0model = mdp.State0Distribution(domain)
    s1model = mdp.State1Distribution(domain)
    rmodel = mdp.RewardDistribution(domain)

    s0model.array = start
    s1model.array = T
    rmodel.array = R

    domain.model = mdp.Model(s0model, s1model, rmodel)
    return domain
コード例 #2
0
    def setUp(self):
        self.s1 = indextools.BoolSpace()
        self.s2 = indextools.DomainSpace('abc')
        self.s3 = indextools.RangeSpace(4)
        self.s4 = indextools.DomainSpace('defgh')

        self.model = Softmax(self.s1, self.s2, cond=(self.s3, self.s4))
コード例 #3
0
    def from_fname(env, fname):
        with data.resource_open(fname, 'fsc') as f:
            dotfsc = parse_fsc(f.read())

        # TODO stadt..
        nspace = indextools.DomainSpace(dotfsc.nodes)
        n0 = dotfsc.start.argmax()
        A = dotfsc.A
        T = np.swapaxes(dotfsc.T, 0, 1)
        return FSC_File(env, nspace, n0, A, T)
コード例 #4
0
def Gridworld():
    # TODO avoid naive way!  or maybe not?  I want positions are tuples!
    svalues = [State((i, j)) for i in range(n) for j in range(n)]
    sspace = indextools.DomainSpace(svalues)
    # sspace.istr = lambda s: str(s.value.pos)
    # NOTE namedtuple takes care of map to string

    avalues = 'north', 'south', 'east', 'west'
    aspace = indextools.DomainSpace(avalues)
    # aspace.istr = lambda a: f'Action({a.value})'

    env = mdp.Environment(sspace, aspace)
    env.gamma = 1
    env.n = n

    s0model = Gridworld_S0Model(env)
    s1model = Gridworld_S1Model(env)
    rmodel = Gridworld_RModel(env)
    env.model = mdp.Model(env, s0model, s1model, rmodel)

    return env
コード例 #5
0
    def from_fname(env, fname):
        with data.resource_open(fname, 'fss') as f:
            dotfss = parse_fss(f.read())

        # TODO check that all actions are used
        # TODO check that structure fully connected
        # TODO check that the actions are the same

        if isinstance(dotfss.nodes[0], str):
            nodes = dotfss.nodes
        else:
            nodes = tuple(f'node_{n}' for n in dotfss.nodes)
        nspace = indextools.DomainSpace(nodes)

        return FSC_Structured(env, nspace, dotfss.start, dotfss.A, dotfss.N)
コード例 #6
0
def generate_pomdp(g_pos=7, num_plate=4, gamma=0.95):
    assert (g_pos > num_plate)
    assert (num_plate > 0)
    assert (0 < gamma <= 1)

    # SPACES
    angle = ['D', 'N', 'U']
    angle_space = indextools.DomainSpace(angle)

    state_space = indextools.JointNamedSpace(
        g_pos=indextools.RangeSpace(g_pos),
        theta=angle_space,
        p_pos=indextools.RangeSpace(
            1, num_plate +
            1)  # Plate starts from index 1 to num_plate inclusively.
    )
    start_states = [s for s in state_space.elems if is_start_state(s)]
    states = [s for s in state_space.elems if is_state_valid(s)]

    action_list = ['U', 'D', 'G']
    action_space = indextools.DomainSpace(action_list)
    actions = list(action_space.elems)

    obs_space = indextools.JointNamedSpace(g_pos=indextools.RangeSpace(g_pos),
                                           theta=angle_space)
    observations = [o for o in obs_space.elems if is_obs_valid(o, num_plate)]

    # PREAMBLE
    pomdp_strs = []

    pomdp_strs.append(f'# Plates Env POMDP\n')
    pomdp_strs.append(
        f'# This specific file was generated with the following parameters:\n')
    pomdp_strs.append(f'# g_pos: {g_pos}\n')
    pomdp_strs.append(f'# num_plate: {num_plate}\n')
    pomdp_strs.append(f'# gamma: {gamma}\n')
    pomdp_strs.append(f'# angle: {angle}\n')
    pomdp_strs.append(f'# Number of all actions: {len(actions)}\n')
    pomdp_strs.append(f'# Number of all states: {len(states)}\n')
    pomdp_strs.append(f'# Number of all observations: {len(observations)}\n')
    header_end_idx = len(
        pomdp_strs)  # Inserting position for the comments generated later.
    pomdp_strs.append('\n')

    pomdp_strs.append(f'discount: {gamma}\n')
    pomdp_strs.append('values: reward\n')
    pomdp_strs.append(f'states: {all_states_fmt(states)}\n')
    pomdp_strs.append(f'actions: {" ".join(action_fmt(a) for a in actions)}\n')
    pomdp_strs.append(f'observations: {all_obs_fmt(observations)}\n')

    # START
    pomdp_strs.append('\n')
    pomdp_strs.append(
        f'start include: {" ".join(state_fmt(s) for s in start_states)}\n')

    # TRANSITIONS
    pomdp_strs.append('\n')
    block_start_idx = len(
        pomdp_strs)  # For counting the state transition probabilities.
    for a in actions:
        # If grasps then resets.
        if a.value == 'G':
            pomdp_strs.append(f'T: {action_fmt(a)} : * reset\n')
            continue

        for s in states:
            if a.value == 'U':
                # Can still move up.
                if s.g_pos.value < g_pos - 1:
                    s_next = copy(s)
                    s_next.g_pos.value += 1

                    # s_next.g_pos between [1, the top plate].
                    if s_next.g_pos.value <= s.p_pos.value:
                        s_next.theta.value = 'D'

                    # s_next.g_pos > the top plate.
                    else:
                        s_next.theta.value = 'N'

                    pomdp_strs.append(
                        f'T: {action_fmt(a)} : {state_fmt(s)} : {state_fmt(s_next)} 1.0\n'
                    )

                # Cannot move up then resets.
                else:
                    pomdp_strs.append(
                        f'T: {action_fmt(a)} : {state_fmt(s)} reset\n')

            elif a.value == 'D':
                # Can still move down.
                if s.g_pos.value > 0:
                    s_next = copy(s)
                    s_next.g_pos.value -= 1

                    # At s_next.g_pos = 0, the angle is neutral.
                    if s_next.g_pos.value == 0:
                        s_next.theta.value = 'N'

                    # s_next.g_pos between [1, the top plate].
                    elif s_next.g_pos.value <= s.p_pos.value:
                        s_next.theta.value = 'U'

                    # s_next.g_pos > the top plate.
                    else:
                        s_next.theta.value = 'N'

                    pomdp_strs.append(
                        f'T: {action_fmt(a)} : {state_fmt(s)} : {state_fmt(s_next)} 1.0\n'
                    )

                # Cannot move down then resets.
                else:
                    pomdp_strs.append(
                        f'T: {action_fmt(a)} : {state_fmt(s)} reset\n')

    block_size = len(pomdp_strs) - block_start_idx
    pomdp_strs.insert(
        header_end_idx,
        f'# Number of state transition probabilities: {block_size}\n')

    # OBSERVATIONS
    pomdp_strs.append('\n')
    block_start_idx = len(
        pomdp_strs)  # For counting the observation probabilities.
    pomdp_strs.append("O: * : * : * 0.0\n")
    for a in actions:
        for s in states:
            for o in observations:
                if o.g_pos.value == s.g_pos.value and o.theta.value == s.theta.value:
                    pomdp_strs.append(
                        f'O: {action_fmt(a)} : {state_fmt(s)} : {obs_fmt(o)} 1.0\n'
                    )

    header_end_idx += 1
    block_size = len(pomdp_strs) - block_start_idx
    pomdp_strs.insert(
        header_end_idx,
        f'# Number of observation probabilities: {block_size}\n')

    # REWARDS
    pomdp_strs.append('\n')
    block_start_idx = len(pomdp_strs)  # For counting the immediate rewards.
    for a in actions:
        if a.value == 'G':
            for s in states:
                if s.g_pos.value == s.p_pos.value:
                    pomdp_strs.append(
                        f'R: {action_fmt(a)} : {state_fmt(s)} : * : * 1.0\n')

    header_end_idx += 1
    block_size = len(pomdp_strs) - block_start_idx
    pomdp_strs.insert(header_end_idx,
                      f'# Number of immediate rewards: {block_size}\n')

    # Print.
    pomdp_text = "".join(pomdp_strs)
    print(pomdp_text)

    # Writes to a pomdp file.
    with open('../../../pomdps/plates.pomdp', 'w') as f:
        f.write(pomdp_text)
コード例 #7
0
    return o.value


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Shopping')
    parser.add_argument('n', type=int, default=None)
    # parser.add_argument('--episodic', action='store_true')
    parser.add_argument('--gamma', type=float, default=0.99)
    config = parser.parse_args()

    assert config.n >= 1
    assert 0 < config.gamma <= 1

    ncells = 2 + 4 * config.n
    cell_space = indextools.RangeSpace(ncells)
    heaven_space = indextools.DomainSpace(['left', 'right'])

    state_space = indextools.JointNamedSpace(
        heaven=heaven_space, cell=cell_space
    )

    actions = ['N', 'S', 'E', 'W']
    action_space = indextools.DomainSpace(actions)

    obs = [f'o{i}' for i in range(cell_space.nelems - 1)] + ['left', 'right']
    obs_space = indextools.DomainSpace(obs)

    print(
        """# A robot will be rewarded +1 for attaining heaven in one
# if it accidently reaches hell it will get -1
# Problem is attributed to Sebastian Thrun but first appeared in Geffner
コード例 #8
0
def main():
    parser = argparse.ArgumentParser(description='Plate')
    parser.add_argument('--g-pos', type=int, default=20)
    parser.add_argument('--theta', type=int, default=3)
    parser.add_argument('--book', type=int, default=16)
    parser.add_argument('--gamma', type=float, default=0.95)
    config = parser.parse_args()

    assert (config.g_pos > config.book)
    assert (config.theta > 0)
    assert (config.book > 0)
    assert (0 < config.gamma <= 1)

    # SPACES
    angle = ['XL', 'L', 'N', 'R', 'XR']
    angle_space = indextools.DomainSpace(angle)

    state_space = indextools.JointNamedSpace(
        g_pos=indextools.RangeSpace(config.g_pos),
        theta=angle_space,
        b_pos=indextools.RangeSpace(
            1, config.book +
            1)  # Plate starts from index 1 to config.book inclusively.
    )
    start_states = [s for s in state_space.elems if is_start_state(s)]
    states = [s for s in state_space.elems if is_state_valid(s, config.book)]

    action_list = ['Right', 'Left', 'Grasp']
    action_space = indextools.DomainSpace(action_list)
    actions = list(action_space.elems)

    obs_space = indextools.JointNamedSpace(g_pos=indextools.RangeSpace(
        config.g_pos),
                                           theta=angle_space)
    observations = [o for o in obs_space.elems if is_obs_valid(o, config.book)]

    # PREAMBLE
    pomdp_strs = []

    pomdp_strs.append(f'# This specific file was generated with parameters:\n')
    pomdp_strs.append(f'# {config}\n')
    pomdp_strs.append('\n')
    pomdp_strs.append(f'discount: {config.gamma}\n')
    pomdp_strs.append('values: reward\n')

    pomdp_strs.append(f'states: {" ".join(state_fmt(s) for s in states)}\n')
    pomdp_strs.append(f'actions: {" ".join(action_fmt(a) for a in actions)}\n')
    pomdp_strs.append(
        f'observations: {" ".join(obs_fmt(o) for o in observations)}\n')

    # START
    pomdp_strs.append('\n')
    pomdp_strs.append(
        f'start include: {" ".join(state_fmt(s) for s in start_states)}\n')

    # TRANSITIONS
    pomdp_strs.append('\n')
    for a in actions:
        # If grasps then resets.
        if a.value == 'Grasp':
            pomdp_strs.append(f'T: {action_fmt(a)} : * reset\n')
            continue

        for s in states:
            if a.value == 'Right':
                # Can still move right.
                if s.g_pos.value < config.g_pos - 1:
                    s1 = copy(s)
                    s1.g_pos.value += 1

                    # x_g between [1, the rightmost book].
                    if s1.g_pos.value <= config.book:
                        if s1.g_pos.value == s.b_pos.value:
                            s1.theta.value = 'XL'
                        else:
                            s1.theta.value = 'L'

                    # The gripper is at least at the book's position.
                    else:
                        s1.theta.value = 'N'

                    pomdp_strs.append(
                        f'T: {action_fmt(a)} : {state_fmt(s)} : {state_fmt(s1)} 1.0\n'
                    )

                # Cannot move right then resets.
                else:
                    pomdp_strs.append(
                        f'T: {action_fmt(a)} : {state_fmt(s)} reset\n')

            elif a.value == 'Left':
                # Can still move left.
                if s.g_pos.value > 0:
                    s1 = copy(s)
                    s1.g_pos.value -= 1

                    # At x_g = 0, the angle is neutral.
                    if s1.g_pos.value == 0:
                        s1.theta.value = 'N'

                    # x_g between [0, the rightmost book].
                    elif s1.g_pos.value <= config.book:
                        if s1.g_pos.value == s.b_pos.value:
                            s1.theta.value = 'XR'
                        else:
                            s1.theta.value = 'R'

                    # The gripper is at least 1 unit far from the book's position.
                    else:
                        s1.theta.value = 'N'

                    pomdp_strs.append(
                        f'T: {action_fmt(a)} : {state_fmt(s)} : {state_fmt(s1)} 1.0\n'
                    )

                # Cannot move left then resets.
                else:
                    pomdp_strs.append(
                        f'T: {action_fmt(a)} : {state_fmt(s)} reset\n')

    # OBSERVATIONS
    pomdp_strs.append('\n')
    pomdp_strs.append("O: * : * : * 0.0\n")
    for a in actions:
        for s in states:
            for o in observations:
                if (o.g_pos.value == s.g_pos.value) and (o.theta.value
                                                         == s.theta.value):
                    pomdp_strs.append(
                        f'O: {action_fmt(a)} : {state_fmt(s)} : {obs_fmt(o)} 1.0\n'
                    )

    # REWARDS
    pomdp_strs.append('\n')
    for a in actions:
        if a.value == 'Grasp':
            for s in states:
                if s.g_pos.value == s.b_pos.value:
                    pomdp_strs.append(
                        f'R: {action_fmt(a)} : {state_fmt(s)} : * : * 1.0\n')

    pomdp_text = "".join(pomdp_strs)
    print(pomdp_text)

    # Writes to a pomdp file.
    with open('books.pomdp', 'w') as f:
        f.write(pomdp_text)
コード例 #9
0
ファイル: bumps1d.v0.py プロジェクト: zhihanyang2022/drqn
def main():
    parser = argparse.ArgumentParser(description='Plate')
    parser.add_argument('--cart', type=int, default=10)
    parser.add_argument('--gamma', type=float, default=0.95)
    config = parser.parse_args()

    angle = ['L', 'Z', 'R']
    angle_space = indextools.DomainSpace(angle)

    cart_pos_min = 0
    cart_pos_max = config.cart

    min_bump_distance = 2
    max_bump_distance = int(cart_pos_max / 2)

    pos_space = indextools.JointNamedSpace(
        cart=indextools.RangeSpace(cart_pos_min, cart_pos_max + 1),
        angle=angle_space,
        lbump=indextools.RangeSpace(cart_pos_min + 1, cart_pos_max),
        rbump=indextools.RangeSpace(cart_pos_min + 1, cart_pos_max))

    check_state = lambda s: check_valid(s, min_bump_distance, max_bump_distance
                                        )

    state_space = indextools.JointNamedSpace(pos=pos_space)
    all_states = [s for s in state_space.elems if check_state(s)]

    start_states = [
        s for s in state_space.elems
        if s.pos.angle.value == 'Z' and check_state(s)
    ]

    actions = ['LS', 'LH', 'RS', 'RH']
    action_space = indextools.DomainSpace(actions)

    obs = indextools.JointNamedSpace(cart=indextools.RangeSpace(
        cart_pos_min, cart_pos_max + 1),
                                     angle=angle_space)
    obs_space = indextools.JointNamedSpace(pos=obs)
    all_obs = [s for s in obs_space.elems]

    print(f'# This specific file was generated with parameters:')
    print(f'# {config}')
    print(f'# observations: {len(all_obs)}')
    print(f'# actions: {len(actions)}')
    print(f'# states: {len(all_states)}')
    print()
    print(f'discount: {config.gamma}')
    print('values: reward')

    print(
        f'states: {" ".join(sfmt(s) for s in state_space.elems if check_state(s))}'
    )
    print(f'actions: {" ".join(afmt(a) for a in action_space.elems)}')
    print(f'observations: {" ".join(ofmt(o) for o in obs_space.elems)}')

    # START
    print()
    print(f'start include: {" ".join(sfmt(s) for s in start_states)}')

    # TRANSITIONS
    print()

    for a in action_space.elems:
        if a.value == 'LS':
            for s in state_space.elems:
                if check_state(s):
                    s1 = copy(s)
                    # Can still move left
                    if (s.pos.cart.value >= cart_pos_min + 1):
                        s1.pos.cart.value -= 1

                        # one of bumps is on the left
                        if ((s.pos.cart.value == s.pos.rbump.value + 1) or
                            (s.pos.cart.value == s.pos.lbump.value + 1)):
                            s1.pos.angle.value = 'R'
                        else:
                            s1.pos.angle.value = 'Z'

                        print(f'T: {afmt(a)}: {sfmt(s)}: {sfmt(s1)} 1.0')

                    else:
                        print(f'T: {afmt(a)}: {sfmt(s)} reset')

        if a.value == 'RS':
            for s in state_space.elems:
                if check_state(s):
                    s1 = copy(s)
                    # Can still move right
                    if (s.pos.cart.value <= cart_pos_max - 1):
                        s1.pos.cart.value += 1

                        # one of bumps is on the right then the finger points LEFT
                        if ((s.pos.cart.value == s.pos.rbump.value - 1) or
                            (s.pos.cart.value == s.pos.lbump.value - 1)):
                            s1.pos.angle.value = 'L'
                        else:
                            s1.pos.angle.value = 'Z'

                        print(f'T: {afmt(a)}: {sfmt(s)}: {sfmt(s1)} 1.0')

                    else:
                        print(f'T: {afmt(a)}: {sfmt(s)} reset')

        if a.value == 'LH':
            for s in state_space.elems:
                if check_state(s):
                    s1 = copy(s)
                    # Can still move left
                    if (s.pos.cart.value >= cart_pos_min + 1):
                        s1.pos.cart.value -= 1

                        # left/right bump is here then it is moved left if angle is positive
                        if ((s.pos.cart.value
                             in [s.pos.rbump.value, s.pos.lbump.value])
                                and (s.pos.angle.value == 'R')):
                            print(f'T: {afmt(a)}: {sfmt(s)} reset')

                        # left/right bump is on the left then it is moved left if angle is zero or point left
                        elif ((s.pos.cart.value - 1
                               in [s.pos.lbump.value, s.pos.rbump.value])
                              and (s.pos.angle.value in ['Z', 'L'])):
                            print(f'T: {afmt(a)}: {sfmt(s)} reset')

                        else:
                            print(f'T: {afmt(a)}: {sfmt(s)}: {sfmt(s1)} 1.0')

                    else:
                        print(f'T: {afmt(a)}: {sfmt(s)} reset')

        if a.value == 'RH':
            for s in state_space.elems:
                if check_state(s):
                    s1 = copy(s)
                    # Can still move right
                    if (s.pos.cart.value <= cart_pos_max - 1):
                        s1.pos.cart.value += 1

                        # left/right bump is here then it is moved left if finger points left
                        if ((s.pos.cart.value
                             in [s.pos.rbump.value, s.pos.lbump.value])
                                and (s.pos.angle.value == 'L')):
                            print(f'T: {afmt(a)}: {sfmt(s)} reset')

                        # left/right bump is on the right then it is moved left if angle is zero or points right
                        elif ((s.pos.cart.value + 1
                               in [s.pos.lbump.value, s.pos.rbump.value])
                              and (s.pos.angle.value in ['Z', 'R'])):
                            print(f'T: {afmt(a)}: {sfmt(s)} reset')

                        else:
                            print(f'T: {afmt(a)}: {sfmt(s)}: {sfmt(s1)} 1.0')

                    else:
                        print(f'T: {afmt(a)}: {sfmt(s)} reset')

    # OBSERVATIONS
    print()
    for a in action_space.elems:
        for s in state_space.elems:
            if check_state(s):
                for o in obs_space.elems:
                    if (o.pos.cart.value
                            == s.pos.cart.value) and (o.pos.angle.value
                                                      == s.pos.angle.value):
                        print(f'O: {afmt(a)}: {sfmt(s)}: {ofmt(o)} 1.0')

    # REWARDS
    print()
    for a in action_space.elems:
        if a.value == 'RH':
            for s in state_space.elems:
                if check_state(s):
                    if s.pos.cart.value == s.pos.rbump.value and s.pos.angle.value == 'L':
                        print(f'R: {afmt(a)}: {sfmt(s)}: *: * 1.0')

                    if s.pos.cart.value + 1 == s.pos.rbump.value and s.pos.angle.value in [
                            'Z', 'R'
                    ]:
                        print(f'R: {afmt(a)}: {sfmt(s)}: *: * 1.0')