Exemple #1
0
def load_fmri_data_from_lmdb(lmdb_filename,
                             fmri_files=None,
                             fmri_data_clean=None,
                             write_frequency=10):
    ##lmdb_filename = pathout + modality + "_MMP_ROI_act_1200R_test_Dec2018_ALL.lmdb"
    ## read lmdb matrix
    import lmdb
    import os
    os.environ['TENSORPACK_SERIALIZE'] = 'msg'
    os.environ['TENSORPACK_ONCE_SERIALIZE'] = 'msg'
    from tensorpack.utils.serialize import loads
    help(loads)

    print('loading data from file: %s' % lmdb_filename)
    matrix_dict = []
    fmri_sub_name = []

    ##########################################33
    lmdb_env = lmdb.open(lmdb_filename, subdir=False)
    try:
        lmdb_txn = lmdb_env.begin()
        listed_fmri_files = loads(lmdb_txn.get(b'__keys__'))
        listed_fmri_files = [l.decode("utf-8") for l in listed_fmri_files]
        print('Stored fmri data from files:')
        print(len(listed_fmri_files))
    except:
        print('Search each key for every fmri file...')

    with lmdb_env.begin() as lmdb_txn:
        cursor = lmdb_txn.cursor()
        for key, value in cursor:
            # print(key)
            if key == b'__keys__':
                continue
            pathsub = Path(os.path.dirname(key.decode("utf-8")))
            ##subname_info = os.path.basename(key.decode("utf-8")).split('_')
            ##fmri_sub_name.append('_'.join((subname_info[0], subname_info[2], subname_info[3])))
            #############change due to directory switch to projects
            subname_info = str(
                Path(os.path.dirname(key.decode("utf-8"))).parts[-3])
            fmri_sub_name.append(
                Path(os.path.dirname(key.decode("utf-8"))).parts[-1].replace(
                    'tfMRI', subname_info))
            data = loads(lmdb_txn.get(key)).astype('float32',
                                                   casting='same_kind')
            matrix_dict.append(np.array(data))
    lmdb_env.close()

    return matrix_dict, fmri_sub_name
Exemple #2
0
def load_fmri_data_from_lmdb(lmdb_filename, Trial_Num=1200):
    ## read lmdb matrix
    print('loading data from file: %s' % lmdb_filename)
    matrix_dict = []
    fmri_sub_name = []
    sub_name = []
    lmdb_env = lmdb.open(lmdb_filename, subdir=False)
    try:
        lmdb_txn = lmdb_env.begin()
        listed_fmri_files = loads(lmdb_txn.get(b'__keys__'))
        listed_fmri_files = [l.decode("utf-8") for l in listed_fmri_files]
        print('Stored fmri data from files:')
        print(len(listed_fmri_files))
    except:
        print('Search each key for every fmri file...')

    with lmdb_env.begin() as lmdb_txn:
        cursor = lmdb_txn.cursor()
        for key, value in cursor:
            # print(key)
            if key == b'__keys__':
                continue
            pathsub = Path(os.path.dirname(key.decode("utf-8")))
            sub_name.append(
                os.path.basename(key.decode("utf-8")).split('_')[0])
            if any('REST' in string for string in lmdb_filename.split('_')):
                fmri_sub_name.append(pathsub.parts[-3] + '_' +
                                     pathsub.parts[-1].split('_')[-2][-1] +
                                     '_' + pathsub.parts[-1].split('_')[-1])
            else:
                fmri_sub_name.append(pathsub.parts[-3] + '_' +
                                     pathsub.parts[-1].split('_')[-1])
            data = loads(lmdb_txn.get(key))
            if any('REST' in string for string in lmdb_filename.split('_')):
                if data is None or data.shape[-1] != Trial_Num:
                    print('fmri data shape mis-matching between subjects...')
                    print('Check subject:  %s with only %d Trials \n' %
                          (fmri_sub_name[-1], data.shape[-1]))
                    del fmri_sub_name[-1]
                else:
                    matrix_dict.append(np.array(data))
            else:
                matrix_dict.append(np.array(data))
    lmdb_env.close()

    matrix_dict = np.array(matrix_dict)
    # print(np.array(matrix_dict).shape)
    # print(fmri_sub_name)
    return matrix_dict, fmri_sub_name, sub_name
Exemple #3
0
    def run(self):
        player = self._build_player()
        context = zmq.Context()
        c2s_socket = context.socket(zmq.PUSH)
        c2s_socket.setsockopt(zmq.IDENTITY, self.identity)
        c2s_socket.set_hwm(2)
        c2s_socket.connect(self.c2s)

        s2c_socket = context.socket(zmq.DEALER)
        s2c_socket.setsockopt(zmq.IDENTITY, self.identity)
        s2c_socket.connect(self.s2c)

        state = player.reset()
        reward, isOver = 0, False
        while True:
            # after taking the last action, get to this state and get this reward/isOver.
            # If isOver, get to the next-episode state immediately.
            # This tuple is not the same as the one put into the memory buffer
            c2s_socket.send(dumps(
                (self.identity, state, reward, isOver)),
                copy=False)
            action = loads(s2c_socket.recv(copy=False).bytes)
            state, reward, isOver, _ = player.step(action)
            if isOver:
                state = player.reset()
Exemple #4
0
    def run(self):
        enable_death_signal()
        player = self._build_player()
        context = zmq.Context()
        c2s_socket = context.socket(zmq.PUSH)
        c2s_socket.setsockopt(zmq.IDENTITY, self.identity)
        c2s_socket.set_hwm(2)
        c2s_socket.connect(self.c2s)

        s2c_socket = context.socket(zmq.DEALER)
        s2c_socket.setsockopt(zmq.IDENTITY, self.identity)
        s2c_socket.connect(self.s2c)

        state = player.reset()
        reward, isOver = 0, False
        while True:
            # after taking the last action, get to this state and get this reward/isOver.
            # If isOver, get to the next-episode state immediately.
            # This tuple is not the same as the one put into the memory buffer
            c2s_socket.send(dumps((self.identity, state, reward, isOver)),
                            copy=False)
            action = loads(s2c_socket.recv(copy=False))
            state, reward, isOver, _ = player.step(action)
            if isOver:
                state = player.reset()
def perf_from_log(log_fn):
    """
    Args:
    log_fn : a stdout file xxx/stdout/triali/stdout.txt
    """
    dn = os.path.dirname(log_fn)
    cache_fn = dn.replace('/', '__')
    cache_fn = os.path.join(cache_dir, cache_fn)
    if os.path.exists(cache_fn):
        with open(cache_fn, 'rb') as fin:
            ss = fin.read()
        try:
            ret = loads(ss)
        except:
            pass
        if ret and not FORCE_LOAD:
            return ret

    if os.path.exists(log_fn):
        min_ve, min_ve_epoch = val_err_from_log(log_fn)
        multi_add, n_params = multi_add_from_log(log_fn)
        ret = (min_ve, multi_add * 2. * 1e-9, min_ve_epoch)
        with open(cache_fn, 'wb') as fout:
            fout.write(dumps(ret))
        return ret
    else:
        return 2.0, -1.0, -1
def get_data(path, isTrain, stat_file):
    ds = LMDBSerializer.load(path, shuffle=isTrain)
    mean, std = serialize.loads(open(stat_file, 'rb').read())
    ds = MapDataComponent(ds, lambda x: (x - mean) / std)
    ds = TIMITBatch(ds, BATCH)
    if isTrain:
        ds = MultiProcessRunnerZMQ(ds, 1)
    return ds
Exemple #7
0
 def f():
     msg = self.sim2exp_socket.recv(copy=False).bytes
     msg = loads(msg)
     print('{}: received msg'.format(self.agent_name))
     try:
         self.queue.put_nowait(msg)
     except Exception:
         logger.info('put queue failed!')
Exemple #8
0
 def request_click(bbox):
     sim2mgr_socket.send(
         dumps([
             self.name, SimulatorManager.MSG_TYPE.CLICK,
             [(bbox[0] + bbox[2]) // 2 + self.window_rect[0] + 6,
              (bbox[1] + bbox[3]) // 2 + self.window_rect[1] + 46]
         ]))
     return loads(mgr2sim_socket.recv(copy=False).bytes)
Exemple #9
0
def get_data(path, isTrain, stat_file):
    ds = LMDBDataPoint(path, shuffle=isTrain)
    mean, std = serialize.loads(open(stat_file, 'rb').read())
    ds = MapDataComponent(ds, lambda x: (x - mean) / std)
    ds = TIMITBatch(ds, BATCH)
    if isTrain:
        ds = PrefetchDataZMQ(ds, 1)
    return ds
Exemple #10
0
def get_data(path, isTrain, stat_file):
    ds = LMDBDataPoint(path, shuffle=isTrain)
    mean, std = serialize.loads(open(stat_file, 'rb').read())
    ds = MapDataComponent(ds, lambda x: (x - mean) / std)
    ds = TIMITBatch(ds, BATCH)
    if isTrain:
        ds = PrefetchDataZMQ(ds, 1)
    return ds
Exemple #11
0
    def _recv(self):
        """Override `_recv` for changing current dataflow size

        WARNING: do not use this with `nr_proc > 1`"""
        from tensorpack.utils.serialize import loads
        d = loads(self.socket.recv(copy=False))

        self._size = d[0]
        return d
Exemple #12
0
def ImageDecode(ds):
    key, obj = ds
    obj = loads(obj)

    def func(im_data, flag):
        img = cv2.imdecode(im_data, flag)
        return img

    return func(obj[0], cv2.IMREAD_COLOR), func(obj[1], cv2.IMREAD_GRAYSCALE)
Exemple #13
0
    def desearialize_data_point(dp):
        idx = int(dp[0])  # index from 0 to N
        value = loads(dp[1])
        assert len(value) == 1, len(value)
        [img_id] = value

        if return_index:
            return img_id, idx
        else:
            return img_id
 def get_frame():
     global cnt
     cnt += 1
     data = loads(sok.recv(copy=False).bytes)
     data = data * 100
     print data
     spheres = [Sphere(3, pos) for pos in data]
     spheres[0].radius = 10
     cyls = build_cylinder_from_3dpts(data)
     f = Frame(spheres, cyls)
     return f
Exemple #15
0
 def desearialize_data_point(dp):
     idx = int(dp[0])  # index from 0 to N
     value = loads(dp[1])
     assert len(value) == 2, len(value)
     img_encoded, label = value
     if new_labels is not None:
         label = new_labels[idx]
     if return_index:
         return img_encoded, label, idx
     else:
         return img_encoded, label
Exemple #16
0
 def run(self):
     self.clients = defaultdict(self.ClientState)
     try:
         while True:
             msg = loads(self.c2s_socket.recv(copy=False))
             ident, state, reward, isOver = msg
             client = self.clients[ident]
             if client.ident is None:
                 client.ident = ident
             # maybe check history and warn about dead client?
             self._process_msg(client, state, reward, isOver)
     except zmq.ContextTerminated:
         logger.info("[Simulator] Context was terminated.")
Exemple #17
0
 def run(self):
     self.clients = defaultdict(self.ClientState)
     try:
         while True:
             msg = loads(self.c2s_socket.recv(copy=False).bytes)
             ident, state, reward, isOver = msg
             client = self.clients[ident]
             if client.ident is None:
                 client.ident = ident
             # maybe check history and warn about dead client?
             self._process_msg(client, state, reward, isOver)
     except zmq.ContextTerminated:
         logger.info("[Simulator] Context was terminated.")
Exemple #18
0
 def run(self):
     self.clients = defaultdict(self.ClientState)
     try:
         while True:
             msg = loads(self.c2s_socket.recv(copy=False).bytes)
             ident, role_id, prob_state, all_state, last_cards, first_st, mask, minor_type, mode, reward, isOver = msg
             client = self.clients[ident]
             if client.ident is None:
                 client.ident = ident
             # maybe check history and warn about dead client?
             self._process_msg(client, role_id, prob_state, all_state,
                               last_cards, first_st, mask, minor_type, mode,
                               reward, isOver)
     except zmq.ContextTerminated:
         logger.info("[Simulator] Context was terminated.")
Exemple #19
0
    def run(self):
        self.clients = defaultdict(self.ClientState)
        while True:
            msg = loads(self.c2s_socket.recv(copy=False).bytes)
            ident, state, reward, isOver = msg
            # TODO check history and warn about dead client
            client = self.clients[ident]

            # check if reward&isOver is valid
            # in the first message, only state is valid
            if len(client.memory) > 0:
                client.memory[-1].reward = reward
                if isOver:
                    self._on_episode_over(ident)
                else:
                    self._on_datapoint(ident)
            # feed state and return action
            self._on_state(state, ident)
Exemple #20
0
        def f():
            msg = loads(self.sim2coord_socket.recv(copy=False).bytes)
            sim_name = msg[0]
            agent_name = msg[1]

            def cb(outputs):
                try:
                    output = outputs.result()
                except CancelledError:
                    logger.info("{} cancelled.".format(sim_name))
                    return
                print('coordinator sending', sim_name.encode('utf-8'),
                      output[0].shape)
                self.coord2sim_socket.send_multipart(
                    [sim_name.encode('utf-8'),
                     dumps(output[0])])

            self.predictors[agent_name].put_task(msg[2:], cb)
Exemple #21
0
def is_mark_failure(log_dir):
    """
    Return:

    is_failed (bool) : whether the marke_stopped file indicates a failure
    ss (bytes) or ret (obj) : if the file content is json-loads capable, then we return the
        loaded json/obj, ret = loads(fin.read()); else we return the bytes in the file
        ss = fin.read()
    """
    fn = stop_mark_fn(log_dir, is_interrupted=False)
    if not os.path.exists(fn):
        return True, None
    with open(fn, 'rb') as fin:
        ss = fin.read()
        try:
            ret = loads(ss)
        except:
            return False, ss
        return ret == 'failed_meow', ret
    def run(self):
        player = self._build_player()
        context = zmq.Context()
        c2s_socket = context.socket(zmq.PUSH)
        c2s_socket.setsockopt(zmq.IDENTITY, self.identity)
        c2s_socket.set_hwm(2)
        c2s_socket.connect(self.c2s)

        s2c_socket = context.socket(zmq.DEALER)
        s2c_socket.setsockopt(zmq.IDENTITY, self.identity)
        # s2c_socket.set_hwm(5)
        s2c_socket.connect(self.s2c)

        state = player.current_state()
        reward, isOver = 0, False
        while True:
            c2s_socket.send(dumps((self.identity, state, reward, isOver)),
                            copy=False)
            action = loads(s2c_socket.recv(copy=False).bytes)
            reward, isOver = player.action(action)
            state = player.current_state()
Exemple #23
0
    def run(self):
        self.clients = defaultdict(self.ClientState)
        try:
            while True:
                msg = loads(self.c2s_socket.recv(copy=False).bytes)
                ident, state, reward, isOver = msg
                # TODO check history and warn about dead client
                client = self.clients[ident]

                # check if reward&isOver is valid
                # in the first message, only state is valid
                if len(client.memory) > 0:
                    client.memory[-1].reward = reward
                    if isOver:
                        self._on_episode_over(ident)
                    else:
                        self._on_datapoint(ident)
                # feed state and return action
                self._on_state(state, ident)
        except zmq.ContextTerminated:
            logger.info("[Simulator] Context was terminated.")
Exemple #24
0
    def run(self):
        player = self._build_player()
        context = zmq.Context()
        c2s_socket = context.socket(zmq.PUSH)
        c2s_socket.setsockopt(zmq.IDENTITY, self.identity)
        c2s_socket.set_hwm(2)
        c2s_socket.connect(self.c2s)

        s2c_socket = context.socket(zmq.DEALER)
        s2c_socket.setsockopt(zmq.IDENTITY, self.identity)
        # s2c_socket.set_hwm(5)
        s2c_socket.connect(self.s2c)

        state = player.current_state()
        reward, isOver = 0, False
        while True:
            c2s_socket.send(dumps(
                (self.identity, state, reward, isOver)),
                copy=False)
            action = loads(s2c_socket.recv(copy=False).bytes)
            reward, isOver = player.action(action)
            state = player.current_state()
Exemple #25
0
    def run(self):
        self.clients = defaultdict(self.ClientState)
        try:
            while True:
                msg = loads(self.c2s_socket.recv(copy=False).bytes)
                ident, state, action, reward, isOver = msg
                # TODO check history and warn about dead client
                client = self.clients[ident]

                # check if reward&isOver is valid
                # in the first message, only state is valid
                if len(client.memory) > 0:
                    client.memory[-1].reward = reward
                    client.memory[-1].action = action
                    if isOver:
                        self._on_episode_over(ident)
                    else:
                        self._on_datapoint(ident)
                # feed state and return action
                self._on_state(state, ident)
        except zmq.ContextTerminated:
            logger.info("[Simulator] Context was terminated.")
  def __iter__(self):
    with self._guard:
      if not self._shuffle:
        c = self._txn.cursor()
        while c.next():
          k, v = c.item()
          if k != b'__keys__':
            yield [k, v]

      else:
        batches = self._size/self.batch_from_disk
        batched_keys = numpy.array_split(self.keys, batches)
        self.rng.shuffle(batched_keys)
        labels = []
        for k in batched_keys[0]:
          v = self._txn.get(k)
          a = loads(v)
          labels.append(a[1])
        print(labels)

        for k in itertools.chain.from_iterable(batched_keys):
          v = self._txn.get(k)
          yield [k, v]
Exemple #27
0
    def run(self):
        player = self._build_player()
        context = zmq.Context()
        c2s_socket = context.socket(zmq.PUSH)
        c2s_socket.setsockopt(zmq.IDENTITY, self.identity)
        c2s_socket.set_hwm(10)
        c2s_socket.connect(self.c2s)

        s2c_socket = context.socket(zmq.DEALER)
        s2c_socket.setsockopt(zmq.IDENTITY, self.identity)
        s2c_socket.connect(self.s2c)

        st = player.reset()
        r, is_over = 0, False

        while True:
            c2s_socket.send(dumps((self.identity, st, r, is_over)), copy=False)
            # action = player.action_space.sample()
            action = loads(s2c_socket.recv(copy=False).bytes)
            st, r, is_over, _ = player.step(action)
            # print(st.shape)
            if is_over:
                player.reset()
    def _eval(self):
        if cfg.TRAINER == 'replicated':
            with ThreadPoolExecutor(max_workers=self.num_predictor, thread_name_prefix='EvalWorker') as executor, \
                    tqdm.tqdm(total=sum([df.size() for df in self.dataflows])) as pbar:
                futures = []
                for dataflow, pred in zip(self.dataflows, self.predictors):
                    futures.append(
                        executor.submit(eval_coco, dataflow, pred, pbar))
                all_results = list(
                    itertools.chain(*[fut.result() for fut in futures]))
        else:
            local_results = eval_coco(self.dataflow, self.predictor)
            results_as_arr = np.frombuffer(dumps(local_results),
                                           dtype=np.uint8)
            sizes, concat_arrs = tf.get_default_session().run(
                [self.string_lens, self.concat_results],
                feed_dict={self.local_result_tensor: results_as_arr})
            if hvd.rank() > 0:
                return
            all_results = []
            start = 0
            for size in sizes:
                substr = concat_arrs[start:start + size]
                results = loads(substr.tobytes())
                all_results.extend(results)
                start = start + size

        output_file = os.path.join(logger.get_logger_dir(),
                                   'outputs{}.json'.format(self.global_step))
        with open(output_file, 'w') as f:
            json.dump(all_results, f)
        try:
            scores = print_evaluation_scores(output_file)
            for k, v in scores.items():
                self.trainer.monitors.put_scalar(k, v)
        except Exception:
            logger.exception("Exception in COCO evaluation.")
Exemple #29
0
def load_fmri_data_from_lmdb(lmdb_filename, modality='MOTOR'):
    ##lmdb_filename = pathout + modality + "_MMP_ROI_act_1200R_test_Dec2018_ALL.lmdb"
    ## read lmdb matrix
    print('loading data from file: %s' % lmdb_filename)
    matrix_dict = []
    fmri_sub_name = []
    if not os.path.isfile(lmdb_filename) and modality == 'ALLTasks':
        print("Loading fMRI data from all tasks and merge into one lmdb file:",
              lmdb_filename)
        lmdb_env = lmdb.open(lmdb_filename,
                             subdir=False,
                             readonly=False,
                             map_size=int(1e12) * 2,
                             meminit=False,
                             map_async=True)
        write_frequency = 100

        pathout = Path(os.path.dirname(lmdb_filename))
        for lmdb_mod in sorted(
                pathout.glob(
                    os.path.basename(lmdb_filename).replace(modality, '*'))):
            mod_name = os.path.basename(lmdb_mod).split('_')[0]
            if mod_name == 'ALLTasks': continue
            print('Loading data for modality', mod_name)
            lmdb_txn = lmdb_env.begin(write=True)

            lmdb_mod_env = lmdb.open(str(lmdb_mod),
                                     subdir=False,
                                     readonly=True)
            with lmdb_mod_env.begin() as lmdb_mod_txn:
                mod_cursor = lmdb_mod_txn.cursor()
                for idx, (key, value) in enumerate(mod_cursor):
                    lmdb_txn.put(key, value)
                    if (idx + 1) % write_frequency == 0:
                        lmdb_txn.commit()
                        lmdb_txn = lmdb_env.begin(write=True)

            lmdb_txn.commit()
            lmdb_mod_env.close()

            lmdb_env.sync()
        lmdb_env.close()

    ##########################################33
    lmdb_env = lmdb.open(lmdb_filename, subdir=False)
    try:
        lmdb_txn = lmdb_env.begin()
        listed_fmri_files = loads(lmdb_txn.get(b'__keys__'))
        listed_fmri_files = [l.decode("utf-8") for l in listed_fmri_files]
        print('Stored fmri data from files:')
        print(len(listed_fmri_files))
    except:
        print('Search each key for every fmri file...')

    with lmdb_env.begin() as lmdb_txn:
        cursor = lmdb_txn.cursor()
        for key, value in cursor:
            # print(key)
            if key == b'__keys__':
                continue
            pathsub = Path(os.path.dirname(key.decode("utf-8")))
            ##subname_info = os.path.basename(key.decode("utf-8")).split('_')
            ##fmri_sub_name.append('_'.join((subname_info[0], subname_info[2], subname_info[3])))
            #############change due to directory switch to projects
            subname_info = str(
                Path(os.path.dirname(key.decode("utf-8"))).parts[-3])
            fmri_sub_name.append(
                Path(os.path.dirname(key.decode("utf-8"))).parts[-1].replace(
                    'tfMRI', subname_info))
            data = loads(lmdb_txn.get(key)).astype('float32',
                                                   casting='same_kind')
            matrix_dict.append(np.array(data))
    lmdb_env.close()

    return matrix_dict, fmri_sub_name
Exemple #30
0
    def run(self):
        player = self._build_player()
        context = zmq.Context()
        c2s_socket = context.socket(zmq.PUSH)
        c2s_socket.setsockopt(zmq.IDENTITY, self.identity)
        c2s_socket.set_hwm(10)
        c2s_socket.connect(self.c2s)

        s2c_socket = context.socket(zmq.DEALER)
        s2c_socket.setsockopt(zmq.IDENTITY, self.identity)
        s2c_socket.connect(self.s2c)

        player.reset()
        init_cards = np.arange(21)
        # init_cards = np.append(init_cards[::4], init_cards[1::4])
        player.prepare_manual(init_cards)
        r, is_over = 0, False
        while True:
            all_state, role_id, curr_handcards_value, last_cards_value, last_category = \
                player.get_state_all_cards(), player.get_role_ID(), player.get_curr_handcards(), player.get_last_outcards(), player.get_last_outcategory_idx()
            # after taking the last action, get to this state and get this reward/isOver.
            # If isOver, get to the next-episode state immediately.
            # This tuple is not the same as the one put into the memory buffer
            is_active = (last_cards_value.size == 0)
            all_state = np.stack([
                get_mask(
                    Card.onehot2char(all_state[i * 60:(i + 1) * 60]),
                    action_space,
                    None if is_active else to_char(last_cards_value)).astype(
                        np.float32) for i in range(3)
            ]).reshape(-1)
            last_state = get_mask(to_char(last_cards_value), action_space,
                                  None).astype(np.float32)

            if role_id == 2:
                st = SubState(
                    ACT_TYPE.PASSIVE if last_cards_value.size > 0
                    else ACT_TYPE.ACTIVE, all_state,
                    to_char(curr_handcards_value), last_cards_value,
                    last_category)
                if last_cards_value.size > 0:
                    assert last_category > 0
                first_st = True
                while not st.finished:
                    c2s_socket.send(dumps(
                        (self.identity, role_id,
                         st.state, st.all_state, last_state, first_st,
                         st.get_mask(), st.minor_type, st.mode, r, is_over)),
                                    copy=False)
                    first_st = False
                    action = loads(s2c_socket.recv(copy=False).bytes)
                    # logger.info('received action {}'.format(action))
                    # print(action)
                    st.step(action)

                # print(st.intention)
                assert st.card_type != -1
                r, is_over, category_idx = player.step_manual(st.intention)
            else:
                _, r, _ = player.step_auto()
                is_over = (r != 0)
            if is_over:
                # print('{} over with reward {}'.format(self.identity, r))
                # logger.info('{} over with reward {}'.format(self.identity, r))
                # sys.stdout.flush()
                player.reset()
                player.prepare_manual(init_cards)
Exemple #31
0
 def request_screen():
     sim2mgr_socket.send(dumps([self.name, SimulatorManager.MSG_TYPE.SCREEN, []]))
     return loads(mgr2sim_socket.recv(copy=False).bytes)
Exemple #32
0
 def request_unlock():
     sim2mgr_socket.send(dumps([self.name, SimulatorManager.MSG_TYPE.UNLOCK, []]))
     return loads(mgr2sim_socket.recv(copy=False).bytes)
Exemple #33
0
 def mapper(data):
     im, label = loads(data[1])
     im = cv2.imdecode(im, cv2.IMREAD_COLOR)
     im = augs.augment(im)
     return im, label
Exemple #34
0
    def run(self):
        player = self._build_player()
        context = zmq.Context()
        c2s_socket = context.socket(zmq.PUSH)
        c2s_socket.setsockopt(zmq.IDENTITY, self.identity)
        c2s_socket.set_hwm(10)
        c2s_socket.connect(self.c2s)

        s2c_socket = context.socket(zmq.DEALER)
        s2c_socket.setsockopt(zmq.IDENTITY, self.identity)
        s2c_socket.connect(self.s2c)

        player.reset()
        # init_cards = np.arange(52)
        # init_cards = np.append(init_cards[::4], init_cards[1::4])
        # player.prepare_manual(init_cards)
        player.prepare()
        r, is_over = 0, False
        lstm_state = np.zeros([1024 * 2])
        while True:
            role_id = player.get_role_ID()
            if role_id in ROLE_IDS_TO_TRAIN:
                prob_state, all_state, curr_handcards_value, last_cards_value, last_category = \
                    player.get_state_prob(), player.get_state_all_cards(), player.get_curr_handcards(), player.get_last_outcards(), player.get_last_outcategory_idx()
                prob_state = np.concatenate(
                    [Card.val2onehot60(curr_handcards_value), prob_state])
                # after taking the last action, get to this state and get this reward/isOver.
                # If isOver, get to the next-episode state immediately.
                # This tuple is not the same as the one put into the memory buffer

                is_active = False if last_cards_value.size > 0 else True
                mask = get_mask(
                    to_char(curr_handcards_value), action_space,
                    None if is_active else to_char(last_cards_value))
                if is_active:
                    mask[0] = 0
                last_two_cards = player.get_last_two_cards()
                last_two_cards_onehot = np.concatenate([
                    Card.val2onehot60(last_two_cards[0]),
                    Card.val2onehot60(last_two_cards[1])
                ])
                c2s_socket.send(dumps(
                    (self.identity, role_id, prob_state, all_state,
                     last_two_cards_onehot, mask, 0 if is_active else 1,
                     lstm_state, r, is_over)),
                                copy=False)
                action_idx, lstm_state = loads(
                    s2c_socket.recv(copy=False).bytes)

                r, is_over, _ = player.step_manual(
                    to_value(action_space[action_idx]))
            else:
                _, r, _ = player.step_auto()
                is_over = (r != 0)
            if is_over:
                # print('{} over with reward {}'.format(self.identity, r))
                # logger.info('{} over with reward {}'.format(self.identity, r))
                # sys.stdout.flush()
                player.reset()
                player.prepare()
                lstm_state = np.zeros([1024 * 2])
Exemple #35
0
 def f():
     msg = loads(self.sim2mgr_socket.recv(copy=False).bytes)
     self.queue.put(msg)