Beispiel #1
0
def write_tfr_batches(data, label,batch_size, num_batches, savepath, dataset_type):
    start =0 
    next_start = 0

    for batch in range(num_batches):
        #print(batch)
        start = batch*batch_size
        filename = '{}_0{}.tfrecord'.format(dataset_type,batch)
        filepath = os.path.join(savepath,filename)
        with open(filepath,'w') as f:
            writer = TFRecordWriter(f.name)

        if(batch != num_batches-1):
            next_start = (batch+1)*batch_size
        else:
            next_start = len(data)

        for i in range(start,next_start):
            #write_tfrecord(data[star:next_start], out_path, )
            record = sequence_to_tfexample(sequence = data[i], sentiment = label[i])
            writer.write(record.SerializeToString())
Beispiel #2
0
def merge_shards(filename, num_shards_to_merge, out_tmp_dir, batch_size, ensure_batch_multiple):
  np.random.seed([int.from_bytes(os.urandom(4), byteorder='little') for i in range(5)])

  tfoptions = TFRecordOptions(compression_type = 'ZLIB')
  record_writer = TFRecordWriter(filename,tfoptions)

  binaryInputNCHWPackeds = []
  globalInputNCs = []
  policyTargetsNCMoves = []
  globalTargetsNCs = []
  scoreDistrNs = []
  valueTargetsNCHWs = []

  for input_idx in range(num_shards_to_merge):
    shard_filename = os.path.join(out_tmp_dir, str(input_idx) + ".npz")
    with np.load(shard_filename) as npz:
      assert(set(npz.keys()) == set(keys))

      binaryInputNCHWPacked = npz["binaryInputNCHWPacked"]
      globalInputNC = npz["globalInputNC"]
      policyTargetsNCMove = npz["policyTargetsNCMove"].astype(np.float32)
      globalTargetsNC = npz["globalTargetsNC"]
      scoreDistrN = npz["scoreDistrN"].astype(np.float32)
      valueTargetsNCHW = npz["valueTargetsNCHW"].astype(np.float32)

      binaryInputNCHWPackeds.append(binaryInputNCHWPacked)
      globalInputNCs.append(globalInputNC)
      policyTargetsNCMoves.append(policyTargetsNCMove)
      globalTargetsNCs.append(globalTargetsNC)
      scoreDistrNs.append(scoreDistrN)
      valueTargetsNCHWs.append(valueTargetsNCHW)

  ###
  #WARNING - if adding anything here, also add it to joint_shuffle below!
  ###
  binaryInputNCHWPacked = np.concatenate(binaryInputNCHWPackeds)
  globalInputNC = np.concatenate(globalInputNCs)
  policyTargetsNCMove = np.concatenate(policyTargetsNCMoves)
  globalTargetsNC = np.concatenate(globalTargetsNCs)
  scoreDistrN = np.concatenate(scoreDistrNs)
  valueTargetsNCHW = np.concatenate(valueTargetsNCHWs)

  num_rows = binaryInputNCHWPacked.shape[0]
  assert(globalInputNC.shape[0] == num_rows)
  assert(policyTargetsNCMove.shape[0] == num_rows)
  assert(globalTargetsNC.shape[0] == num_rows)
  assert(scoreDistrN.shape[0] == num_rows)
  assert(valueTargetsNCHW.shape[0] == num_rows)

  [binaryInputNCHWPacked,globalInputNC,policyTargetsNCMove,globalTargetsNC,scoreDistrN,valueTargetsNCHW] = (
    joint_shuffle_take_first_n(num_rows,[binaryInputNCHWPacked,globalInputNC,policyTargetsNCMove,globalTargetsNC,scoreDistrN,valueTargetsNCHW])
  )

  assert(binaryInputNCHWPacked.shape[0] == num_rows)
  assert(globalInputNC.shape[0] == num_rows)
  assert(policyTargetsNCMove.shape[0] == num_rows)
  assert(globalTargetsNC.shape[0] == num_rows)
  assert(scoreDistrN.shape[0] == num_rows)
  assert(valueTargetsNCHW.shape[0] == num_rows)

  #Just truncate and lose the batch at the end, it's fine
  num_batches = (num_rows // (batch_size * ensure_batch_multiple)) * ensure_batch_multiple
  for i in range(num_batches):
    start = i*batch_size
    stop = (i+1)*batch_size

    example = tfrecordio.make_tf_record_example(
      binaryInputNCHWPacked,
      globalInputNC,
      policyTargetsNCMove,
      globalTargetsNC,
      scoreDistrN,
      valueTargetsNCHW,
      start,
      stop
    )
    record_writer.write(example.SerializeToString())

  jsonfilename = os.path.splitext(filename)[0] + ".json"
  with open(jsonfilename,"w") as f:
    json.dump({"num_rows":num_rows,"num_batches":num_batches},f)

  record_writer.close()
  return num_batches * batch_size