Ejemplo n.º 1
0
def ReadData(proto_file):
    protoname, ext = os.path.splitext(proto_file)
    proto = deepnet_pb2.Dataset()
    if ext == '.pbtxt':
        proto_pbtxt = open(proto_file, 'r')
        text_format.Merge(proto_pbtxt.read(), proto)
    else:
        f = open(proto_file, 'rb')
        proto.ParseFromString(f.read())
        f.close()
    return proto
Ejemplo n.º 2
0
def CopyDataset(data):
    copy = deepnet_pb2.Dataset()
    copy.CopyFrom(data)
    return copy
Ejemplo n.º 3
0
def main():

    data_pbtxt = sys.argv[1]
    output_dir = sys.argv[2]
    prefix = sys.argv[3]
    r = int(sys.argv[4])
    gpu_mem = sys.argv[5]
    main_mem = sys.argv[6]
    if not os.path.isdir(output_dir):
        os.makedirs(output_dir)

    rep_dict, stats_files = MakeDict(data_pbtxt)
    reps = rep_dict.keys()

    splits_folder = ''  #'splits'
    indices_file = os.path.join(prefix, splits_folder,
                                'train_indices_%d.npy' % r)

    if os.path.exists(indices_file):

        train = np.load(indices_file)
        valid = np.load(
            os.path.join(prefix, splits_folder, 'valid_indices_%d.npy' % r))
        test = np.load(
            os.path.join(prefix, splits_folder, 'test_indices_%d.npy' % r))
    else:

        total = 3044
        indices = np.arange(total)
        np.random.shuffle(indices)
        ntrain = int(0.6 * total)
        ntest = int(0.2 * total)
        nvalid = total - ntrain - ntest
        print("ntrain ", ntrain, "ntest", ntest, "nvalid", nvalid)
        train = indices[:ntrain]
        valid = indices[ntrain:ntrain + nvalid]
        test = indices[ntrain + nvalid:]
        np.save(
            os.path.join(prefix, splits_folder, 'train_indices_%d.npy' % r),
            train)
        np.save(
            os.path.join(prefix, splits_folder, 'valid_indices_%d.npy' % r),
            valid)
        np.save(os.path.join(prefix, splits_folder, 'test_indices_%d.npy' % r),
                test)

    print('Splitting data')
    dataset_pb = deepnet_pb2.Dataset()
    dataset_pb.name = 'spam_split_%d' % r
    dataset_pb.gpu_memory = gpu_mem
    dataset_pb.main_memory = main_mem
    for rep in reps:
        data = rep_dict[rep]
        stats_file = stats_files[rep]
        DumpDataSplit(data[train], output_dir, 'train_%s' % rep, dataset_pb,
                      stats_file)
        DumpDataSplit(data[valid], output_dir, 'valid_%s' % rep, dataset_pb,
                      stats_file)
        DumpDataSplit(data[test], output_dir, 'test_%s' % rep, dataset_pb,
                      stats_file)

    print('Splitting labels')
    labels = np.load(os.path.join(prefix, 'labels.npy')).astype('float32')
    DumpLabelSplit(labels[train, ], output_dir, 'train_labels', dataset_pb)
    DumpLabelSplit(labels[valid, ], output_dir, 'valid_labels', dataset_pb)
    DumpLabelSplit(labels[test, ], output_dir, 'test_labels', dataset_pb)

    #d = 'indices'
    #np.save(os.path.join(output_dir, 'train_%s.npy' % d), train)
    #np.save(os.path.join(output_dir, 'valid_%s.npy' % d), valid)
    #np.save(os.path.join(output_dir, 'test_%s.npy' % d), test)

    with open(os.path.join(output_dir, 'data.pbtxt'), 'w') as f:
        text_format.PrintMessage(dataset_pb, f)

    print('Output written in directory %s' % output_dir)