コード例 #1
0
ファイル: memory_network.py プロジェクト: JimStearns206/taxi
    def train(self, req_vars):
        valid = TaxiDataset(self.config.valid_set, "valid.hdf5", sources=("trip_id",))
        valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]

        dataset = TaxiDataset("train")

        prefix_stream = DataStream(dataset, iteration_scheme=TaxiTimeCutScheme(self.config.num_cuts))
        prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, valid_trips_ids)
        prefix_stream = transformers.TaxiGenerateSplits(prefix_stream, max_splits=self.config.max_splits)
        prefix_stream = transformers.taxi_add_datetime(prefix_stream)
        prefix_stream = transformers.taxi_add_first_last_len(prefix_stream, self.config.n_begin_end_pts)
        prefix_stream = Batch(prefix_stream, iteration_scheme=ConstantScheme(self.config.batch_size))

        candidate_stream = DataStream(dataset, iteration_scheme=ShuffledExampleScheme(dataset.num_examples))
        candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, valid_trips_ids)
        candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream)
        candidate_stream = transformers.taxi_add_datetime(candidate_stream)
        candidate_stream = transformers.taxi_add_first_last_len(candidate_stream, self.config.n_begin_end_pts)
        candidate_stream = Batch(candidate_stream, iteration_scheme=ConstantScheme(self.config.train_candidate_size))

        sources = prefix_stream.sources + tuple("candidate_%s" % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)
        stream = transformers.Select(stream, tuple(req_vars))
        stream = MultiProcessing(stream)
        return stream
コード例 #2
0
    def test(self, req_vars):
        prefix_stream = DataStream(self.test_dataset,
                                   iteration_scheme=SequentialExampleScheme(
                                       self.test_dataset.num_examples))
        prefix_stream = transformers.taxi_add_datetime(prefix_stream)
        prefix_stream = transformers.taxi_add_first_last_len(
            prefix_stream, self.config.n_begin_end_pts)

        if not data.tvt:
            prefix_stream = transformers.taxi_remove_test_only_clients(
                prefix_stream)

        prefix_stream = Batch(prefix_stream,
                              iteration_scheme=ConstantScheme(
                                  self.config.batch_size))

        candidate_stream = self.candidate_stream(
            self.config.test_candidate_size)

        sources = prefix_stream.sources + tuple(
            'candidate_%s' % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)
        stream = transformers.Select(stream, tuple(req_vars))
        stream = MultiProcessing(stream)
        return stream
コード例 #3
0
ファイル: mlp.py プロジェクト: JimStearns206/taxi
    def train(self, req_vars):
        valid = TaxiDataset(self.config.valid_set, 'valid.hdf5', sources=('trip_id',))
        valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]

        stream = TaxiDataset('train')

        if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training:
            stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme())
        else:
            stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))

        stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)
        stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits)

        if hasattr(self.config, 'shuffle_batch_size'):
            stream = transformers.Batch(stream, iteration_scheme=ConstantScheme(self.config.shuffle_batch_size))
            stream = Mapping(stream, SortMapping(key=UniformGenerator()))
            stream = Unpack(stream)

        stream = transformers.taxi_add_datetime(stream)
        stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
        stream = transformers.Select(stream, tuple(req_vars))
        
        stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))

        stream = MultiProcessing(stream)

        return stream
コード例 #4
0
    def train(self, req_vars):
        prefix_stream = DataStream(self.train_dataset,
                                   iteration_scheme=ShuffledExampleScheme(
                                       self.train_dataset.num_examples))

        if not data.tvt:
            prefix_stream = transformers.TaxiExcludeTrips(
                prefix_stream, self.valid_trips_ids)
        prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream)
        prefix_stream = transformers.TaxiGenerateSplits(
            prefix_stream, max_splits=self.config.max_splits)
        prefix_stream = transformers.taxi_add_datetime(prefix_stream)
        prefix_stream = transformers.taxi_add_first_last_len(
            prefix_stream, self.config.n_begin_end_pts)
        prefix_stream = Batch(prefix_stream,
                              iteration_scheme=ConstantScheme(
                                  self.config.batch_size))

        candidate_stream = self.candidate_stream(
            self.config.train_candidate_size)

        sources = prefix_stream.sources + tuple(
            'candidate_%s' % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)
        stream = transformers.Select(stream, tuple(req_vars))
        stream = MultiProcessing(stream)
        return stream
コード例 #5
0
ファイル: mlp.py プロジェクト: JimStearns206/taxi
    def valid(self, req_vars):
        stream = TaxiStream(self.config.valid_set, 'valid.hdf5')

        stream = transformers.taxi_add_datetime(stream)
        stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
        stream = transformers.Select(stream, tuple(req_vars))
        return Batch(stream, iteration_scheme=ConstantScheme(1000))
コード例 #6
0
ファイル: mlp.py プロジェクト: zhaojuanjuan511/taxi
    def valid(self, req_vars):
        stream = TaxiStream(data.valid_set, data.valid_ds)

        stream = transformers.taxi_add_datetime(stream)
        stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
        stream = transformers.Select(stream, tuple(req_vars))
        return Batch(stream, iteration_scheme=ConstantScheme(1000))
コード例 #7
0
ファイル: mlp.py プロジェクト: zhaojuanjuan511/taxi
    def train(self, req_vars):
        stream = TaxiDataset('train', data.traintest_ds)

        if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training:
            stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme())
        else:
            stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))

        if not data.tvt:
            valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',))
            valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]
            stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)

        stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits)

        if hasattr(self.config, 'shuffle_batch_size'):
            stream = transformers.Batch(stream, iteration_scheme=ConstantScheme(self.config.shuffle_batch_size))
            stream = Mapping(stream, SortMapping(key=UniformGenerator()))
            stream = Unpack(stream)

        stream = transformers.taxi_add_datetime(stream)
        stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
        stream = transformers.Select(stream, tuple(req_vars))
        
        stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))

        stream = MultiProcessing(stream)

        return stream
コード例 #8
0
ファイル: mlp.py プロジェクト: JimStearns206/taxi
    def test(self, req_vars):
        stream = TaxiStream('test')
        
        stream = transformers.taxi_add_datetime(stream)
        stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
        stream = transformers.taxi_remove_test_only_clients(stream)

        return Batch(stream, iteration_scheme=ConstantScheme(1))
コード例 #9
0
ファイル: mlp.py プロジェクト: zhaojuanjuan511/taxi
    def test(self, req_vars):
        stream = TaxiStream('test', data.traintest_ds)
        
        stream = transformers.taxi_add_datetime(stream)
        stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
        stream = transformers.taxi_remove_test_only_clients(stream)

        return Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
コード例 #10
0
ファイル: memory_network.py プロジェクト: JimStearns206/taxi
    def train(self, req_vars):
        valid = TaxiDataset(self.config.valid_set,
                            'valid.hdf5',
                            sources=('trip_id', ))
        valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]

        dataset = TaxiDataset('train')

        prefix_stream = DataStream(dataset,
                                   iteration_scheme=TaxiTimeCutScheme(
                                       self.config.num_cuts))
        prefix_stream = transformers.TaxiExcludeTrips(prefix_stream,
                                                      valid_trips_ids)
        prefix_stream = transformers.TaxiGenerateSplits(
            prefix_stream, max_splits=self.config.max_splits)
        prefix_stream = transformers.taxi_add_datetime(prefix_stream)
        prefix_stream = transformers.taxi_add_first_last_len(
            prefix_stream, self.config.n_begin_end_pts)
        prefix_stream = Batch(prefix_stream,
                              iteration_scheme=ConstantScheme(
                                  self.config.batch_size))

        candidate_stream = DataStream(dataset,
                                      iteration_scheme=ShuffledExampleScheme(
                                          dataset.num_examples))
        candidate_stream = transformers.TaxiExcludeTrips(
            candidate_stream, valid_trips_ids)
        candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream)
        candidate_stream = transformers.taxi_add_datetime(candidate_stream)
        candidate_stream = transformers.taxi_add_first_last_len(
            candidate_stream, self.config.n_begin_end_pts)
        candidate_stream = Batch(candidate_stream,
                                 iteration_scheme=ConstantScheme(
                                     self.config.train_candidate_size))

        sources = prefix_stream.sources + tuple(
            'candidate_%s' % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)
        stream = transformers.Select(stream, tuple(req_vars))
        stream = MultiProcessing(stream)
        return stream
コード例 #11
0
ファイル: memory_network.py プロジェクト: DragonCircle/taxi
    def candidate_stream(self, n_candidates):
        candidate_stream = DataStream(self.train_dataset,
                                      iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples))
        if not data.tvt:
            candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, self.valid_trips_ids)
        candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream)
        candidate_stream = transformers.taxi_add_datetime(candidate_stream)
        candidate_stream = transformers.taxi_add_first_last_len(candidate_stream,
                                                                self.config.n_begin_end_pts)
        if not data.tvt:
            candidate_stream = transformers.add_destination(candidate_stream)

        return Batch(candidate_stream,
                     iteration_scheme=ConstantScheme(n_candidates))
コード例 #12
0
ファイル: memory_network.py プロジェクト: JimStearns206/taxi
    def valid(self, req_vars):
        valid_dataset = TaxiDataset(self.config.valid_set, "valid.hdf5")
        train_dataset = TaxiDataset("train")
        valid_trips_ids = valid_dataset.get_data(None, slice(0, valid_dataset.num_examples))[
            valid_dataset.sources.index("trip_id")
        ]

        prefix_stream = DataStream(valid_dataset, iteration_scheme=SequentialExampleScheme(valid_dataset.num_examples))
        prefix_stream = transformers.taxi_add_datetime(prefix_stream)
        prefix_stream = transformers.taxi_add_first_last_len(prefix_stream, self.config.n_begin_end_pts)
        prefix_stream = Batch(prefix_stream, iteration_scheme=ConstantScheme(self.config.batch_size))

        candidate_stream = DataStream(train_dataset, iteration_scheme=ShuffledExampleScheme(train_dataset.num_examples))
        candidate_stream = transformers.TaxiExcludeTrips(candidate_stream, valid_trips_ids)
        candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream)
        candidate_stream = transformers.taxi_add_datetime(candidate_stream)
        candidate_stream = transformers.taxi_add_first_last_len(candidate_stream, self.config.n_begin_end_pts)
        candidate_stream = Batch(candidate_stream, iteration_scheme=ConstantScheme(self.config.valid_candidate_size))

        sources = prefix_stream.sources + tuple("candidate_%s" % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)
        stream = transformers.Select(stream, tuple(req_vars))
        stream = MultiProcessing(stream)
        return stream
コード例 #13
0
ファイル: memory_network.py プロジェクト: JimStearns206/taxi
    def valid(self, req_vars):
        valid_dataset = TaxiDataset(self.config.valid_set, 'valid.hdf5')
        train_dataset = TaxiDataset('train')
        valid_trips_ids = valid_dataset.get_data(
            None, slice(0, valid_dataset.num_examples))[
                valid_dataset.sources.index('trip_id')]

        prefix_stream = DataStream(valid_dataset,
                                   iteration_scheme=SequentialExampleScheme(
                                       valid_dataset.num_examples))
        prefix_stream = transformers.taxi_add_datetime(prefix_stream)
        prefix_stream = transformers.taxi_add_first_last_len(
            prefix_stream, self.config.n_begin_end_pts)
        prefix_stream = Batch(prefix_stream,
                              iteration_scheme=ConstantScheme(
                                  self.config.batch_size))

        candidate_stream = DataStream(train_dataset,
                                      iteration_scheme=ShuffledExampleScheme(
                                          train_dataset.num_examples))
        candidate_stream = transformers.TaxiExcludeTrips(
            candidate_stream, valid_trips_ids)
        candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream)
        candidate_stream = transformers.taxi_add_datetime(candidate_stream)
        candidate_stream = transformers.taxi_add_first_last_len(
            candidate_stream, self.config.n_begin_end_pts)
        candidate_stream = Batch(candidate_stream,
                                 iteration_scheme=ConstantScheme(
                                     self.config.valid_candidate_size))

        sources = prefix_stream.sources + tuple(
            'candidate_%s' % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)
        stream = transformers.Select(stream, tuple(req_vars))
        stream = MultiProcessing(stream)
        return stream
コード例 #14
0
    def candidate_stream(self, n_candidates):
        candidate_stream = DataStream(self.train_dataset,
                                      iteration_scheme=ShuffledExampleScheme(
                                          self.train_dataset.num_examples))
        if not data.tvt:
            candidate_stream = transformers.TaxiExcludeTrips(
                candidate_stream, self.valid_trips_ids)
        candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream)
        candidate_stream = transformers.taxi_add_datetime(candidate_stream)
        candidate_stream = transformers.taxi_add_first_last_len(
            candidate_stream, self.config.n_begin_end_pts)
        if not data.tvt:
            candidate_stream = transformers.add_destination(candidate_stream)

        return Batch(candidate_stream,
                     iteration_scheme=ConstantScheme(n_candidates))
コード例 #15
0
ファイル: memory_network.py プロジェクト: DragonCircle/taxi
    def valid(self, req_vars):
        prefix_stream = DataStream(
                           self.valid_dataset,
                           iteration_scheme=SequentialExampleScheme(self.valid_dataset.num_examples))
        prefix_stream = transformers.taxi_add_datetime(prefix_stream)
        prefix_stream = transformers.taxi_add_first_last_len(prefix_stream,
                                                             self.config.n_begin_end_pts)
        prefix_stream = Batch(prefix_stream,
                              iteration_scheme=ConstantScheme(self.config.batch_size))

        candidate_stream = self.candidate_stream(self.config.valid_candidate_size)

        sources = prefix_stream.sources + tuple('candidate_%s' % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)
        stream = transformers.Select(stream, tuple(req_vars))
        stream = MultiProcessing(stream)
        return stream
コード例 #16
0
ファイル: memory_network.py プロジェクト: DragonCircle/taxi
    def train(self, req_vars):
        prefix_stream = DataStream(self.train_dataset,
                                   iteration_scheme=ShuffledExampleScheme(self.train_dataset.num_examples))

        if not data.tvt:
            prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, self.valid_trips_ids)
        prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream)
        prefix_stream = transformers.TaxiGenerateSplits(prefix_stream,
                                                        max_splits=self.config.max_splits)
        prefix_stream = transformers.taxi_add_datetime(prefix_stream)
        prefix_stream = transformers.taxi_add_first_last_len(prefix_stream,
                                                             self.config.n_begin_end_pts)
        prefix_stream = Batch(prefix_stream,
                              iteration_scheme=ConstantScheme(self.config.batch_size))

        candidate_stream = self.candidate_stream(self.config.train_candidate_size)

        sources = prefix_stream.sources + tuple('candidate_%s' % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)
        stream = transformers.Select(stream, tuple(req_vars))
        stream = MultiProcessing(stream)
        return stream