def samples_from_arrays(structures, rewards=None, batch_index=None, metadata=None): """Makes a generator of Samples from fields. Args: structures: Iterable of structures (1-D np array or list). rewards: Iterable of float rewards. If None, the corresponding Samples are given each given a reward of None. batch_index: Either an int, in which case all Samples created by this function will be given this batch_index or an iterable of ints for each corresponding structure. metadata: Metadata to store in the Sample. Yields: A generator of Samples """ structures = utils.to_array(structures) if metadata is None: metadata = [None] * len(structures) if rewards is None: rewards = [None] * len(structures) else: rewards = utils.to_array(rewards) if len(structures) != len(rewards): raise ValueError( 'Structures and rewards must be same length. Are %s and %s' % (len(structures), len(rewards))) if len(metadata) != len(rewards): raise ValueError( 'Metadata and rewards must be same length. Are %s and %s' % (len(metadata), len(rewards))) if batch_index is None: batch_index = 0 if isinstance(batch_index, int): batch_index = [batch_index] * len(structures) for structure, reward, batch_index, meta in zip(structures, rewards, batch_index, metadata): yield Sample(structure=structure, reward=reward, batch_index=batch_index, metadata=meta)
def _record_to_dict(record): mapping = { key: utils.to_array(value) for key, value in record.items() } if 'batch_index' in mapping: mapping['batch_index'] = int(mapping['batch_index']) return mapping