示例#1
0
def write_record(dataset, record):
    features = dict()
    feature_lists = dict()
    for value_name, value_type in dataset.values.items():
        value_type, alts = util.alternatives_type(value_type=value_type)
        if value_type == 'model':
            continue
        if value_type == 'int':
            if alts:
                feature_lists[value_name] = tf.train.FeatureList(feature=[tf.train.Feature(int64_list=tf.train.Int64List(value=(value,))) for value in record[value_name]])
            else:
                features[value_name] = tf.train.Feature(int64_list=tf.train.Int64List(value=(record[value_name],)))
        elif value_type == 'float':
            if alts:
                feature_lists[value_name] = tf.train.FeatureList(feature=[tf.train.Feature(float_list=tf.train.FloatList(value=(value,))) for value in record[value_name]])
            else:
                features[value_name] = tf.train.Feature(float_list=tf.train.FloatList(value=(record[value_name],)))
        elif value_type == 'vector(int)' or value_type in dataset.vocabularies:
            if alts:
                feature_lists[value_name] = tf.train.FeatureList(feature=[tf.train.Feature(int64_list=tf.train.Int64List(value=value)) for value in record[value_name]])
            else:
                features[value_name] = tf.train.Feature(int64_list=tf.train.Int64List(value=record[value_name]))
        elif value_type == 'vector(float)':
            if alts:
                feature_lists[value_name] = tf.train.FeatureList(feature=[tf.train.Feature(float_list=tf.train.FloatList(value=value.flatten())) for value in record[value_name]])
            else:
                features[value_name] = tf.train.Feature(float_list=tf.train.FloatList(value=record[value_name].flatten()))
        elif value_type == 'world':
            if alts:
                features[value_name] = tf.train.FeatureList(feature=[tf.train.Feature(float_list=tf.train.FloatList(value=value.flatten())) for value in record[value_name]])
            else:
                features[value_name] = tf.train.Feature(float_list=tf.train.FloatList(value=record[value_name].flatten()))
    record = tf.train.SequenceExample(context=tf.train.Features(feature=features), feature_lists=tf.train.FeatureLists(feature_list=feature_lists))
    serialized_record = record.SerializeToString()
    return serialized_record
示例#2
0
def read_record(dataset, serialized_record):
    features = dict()
    feature_lists = dict()
    for value_name, value_type in dataset.values.items():
        value_type, alts = util.alternatives_type(value_type=value_type)
        if value_type == 'int':
            if alts:
                feature_lists[value_name] = tf.FixedLenSequenceFeature(shape=(), dtype=tf.int64)
            else:
                features[value_name] = tf.FixedLenFeature(shape=(), dtype=tf.int64)
        elif value_type == 'float':
            if alts:
                feature_lists[value_name] = tf.FixedLenSequenceFeature(shape=(), dtype=tf.float32)
            else:
                features[value_name] = tf.FixedLenFeature(shape=(), dtype=tf.float32)
        elif value_type == 'vector(int)' or value_type in dataset.vocabularies:
            if alts:
                feature_lists[value_name] = tf.FixedLenSequenceFeature(shape=dataset.vector_shape(value_name=value_name), dtype=tf.int64)
            else:
                features[value_name] = tf.FixedLenFeature(shape=dataset.vector_shape(value_name=value_name), dtype=tf.int64)
        elif value_type == 'vector(float)':
            if alts:
                feature_lists[value_name] = tf.FixedLenSequenceFeature(shape=dataset.vector_shape(value_name=value_name), dtype=tf.float32)
            else:
                features[value_name] = tf.FixedLenFeature(shape=dataset.vector_shape(value_name=value_name), dtype=tf.float32)
        elif value_type == 'world':
            if alts:
                feature_lists[value_name] = tf.FixedLenSequenceFeature(shape=dataset.world_shape(), dtype=tf.float32)
            else:
                features[value_name] = tf.FixedLenFeature(shape=dataset.world_shape(), dtype=tf.float32)
        else:
            pass
    record, sequence_record = tf.parse_single_sequence_example(serialized=serialized_record, context_features=features, sequence_features=feature_lists)
    return record, sequence_record
示例#3
0
def batch_records(dataset, mode, batch_size):
    """
    implicit include_model=False
    implicit alternatives=False

    queue runners need to be initialized:

        with tf.Session() as session:
            coordinator = tf.train.Coordinator()
            queue_threads = tf.train.start_queue_runners(sess=session, coord=coordinator)

            # session calls, for instance:
            batch = session.run(fetches=generated)

            coordinator.request_stop()
            coordinator.join(threads=queue_threads)
    """

    with tf.variable_scope(name_or_scope='tf-records'):
        records, sequence_records = read_records(dataset=dataset, mode=mode)
        if not isinstance(dataset, LoadedDataset) or dataset.random_sampling:
            if 'alternatives' in records:
                sample = tf.cast(x=tf.floor(x=tf.multiply(
                    x=tf.cast(x=records['alternatives'], dtype=tf.float32),
                    y=tf.random_uniform(shape=()))),
                                 dtype=tf.int32)
                for value_name, sequence_record in sequence_records.items():
                    records[value_name] = sequence_record[sample]
                records.pop('alternatives')
            batch = tf.train.shuffle_batch(tensors=records,
                                           batch_size=batch_size,
                                           capacity=(batch_size * 50),
                                           min_after_dequeue=(batch_size * 10),
                                           num_threads=1)
        else:
            if 'alternatives' in records:
                for value_name, sequence_record in sequence_records.items():
                    records[value_name] = sequence_record[0]
                records.pop('alternatives')
            batch = tf.train.batch(tensors=records,
                                   batch_size=batch_size,
                                   num_threads=1,
                                   capacity=(batch_size * 50))
        for value_name in batch:
            value_type, _ = util.alternatives_type(
                value_type=dataset.values[value_name])
            if dataset.pixel_noise_stddev is not None and dataset.pixel_noise_stddev > 0.0 and value_type == 'world':
                noise = tf.truncated_normal(shape=((batch_size, ) +
                                                   dataset.world_shape()),
                                            mean=0.0,
                                            stddev=dataset.pixel_noise_stddev)
                batch[value_name] = tf.clip_by_value(t=(batch[value_name] +
                                                        noise),
                                                     clip_value_min=0.0,
                                                     clip_value_max=1.0)
            elif value_type == 'int' or value_type == 'vector(int)' or value_type in dataset.vocabularies:
                batch[value_name] = tf.cast(x=batch[value_name],
                                            dtype=tf.int32)
        return batch
示例#4
0
    if args.instances * util.product(dataset.world_shape()) > 5e8:  # > 500MB
        sys.stdout.write('{time} warning: shard size is {size}MB '.format(
            time=datetime.now().strftime('%H:%M:%S'),
            size=int(args.instances * util.product(dataset.world_shape()) /
                     1e6)))
        sys.stdout.flush()
        if args.yes:
            sys.stdout.write('y\n')
        elif util.negative_response(sys.stdin.readline()[:-1]):
            exit(0)

    if args.features:
        from pretrained import PretrainedModel
        pretrained_model = PretrainedModel(image_shape=dataset.world_shape())
        for value_name, value_type in list(dataset.values.items()):
            value_type, alts = util.alternatives_type(value_type=value_type)
            if value_type == 'world':
                if alts:
                    dataset.values[value_name +
                                   '_features'] = 'alternatives(vector(float))'
                else:
                    dataset.values[value_name + '_features'] = 'vector(float)'
                dataset.vectors[value_name +
                                '_features'] = pretrained_model.features_shape

    specification = dataset.specification()
    if args.archive:
        specification['archive'] = args.archive
    if args.delay_pixel_noise and dataset.pixel_noise_stddev > 0.0:
        specification['pixel_noise_stddev'] = dataset.pixel_noise_stddev
        dataset.pixel_noise_stddev = 0.0