コード例 #1
0
def GetMachine2DeviceIdListOFRecordFromParallelConf(parallel_conf):
    serialized_parallel_conf = str(parallel_conf)
    ofrecord = oneflow._oneflow_internal.GetMachine2DeviceIdListOFRecordFromParallelConf(
        serialized_parallel_conf)
    return text_format.Parse(ofrecord, record_util.OFRecord())
コード例 #2
0
        value = [value]
    if not six.PY2:
        if isinstance(value[0], str):
            value = [x.encode() for x in value]
    return ofrecord.Feature(bytes_list=ofrecord.BytesList(value=value))


obserations = 28 * 28

f = open("./dataset/part-0", "wb")

for loop in range(0, 3):
    image = [random.random() for x in range(0, obserations)]
    label = [random.randint(0, 9)]

    topack = {
        "images": float_feature(image),
        "labels": int64_feature(label),
    }

    ofrecord_features = ofrecord.OFRecord(feature=topack)
    serilizedBytes = ofrecord_features.SerializeToString()

    length = ofrecord_features.ByteSize()

    f.write(struct.pack("q", length))
    f.write(serilizedBytes)

print("Done!")
f.close()