Пример #1
0
    def train(self, req_vars):
        valid = TaxiDataset(self.config.valid_set, "valid.hdf5", sources=("trip_id",))
        valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]

        dataset = TaxiDataset("train")

        prefix_stream = DataStream(dataset, iteration_scheme=TaxiTimeCutScheme(self.config.num_cuts))
        prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, valid_trips_ids)
        prefix_stream = transformers.TaxiGenerateSplits(prefix_stream, max_splits=self.config.max_splits)
        prefix_stream = transformers.taxi_add_datetime(prefix_stream)
        prefix_stream = transformers.taxi_add_first_last_len(prefix_stream, self.config.n_begin_end_pts)
        prefix_stream = Batch(prefix_stream, iteration_scheme=ConstantScheme(self.config.batch_size))

        candidate_stream = DataStream(dataset, iteration_scheme=ShuffledExampleScheme(dataset.num_examples))
        candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, valid_trips_ids)
        candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream)
        candidate_stream = transformers.taxi_add_datetime(candidate_stream)
        candidate_stream = transformers.taxi_add_first_last_len(candidate_stream, self.config.n_begin_end_pts)
        candidate_stream = Batch(candidate_stream, iteration_scheme=ConstantScheme(self.config.train_candidate_size))

        sources = prefix_stream.sources + tuple("candidate_%s" % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)
        stream = transformers.Select(stream, tuple(req_vars))
        stream = MultiProcessing(stream)
        return stream
Пример #2
0
    def train(self, req_vars):
        valid = TaxiDataset(self.config.valid_set,
                            'valid.hdf5',
                            sources=('trip_id', ))
        valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]

        stream = TaxiDataset('train')
        stream = DataStream(stream,
                            iteration_scheme=ShuffledExampleScheme(
                                stream.num_examples))
        stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)
        stream = transformers.TaxiExcludeEmptyTrips(stream)
        stream = transformers.taxi_add_datetime(stream)
        stream = transformers.add_destination(stream)
        stream = transformers.Select(
            stream, tuple(v for v in req_vars if not v.endswith('_mask')))

        stream = transformers.balanced_batch(
            stream,
            key='latitude',
            batch_size=self.config.batch_size,
            batch_sort_size=self.config.batch_sort_size)
        stream = Padding(stream, mask_sources=['latitude', 'longitude'])
        stream = transformers.Select(stream, req_vars)
        return stream
Пример #3
0
    def train(self, req_vars):
        valid = TaxiDataset(self.config.valid_set,
                            'valid.hdf5',
                            sources=('trip_id', ))
        valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]

        stream = TaxiDataset('train')

        if hasattr(
                self.config,
                'use_cuts_for_training') and self.config.use_cuts_for_training:
            stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme())
        else:
            stream = DataStream(stream,
                                iteration_scheme=ShuffledExampleScheme(
                                    stream.num_examples))

        stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)
        stream = transformers.TaxiGenerateSplits(
            stream, max_splits=self.config.max_splits)

        stream = transformers.taxi_add_datetime(stream)
        # stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
        stream = transformers.Select(stream, tuple(req_vars))

        stream = Batch(stream,
                       iteration_scheme=ConstantScheme(self.config.batch_size))

        stream = MultiProcessing(stream)

        return stream
Пример #4
0
    def train(self, req_vars):
        valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',))
        valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]

        stream = TaxiDataset('train')

        if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training:
            stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme())
        else:
            stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))

        stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)
        stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits)

        if hasattr(self.config, 'shuffle_batch_size'):
            stream = transformers.Batch(stream, iteration_scheme=ConstantScheme(self.config.shuffle_batch_size))
            stream = Mapping(stream, SortMapping(key=UniformGenerator()))
            stream = Unpack(stream)

        stream = transformers.taxi_add_datetime(stream)
        stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
        stream = transformers.Select(stream, tuple(req_vars))
        
        stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))

        stream = MultiProcessing(stream)

        return stream
Пример #5
0
    def train(self, req_vars):
        stream = TaxiDataset('train', data.traintest_ds)

        if hasattr(
                self.config,
                'use_cuts_for_training') and self.config.use_cuts_for_training:
            stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme())
        else:
            stream = DataStream(stream,
                                iteration_scheme=ShuffledExampleScheme(
                                    stream.num_examples))

        if not data.tvt:
            valid = TaxiDataset(data.valid_set,
                                data.valid_ds,
                                sources=('trip_id', ))
            valid_trips_ids = valid.get_data(None,
                                             slice(0, valid.num_examples))[0]
            stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)

        if hasattr(self.config, 'max_splits'):
            stream = transformers.TaxiGenerateSplits(
                stream, max_splits=self.config.max_splits)
        elif not data.tvt:
            stream = transformers.add_destination(stream)

        if hasattr(self.config, 'train_max_len'):
            idx = stream.sources.index('latitude')

            def max_len_filter(x):
                return len(x[idx]) <= self.config.train_max_len

            stream = Filter(stream, max_len_filter)

        stream = transformers.TaxiExcludeEmptyTrips(stream)
        stream = transformers.taxi_add_datetime(stream)
        stream = transformers.Select(
            stream, tuple(v for v in req_vars if not v.endswith('_mask')))

        stream = transformers.balanced_batch(
            stream,
            key='latitude',
            batch_size=self.config.batch_size,
            batch_sort_size=self.config.batch_sort_size)
        stream = Padding(stream, mask_sources=['latitude', 'longitude'])
        stream = transformers.Select(stream, req_vars)
        stream = MultiProcessing(stream)

        return stream
Пример #6
0
    def train(self, req_vars):
        valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',))
        valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]

        stream = TaxiDataset('train')
        stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))
        stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)
        stream = transformers.TaxiExcludeEmptyTrips(stream)
        stream = transformers.taxi_add_datetime(stream)
        stream = transformers.add_destination(stream)
        stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))

        stream = transformers.balanced_batch(stream, key='latitude', batch_size=self.config.batch_size, batch_sort_size=self.config.batch_sort_size)
        stream = Padding(stream, mask_sources=['latitude', 'longitude'])
        stream = transformers.Select(stream, req_vars)
        return stream
Пример #7
0
    def train(self, req_vars):
        valid = TaxiDataset(self.config.valid_set,
                            'valid.hdf5',
                            sources=('trip_id', ))
        valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]

        dataset = TaxiDataset('train')

        prefix_stream = DataStream(dataset,
                                   iteration_scheme=TaxiTimeCutScheme(
                                       self.config.num_cuts))
        prefix_stream = transformers.TaxiExcludeTrips(prefix_stream,
                                                      valid_trips_ids)
        prefix_stream = transformers.TaxiGenerateSplits(
            prefix_stream, max_splits=self.config.max_splits)
        prefix_stream = transformers.taxi_add_datetime(prefix_stream)
        prefix_stream = transformers.taxi_add_first_last_len(
            prefix_stream, self.config.n_begin_end_pts)
        prefix_stream = Batch(prefix_stream,
                              iteration_scheme=ConstantScheme(
                                  self.config.batch_size))

        candidate_stream = DataStream(dataset,
                                      iteration_scheme=ShuffledExampleScheme(
                                          dataset.num_examples))
        candidate_stream = transformers.TaxiExcludeTrips(
            candidate_stream, valid_trips_ids)
        candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream)
        candidate_stream = transformers.taxi_add_datetime(candidate_stream)
        candidate_stream = transformers.taxi_add_first_last_len(
            candidate_stream, self.config.n_begin_end_pts)
        candidate_stream = Batch(candidate_stream,
                                 iteration_scheme=ConstantScheme(
                                     self.config.train_candidate_size))

        sources = prefix_stream.sources + tuple(
            'candidate_%s' % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)
        stream = transformers.Select(stream, tuple(req_vars))
        stream = MultiProcessing(stream)
        return stream
Пример #8
0
    def valid(self, req_vars):
        valid_dataset = TaxiDataset(self.config.valid_set, 'valid.hdf5')
        train_dataset = TaxiDataset('train')
        valid_trips_ids = valid_dataset.get_data(
            None, slice(0, valid_dataset.num_examples))[
                valid_dataset.sources.index('trip_id')]

        prefix_stream = DataStream(valid_dataset,
                                   iteration_scheme=SequentialExampleScheme(
                                       valid_dataset.num_examples))
        prefix_stream = transformers.taxi_add_datetime(prefix_stream)
        prefix_stream = transformers.taxi_add_first_last_len(
            prefix_stream, self.config.n_begin_end_pts)
        prefix_stream = Batch(prefix_stream,
                              iteration_scheme=ConstantScheme(
                                  self.config.batch_size))

        candidate_stream = DataStream(train_dataset,
                                      iteration_scheme=ShuffledExampleScheme(
                                          train_dataset.num_examples))
        candidate_stream = transformers.TaxiExcludeTrips(
            candidate_stream, valid_trips_ids)
        candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream)
        candidate_stream = transformers.taxi_add_datetime(candidate_stream)
        candidate_stream = transformers.taxi_add_first_last_len(
            candidate_stream, self.config.n_begin_end_pts)
        candidate_stream = Batch(candidate_stream,
                                 iteration_scheme=ConstantScheme(
                                     self.config.valid_candidate_size))

        sources = prefix_stream.sources + tuple(
            'candidate_%s' % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)
        stream = transformers.Select(stream, tuple(req_vars))
        stream = MultiProcessing(stream)
        return stream
Пример #9
0
    def train(self, req_vars):
        stream = TaxiDataset('train', data.traintest_ds)

        if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training:
            stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme())
        else:
            stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))

        if not data.tvt:
            valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',))
            valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]
            stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)

        if hasattr(self.config, 'max_splits'):
            stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits)
        elif not data.tvt:
            stream = transformers.add_destination(stream)

        if hasattr(self.config, 'train_max_len'):
            idx = stream.sources.index('latitude')
            def max_len_filter(x):
                return len(x[idx]) <= self.config.train_max_len
            stream = Filter(stream, max_len_filter)

        stream = transformers.TaxiExcludeEmptyTrips(stream)
        stream = transformers.taxi_add_datetime(stream)
        stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))

        stream = transformers.balanced_batch(stream, key='latitude',
                                             batch_size=self.config.batch_size,
                                             batch_sort_size=self.config.batch_sort_size)
        stream = Padding(stream, mask_sources=['latitude', 'longitude'])
        stream = transformers.Select(stream, req_vars)
        stream = MultiProcessing(stream)

        return stream
Пример #10
0
    def valid(self, req_vars):
        valid_dataset = TaxiDataset(self.config.valid_set, "valid.hdf5")
        train_dataset = TaxiDataset("train")
        valid_trips_ids = valid_dataset.get_data(None, slice(0, valid_dataset.num_examples))[
            valid_dataset.sources.index("trip_id")
        ]

        prefix_stream = DataStream(valid_dataset, iteration_scheme=SequentialExampleScheme(valid_dataset.num_examples))
        prefix_stream = transformers.taxi_add_datetime(prefix_stream)
        prefix_stream = transformers.taxi_add_first_last_len(prefix_stream, self.config.n_begin_end_pts)
        prefix_stream = Batch(prefix_stream, iteration_scheme=ConstantScheme(self.config.batch_size))

        candidate_stream = DataStream(train_dataset, iteration_scheme=ShuffledExampleScheme(train_dataset.num_examples))
        candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, valid_trips_ids)
        candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream)
        candidate_stream = transformers.taxi_add_datetime(candidate_stream)
        candidate_stream = transformers.taxi_add_first_last_len(candidate_stream, self.config.n_begin_end_pts)
        candidate_stream = Batch(candidate_stream, iteration_scheme=ConstantScheme(self.config.valid_candidate_size))

        sources = prefix_stream.sources + tuple("candidate_%s" % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)
        stream = transformers.Select(stream, tuple(req_vars))
        stream = MultiProcessing(stream)
        return stream
Пример #11
0
                    else:
                        content.append(
                            Path(train_data.extract(i), '%d<br>' % i))
            elif len(r) > 2:
                self.send_error(404, 'File not found')
                return None
        content.write(f)
        length = f.tell()
        f.seek(0)
        self.send_response(200)
        encoding = sys.getfilesystemencoding()
        self.send_header("Content-type", "text/html; charset=%s" % encoding)
        self.send_header("Content-Length", str(length))
        self.end_headers()
        return f


if __name__ == '__main__':
    if len(sys.argv) != 2:
        print >> sys.stderr, 'Usage: %s port' % sys.argv[0]

    print >> sys.stderr, 'Loading dataset...',
    path = os.path.join(data.path, 'data.hdf5')
    train_data = TaxiDataset('train')
    test_data = TaxiDataset('test')
    print >> sys.stderr, 'done'

    httpd = SocketServer.TCPServer(('', int(sys.argv[1])),
                                   VisualizerHTTPRequestHandler)
    httpd.serve_forever()
Пример #12
0
import csv
import os

from fuel.iterator import DataIterator
from fuel.schemes import SequentialExampleScheme
from fuel.streams import DataStream

from data.hdf5 import TaxiDataset
import data

dest_outfile = open(os.path.join(data.path, 'test_answer.csv'), 'w')
dest_outcsv = csv.writer(dest_outfile)
dest_outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"])

dataset = TaxiDataset('test',
                      'tvt.hdf5',
                      sources=('trip_id', 'longitude', 'latitude',
                               'destination_longitude',
                               'destination_latitude'))
it = DataIterator(DataStream(dataset),
                  iter(xrange(dataset.num_examples)),
                  as_dict=True)

for v in it:
    # print v
    dest_outcsv.writerow(
        [v['trip_id'], v['destination_latitude'], v['destination_longitude']])

dest_outfile.close()
Пример #13
0
def make_tvt(test_cuts_name, valid_cuts_name, outpath):
    trainset = TaxiDataset("train")
    traindata = trainset.get_data(None, slice(0, trainset.num_examples))
    idsort = traindata[trainset.sources.index("timestamp")].argsort()

    traindata = dict(zip(trainset.sources, (t[idsort] for t in traindata)))

    print >> sys.stderr, "test cut begin"
    test_cuts = importlib.import_module(".%s" % test_cuts_name, "data.cuts").cuts
    test = cut_me_baby(traindata, test_cuts)

    print >> sys.stderr, "valid cut begin"
    valid_cuts = importlib.import_module(".%s" % valid_cuts_name, "data.cuts").cuts
    valid = cut_me_baby(traindata, valid_cuts, test)

    test_size = len(test)
    valid_size = len(valid)
    train_size = data.train_size - test_size - valid_size

    print " set   | size    | ratio"
    print " ----- | ------- | -----"
    print " train | {:>7d} | {:>5.3f}".format(train_size, float(train_size) / data.train_size)
    print " valid | {:>7d} | {:>5.3f}".format(valid_size, float(valid_size) / data.train_size)
    print " test  | {:>7d} | {:>5.3f}".format(test_size, float(test_size) / data.train_size)

    with open(os.path.join(data.path, "arrival-clusters.pkl"), "r") as f:
        clusters = cPickle.load(f)

    print >> sys.stderr, "compiling cluster assignment function"
    latitude = theano.tensor.scalar("latitude")
    longitude = theano.tensor.scalar("longitude")
    coords = theano.tensor.stack(latitude, longitude).dimshuffle("x", 0)
    parent = theano.tensor.argmin(hdist(clusters, coords))
    cluster = theano.function([latitude, longitude], parent)

    train_clients = set()

    print >> sys.stderr, "preparing hdf5 data"
    hdata = {k: numpy.empty(shape=(data.train_size,), dtype=v) for k, v in all_fields.iteritems()}

    train_i = 0
    valid_i = train_size
    test_i = train_size + valid_size

    print >> sys.stderr, "write: begin"
    for idtraj in xrange(data.train_size):
        if idtraj % 10000 == 0 and idtraj != 0:
            print >> sys.stderr, "write: {:d} done".format(idtraj)
        in_test = idtraj in test
        in_valid = not in_test and idtraj in valid
        in_train = not in_test and not in_valid

        if idtraj in test:
            i = test_i
            test_i += 1
        elif idtraj in valid:
            i = valid_i
            valid_i += 1
        else:
            train_clients.add(traindata["origin_call"][idtraj])
            i = train_i
            train_i += 1

        trajlen = len(traindata["latitude"][idtraj])
        if trajlen == 0:
            hdata["destination_latitude"][i] = data.train_gps_mean[0]
            hdata["destination_longitude"][i] = data.train_gps_mean[1]
        else:
            hdata["destination_latitude"][i] = traindata["latitude"][idtraj][-1]
            hdata["destination_longitude"][i] = traindata["longitude"][idtraj][-1]
        hdata["travel_time"][i] = trajlen

        for field in native_fields:
            val = traindata[field][idtraj]
            if field in ["latitude", "longitude"]:
                if in_test:
                    val = val[: test[idtraj]]
                elif in_valid:
                    val = val[: valid[idtraj]]
            hdata[field][i] = val

        plen = len(hdata["latitude"][i])
        hdata["path_len"][i] = plen
        hdata["cluster"][i] = -1 if plen == 0 else cluster(hdata["latitude"][i][0], hdata["longitude"][i][0])

    print >> sys.stderr, "write: end"

    print >> sys.stderr, "removing useless origin_call"
    for i in xrange(train_size, data.train_size):
        if hdata["origin_call"][i] not in train_clients:
            hdata["origin_call"][i] = 0

    print >> sys.stderr, "preparing split array"

    split_array = numpy.empty(
        len(all_fields) * 3,
        dtype=numpy.dtype(
            [
                ("split", "a", 64),
                ("source", "a", 21),
                ("start", numpy.int64, 1),
                ("stop", numpy.int64, 1),
                ("indices", h5py.special_dtype(ref=h5py.Reference)),
                ("available", numpy.bool, 1),
                ("comment", "a", 1),
            ]
        ),
    )

    flen = len(all_fields)
    for i, field in enumerate(all_fields):
        split_array[i]["split"] = "train".encode("utf8")
        split_array[i + flen]["split"] = "valid".encode("utf8")
        split_array[i + 2 * flen]["split"] = "test".encode("utf8")
        split_array[i]["start"] = 0
        split_array[i]["stop"] = train_size
        split_array[i + flen]["start"] = train_size
        split_array[i + flen]["stop"] = train_size + valid_size
        split_array[i + 2 * flen]["start"] = train_size + valid_size
        split_array[i + 2 * flen]["stop"] = train_size + valid_size + test_size

        for d in [0, flen, 2 * flen]:
            split_array[i + d]["source"] = field.encode("utf8")

    split_array[:]["indices"] = None
    split_array[:]["available"] = True
    split_array[:]["comment"] = ".".encode("utf8")

    print >> sys.stderr, "writing hdf5 file"
    file = h5py.File(outpath, "w")
    for k in all_fields.keys():
        file.create_dataset(k, data=hdata[k], maxshape=(data.train_size,))

    file.attrs["split"] = split_array

    file.flush()
    file.close()
Пример #14
0
 def test_dataset(self):
     return TaxiDataset('test', data.traintest_ds)
Пример #15
0
#!/usr/bin/env python

import os

import data
from data.hdf5 import TaxiDataset
from visualizer import Path


poi = {
    'longest': 1492417
}

if __name__ == '__main__':
    prefix = os.path.join(data.path, 'visualizer', 'Train POI')
    if not os.path.isdir(prefix):
        os.mkdir(prefix)

    d = TaxiDataset('train')
    for (k, v) in poi.items():
        Path(d.extract(v)).save(os.path.join('Train POI', k))
Пример #16
0
 def valid_trips_ids(self):
     valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',))
     return valid.get_data(None, slice(0, valid.num_examples))[0]
Пример #17
0
#!/usr/bin/env python

import os
import data
from data.hdf5 import TaxiDataset
from visualizer import Path

poi = {'longest': 1492417}

if __name__ == '__main__':
    prefix = os.path.join(data.path, 'visualizer', 'Train POI')
    if not os.path.isdir(prefix):
        os.mkdir(prefix)

    d = TaxiDataset('train')
    for (k, v) in poi.items():
        Path(d.extract(v)).save(os.path.join('Train POI', k))
Пример #18
0
 def valid_dataset(self):
     return TaxiDataset(data.valid_set, data.valid_ds)
Пример #19
0
 def valid_trips_ids(self):
     valid = TaxiDataset(data.valid_set,
                         data.valid_ds,
                         sources=('trip_id', ))
     return valid.get_data(None, slice(0, valid.num_examples))[0]
Пример #20
0
def make_tvt(test_cuts_name, valid_cuts_name, outpath):
    trainset = TaxiDataset('train')
    traindata = trainset.get_data(None, slice(0, trainset.num_examples))
    idsort = traindata[trainset.sources.index('timestamp')].argsort()

    traindata = dict(zip(trainset.sources, (t[idsort] for t in traindata)))

    print >> sys.stderr, 'test cut begin'
    test_cuts = importlib.import_module('.%s' % test_cuts_name,
                                        'data.cuts').cuts
    test = cut_me_baby(traindata, test_cuts)

    print >> sys.stderr, 'valid cut begin'
    valid_cuts = importlib.import_module('.%s' % valid_cuts_name,
                                         'data.cuts').cuts
    valid = cut_me_baby(traindata, valid_cuts, test)

    test_size = len(test)
    valid_size = len(valid)
    train_size = data.train_size - test_size - valid_size

    print ' set   | size    | ratio'
    print ' ----- | ------- | -----'
    print ' train | {:>7d} | {:>5.3f}'.format(
        train_size,
        float(train_size) / data.train_size)
    print ' valid | {:>7d} | {:>5.3f}'.format(
        valid_size,
        float(valid_size) / data.train_size)
    print ' test  | {:>7d} | {:>5.3f}'.format(
        test_size,
        float(test_size) / data.train_size)

    with open(os.path.join(data.path, 'arrival-clusters.pkl'), 'r') as f:
        clusters = cPickle.load(f)

    print >> sys.stderr, 'compiling cluster assignment function'
    latitude = theano.tensor.scalar('latitude')
    longitude = theano.tensor.scalar('longitude')
    coords = theano.tensor.stack(latitude, longitude).dimshuffle('x', 0)
    parent = theano.tensor.argmin(hdist(clusters, coords))
    cluster = theano.function([latitude, longitude], parent)

    train_clients = set()

    print >> sys.stderr, 'preparing hdf5 data'
    hdata = {
        k: numpy.empty(shape=(data.train_size, ), dtype=v)
        for k, v in all_fields.iteritems()
    }

    train_i = 0
    valid_i = train_size
    test_i = train_size + valid_size

    print >> sys.stderr, 'write: begin'
    for idtraj in xrange(data.train_size):
        if idtraj % 10000 == 0 and idtraj != 0:
            print >> sys.stderr, 'write: {:d} done'.format(idtraj)
        in_test = idtraj in test
        in_valid = not in_test and idtraj in valid
        in_train = not in_test and not in_valid

        if idtraj in test:
            i = test_i
            test_i += 1
        elif idtraj in valid:
            i = valid_i
            valid_i += 1
        else:
            train_clients.add(traindata['origin_call'][idtraj])
            i = train_i
            train_i += 1

        trajlen = len(traindata['latitude'][idtraj])
        if trajlen == 0:
            hdata['destination_latitude'][i] = data.train_gps_mean[0]
            hdata['destination_longitude'][i] = data.train_gps_mean[1]
        else:
            hdata['destination_latitude'][i] = traindata['latitude'][idtraj][
                -1]
            hdata['destination_longitude'][i] = traindata['longitude'][idtraj][
                -1]
        hdata['travel_time'][i] = trajlen

        for field in native_fields:
            val = traindata[field][idtraj]
            if field in ['latitude', 'longitude']:
                if in_test:
                    val = val[:test[idtraj]]
                elif in_valid:
                    val = val[:valid[idtraj]]
            hdata[field][i] = val

        plen = len(hdata['latitude'][i])
        hdata['path_len'][i] = plen
        hdata['cluster'][i] = -1 if plen == 0 else cluster(
            hdata['latitude'][i][0], hdata['longitude'][i][0])

    print >> sys.stderr, 'write: end'

    print >> sys.stderr, 'removing useless origin_call'
    for i in xrange(train_size, data.train_size):
        if hdata['origin_call'][i] not in train_clients:
            hdata['origin_call'][i] = 0

    print >> sys.stderr, 'preparing split array'

    split_array = numpy.empty(len(all_fields) * 3,
                              dtype=numpy.dtype([
                                  ('split', 'a', 64), ('source', 'a', 21),
                                  ('start', numpy.int64, 1),
                                  ('stop', numpy.int64, 1),
                                  ('indices',
                                   h5py.special_dtype(ref=h5py.Reference)),
                                  ('available', numpy.bool, 1),
                                  ('comment', 'a', 1)
                              ]))

    flen = len(all_fields)
    for i, field in enumerate(all_fields):
        split_array[i]['split'] = 'train'.encode('utf8')
        split_array[i + flen]['split'] = 'valid'.encode('utf8')
        split_array[i + 2 * flen]['split'] = 'test'.encode('utf8')
        split_array[i]['start'] = 0
        split_array[i]['stop'] = train_size
        split_array[i + flen]['start'] = train_size
        split_array[i + flen]['stop'] = train_size + valid_size
        split_array[i + 2 * flen]['start'] = train_size + valid_size
        split_array[i + 2 * flen]['stop'] = train_size + valid_size + test_size

        for d in [0, flen, 2 * flen]:
            split_array[i + d]['source'] = field.encode('utf8')

    split_array[:]['indices'] = None
    split_array[:]['available'] = True
    split_array[:]['comment'] = '.'.encode('utf8')

    print >> sys.stderr, 'writing hdf5 file'
    file = h5py.File(outpath, 'w')
    for k in all_fields.keys():
        file.create_dataset(k, data=hdata[k], maxshape=(data.train_size, ))

    file.attrs['split'] = split_array

    file.flush()
    file.close()