def write_tfr_batches(data, label,batch_size, num_batches, savepath, dataset_type): start =0 next_start = 0 for batch in range(num_batches): #print(batch) start = batch*batch_size filename = '{}_0{}.tfrecord'.format(dataset_type,batch) filepath = os.path.join(savepath,filename) with open(filepath,'w') as f: writer = TFRecordWriter(f.name) if(batch != num_batches-1): next_start = (batch+1)*batch_size else: next_start = len(data) for i in range(start,next_start): #write_tfrecord(data[star:next_start], out_path, ) record = sequence_to_tfexample(sequence = data[i], sentiment = label[i]) writer.write(record.SerializeToString())
def merge_shards(filename, num_shards_to_merge, out_tmp_dir, batch_size, ensure_batch_multiple): np.random.seed([int.from_bytes(os.urandom(4), byteorder='little') for i in range(5)]) tfoptions = TFRecordOptions(compression_type = 'ZLIB') record_writer = TFRecordWriter(filename,tfoptions) binaryInputNCHWPackeds = [] globalInputNCs = [] policyTargetsNCMoves = [] globalTargetsNCs = [] scoreDistrNs = [] valueTargetsNCHWs = [] for input_idx in range(num_shards_to_merge): shard_filename = os.path.join(out_tmp_dir, str(input_idx) + ".npz") with np.load(shard_filename) as npz: assert(set(npz.keys()) == set(keys)) binaryInputNCHWPacked = npz["binaryInputNCHWPacked"] globalInputNC = npz["globalInputNC"] policyTargetsNCMove = npz["policyTargetsNCMove"].astype(np.float32) globalTargetsNC = npz["globalTargetsNC"] scoreDistrN = npz["scoreDistrN"].astype(np.float32) valueTargetsNCHW = npz["valueTargetsNCHW"].astype(np.float32) binaryInputNCHWPackeds.append(binaryInputNCHWPacked) globalInputNCs.append(globalInputNC) policyTargetsNCMoves.append(policyTargetsNCMove) globalTargetsNCs.append(globalTargetsNC) scoreDistrNs.append(scoreDistrN) valueTargetsNCHWs.append(valueTargetsNCHW) ### #WARNING - if adding anything here, also add it to joint_shuffle below! ### binaryInputNCHWPacked = np.concatenate(binaryInputNCHWPackeds) globalInputNC = np.concatenate(globalInputNCs) policyTargetsNCMove = np.concatenate(policyTargetsNCMoves) globalTargetsNC = np.concatenate(globalTargetsNCs) scoreDistrN = np.concatenate(scoreDistrNs) valueTargetsNCHW = np.concatenate(valueTargetsNCHWs) num_rows = binaryInputNCHWPacked.shape[0] assert(globalInputNC.shape[0] == num_rows) assert(policyTargetsNCMove.shape[0] == num_rows) assert(globalTargetsNC.shape[0] == num_rows) assert(scoreDistrN.shape[0] == num_rows) assert(valueTargetsNCHW.shape[0] == num_rows) [binaryInputNCHWPacked,globalInputNC,policyTargetsNCMove,globalTargetsNC,scoreDistrN,valueTargetsNCHW] = ( joint_shuffle_take_first_n(num_rows,[binaryInputNCHWPacked,globalInputNC,policyTargetsNCMove,globalTargetsNC,scoreDistrN,valueTargetsNCHW]) ) assert(binaryInputNCHWPacked.shape[0] == num_rows) assert(globalInputNC.shape[0] == num_rows) assert(policyTargetsNCMove.shape[0] == num_rows) assert(globalTargetsNC.shape[0] == num_rows) assert(scoreDistrN.shape[0] == num_rows) assert(valueTargetsNCHW.shape[0] == num_rows) #Just truncate and lose the batch at the end, it's fine num_batches = (num_rows // (batch_size * ensure_batch_multiple)) * ensure_batch_multiple for i in range(num_batches): start = i*batch_size stop = (i+1)*batch_size example = tfrecordio.make_tf_record_example( binaryInputNCHWPacked, globalInputNC, policyTargetsNCMove, globalTargetsNC, scoreDistrN, valueTargetsNCHW, start, stop ) record_writer.write(example.SerializeToString()) jsonfilename = os.path.splitext(filename)[0] + ".json" with open(jsonfilename,"w") as f: json.dump({"num_rows":num_rows,"num_batches":num_batches},f) record_writer.close() return num_batches * batch_size