Exemple #1
0
def data_generator(rng):
    env = Env(rng.randint(1 << 31))
    # logger.info('called')

    while True:
        env.reset()
        env.prepare()
        r = 0
        while r == 0:
            last_cards_value = env.get_last_outcards()
            last_cards_char = to_char(last_cards_value)
            last_out_cards = Card.val2onehot60(last_cards_value)
            last_category_idx = env.get_last_outcategory_idx()
            curr_cards_char = to_char(env.get_curr_handcards())
            is_active = True if last_cards_value.size == 0 else False

            s = env.get_state_prob()
            # s = s[:60]
            intention, r, category_idx = env.step_auto()

            if category_idx == 14:
                continue
            minor_cards_targets = pick_minor_targets(category_idx,
                                                     to_char(intention))
            # self, state, last_cards, passive_decision_target, passive_bomb_target, passive_response_target,
            # active_decision_target, active_response_target, seq_length_target, minor_response_target, minor_type, mode
            if not is_active:
                if category_idx == Category.QUADRIC.value and category_idx != last_category_idx:
                    passive_decision_input = 1
                    passive_bomb_input = intention[0] - 3
                    yield s, last_out_cards, passive_decision_input, 0, 0, 0, 0, 0, 0, 0, 0
                    yield s, last_out_cards, 0, passive_bomb_input, 0, 0, 0, 0, 0, 0, 1

                else:
                    if category_idx == Category.BIGBANG.value:
                        passive_decision_input = 2
                        yield s, last_out_cards, passive_decision_input, 0, 0, 0, 0, 0, 0, 0, 0
                    else:
                        if category_idx != Category.EMPTY.value:
                            passive_decision_input = 3
                            # OFFSET_ONE
                            # 1st, Feb - remove relative card output since shift is hard for the network to learn
                            passive_response_input = intention[0] - 3
                            if passive_response_input < 0:
                                print("something bad happens")
                                passive_response_input = 0
                            yield s, last_out_cards, passive_decision_input, 0, 0, 0, 0, 0, 0, 0, 0
                            yield s, last_out_cards, 0, 0, passive_response_input, 0, 0, 0, 0, 0, 2
                        else:
                            passive_decision_input = 0
                            yield s, last_out_cards, passive_decision_input, 0, 0, 0, 0, 0, 0, 0, 0

            else:
                seq_length = get_seq_length(category_idx, intention)

                # ACTIVE OFFSET ONE!
                active_decision_input = category_idx - 1
                active_response_input = intention[0] - 3
                yield s, last_out_cards, 0, 0, 0, active_decision_input, 0, 0, 0, 0, 3
                yield s, last_out_cards, 0, 0, 0, 0, active_response_input, 0, 0, 0, 4

                if seq_length is not None:
                    # length offset one
                    seq_length_input = seq_length - 1
                    yield s, last_out_cards, 0, 0, 0, 0, 0, seq_length_input, 0, 0, 5

            if minor_cards_targets is not None:
                main_cards = pick_main_cards(category_idx, to_char(intention))
                handcards = curr_cards_char.copy()
                state = s.copy()
                for main_card in main_cards:
                    handcards.remove(main_card)
                cards_onehot = Card.char2onehot60(main_cards)

                # we must make the order in each 4 batch correct...
                discard_onehot_from_s_60(state, cards_onehot)

                is_pair = False
                minor_type = 0
                if category_idx == Category.THREE_TWO.value or category_idx == Category.THREE_TWO_LINE.value:
                    is_pair = True
                    minor_type = 1
                for target in minor_cards_targets:
                    target_val = Card.char2value_3_17(target) - 3
                    yield state.copy(
                    ), last_out_cards, 0, 0, 0, 0, 0, 0, target_val, minor_type, 6
                    cards = [target]
                    handcards.remove(target)
                    if is_pair:
                        if target not in handcards:
                            print('something wrong...')
                            print('minor', target)
                            print('main_cards', main_cards)
                            print('handcards', handcards)
                            print('intention', intention)
                            print('category_idx', category_idx)
                        else:
                            handcards.remove(target)
                            cards.append(target)

                    # correct for one-hot state
                    cards_onehot = Card.char2onehot60(cards)

                    # print(s.shape)
                    # print(cards_onehot.shape)
                    discard_onehot_from_s_60(state, cards_onehot)
Exemple #2
0
def play_one_episode(env, func):
    env.reset()
    env.prepare()
    r = 0
    stats = [StatCounter() for _ in range(7)]
    while r == 0:
        last_cards_value = env.get_last_outcards()
        last_cards_char = to_char(last_cards_value)
        last_out_cards = Card.val2onehot60(last_cards_value)
        last_category_idx = env.get_last_outcategory_idx()
        curr_cards_char = to_char(env.get_curr_handcards())
        is_active = True if last_cards_value.size == 0 else False

        s = env.get_state_prob()
        intention, r, category_idx = env.step_auto()

        if category_idx == 14:
            continue
        minor_cards_targets = pick_minor_targets(category_idx,
                                                 to_char(intention))

        if not is_active:
            if category_idx == Category.QUADRIC.value and category_idx != last_category_idx:
                passive_decision_input = 1
                passive_bomb_input = intention[0] - 3
                passive_decision_prob, passive_bomb_prob, _, _, _, _, _ = func(
                    [
                        s.reshape(1, -1),
                        last_out_cards.reshape(1, -1),
                        np.zeros([s.shape[0]])
                    ])
                stats[0].feed(
                    int(passive_decision_input == np.argmax(
                        passive_decision_prob)))
                stats[1].feed(
                    int(passive_bomb_input == np.argmax(passive_bomb_prob)))

            else:
                if category_idx == Category.BIGBANG.value:
                    passive_decision_input = 2
                    passive_decision_prob, _, _, _, _, _, _ = func([
                        s.reshape(1, -1),
                        last_out_cards.reshape(1, -1),
                        np.zeros([s.shape[0]])
                    ])
                    stats[0].feed(
                        int(passive_decision_input == np.argmax(
                            passive_decision_prob)))
                else:
                    if category_idx != Category.EMPTY.value:
                        passive_decision_input = 3
                        # OFFSET_ONE
                        # 1st, Feb - remove relative card output since shift is hard for the network to learn
                        passive_response_input = intention[0] - 3
                        if passive_response_input < 0:
                            print("something bad happens")
                            passive_response_input = 0
                        passive_decision_prob, _, passive_response_prob, _, _, _, _ = func(
                            [
                                s.reshape(1, -1),
                                last_out_cards.reshape(1, -1),
                                np.zeros([s.shape[0]])
                            ])
                        stats[0].feed(
                            int(passive_decision_input == np.argmax(
                                passive_decision_prob)))
                        stats[2].feed(
                            int(passive_response_input == np.argmax(
                                passive_response_prob)))
                    else:
                        passive_decision_input = 0
                        passive_decision_prob, _, _, _, _, _, _ = func([
                            s.reshape(1, -1),
                            last_out_cards.reshape(1, -1),
                            np.zeros([s.shape[0]])
                        ])
                        stats[0].feed(
                            int(passive_decision_input == np.argmax(
                                passive_decision_prob)))

        else:
            seq_length = get_seq_length(category_idx, intention)

            # ACTIVE OFFSET ONE!
            active_decision_input = category_idx - 1
            active_response_input = intention[0] - 3
            _, _, _, active_decision_prob, active_response_prob, active_seq_prob, _ = func(
                [
                    s.reshape(1, -1),
                    last_out_cards.reshape(1, -1),
                    np.zeros([s.shape[0]])
                ])

            stats[3].feed(
                int(active_decision_input == np.argmax(active_decision_prob)))
            stats[4].feed(
                int(active_response_input == np.argmax(active_response_prob)))

            if seq_length is not None:
                # length offset one
                seq_length_input = seq_length - 1
                stats[5].feed(
                    int(seq_length_input == np.argmax(active_seq_prob)))

        if minor_cards_targets is not None:
            main_cards = pick_main_cards(category_idx, to_char(intention))
            handcards = curr_cards_char.copy()
            state = s.copy()
            for main_card in main_cards:
                handcards.remove(main_card)
            cards_onehot = Card.char2onehot60(main_cards)

            # we must make the order in each 4 batch correct...
            discard_onehot_from_s_60(state, cards_onehot)

            is_pair = False
            minor_type = 0
            if category_idx == Category.THREE_TWO.value or category_idx == Category.THREE_TWO_LINE.value:
                is_pair = True
                minor_type = 1
            for target in minor_cards_targets:
                target_val = Card.char2value_3_17(target) - 3
                _, _, _, _, _, _, minor_response_prob = func([
                    state.copy().reshape(1, -1),
                    last_out_cards.reshape(1, -1),
                    np.array([minor_type])
                ])
                stats[6].feed(
                    int(target_val == np.argmax(minor_response_prob)))
                cards = [target]
                handcards.remove(target)
                if is_pair:
                    if target not in handcards:
                        logger.warn('something wrong...')
                        logger.warn('minor', target)
                        logger.warn('main_cards', main_cards)
                        logger.warn('handcards', handcards)
                    else:
                        handcards.remove(target)
                        cards.append(target)

                # correct for one-hot state
                cards_onehot = Card.char2onehot60(cards)

                # print(s.shape)
                # print(cards_onehot.shape)
                discard_onehot_from_s_60(state, cards_onehot)
    return stats