def save_tfrecords(save_dir, train_list, eval_list, test_list, idx): with TFRecordWriter(os.path.join(save_dir, f"{idx}_train_.tfrecords")) as writer: for e in train_list: writer.write(e) with TFRecordWriter(os.path.join(save_dir, f"{idx}_test_.tfrecords")) as writer: for e in test_list: writer.write(e) with TFRecordWriter(os.path.join(save_dir, f"{idx}_eval_.tfrecords")) as writer: for e in eval_list: writer.write(e)
def open_sharded_output_tfrecords(exit_stack, base_path, num_shards): """Opens all TFRecord shards for writing and adds them to an exit stack. Note: Copied directly from https://github.com/tensorflow/models/blob/master/research/ object_detection/dataset_tools/tf_record_creation_util.py Parameters ---------- exit_stack: context2.ExitStack A context2.ExitStack used to automatically closed the TFRecords opened in this function. base_path: str The base path for all shards num_shards: int The number of shards Returns ------- tfrecords: list(tf.TFRecord) The list of opened TFRecords. Position k in the list corresponds to shard k. """ tf_record_output_filenames = [ '{}-{:05d}-of-{:05d}'.format(base_path, idx, num_shards) for idx in range(num_shards) ] tfrecords = [ exit_stack.enter_context(TFRecordWriter(file_name)) for file_name in tf_record_output_filenames ] return tfrecords
def build_with_feature(output_file, feature_generator): """ build tfrecords dataset, must provide both feature generator and label generator output_file: where output tfrecords should be feature_generator: a generator function, should yield a dict, where key is string and value is tf.train.Feature """ cnt = 0 with TFRecordWriter(output_file) as writer: for feature in feature_generator: sample = tf.train.Example(features=tf.train.Features( feature=feature)) writer.write(sample.SerializeToString()) cnt += 1
def _process_images(self, name, images, labels, id_file, n_files): """Processes and saves list of images as TFRecord in 1 thread. Args: name: string, unique identifier specifying the data set images: array of images labels: array of labels """ output_filename = '{}-{:05d}-of-{:05d}'.format(name, id_file, n_files) output_file = os.path.join(FLAGS.output_dir, self.name, output_filename) with TFRecordWriter(output_file) as writer: for image, label in zip(images, labels): example = self._convert_to_example(image.tobytes(), label) writer.write(example.SerializeToString()) print('{}: Wrote {} images to {}'.format( datetime.now(), len(images), output_file), flush=True)
def write_tfrecords(path, tf_examples, num_shards=TF_RECORD_SHARDS): if NP_RANDOM_SEED is not None: np.random.seed(NP_RANDOM_SEED) np.random.shuffle(tf_examples) tf_examples = [example.SerializeToString() for example in tf_examples] num_per_shard = int(math.ceil(len(tf_examples) / float(num_shards))) total_shards = int(math.ceil(len(tf_examples) / float(num_per_shard))) makedirs(path, exist_ok=True) for shard_no in range(total_shards): start = shard_no * num_per_shard end = min((shard_no + 1) * num_per_shard, len(tf_examples)) file_name = "{0}_of_{1}.tfrecord".format(shard_no + 1, total_shards) file_path = join(path, file_name) with TFRecordWriter(file_path) as tf_writer: for serialized_example in tf_examples[start:end]: tf_writer.write(serialized_example) if end == len(tf_examples): break
def main(_): writer = TFRecordWriter(FLAGS.output_path) path = os.path.join(os.getcwd(), FLAGS.img_path) print(FLAGS.csv_input) examples = pd.read_csv(FLAGS.csv_input) grouped = split(examples, 'filename') for group in grouped: tf_example = create_tf_example(group, path) writer.write(tf_example.SerializeToString()) writer.close() output_path = os.path.join(os.getcwd(), FLAGS.output_path) print('Successfully created the TFRecords: {}'.format(output_path))
def write_to_tfrecords(adj, feature, label_data, label_mask, tfrname): """ Writes graph related data to disk. """ adj_row, adj_col = np.nonzero(adj) adj_values = adj[adj_row, adj_col] adj_elem_len = len(adj_row) feature = np.array(feature) feature_row, feature_col = np.nonzero(feature) feature_values = feature[feature_row, feature_col] feature_elem_len = len(feature_row) features = Features( feature={ 'label': Feature(int64_list=Int64List(value=label_data)), 'mask_label': Feature(int64_list=Int64List(value=label_mask)), 'adj_row': Feature(int64_list=Int64List(value=list(adj_row))), 'adj_column': Feature(int64_list=Int64List(value=list(adj_col))), 'adj_values': Feature(float_list=FloatList(value=list(adj_values))), 'adj_elem_len': Feature(int64_list=Int64List(value=[adj_elem_len])), 'feature_row': Feature(int64_list=Int64List(value=list(feature_row))), 'feature_column': Feature(int64_list=Int64List(value=list(feature_col))), 'feature_values': Feature(float_list=FloatList(value=list(feature_values))), 'feature_elem_len': Feature(int64_list=Int64List(value=[feature_elem_len])), 'size': Feature(int64_list=Int64List(value=list(feature.shape))) }) ex = Example(features=features) with TFRecordWriter(tfrname) as single_writer: single_writer.write(ex.SerializeToString())
def write_examples_as_tfrecord(examples, output_filebase, example_encoder, num_shards=1): """Serialize examples as a TFRecord dataset. Note: Adapted from https://github.com/tensorflow/models/blob/master/research/ object_detection/g3doc/using_your_own_dataset.md Parameters ---------- examples: list(dict-like) A list of key/value maps, each map contains relevant info for a single data example. output_filebase: str The base path for all shards example_encoder: func A function that encodes an input example as a tf.Example. num_shards: int The number of shards to divide the examples among. If > 1 multiple tfrecord files will be created with names appended with a shard index. """ if num_shards == 1: writer = TFRecordWriter(output_filebase) for example in tqdm(examples): tf_example = example_encoder(example) writer.write(tf_example.SerializeToString()) writer.close() else: with contextlib.ExitStack() as tf_record_close_stack: output_tfrecords = open_sharded_output_tfrecords( tf_record_close_stack, output_filebase, num_shards) for index, example in tqdm(enumerate(examples), total=len(examples)): tf_example = example_encoder(example) output_shard_index = index % num_shards output_tfrecords[output_shard_index].write( tf_example.SerializeToString())
def merge_shards(filename, num_shards_to_merge, out_tmp_dir, batch_size, ensure_batch_multiple): tfoptions = TFRecordOptions(TFRecordCompressionType.ZLIB) record_writer = TFRecordWriter(filename, tfoptions) binaryInputNCHWPackeds = [] globalInputNCs = [] policyTargetsNCMoves = [] globalTargetsNCs = [] scoreDistrNs = [] valueTargetsNCHWs = [] for input_idx in range(num_shards_to_merge): shard_filename = os.path.join(out_tmp_dir, str(input_idx) + ".npz") with np.load(shard_filename) as npz: assert (set(npz.keys()) == set(keys)) binaryInputNCHWPacked = npz["binaryInputNCHWPacked"] globalInputNC = npz["globalInputNC"] policyTargetsNCMove = npz["policyTargetsNCMove"].astype(np.float32) globalTargetsNC = npz["globalTargetsNC"] scoreDistrN = npz["scoreDistrN"].astype(np.float32) valueTargetsNCHW = npz["valueTargetsNCHW"].astype(np.float32) binaryInputNCHWPackeds.append(binaryInputNCHWPacked) globalInputNCs.append(globalInputNC) policyTargetsNCMoves.append(policyTargetsNCMove) globalTargetsNCs.append(globalTargetsNC) scoreDistrNs.append(scoreDistrN) valueTargetsNCHWs.append(valueTargetsNCHW) ### #WARNING - if adding anything here, also add it to joint_shuffle below! ### binaryInputNCHWPacked = np.concatenate(binaryInputNCHWPackeds) globalInputNC = np.concatenate(globalInputNCs) policyTargetsNCMove = np.concatenate(policyTargetsNCMoves) globalTargetsNC = np.concatenate(globalTargetsNCs) scoreDistrN = np.concatenate(scoreDistrNs) valueTargetsNCHW = np.concatenate(valueTargetsNCHWs) joint_shuffle((binaryInputNCHWPacked, globalInputNC, policyTargetsNCMove, globalTargetsNC, scoreDistrN, valueTargetsNCHW)) num_rows = binaryInputNCHWPacked.shape[0] #Just truncate and lose the batch at the end, it's fine num_batches = ( num_rows // (batch_size * ensure_batch_multiple)) * ensure_batch_multiple for i in range(num_batches): start = i * batch_size stop = (i + 1) * batch_size example = tfrecordio.make_tf_record_example( binaryInputNCHWPacked, globalInputNC, policyTargetsNCMove, globalTargetsNC, scoreDistrN, valueTargetsNCHW, start, stop) record_writer.write(example.SerializeToString()) jsonfilename = os.path.splitext(filename)[0] + ".json" with open(jsonfilename, "w") as f: json.dump({"num_rows": num_rows, "num_batches": num_batches}, f) record_writer.close() return num_batches * batch_size
def random_n_fold_crossval_tfrec_format(n_folds, output_location, file_list=['BIO_format.txt'], base_name='data', downsample=True, downsampling_rate=0.5): """ Create Tensorflow Records from BIO format. For a n-fold cross validation n tfrecs are created. The creation is based on counting the number of samples in the file and randomly sorting them to n buckets. A training and test version are created, respectively for each fold. Downsampling is applied as a very simple batch control on the training version. Arguments: n_folds {[int]} -- [num of folds for cross validation] output_location {[str]} -- [path to data] Keyword Arguments: file_list {list} -- [BIO files] (default: {['BIO_format.txt']}) base_name {str} -- [base name for reading and writing] (default: {'data'}) downsample {bool} -- [apply downsampling to corpus] (default: {False}) downsampling_rate {float} -- [rate for downsampling empty data] (default: {0.0}) """ random.seed(1) sample_count, word_count = count_samples(file_list, output_location) print("Creating TFRecord for {}-fold random cross validation.".format( n_folds)) print("Found {} input samples (sentences) and {} words.".format( sample_count, word_count)) train_path = join( output_location, str(n_folds) + '_cval_random_d_' + str(downsample) + "_" + str(downsampling_rate), 'records_train') train_path_labels = join( output_location, str(n_folds) + '_cval_random_d_' + str(downsample) + "_" + str(downsampling_rate), 'labels_train') test_path = join( output_location, str(n_folds) + '_cval_random_d_' + str(downsample) + "_" + str(downsampling_rate), 'records_test') test_path_labels = join( output_location, str(n_folds) + '_cval_random_d_' + str(downsample) + "_" + str(downsampling_rate), 'labels_test') make_or_clean_if_exists(train_path, train_path_labels, test_path, test_path_labels) with Path(join(output_location, (base_name + '.words.txt'))).open() as f: word_to_idx = {line.strip(): idx for idx, line in enumerate(f)} with Path(join(output_location, (base_name + '.tag.txt'))).open() as f: label_mapping = {line.strip(): idx for idx, line in enumerate(f)} with Path(join(output_location, (base_name + '.chars.txt'))).open() as f: char_mapping = {line.strip(): idx for idx, line in enumerate(f)} tfwriter_set_train = [] correct_labels_train = [] sentences_to_labels_train = [] words_to_labels_train = [] tfwriter_set_test = [] correct_labels_test = [] sentences_to_labels_test = [] words_to_labels_test = [] for idx in range(n_folds): record_name = base_name + "_" + str(idx + 1) + ".tfrec" tfwriter_set_train.append(TFRecordWriter(join(train_path, record_name))) correct_labels_train.append([]) sentences_to_labels_train.append([]) words_to_labels_train.append([]) tfwriter_set_test.append(TFRecordWriter(join(test_path, record_name))) correct_labels_test.append([]) sentences_to_labels_test.append([]) words_to_labels_test.append([]) plain_words = [] for f in file_list: print("Working on file: " + join(output_location, f)) with open(join(output_location, f), 'r') as datafile: indices = [] labels = [] words = [] high_low = [] length = [] for line in datafile: if line in ['\n', '\r\n']: index_list = train.FeatureList(feature=[ train.Feature(int64_list=train.Int64List(value=[ind])) for ind in indices ]) label_list = train.FeatureList(feature=[ train.Feature(int64_list=train.Int64List(value=[lab])) for lab in labels ]) word_list = train.FeatureList(feature=[ train.Feature(bytes_list=train.BytesList( value=[wor.encode()])) for wor in words ]) upper_list = train.FeatureList(feature=[ train.Feature(int64_list=train.Int64List(value=[hl])) for hl in high_low ]) len_list = train.FeatureList(feature=[ train.Feature(int64_list=train.Int64List(value=[le])) for le in length ]) sentences = train.FeatureLists( feature_list={ 'indices': index_list, 'labels': label_list, 'plain': word_list, 'upper': upper_list, 'length': len_list }) example = train.SequenceExample(feature_lists=sentences) sentence_reconstructed = '_'.join(plain_words) labels_in_sent = list(set(labels)) set_to_put = random.randint(0, n_folds - 1) tfwriter_set_test[set_to_put].write( example.SerializeToString()) correct_labels_test[set_to_put].extend(labels) for i in range(len(words)): sentences_to_labels_test[set_to_put].append( sentence_reconstructed) words_to_labels_test[set_to_put].extend(plain_words) if downsample and len( labels_in_sent ) == 1 and labels_in_sent[0] == 5 and random.random( ) < downsampling_rate: pass else: tfwriter_set_train[set_to_put].write( example.SerializeToString()) correct_labels_train[set_to_put].extend(labels) for i in range(len(words)): sentences_to_labels_train[set_to_put].append( sentence_reconstructed) words_to_labels_train[set_to_put].extend(plain_words) indices = [] labels = [] words = [] plain_words = [] high_low = [] length = [] elif not line.startswith("-DOCSTART-") and '---' not in line: info = line.split() word = info[0] characters = [char_mapping[x] for x in word] words.append(" ".join(str(x) for x in characters)) labels.append(label_mapping[info[1]]) if info[0] not in word_to_idx: indices.append(len(word_to_idx)) print(info[0]) raise RuntimeError( "Tried to process a word that is not in the prebuilt dictionary." ) else: indices.append(word_to_idx[info[0]]) high_low.append(word.isupper()) length.append(len(word)) plain_words.append(word) for idx in range(n_folds): label_name = base_name + "_" + str(idx + 1) + ".tfrec" + '_labels' with open(join(train_path_labels, label_name), 'w') as cor_labels: print("In train " + str(idx + 1) + " are " + str(len(correct_labels_train[idx])) + " samples (words).") if len(correct_labels_train[idx]) != len( sentences_to_labels_train[idx]) or len( correct_labels_train[idx]) != len( words_to_labels_train[idx]): raise (RuntimeError("Sentence length does not match labels.")) for l, s, w in zip(correct_labels_train[idx], sentences_to_labels_train[idx], words_to_labels_train[idx]): cor_labels.write(str(l) + " " + s + " " + w + '\n') tfwriter_set_train[idx].close() with open(join(test_path_labels, label_name), 'w') as cor_labels: print("In test " + str(idx + 1) + " are " + str(len(correct_labels_test[idx])) + " samples (words).") if len(correct_labels_test[idx]) != len( sentences_to_labels_test[idx]) or len( correct_labels_test[idx]) != len( words_to_labels_test[idx]): raise ( RuntimeError("Sentence length does not match labels.\n")) for l, s, w in zip(correct_labels_test[idx], sentences_to_labels_test[idx], words_to_labels_test[idx]): cor_labels.write(str(l) + " " + s + " " + w + '\n') tfwriter_set_test[idx].close() print("Done building the TFRecords.")
def merge_shards(filename, num_shards_to_merge, out_tmp_dir, batch_size): #print("Merging shards for output file: %s (%d shards to merge)" % (filename,num_shards_to_merge)) tfoptions = TFRecordOptions(TFRecordCompressionType.ZLIB) record_writer = TFRecordWriter(filename,tfoptions) binaryInputNCHWPackeds = [] globalInputNCs = [] policyTargetsNCMoves = [] globalTargetsNCs = [] scoreDistrNs = [] selfBonusScoreNs = [] valueTargetsNCHWs = [] for input_idx in range(num_shards_to_merge): shard_filename = os.path.join(out_tmp_dir, str(input_idx) + ".npz") #print("Merge loading shard: %d (mem usage %dMB)" % (input_idx,memusage_mb())) npz = np.load(shard_filename) assert(set(npz.keys()) == set(keys)) binaryInputNCHWPacked = npz["binaryInputNCHWPacked"] globalInputNC = npz["globalInputNC"] policyTargetsNCMove = npz["policyTargetsNCMove"].astype(np.float32) globalTargetsNC = npz["globalTargetsNC"] scoreDistrN = npz["scoreDistrN"].astype(np.float32) selfBonusScoreN = npz["selfBonusScoreN"].astype(np.float32) valueTargetsNCHW = npz["valueTargetsNCHW"].astype(np.float32) binaryInputNCHWPackeds.append(binaryInputNCHWPacked) globalInputNCs.append(globalInputNC) policyTargetsNCMoves.append(policyTargetsNCMove) globalTargetsNCs.append(globalTargetsNC) scoreDistrNs.append(scoreDistrN) selfBonusScoreNs.append(selfBonusScoreN) valueTargetsNCHWs.append(valueTargetsNCHW) ### #WARNING - if adding anything here, also add it to joint_shuffle below! ### #print("Merge concatenating... (mem usage %dMB)" % memusage_mb()) binaryInputNCHWPacked = np.concatenate(binaryInputNCHWPackeds) globalInputNC = np.concatenate(globalInputNCs) policyTargetsNCMove = np.concatenate(policyTargetsNCMoves) globalTargetsNC = np.concatenate(globalTargetsNCs) scoreDistrN = np.concatenate(scoreDistrNs) selfBonusScoreN = np.concatenate(selfBonusScoreNs) valueTargetsNCHW = np.concatenate(valueTargetsNCHWs) #print("Merge shuffling... (mem usage %dMB)" % memusage_mb()) joint_shuffle((binaryInputNCHWPacked,globalInputNC,policyTargetsNCMove,globalTargetsNC,scoreDistrN,selfBonusScoreN,valueTargetsNCHW)) #print("Merge writing in batches...") num_rows = binaryInputNCHWPacked.shape[0] #Just truncate and lose the batch at the end, it's fine num_batches = num_rows // batch_size for i in range(num_batches): start = i*batch_size stop = (i+1)*batch_size example = tfrecordio.make_tf_record_example( binaryInputNCHWPacked, globalInputNC, policyTargetsNCMove, globalTargetsNC, scoreDistrN, selfBonusScoreN, valueTargetsNCHW, start, stop ) record_writer.write(example.SerializeToString()) jsonfilename = os.path.splitext(filename)[0] + ".json" with open(jsonfilename,"w") as f: json.dump({"num_rows":num_rows,"num_batches":num_batches},f) #print("Merge done %s (%d rows)" % (filename, num_batches * batch_size)) record_writer.close() return num_batches * batch_size
from IPython import embed def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def _float_feature(value): return tf.train.Feature(float_list=tf.train.FloatList(value=value)) input_roots = '/data/dataTrain/val_*/' output_name = '/data/dataTrain/val.tfrecords' writer = TFRecordWriter(output_name) h5files = glob.glob(os.path.join(input_roots, '*.h5')) for h5file in tqdm(h5files): try: data = h5py.File(h5file, 'r') for i in range(200): img = data['CameraRGB'][i] target = data['targets'][i] feature_dict = { 'image': _bytes_feature(img.tostring()), 'targets': _float_feature(target) }
def main(): args = get_parser() if args.use_deepchem_feature: args.degree_dim = 11 args.use_sybyl = False args.use_electronegativity = False args.use_gasteiger = False adj_list = [] feature_list = [] label_data_list = [] label_mask_list = [] atom_num_list = [] mol_name_list = [] seq_symbol_list = None dragon_data_list = None task_name_list = None seq_list = None seq = None dragon_data = None profeat = None mol_list = [] if args.solubility: args.sdf_label = "SOL_classification" args.sdf_label_active = "high" args.sdf_label_inactive = "low" if args.assay_dir is not None: mol_obj_list, label_data, label_mask, dragon_data, task_name_list, seq, seq_symbol, profeat, publication_years = extract_mol_info( args) else: mol_obj_list, label_data, label_mask, _, task_name_list, _, _, _, publication_years = extract_mol_info( args) if args.vector_modal is not None: dragon_data = build_vector_modal(args) ## automatically setting atom_num_limit if args.atom_num_limit is None: args.atom_num_limit = 0 for index, mol in enumerate(mol_obj_list): if mol is None: continue Chem.SanitizeMol(mol, sanitizeOps=Chem.SANITIZE_ADJUSTHS) if args.atom_num_limit < mol.GetNumAtoms(): args.atom_num_limit = mol.GetNumAtoms() if args.use_electronegativity: ELECTRONEGATIVITIES = [ element(i).electronegativity('pauling') for i in range(1, 100) ] ELECTRONEGATIVITIES = [ e if e is not None else 0 for e in ELECTRONEGATIVITIES ] for index, mol in enumerate(mol_obj_list): if mol is None: continue Chem.SanitizeMol(mol, sanitizeOps=Chem.SANITIZE_ADJUSTHS) # Skip the compound whose total number of atoms is larger than "atom_num_limit" if args.atom_num_limit is not None and mol.GetNumAtoms( ) > args.atom_num_limit: continue # Get mol. name try: name = mol.GetProp("_Name") except KeyError: name = "index_" + str(index) mol_list.append(mol) mol_name_list.append(name) adj = create_adjancy_matrix(mol) feature = create_feature_matrix( mol, args) if not args.use_electronegativity else create_feature_matrix( mol, args, en_list=ELECTRONEGATIVITIES) if args.tfrecords: tfrname = os.path.join(args.output, name + '_.tfrecords') if args.csv_reaxys: if publication_years[index] < 2015: name += "_train" else: name += random.choice(["_test", "_eval"]) tfrname = os.path.join(args.output, str(publication_years[index]), name + '_.tfrecords') pathlib.Path(os.path.dirname(tfrname)).mkdir(parents=True, exist_ok=True) ex = convert_to_example(adj, feature, label_data[index], label_mask[index]) with TFRecordWriter(tfrname) as single_writer: single_writer.write(ex.SerializeToString()) continue atom_num_list.append(mol.GetNumAtoms()) adj_list.append(dense_to_sparse(adj)) feature_list.append(feature) # Create labels if args.sdf_label: line = mol.GetProp(args.sdf_label) if line.find(args.sdf_label_active) != -1: label_data_list.append([0, 1]) label_mask_list.append([1, 1]) elif line.find(args.sdf_label_inactive) != -1: label_data_list.append([1, 0]) label_mask_list.append([1, 1]) else: # missing print("[WARN] unknown label:", line) label_data_list.append([0, 0]) label_mask_list.append([0, 0]) else: label_data_list.append(label_data[index]) label_mask_list.append(label_mask[index]) if dragon_data is not None: if dragon_data_list is None: dragon_data_list = [] dragon_data_list.append(dragon_data[index]) if args.multimodal: if seq is not None: if seq_list is None: seq_list, seq_symbol_list = [], [] seq_list.append(seq[index]) seq_symbol_list.append(seq[index]) if args.tfrecords: with open(os.path.join(args.output, "tasks.txt"), "w") as f: f.write("\n".join(task_name_list)) sys.exit(0) # joblib output obj = {"feature": np.asarray(feature_list), "adj": np.asarray(adj_list)} if not args.sparse_label: obj["label"] = np.asarray(label_data_list) obj["mask_label"] = np.asarray(label_mask_list) else: from scipy.sparse import csr_matrix label_data = np.asarray(label_data_list) label_mask = np.asarray(label_mask_list) if args.label_dim is None: obj['label_dim'] = label_data.shape[1] else: obj['label_dim'] = args.label_dim obj['label_sparse'] = csr_matrix(label_data.astype(float)) obj['mask_label_sparse'] = csr_matrix(label_mask.astype(float)) if task_name_list is not None: obj["task_names"] = np.asarray(task_name_list) if dragon_data_list is not None: obj["dragon"] = np.asarray(dragon_data_list) if profeat is not None: obj["profeat"] = np.asarray(profeat) obj["max_node_num"] = args.atom_num_limit mol_info = {"obj_list": mol_list, "name_list": mol_name_list} obj["mol_info"] = mol_info if not args.regression: label_int = np.argmax(label_data_list, axis=1) cw = class_weight.compute_class_weight("balanced", np.unique(label_int), label_int) obj["class_weight"] = cw if args.generate_mfp: from rdkit.Chem import AllChem mfps = [] for mol in mol_list: FastFindRings(mol) mfp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048) mfp_vec = np.array([mfp.GetBit(i) for i in range(2048)], np.int32) mfps.append(mfp_vec) obj["mfp"] = np.array(mfps) ## if args.multimodal: if seq is not None: if args.max_len_seq is not None: max_len_seq = args.max_len_seq else: max_len_seq = max(map(len, seq_list)) print("max_len_seq:", max_len_seq) seq_mat = np.zeros((len(seq_list), max_len_seq), np.int32) for i, s in enumerate(seq_list): seq_mat[i, 0:len(s)] = s obj["sequence"] = seq_mat obj["sequence_symbol"] = seq_symbol_list obj["sequence_length"] = list(map(len, seq_list)) obj["sequence_symbol_num"] = int(np.max(seq_mat) + 1) filename = args.output print("[SAVE] " + filename) joblib.dump(obj, filename, compress=3)
def merge_shards(filename, num_shards_to_merge, out_tmp_dir, batch_size, ensure_batch_multiple, output_npz): np.random.seed( [int.from_bytes(os.urandom(4), byteorder='little') for i in range(5)]) if output_npz: record_writer = None else: tfoptions = TFRecordOptions(TFRecordCompressionType.ZLIB) record_writer = TFRecordWriter(filename, tfoptions) binaryInputNCHWPackeds = [] globalInputNCs = [] policyTargetsNCMoves = [] globalTargetsNCs = [] scoreDistrNs = [] valueTargetsNCHWs = [] for input_idx in range(num_shards_to_merge): shard_filename = os.path.join(out_tmp_dir, str(input_idx) + ".npz") with np.load(shard_filename) as npz: assert (set(npz.keys()) == set(keys)) binaryInputNCHWPacked = npz["binaryInputNCHWPacked"] globalInputNC = npz["globalInputNC"] policyTargetsNCMove = npz["policyTargetsNCMove"].astype(np.float32) globalTargetsNC = npz["globalTargetsNC"] scoreDistrN = npz["scoreDistrN"].astype(np.float32) valueTargetsNCHW = npz["valueTargetsNCHW"].astype(np.float32) binaryInputNCHWPackeds.append(binaryInputNCHWPacked) globalInputNCs.append(globalInputNC) policyTargetsNCMoves.append(policyTargetsNCMove) globalTargetsNCs.append(globalTargetsNC) scoreDistrNs.append(scoreDistrN) valueTargetsNCHWs.append(valueTargetsNCHW) ### #WARNING - if adding anything here, also add it to joint_shuffle below! ### binaryInputNCHWPacked = np.concatenate(binaryInputNCHWPackeds) globalInputNC = np.concatenate(globalInputNCs) policyTargetsNCMove = np.concatenate(policyTargetsNCMoves) globalTargetsNC = np.concatenate(globalTargetsNCs) scoreDistrN = np.concatenate(scoreDistrNs) valueTargetsNCHW = np.concatenate(valueTargetsNCHWs) num_rows = binaryInputNCHWPacked.shape[0] assert (globalInputNC.shape[0] == num_rows) assert (policyTargetsNCMove.shape[0] == num_rows) assert (globalTargetsNC.shape[0] == num_rows) assert (scoreDistrN.shape[0] == num_rows) assert (valueTargetsNCHW.shape[0] == num_rows) [ binaryInputNCHWPacked, globalInputNC, policyTargetsNCMove, globalTargetsNC, scoreDistrN, valueTargetsNCHW ] = (joint_shuffle_take_first_n(num_rows, [ binaryInputNCHWPacked, globalInputNC, policyTargetsNCMove, globalTargetsNC, scoreDistrN, valueTargetsNCHW ])) assert (binaryInputNCHWPacked.shape[0] == num_rows) assert (globalInputNC.shape[0] == num_rows) assert (policyTargetsNCMove.shape[0] == num_rows) assert (globalTargetsNC.shape[0] == num_rows) assert (scoreDistrN.shape[0] == num_rows) assert (valueTargetsNCHW.shape[0] == num_rows) #Just truncate and lose the batch at the end, it's fine num_batches = ( num_rows // (batch_size * ensure_batch_multiple)) * ensure_batch_multiple if output_npz: start = 0 stop = num_batches * batch_size np.savez_compressed( filename, binaryInputNCHWPacked=binaryInputNCHWPacked[start:stop], globalInputNC=globalInputNC[start:stop], policyTargetsNCMove=policyTargetsNCMove[start:stop], globalTargetsNC=globalTargetsNC[start:stop], scoreDistrN=scoreDistrN[start:stop], valueTargetsNCHW=valueTargetsNCHW[start:stop]) else: for i in range(num_batches): start = i * batch_size stop = (i + 1) * batch_size example = tfrecordio.make_tf_record_example( binaryInputNCHWPacked, globalInputNC, policyTargetsNCMove, globalTargetsNC, scoreDistrN, valueTargetsNCHW, start, stop) record_writer.write(example.SerializeToString()) jsonfilename = os.path.splitext(filename)[0] + ".json" with open(jsonfilename, "w") as f: json.dump({"num_rows": num_rows, "num_batches": num_batches}, f) if record_writer is not None: record_writer.close() return num_batches * batch_size