def ReadData(proto_file): protoname, ext = os.path.splitext(proto_file) proto = deepnet_pb2.Dataset() if ext == '.pbtxt': proto_pbtxt = open(proto_file, 'r') text_format.Merge(proto_pbtxt.read(), proto) else: f = open(proto_file, 'rb') proto.ParseFromString(f.read()) f.close() return proto
def CopyDataset(data): copy = deepnet_pb2.Dataset() copy.CopyFrom(data) return copy
def main(): data_pbtxt = sys.argv[1] output_dir = sys.argv[2] prefix = sys.argv[3] r = int(sys.argv[4]) gpu_mem = sys.argv[5] main_mem = sys.argv[6] if not os.path.isdir(output_dir): os.makedirs(output_dir) rep_dict, stats_files = MakeDict(data_pbtxt) reps = rep_dict.keys() splits_folder = '' #'splits' indices_file = os.path.join(prefix, splits_folder, 'train_indices_%d.npy' % r) if os.path.exists(indices_file): train = np.load(indices_file) valid = np.load( os.path.join(prefix, splits_folder, 'valid_indices_%d.npy' % r)) test = np.load( os.path.join(prefix, splits_folder, 'test_indices_%d.npy' % r)) else: total = 3044 indices = np.arange(total) np.random.shuffle(indices) ntrain = int(0.6 * total) ntest = int(0.2 * total) nvalid = total - ntrain - ntest print("ntrain ", ntrain, "ntest", ntest, "nvalid", nvalid) train = indices[:ntrain] valid = indices[ntrain:ntrain + nvalid] test = indices[ntrain + nvalid:] np.save( os.path.join(prefix, splits_folder, 'train_indices_%d.npy' % r), train) np.save( os.path.join(prefix, splits_folder, 'valid_indices_%d.npy' % r), valid) np.save(os.path.join(prefix, splits_folder, 'test_indices_%d.npy' % r), test) print('Splitting data') dataset_pb = deepnet_pb2.Dataset() dataset_pb.name = 'spam_split_%d' % r dataset_pb.gpu_memory = gpu_mem dataset_pb.main_memory = main_mem for rep in reps: data = rep_dict[rep] stats_file = stats_files[rep] DumpDataSplit(data[train], output_dir, 'train_%s' % rep, dataset_pb, stats_file) DumpDataSplit(data[valid], output_dir, 'valid_%s' % rep, dataset_pb, stats_file) DumpDataSplit(data[test], output_dir, 'test_%s' % rep, dataset_pb, stats_file) print('Splitting labels') labels = np.load(os.path.join(prefix, 'labels.npy')).astype('float32') DumpLabelSplit(labels[train, ], output_dir, 'train_labels', dataset_pb) DumpLabelSplit(labels[valid, ], output_dir, 'valid_labels', dataset_pb) DumpLabelSplit(labels[test, ], output_dir, 'test_labels', dataset_pb) #d = 'indices' #np.save(os.path.join(output_dir, 'train_%s.npy' % d), train) #np.save(os.path.join(output_dir, 'valid_%s.npy' % d), valid) #np.save(os.path.join(output_dir, 'test_%s.npy' % d), test) with open(os.path.join(output_dir, 'data.pbtxt'), 'w') as f: text_format.PrintMessage(dataset_pb, f) print('Output written in directory %s' % output_dir)