Exemplo n.º 1
0
def save_tfrecords(save_dir, train_list, eval_list, test_list, idx):
    with TFRecordWriter(os.path.join(save_dir, f"{idx}_train_.tfrecords")) as writer:
        for e in train_list:
            writer.write(e)
    with TFRecordWriter(os.path.join(save_dir, f"{idx}_test_.tfrecords")) as writer:
        for e in test_list:
            writer.write(e)
    with TFRecordWriter(os.path.join(save_dir, f"{idx}_eval_.tfrecords")) as writer:
        for e in eval_list:
            writer.write(e)
Exemplo n.º 2
0
def open_sharded_output_tfrecords(exit_stack, base_path, num_shards):
    """Opens all TFRecord shards for writing and adds them to an exit stack.

    Note: Copied directly from https://github.com/tensorflow/models/blob/master/research/
        object_detection/dataset_tools/tf_record_creation_util.py

    Parameters
    ----------
    exit_stack: context2.ExitStack
        A context2.ExitStack used to automatically closed the TFRecords
        opened in this function.
    base_path: str
        The base path for all shards
    num_shards: int
        The number of shards

    Returns
    -------
    tfrecords: list(tf.TFRecord)
        The list of opened TFRecords. Position k in the list corresponds to shard k.
    """
    tf_record_output_filenames = [
        '{}-{:05d}-of-{:05d}'.format(base_path, idx, num_shards)
        for idx in range(num_shards)
    ]
    tfrecords = [
        exit_stack.enter_context(TFRecordWriter(file_name))
        for file_name in tf_record_output_filenames
    ]
    return tfrecords
Exemplo n.º 3
0
def build_with_feature(output_file, feature_generator):
    """
        build tfrecords dataset, must provide both feature generator and label generator
        output_file: where output tfrecords should be
        feature_generator: a generator function, should yield a dict,
            where key is string and value is tf.train.Feature
    """
    cnt = 0
    with TFRecordWriter(output_file) as writer:
        for feature in feature_generator:
            sample = tf.train.Example(features=tf.train.Features(
                feature=feature))
            writer.write(sample.SerializeToString())
            cnt += 1
 def _process_images(self, name, images, labels, id_file, n_files):
   """Processes and saves list of images as TFRecord in 1 thread.
   Args:
     name: string, unique identifier specifying the data set
     images: array of images
     labels: array of labels
   """
   output_filename = '{}-{:05d}-of-{:05d}'.format(name, id_file, n_files)
   output_file = os.path.join(FLAGS.output_dir, self.name, output_filename)
   with TFRecordWriter(output_file) as writer:
     for image, label in zip(images, labels):
       example = self._convert_to_example(image.tobytes(), label)
       writer.write(example.SerializeToString())
   print('{}: Wrote {} images to {}'.format(
       datetime.now(), len(images), output_file), flush=True)
Exemplo n.º 5
0
Arquivo: io.py Projeto: SijanC147/Msc
def write_tfrecords(path, tf_examples, num_shards=TF_RECORD_SHARDS):
    if NP_RANDOM_SEED is not None:
        np.random.seed(NP_RANDOM_SEED)
    np.random.shuffle(tf_examples)
    tf_examples = [example.SerializeToString() for example in tf_examples]
    num_per_shard = int(math.ceil(len(tf_examples) / float(num_shards)))
    total_shards = int(math.ceil(len(tf_examples) / float(num_per_shard)))
    makedirs(path, exist_ok=True)
    for shard_no in range(total_shards):
        start = shard_no * num_per_shard
        end = min((shard_no + 1) * num_per_shard, len(tf_examples))
        file_name = "{0}_of_{1}.tfrecord".format(shard_no + 1, total_shards)
        file_path = join(path, file_name)
        with TFRecordWriter(file_path) as tf_writer:
            for serialized_example in tf_examples[start:end]:
                tf_writer.write(serialized_example)
        if end == len(tf_examples):
            break
Exemplo n.º 6
0
def main(_):
    writer = TFRecordWriter(FLAGS.output_path)
    path = os.path.join(os.getcwd(), FLAGS.img_path)
    print(FLAGS.csv_input)
    examples = pd.read_csv(FLAGS.csv_input)
    grouped = split(examples, 'filename')
    for group in grouped:
        tf_example = create_tf_example(group, path)
        writer.write(tf_example.SerializeToString())

    writer.close()
    output_path = os.path.join(os.getcwd(), FLAGS.output_path)
    print('Successfully created the TFRecords: {}'.format(output_path))
Exemplo n.º 7
0
Arquivo: prep.py Projeto: wibrow/kGCN
def write_to_tfrecords(adj, feature, label_data, label_mask, tfrname):
    """
    Writes graph related data to disk.
    """
    adj_row, adj_col = np.nonzero(adj)
    adj_values = adj[adj_row, adj_col]
    adj_elem_len = len(adj_row)
    feature = np.array(feature)
    feature_row, feature_col = np.nonzero(feature)
    feature_values = feature[feature_row, feature_col]
    feature_elem_len = len(feature_row)
    features = Features(
        feature={
            'label':
            Feature(int64_list=Int64List(value=label_data)),
            'mask_label':
            Feature(int64_list=Int64List(value=label_mask)),
            'adj_row':
            Feature(int64_list=Int64List(value=list(adj_row))),
            'adj_column':
            Feature(int64_list=Int64List(value=list(adj_col))),
            'adj_values':
            Feature(float_list=FloatList(value=list(adj_values))),
            'adj_elem_len':
            Feature(int64_list=Int64List(value=[adj_elem_len])),
            'feature_row':
            Feature(int64_list=Int64List(value=list(feature_row))),
            'feature_column':
            Feature(int64_list=Int64List(value=list(feature_col))),
            'feature_values':
            Feature(float_list=FloatList(value=list(feature_values))),
            'feature_elem_len':
            Feature(int64_list=Int64List(value=[feature_elem_len])),
            'size':
            Feature(int64_list=Int64List(value=list(feature.shape)))
        })
    ex = Example(features=features)
    with TFRecordWriter(tfrname) as single_writer:
        single_writer.write(ex.SerializeToString())
Exemplo n.º 8
0
def write_examples_as_tfrecord(examples,
                               output_filebase,
                               example_encoder,
                               num_shards=1):
    """Serialize examples as a TFRecord dataset.

    Note: Adapted from https://github.com/tensorflow/models/blob/master/research/
        object_detection/g3doc/using_your_own_dataset.md

    Parameters
    ----------
    examples: list(dict-like)
        A list of key/value maps, each map contains relevant info
        for a single data example.
    output_filebase: str
        The base path for all shards
    example_encoder: func
        A function that encodes an input example as a tf.Example.
    num_shards: int
        The number of shards to divide the examples among. If > 1 multiple
        tfrecord files will be created with names appended with a shard index.
    """
    if num_shards == 1:
        writer = TFRecordWriter(output_filebase)
        for example in tqdm(examples):
            tf_example = example_encoder(example)
            writer.write(tf_example.SerializeToString())
        writer.close()
    else:
        with contextlib.ExitStack() as tf_record_close_stack:
            output_tfrecords = open_sharded_output_tfrecords(
                tf_record_close_stack, output_filebase, num_shards)
            for index, example in tqdm(enumerate(examples),
                                       total=len(examples)):
                tf_example = example_encoder(example)
                output_shard_index = index % num_shards
                output_tfrecords[output_shard_index].write(
                    tf_example.SerializeToString())
Exemplo n.º 9
0
def merge_shards(filename, num_shards_to_merge, out_tmp_dir, batch_size,
                 ensure_batch_multiple):
    tfoptions = TFRecordOptions(TFRecordCompressionType.ZLIB)
    record_writer = TFRecordWriter(filename, tfoptions)

    binaryInputNCHWPackeds = []
    globalInputNCs = []
    policyTargetsNCMoves = []
    globalTargetsNCs = []
    scoreDistrNs = []
    valueTargetsNCHWs = []

    for input_idx in range(num_shards_to_merge):
        shard_filename = os.path.join(out_tmp_dir, str(input_idx) + ".npz")
        with np.load(shard_filename) as npz:
            assert (set(npz.keys()) == set(keys))

            binaryInputNCHWPacked = npz["binaryInputNCHWPacked"]
            globalInputNC = npz["globalInputNC"]
            policyTargetsNCMove = npz["policyTargetsNCMove"].astype(np.float32)
            globalTargetsNC = npz["globalTargetsNC"]
            scoreDistrN = npz["scoreDistrN"].astype(np.float32)
            valueTargetsNCHW = npz["valueTargetsNCHW"].astype(np.float32)

            binaryInputNCHWPackeds.append(binaryInputNCHWPacked)
            globalInputNCs.append(globalInputNC)
            policyTargetsNCMoves.append(policyTargetsNCMove)
            globalTargetsNCs.append(globalTargetsNC)
            scoreDistrNs.append(scoreDistrN)
            valueTargetsNCHWs.append(valueTargetsNCHW)

    ###
    #WARNING - if adding anything here, also add it to joint_shuffle below!
    ###
    binaryInputNCHWPacked = np.concatenate(binaryInputNCHWPackeds)
    globalInputNC = np.concatenate(globalInputNCs)
    policyTargetsNCMove = np.concatenate(policyTargetsNCMoves)
    globalTargetsNC = np.concatenate(globalTargetsNCs)
    scoreDistrN = np.concatenate(scoreDistrNs)
    valueTargetsNCHW = np.concatenate(valueTargetsNCHWs)

    joint_shuffle((binaryInputNCHWPacked, globalInputNC, policyTargetsNCMove,
                   globalTargetsNC, scoreDistrN, valueTargetsNCHW))

    num_rows = binaryInputNCHWPacked.shape[0]
    #Just truncate and lose the batch at the end, it's fine
    num_batches = (
        num_rows //
        (batch_size * ensure_batch_multiple)) * ensure_batch_multiple
    for i in range(num_batches):
        start = i * batch_size
        stop = (i + 1) * batch_size

        example = tfrecordio.make_tf_record_example(
            binaryInputNCHWPacked, globalInputNC, policyTargetsNCMove,
            globalTargetsNC, scoreDistrN, valueTargetsNCHW, start, stop)
        record_writer.write(example.SerializeToString())

    jsonfilename = os.path.splitext(filename)[0] + ".json"
    with open(jsonfilename, "w") as f:
        json.dump({"num_rows": num_rows, "num_batches": num_batches}, f)

    record_writer.close()
    return num_batches * batch_size
Exemplo n.º 10
0
def random_n_fold_crossval_tfrec_format(n_folds,
                                        output_location,
                                        file_list=['BIO_format.txt'],
                                        base_name='data',
                                        downsample=True,
                                        downsampling_rate=0.5):
    """
    Create Tensorflow Records from BIO format. 
    For a n-fold cross validation n tfrecs are created. 
    The creation is based on counting the number of samples 
    in the file and randomly sorting them to n buckets.

    A training and test version are created, respectively 
    for each fold.
    Downsampling is applied as a very simple batch control
    on the training version. 
    
    Arguments:
        n_folds {[int]} -- [num of folds for cross validation]
        output_location {[str]} -- [path to data]
    
    Keyword Arguments:
        file_list {list} -- [BIO files] (default: {['BIO_format.txt']})
        base_name {str} -- [base name for reading and writing] (default: {'data'})
        downsample {bool} -- [apply downsampling to corpus] (default: {False})
        downsampling_rate {float} -- [rate for downsampling empty data] (default: {0.0})
    """
    random.seed(1)
    sample_count, word_count = count_samples(file_list, output_location)
    print("Creating TFRecord for {}-fold random cross validation.".format(
        n_folds))
    print("Found {} input samples (sentences) and {} words.".format(
        sample_count, word_count))

    train_path = join(
        output_location,
        str(n_folds) + '_cval_random_d_' + str(downsample) + "_" +
        str(downsampling_rate), 'records_train')
    train_path_labels = join(
        output_location,
        str(n_folds) + '_cval_random_d_' + str(downsample) + "_" +
        str(downsampling_rate), 'labels_train')
    test_path = join(
        output_location,
        str(n_folds) + '_cval_random_d_' + str(downsample) + "_" +
        str(downsampling_rate), 'records_test')
    test_path_labels = join(
        output_location,
        str(n_folds) + '_cval_random_d_' + str(downsample) + "_" +
        str(downsampling_rate), 'labels_test')
    make_or_clean_if_exists(train_path, train_path_labels, test_path,
                            test_path_labels)

    with Path(join(output_location, (base_name + '.words.txt'))).open() as f:
        word_to_idx = {line.strip(): idx for idx, line in enumerate(f)}
    with Path(join(output_location, (base_name + '.tag.txt'))).open() as f:
        label_mapping = {line.strip(): idx for idx, line in enumerate(f)}
    with Path(join(output_location, (base_name + '.chars.txt'))).open() as f:
        char_mapping = {line.strip(): idx for idx, line in enumerate(f)}

    tfwriter_set_train = []
    correct_labels_train = []
    sentences_to_labels_train = []
    words_to_labels_train = []
    tfwriter_set_test = []
    correct_labels_test = []
    sentences_to_labels_test = []
    words_to_labels_test = []
    for idx in range(n_folds):
        record_name = base_name + "_" + str(idx + 1) + ".tfrec"
        tfwriter_set_train.append(TFRecordWriter(join(train_path,
                                                      record_name)))
        correct_labels_train.append([])
        sentences_to_labels_train.append([])
        words_to_labels_train.append([])
        tfwriter_set_test.append(TFRecordWriter(join(test_path, record_name)))
        correct_labels_test.append([])
        sentences_to_labels_test.append([])
        words_to_labels_test.append([])

    plain_words = []
    for f in file_list:
        print("Working on file: " + join(output_location, f))
        with open(join(output_location, f), 'r') as datafile:
            indices = []
            labels = []
            words = []
            high_low = []
            length = []
            for line in datafile:
                if line in ['\n', '\r\n']:
                    index_list = train.FeatureList(feature=[
                        train.Feature(int64_list=train.Int64List(value=[ind]))
                        for ind in indices
                    ])
                    label_list = train.FeatureList(feature=[
                        train.Feature(int64_list=train.Int64List(value=[lab]))
                        for lab in labels
                    ])
                    word_list = train.FeatureList(feature=[
                        train.Feature(bytes_list=train.BytesList(
                            value=[wor.encode()])) for wor in words
                    ])
                    upper_list = train.FeatureList(feature=[
                        train.Feature(int64_list=train.Int64List(value=[hl]))
                        for hl in high_low
                    ])
                    len_list = train.FeatureList(feature=[
                        train.Feature(int64_list=train.Int64List(value=[le]))
                        for le in length
                    ])
                    sentences = train.FeatureLists(
                        feature_list={
                            'indices': index_list,
                            'labels': label_list,
                            'plain': word_list,
                            'upper': upper_list,
                            'length': len_list
                        })
                    example = train.SequenceExample(feature_lists=sentences)
                    sentence_reconstructed = '_'.join(plain_words)
                    labels_in_sent = list(set(labels))
                    set_to_put = random.randint(0, n_folds - 1)
                    tfwriter_set_test[set_to_put].write(
                        example.SerializeToString())
                    correct_labels_test[set_to_put].extend(labels)
                    for i in range(len(words)):
                        sentences_to_labels_test[set_to_put].append(
                            sentence_reconstructed)
                    words_to_labels_test[set_to_put].extend(plain_words)
                    if downsample and len(
                            labels_in_sent
                    ) == 1 and labels_in_sent[0] == 5 and random.random(
                    ) < downsampling_rate:
                        pass
                    else:
                        tfwriter_set_train[set_to_put].write(
                            example.SerializeToString())
                        correct_labels_train[set_to_put].extend(labels)
                        for i in range(len(words)):
                            sentences_to_labels_train[set_to_put].append(
                                sentence_reconstructed)
                        words_to_labels_train[set_to_put].extend(plain_words)
                    indices = []
                    labels = []
                    words = []
                    plain_words = []
                    high_low = []
                    length = []

                elif not line.startswith("-DOCSTART-") and '---' not in line:
                    info = line.split()
                    word = info[0]
                    characters = [char_mapping[x] for x in word]
                    words.append(" ".join(str(x) for x in characters))
                    labels.append(label_mapping[info[1]])

                    if info[0] not in word_to_idx:
                        indices.append(len(word_to_idx))
                        print(info[0])
                        raise RuntimeError(
                            "Tried to process a word that is not in the prebuilt dictionary."
                        )
                    else:
                        indices.append(word_to_idx[info[0]])

                    high_low.append(word.isupper())
                    length.append(len(word))
                    plain_words.append(word)

    for idx in range(n_folds):
        label_name = base_name + "_" + str(idx + 1) + ".tfrec" + '_labels'
        with open(join(train_path_labels, label_name), 'w') as cor_labels:
            print("In train " + str(idx + 1) + " are " +
                  str(len(correct_labels_train[idx])) + " samples (words).")
            if len(correct_labels_train[idx]) != len(
                    sentences_to_labels_train[idx]) or len(
                        correct_labels_train[idx]) != len(
                            words_to_labels_train[idx]):
                raise (RuntimeError("Sentence length does not match labels."))
            for l, s, w in zip(correct_labels_train[idx],
                               sentences_to_labels_train[idx],
                               words_to_labels_train[idx]):
                cor_labels.write(str(l) + " " + s + " " + w + '\n')
        tfwriter_set_train[idx].close()

        with open(join(test_path_labels, label_name), 'w') as cor_labels:
            print("In test " + str(idx + 1) + " are " +
                  str(len(correct_labels_test[idx])) + " samples (words).")
            if len(correct_labels_test[idx]) != len(
                    sentences_to_labels_test[idx]) or len(
                        correct_labels_test[idx]) != len(
                            words_to_labels_test[idx]):
                raise (
                    RuntimeError("Sentence length does not match labels.\n"))
            for l, s, w in zip(correct_labels_test[idx],
                               sentences_to_labels_test[idx],
                               words_to_labels_test[idx]):
                cor_labels.write(str(l) + " " + s + " " + w + '\n')
        tfwriter_set_test[idx].close()
    print("Done building the TFRecords.")
Exemplo n.º 11
0
def merge_shards(filename, num_shards_to_merge, out_tmp_dir, batch_size):
  #print("Merging shards for output file: %s (%d shards to merge)" % (filename,num_shards_to_merge))
  tfoptions = TFRecordOptions(TFRecordCompressionType.ZLIB)
  record_writer = TFRecordWriter(filename,tfoptions)

  binaryInputNCHWPackeds = []
  globalInputNCs = []
  policyTargetsNCMoves = []
  globalTargetsNCs = []
  scoreDistrNs = []
  selfBonusScoreNs = []
  valueTargetsNCHWs = []

  for input_idx in range(num_shards_to_merge):
    shard_filename = os.path.join(out_tmp_dir, str(input_idx) + ".npz")
    #print("Merge loading shard: %d (mem usage %dMB)" % (input_idx,memusage_mb()))

    npz = np.load(shard_filename)
    assert(set(npz.keys()) == set(keys))

    binaryInputNCHWPacked = npz["binaryInputNCHWPacked"]
    globalInputNC = npz["globalInputNC"]
    policyTargetsNCMove = npz["policyTargetsNCMove"].astype(np.float32)
    globalTargetsNC = npz["globalTargetsNC"]
    scoreDistrN = npz["scoreDistrN"].astype(np.float32)
    selfBonusScoreN = npz["selfBonusScoreN"].astype(np.float32)
    valueTargetsNCHW = npz["valueTargetsNCHW"].astype(np.float32)

    binaryInputNCHWPackeds.append(binaryInputNCHWPacked)
    globalInputNCs.append(globalInputNC)
    policyTargetsNCMoves.append(policyTargetsNCMove)
    globalTargetsNCs.append(globalTargetsNC)
    scoreDistrNs.append(scoreDistrN)
    selfBonusScoreNs.append(selfBonusScoreN)
    valueTargetsNCHWs.append(valueTargetsNCHW)

  ###
  #WARNING - if adding anything here, also add it to joint_shuffle below!
  ###
  #print("Merge concatenating... (mem usage %dMB)" % memusage_mb())
  binaryInputNCHWPacked = np.concatenate(binaryInputNCHWPackeds)
  globalInputNC = np.concatenate(globalInputNCs)
  policyTargetsNCMove = np.concatenate(policyTargetsNCMoves)
  globalTargetsNC = np.concatenate(globalTargetsNCs)
  scoreDistrN = np.concatenate(scoreDistrNs)
  selfBonusScoreN = np.concatenate(selfBonusScoreNs)
  valueTargetsNCHW = np.concatenate(valueTargetsNCHWs)

  #print("Merge shuffling... (mem usage %dMB)" % memusage_mb())
  joint_shuffle((binaryInputNCHWPacked,globalInputNC,policyTargetsNCMove,globalTargetsNC,scoreDistrN,selfBonusScoreN,valueTargetsNCHW))

  #print("Merge writing in batches...")
  num_rows = binaryInputNCHWPacked.shape[0]
  #Just truncate and lose the batch at the end, it's fine
  num_batches = num_rows // batch_size
  for i in range(num_batches):
    start = i*batch_size
    stop = (i+1)*batch_size

    example = tfrecordio.make_tf_record_example(
      binaryInputNCHWPacked,
      globalInputNC,
      policyTargetsNCMove,
      globalTargetsNC,
      scoreDistrN,
      selfBonusScoreN,
      valueTargetsNCHW,
      start,
      stop
    )
    record_writer.write(example.SerializeToString())

  jsonfilename = os.path.splitext(filename)[0] + ".json"
  with open(jsonfilename,"w") as f:
    json.dump({"num_rows":num_rows,"num_batches":num_batches},f)

  #print("Merge done %s (%d rows)" % (filename, num_batches * batch_size))

  record_writer.close()
  return num_batches * batch_size
Exemplo n.º 12
0
from IPython import embed


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


input_roots = '/data/dataTrain/val_*/'
output_name = '/data/dataTrain/val.tfrecords'

writer = TFRecordWriter(output_name)

h5files = glob.glob(os.path.join(input_roots, '*.h5'))

for h5file in tqdm(h5files):
    try:
        data = h5py.File(h5file, 'r')
        for i in range(200):
            img = data['CameraRGB'][i]
            target = data['targets'][i]

            feature_dict = {
                'image': _bytes_feature(img.tostring()),
                'targets': _float_feature(target)
            }
Exemplo n.º 13
0
def main():
    args = get_parser()
    if args.use_deepchem_feature:
        args.degree_dim = 11
        args.use_sybyl = False
        args.use_electronegativity = False
        args.use_gasteiger = False

    adj_list = []
    feature_list = []
    label_data_list = []
    label_mask_list = []
    atom_num_list = []
    mol_name_list = []
    seq_symbol_list = None
    dragon_data_list = None
    task_name_list = None
    seq_list = None
    seq = None
    dragon_data = None
    profeat = None
    mol_list = []
    if args.solubility:
        args.sdf_label = "SOL_classification"
        args.sdf_label_active = "high"
        args.sdf_label_inactive = "low"
    if args.assay_dir is not None:
        mol_obj_list, label_data, label_mask, dragon_data, task_name_list, seq, seq_symbol, profeat, publication_years = extract_mol_info(
            args)
    else:
        mol_obj_list, label_data, label_mask, _, task_name_list, _, _, _, publication_years = extract_mol_info(
            args)

    if args.vector_modal is not None:
        dragon_data = build_vector_modal(args)
    ## automatically setting atom_num_limit
    if args.atom_num_limit is None:
        args.atom_num_limit = 0
        for index, mol in enumerate(mol_obj_list):
            if mol is None:
                continue
            Chem.SanitizeMol(mol, sanitizeOps=Chem.SANITIZE_ADJUSTHS)
            if args.atom_num_limit < mol.GetNumAtoms():
                args.atom_num_limit = mol.GetNumAtoms()

    if args.use_electronegativity:
        ELECTRONEGATIVITIES = [
            element(i).electronegativity('pauling') for i in range(1, 100)
        ]
        ELECTRONEGATIVITIES = [
            e if e is not None else 0 for e in ELECTRONEGATIVITIES
        ]

    for index, mol in enumerate(mol_obj_list):
        if mol is None:
            continue
        Chem.SanitizeMol(mol, sanitizeOps=Chem.SANITIZE_ADJUSTHS)
        # Skip the compound whose total number of atoms is larger than "atom_num_limit"
        if args.atom_num_limit is not None and mol.GetNumAtoms(
        ) > args.atom_num_limit:
            continue
        # Get mol. name
        try:
            name = mol.GetProp("_Name")
        except KeyError:
            name = "index_" + str(index)
        mol_list.append(mol)
        mol_name_list.append(name)
        adj = create_adjancy_matrix(mol)
        feature = create_feature_matrix(
            mol,
            args) if not args.use_electronegativity else create_feature_matrix(
                mol, args, en_list=ELECTRONEGATIVITIES)
        if args.tfrecords:
            tfrname = os.path.join(args.output, name + '_.tfrecords')
            if args.csv_reaxys:
                if publication_years[index] < 2015:
                    name += "_train"
                else:
                    name += random.choice(["_test", "_eval"])
                tfrname = os.path.join(args.output,
                                       str(publication_years[index]),
                                       name + '_.tfrecords')
            pathlib.Path(os.path.dirname(tfrname)).mkdir(parents=True,
                                                         exist_ok=True)
            ex = convert_to_example(adj, feature, label_data[index],
                                    label_mask[index])
            with TFRecordWriter(tfrname) as single_writer:
                single_writer.write(ex.SerializeToString())
            continue

        atom_num_list.append(mol.GetNumAtoms())
        adj_list.append(dense_to_sparse(adj))
        feature_list.append(feature)
        # Create labels
        if args.sdf_label:
            line = mol.GetProp(args.sdf_label)
            if line.find(args.sdf_label_active) != -1:
                label_data_list.append([0, 1])
                label_mask_list.append([1, 1])
            elif line.find(args.sdf_label_inactive) != -1:
                label_data_list.append([1, 0])
                label_mask_list.append([1, 1])
            else:
                # missing
                print("[WARN] unknown label:", line)
                label_data_list.append([0, 0])
                label_mask_list.append([0, 0])
        else:
            label_data_list.append(label_data[index])
            label_mask_list.append(label_mask[index])
            if dragon_data is not None:
                if dragon_data_list is None:
                    dragon_data_list = []
                dragon_data_list.append(dragon_data[index])
        if args.multimodal:
            if seq is not None:
                if seq_list is None:
                    seq_list, seq_symbol_list = [], []
                seq_list.append(seq[index])
                seq_symbol_list.append(seq[index])
    if args.tfrecords:
        with open(os.path.join(args.output, "tasks.txt"), "w") as f:
            f.write("\n".join(task_name_list))
        sys.exit(0)
    # joblib output
    obj = {"feature": np.asarray(feature_list), "adj": np.asarray(adj_list)}
    if not args.sparse_label:
        obj["label"] = np.asarray(label_data_list)
        obj["mask_label"] = np.asarray(label_mask_list)
    else:
        from scipy.sparse import csr_matrix
        label_data = np.asarray(label_data_list)
        label_mask = np.asarray(label_mask_list)
        if args.label_dim is None:
            obj['label_dim'] = label_data.shape[1]
        else:
            obj['label_dim'] = args.label_dim
        obj['label_sparse'] = csr_matrix(label_data.astype(float))
        obj['mask_label_sparse'] = csr_matrix(label_mask.astype(float))
    if task_name_list is not None:
        obj["task_names"] = np.asarray(task_name_list)
    if dragon_data_list is not None:
        obj["dragon"] = np.asarray(dragon_data_list)
    if profeat is not None:
        obj["profeat"] = np.asarray(profeat)
    obj["max_node_num"] = args.atom_num_limit
    mol_info = {"obj_list": mol_list, "name_list": mol_name_list}
    obj["mol_info"] = mol_info
    if not args.regression:
        label_int = np.argmax(label_data_list, axis=1)
        cw = class_weight.compute_class_weight("balanced",
                                               np.unique(label_int), label_int)
        obj["class_weight"] = cw

    if args.generate_mfp:
        from rdkit.Chem import AllChem
        mfps = []
        for mol in mol_list:
            FastFindRings(mol)
            mfp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
            mfp_vec = np.array([mfp.GetBit(i) for i in range(2048)], np.int32)
            mfps.append(mfp_vec)
        obj["mfp"] = np.array(mfps)
    ##

    if args.multimodal:
        if seq is not None:
            if args.max_len_seq is not None:
                max_len_seq = args.max_len_seq
            else:
                max_len_seq = max(map(len, seq_list))
            print("max_len_seq:", max_len_seq)
            seq_mat = np.zeros((len(seq_list), max_len_seq), np.int32)
            for i, s in enumerate(seq_list):
                seq_mat[i, 0:len(s)] = s
            obj["sequence"] = seq_mat
            obj["sequence_symbol"] = seq_symbol_list
            obj["sequence_length"] = list(map(len, seq_list))
            obj["sequence_symbol_num"] = int(np.max(seq_mat) + 1)

    filename = args.output
    print("[SAVE] " + filename)
    joblib.dump(obj, filename, compress=3)
Exemplo n.º 14
0
def merge_shards(filename, num_shards_to_merge, out_tmp_dir, batch_size,
                 ensure_batch_multiple, output_npz):
    np.random.seed(
        [int.from_bytes(os.urandom(4), byteorder='little') for i in range(5)])

    if output_npz:
        record_writer = None
    else:
        tfoptions = TFRecordOptions(TFRecordCompressionType.ZLIB)
        record_writer = TFRecordWriter(filename, tfoptions)

    binaryInputNCHWPackeds = []
    globalInputNCs = []
    policyTargetsNCMoves = []
    globalTargetsNCs = []
    scoreDistrNs = []
    valueTargetsNCHWs = []

    for input_idx in range(num_shards_to_merge):
        shard_filename = os.path.join(out_tmp_dir, str(input_idx) + ".npz")
        with np.load(shard_filename) as npz:
            assert (set(npz.keys()) == set(keys))

            binaryInputNCHWPacked = npz["binaryInputNCHWPacked"]
            globalInputNC = npz["globalInputNC"]
            policyTargetsNCMove = npz["policyTargetsNCMove"].astype(np.float32)
            globalTargetsNC = npz["globalTargetsNC"]
            scoreDistrN = npz["scoreDistrN"].astype(np.float32)
            valueTargetsNCHW = npz["valueTargetsNCHW"].astype(np.float32)

            binaryInputNCHWPackeds.append(binaryInputNCHWPacked)
            globalInputNCs.append(globalInputNC)
            policyTargetsNCMoves.append(policyTargetsNCMove)
            globalTargetsNCs.append(globalTargetsNC)
            scoreDistrNs.append(scoreDistrN)
            valueTargetsNCHWs.append(valueTargetsNCHW)

    ###
    #WARNING - if adding anything here, also add it to joint_shuffle below!
    ###
    binaryInputNCHWPacked = np.concatenate(binaryInputNCHWPackeds)
    globalInputNC = np.concatenate(globalInputNCs)
    policyTargetsNCMove = np.concatenate(policyTargetsNCMoves)
    globalTargetsNC = np.concatenate(globalTargetsNCs)
    scoreDistrN = np.concatenate(scoreDistrNs)
    valueTargetsNCHW = np.concatenate(valueTargetsNCHWs)

    num_rows = binaryInputNCHWPacked.shape[0]
    assert (globalInputNC.shape[0] == num_rows)
    assert (policyTargetsNCMove.shape[0] == num_rows)
    assert (globalTargetsNC.shape[0] == num_rows)
    assert (scoreDistrN.shape[0] == num_rows)
    assert (valueTargetsNCHW.shape[0] == num_rows)

    [
        binaryInputNCHWPacked, globalInputNC, policyTargetsNCMove,
        globalTargetsNC, scoreDistrN, valueTargetsNCHW
    ] = (joint_shuffle_take_first_n(num_rows, [
        binaryInputNCHWPacked, globalInputNC, policyTargetsNCMove,
        globalTargetsNC, scoreDistrN, valueTargetsNCHW
    ]))

    assert (binaryInputNCHWPacked.shape[0] == num_rows)
    assert (globalInputNC.shape[0] == num_rows)
    assert (policyTargetsNCMove.shape[0] == num_rows)
    assert (globalTargetsNC.shape[0] == num_rows)
    assert (scoreDistrN.shape[0] == num_rows)
    assert (valueTargetsNCHW.shape[0] == num_rows)

    #Just truncate and lose the batch at the end, it's fine
    num_batches = (
        num_rows //
        (batch_size * ensure_batch_multiple)) * ensure_batch_multiple
    if output_npz:
        start = 0
        stop = num_batches * batch_size
        np.savez_compressed(
            filename,
            binaryInputNCHWPacked=binaryInputNCHWPacked[start:stop],
            globalInputNC=globalInputNC[start:stop],
            policyTargetsNCMove=policyTargetsNCMove[start:stop],
            globalTargetsNC=globalTargetsNC[start:stop],
            scoreDistrN=scoreDistrN[start:stop],
            valueTargetsNCHW=valueTargetsNCHW[start:stop])
    else:
        for i in range(num_batches):
            start = i * batch_size
            stop = (i + 1) * batch_size

            example = tfrecordio.make_tf_record_example(
                binaryInputNCHWPacked, globalInputNC, policyTargetsNCMove,
                globalTargetsNC, scoreDistrN, valueTargetsNCHW, start, stop)
            record_writer.write(example.SerializeToString())

    jsonfilename = os.path.splitext(filename)[0] + ".json"
    with open(jsonfilename, "w") as f:
        json.dump({"num_rows": num_rows, "num_batches": num_batches}, f)

    if record_writer is not None:
        record_writer.close()
    return num_batches * batch_size