Beispiel #1
0
def merge_shards(filename, num_shards_to_merge, out_tmp_dir, batch_size):
  #print("Merging shards for output file: %s (%d shards to merge)" % (filename,num_shards_to_merge))
  tfoptions = TFRecordOptions(TFRecordCompressionType.ZLIB)
  record_writer = TFRecordWriter(filename,tfoptions)

  binaryInputNCHWPackeds = []
  globalInputNCs = []
  policyTargetsNCMoves = []
  globalTargetsNCs = []
  scoreDistrNs = []
  selfBonusScoreNs = []
  valueTargetsNCHWs = []

  for input_idx in range(num_shards_to_merge):
    shard_filename = os.path.join(out_tmp_dir, str(input_idx) + ".npz")
    #print("Merge loading shard: %d (mem usage %dMB)" % (input_idx,memusage_mb()))

    npz = np.load(shard_filename)
    assert(set(npz.keys()) == set(keys))

    binaryInputNCHWPacked = npz["binaryInputNCHWPacked"]
    globalInputNC = npz["globalInputNC"]
    policyTargetsNCMove = npz["policyTargetsNCMove"].astype(np.float32)
    globalTargetsNC = npz["globalTargetsNC"]
    scoreDistrN = npz["scoreDistrN"].astype(np.float32)
    selfBonusScoreN = npz["selfBonusScoreN"].astype(np.float32)
    valueTargetsNCHW = npz["valueTargetsNCHW"].astype(np.float32)

    binaryInputNCHWPackeds.append(binaryInputNCHWPacked)
    globalInputNCs.append(globalInputNC)
    policyTargetsNCMoves.append(policyTargetsNCMove)
    globalTargetsNCs.append(globalTargetsNC)
    scoreDistrNs.append(scoreDistrN)
    selfBonusScoreNs.append(selfBonusScoreN)
    valueTargetsNCHWs.append(valueTargetsNCHW)

  ###
  #WARNING - if adding anything here, also add it to joint_shuffle below!
  ###
  #print("Merge concatenating... (mem usage %dMB)" % memusage_mb())
  binaryInputNCHWPacked = np.concatenate(binaryInputNCHWPackeds)
  globalInputNC = np.concatenate(globalInputNCs)
  policyTargetsNCMove = np.concatenate(policyTargetsNCMoves)
  globalTargetsNC = np.concatenate(globalTargetsNCs)
  scoreDistrN = np.concatenate(scoreDistrNs)
  selfBonusScoreN = np.concatenate(selfBonusScoreNs)
  valueTargetsNCHW = np.concatenate(valueTargetsNCHWs)

  #print("Merge shuffling... (mem usage %dMB)" % memusage_mb())
  joint_shuffle((binaryInputNCHWPacked,globalInputNC,policyTargetsNCMove,globalTargetsNC,scoreDistrN,selfBonusScoreN,valueTargetsNCHW))

  #print("Merge writing in batches...")
  num_rows = binaryInputNCHWPacked.shape[0]
  #Just truncate and lose the batch at the end, it's fine
  num_batches = num_rows // batch_size
  for i in range(num_batches):
    start = i*batch_size
    stop = (i+1)*batch_size

    example = tfrecordio.make_tf_record_example(
      binaryInputNCHWPacked,
      globalInputNC,
      policyTargetsNCMove,
      globalTargetsNC,
      scoreDistrN,
      selfBonusScoreN,
      valueTargetsNCHW,
      start,
      stop
    )
    record_writer.write(example.SerializeToString())

  jsonfilename = os.path.splitext(filename)[0] + ".json"
  with open(jsonfilename,"w") as f:
    json.dump({"num_rows":num_rows,"num_batches":num_batches},f)

  #print("Merge done %s (%d rows)" % (filename, num_batches * batch_size))

  record_writer.close()
  return num_batches * batch_size
Beispiel #2
0
def merge_shards(filename, num_shards_to_merge, out_tmp_dir, batch_size,
                 ensure_batch_multiple):
    tfoptions = TFRecordOptions(TFRecordCompressionType.ZLIB)
    record_writer = TFRecordWriter(filename, tfoptions)

    binaryInputNCHWPackeds = []
    globalInputNCs = []
    policyTargetsNCMoves = []
    globalTargetsNCs = []
    scoreDistrNs = []
    valueTargetsNCHWs = []

    for input_idx in range(num_shards_to_merge):
        shard_filename = os.path.join(out_tmp_dir, str(input_idx) + ".npz")
        with np.load(shard_filename) as npz:
            assert (set(npz.keys()) == set(keys))

            binaryInputNCHWPacked = npz["binaryInputNCHWPacked"]
            globalInputNC = npz["globalInputNC"]
            policyTargetsNCMove = npz["policyTargetsNCMove"].astype(np.float32)
            globalTargetsNC = npz["globalTargetsNC"]
            scoreDistrN = npz["scoreDistrN"].astype(np.float32)
            valueTargetsNCHW = npz["valueTargetsNCHW"].astype(np.float32)

            binaryInputNCHWPackeds.append(binaryInputNCHWPacked)
            globalInputNCs.append(globalInputNC)
            policyTargetsNCMoves.append(policyTargetsNCMove)
            globalTargetsNCs.append(globalTargetsNC)
            scoreDistrNs.append(scoreDistrN)
            valueTargetsNCHWs.append(valueTargetsNCHW)

    ###
    #WARNING - if adding anything here, also add it to joint_shuffle below!
    ###
    binaryInputNCHWPacked = np.concatenate(binaryInputNCHWPackeds)
    globalInputNC = np.concatenate(globalInputNCs)
    policyTargetsNCMove = np.concatenate(policyTargetsNCMoves)
    globalTargetsNC = np.concatenate(globalTargetsNCs)
    scoreDistrN = np.concatenate(scoreDistrNs)
    valueTargetsNCHW = np.concatenate(valueTargetsNCHWs)

    joint_shuffle((binaryInputNCHWPacked, globalInputNC, policyTargetsNCMove,
                   globalTargetsNC, scoreDistrN, valueTargetsNCHW))

    num_rows = binaryInputNCHWPacked.shape[0]
    #Just truncate and lose the batch at the end, it's fine
    num_batches = (
        num_rows //
        (batch_size * ensure_batch_multiple)) * ensure_batch_multiple
    for i in range(num_batches):
        start = i * batch_size
        stop = (i + 1) * batch_size

        example = tfrecordio.make_tf_record_example(
            binaryInputNCHWPacked, globalInputNC, policyTargetsNCMove,
            globalTargetsNC, scoreDistrN, valueTargetsNCHW, start, stop)
        record_writer.write(example.SerializeToString())

    jsonfilename = os.path.splitext(filename)[0] + ".json"
    with open(jsonfilename, "w") as f:
        json.dump({"num_rows": num_rows, "num_batches": num_batches}, f)

    record_writer.close()
    return num_batches * batch_size
Beispiel #3
0
def merge_shards(filename, num_shards_to_merge, out_tmp_dir, batch_size,
                 ensure_batch_multiple, output_npz):
    np.random.seed(
        [int.from_bytes(os.urandom(4), byteorder='little') for i in range(5)])

    if output_npz:
        record_writer = None
    else:
        tfoptions = TFRecordOptions(TFRecordCompressionType.ZLIB)
        record_writer = TFRecordWriter(filename, tfoptions)

    binaryInputNCHWPackeds = []
    globalInputNCs = []
    policyTargetsNCMoves = []
    globalTargetsNCs = []
    scoreDistrNs = []
    valueTargetsNCHWs = []

    for input_idx in range(num_shards_to_merge):
        shard_filename = os.path.join(out_tmp_dir, str(input_idx) + ".npz")
        with np.load(shard_filename) as npz:
            assert (set(npz.keys()) == set(keys))

            binaryInputNCHWPacked = npz["binaryInputNCHWPacked"]
            globalInputNC = npz["globalInputNC"]
            policyTargetsNCMove = npz["policyTargetsNCMove"].astype(np.float32)
            globalTargetsNC = npz["globalTargetsNC"]
            scoreDistrN = npz["scoreDistrN"].astype(np.float32)
            valueTargetsNCHW = npz["valueTargetsNCHW"].astype(np.float32)

            binaryInputNCHWPackeds.append(binaryInputNCHWPacked)
            globalInputNCs.append(globalInputNC)
            policyTargetsNCMoves.append(policyTargetsNCMove)
            globalTargetsNCs.append(globalTargetsNC)
            scoreDistrNs.append(scoreDistrN)
            valueTargetsNCHWs.append(valueTargetsNCHW)

    ###
    #WARNING - if adding anything here, also add it to joint_shuffle below!
    ###
    binaryInputNCHWPacked = np.concatenate(binaryInputNCHWPackeds)
    globalInputNC = np.concatenate(globalInputNCs)
    policyTargetsNCMove = np.concatenate(policyTargetsNCMoves)
    globalTargetsNC = np.concatenate(globalTargetsNCs)
    scoreDistrN = np.concatenate(scoreDistrNs)
    valueTargetsNCHW = np.concatenate(valueTargetsNCHWs)

    num_rows = binaryInputNCHWPacked.shape[0]
    assert (globalInputNC.shape[0] == num_rows)
    assert (policyTargetsNCMove.shape[0] == num_rows)
    assert (globalTargetsNC.shape[0] == num_rows)
    assert (scoreDistrN.shape[0] == num_rows)
    assert (valueTargetsNCHW.shape[0] == num_rows)

    [
        binaryInputNCHWPacked, globalInputNC, policyTargetsNCMove,
        globalTargetsNC, scoreDistrN, valueTargetsNCHW
    ] = (joint_shuffle_take_first_n(num_rows, [
        binaryInputNCHWPacked, globalInputNC, policyTargetsNCMove,
        globalTargetsNC, scoreDistrN, valueTargetsNCHW
    ]))

    assert (binaryInputNCHWPacked.shape[0] == num_rows)
    assert (globalInputNC.shape[0] == num_rows)
    assert (policyTargetsNCMove.shape[0] == num_rows)
    assert (globalTargetsNC.shape[0] == num_rows)
    assert (scoreDistrN.shape[0] == num_rows)
    assert (valueTargetsNCHW.shape[0] == num_rows)

    #Just truncate and lose the batch at the end, it's fine
    num_batches = (
        num_rows //
        (batch_size * ensure_batch_multiple)) * ensure_batch_multiple
    if output_npz:
        start = 0
        stop = num_batches * batch_size
        np.savez_compressed(
            filename,
            binaryInputNCHWPacked=binaryInputNCHWPacked[start:stop],
            globalInputNC=globalInputNC[start:stop],
            policyTargetsNCMove=policyTargetsNCMove[start:stop],
            globalTargetsNC=globalTargetsNC[start:stop],
            scoreDistrN=scoreDistrN[start:stop],
            valueTargetsNCHW=valueTargetsNCHW[start:stop])
    else:
        for i in range(num_batches):
            start = i * batch_size
            stop = (i + 1) * batch_size

            example = tfrecordio.make_tf_record_example(
                binaryInputNCHWPacked, globalInputNC, policyTargetsNCMove,
                globalTargetsNC, scoreDistrN, valueTargetsNCHW, start, stop)
            record_writer.write(example.SerializeToString())

    jsonfilename = os.path.splitext(filename)[0] + ".json"
    with open(jsonfilename, "w") as f:
        json.dump({"num_rows": num_rows, "num_batches": num_batches}, f)

    if record_writer is not None:
        record_writer.close()
    return num_batches * batch_size