예제 #1
0
    def test(self, req_vars):
        prefix_stream = DataStream(self.test_dataset,
                                   iteration_scheme=SequentialExampleScheme(
                                       self.test_dataset.num_examples))
        prefix_stream = transformers.taxi_add_datetime(prefix_stream)
        prefix_stream = transformers.taxi_add_first_last_len(
            prefix_stream, self.config.n_begin_end_pts)

        if not data.tvt:
            prefix_stream = transformers.taxi_remove_test_only_clients(
                prefix_stream)

        prefix_stream = Batch(prefix_stream,
                              iteration_scheme=ConstantScheme(
                                  self.config.batch_size))

        candidate_stream = self.candidate_stream(
            self.config.test_candidate_size)

        sources = prefix_stream.sources + tuple(
            'candidate_%s' % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)
        stream = transformers.Select(stream, tuple(req_vars))
        stream = MultiProcessing(stream)
        return stream
예제 #2
0
    def train(self, req_vars):
        valid = TaxiDataset(self.config.valid_set,
                            'valid.hdf5',
                            sources=('trip_id', ))
        valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]

        stream = TaxiDataset('train')

        if hasattr(
                self.config,
                'use_cuts_for_training') and self.config.use_cuts_for_training:
            stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme())
        else:
            stream = DataStream(stream,
                                iteration_scheme=ShuffledExampleScheme(
                                    stream.num_examples))

        stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)
        stream = transformers.TaxiGenerateSplits(
            stream, max_splits=self.config.max_splits)

        stream = transformers.taxi_add_datetime(stream)
        # stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
        stream = transformers.Select(stream, tuple(req_vars))

        stream = Batch(stream,
                       iteration_scheme=ConstantScheme(self.config.batch_size))

        stream = MultiProcessing(stream)

        return stream
예제 #3
0
    def train(self, req_vars):
        prefix_stream = DataStream(self.train_dataset,
                                   iteration_scheme=ShuffledExampleScheme(
                                       self.train_dataset.num_examples))

        if not data.tvt:
            prefix_stream = transformers.TaxiExcludeTrips(
                prefix_stream, self.valid_trips_ids)
        prefix_stream = transformers.TaxiExcludeEmptyTrips(prefix_stream)
        prefix_stream = transformers.TaxiGenerateSplits(
            prefix_stream, max_splits=self.config.max_splits)
        prefix_stream = transformers.taxi_add_datetime(prefix_stream)
        prefix_stream = transformers.taxi_add_first_last_len(
            prefix_stream, self.config.n_begin_end_pts)
        prefix_stream = Batch(prefix_stream,
                              iteration_scheme=ConstantScheme(
                                  self.config.batch_size))

        candidate_stream = self.candidate_stream(
            self.config.train_candidate_size)

        sources = prefix_stream.sources + tuple(
            'candidate_%s' % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)
        stream = transformers.Select(stream, tuple(req_vars))
        stream = MultiProcessing(stream)
        return stream
예제 #4
0
    def train(self, req_vars):
        stream = TaxiDataset('train', data.traintest_ds)

        if hasattr(
                self.config,
                'use_cuts_for_training') and self.config.use_cuts_for_training:
            stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme())
        else:
            stream = DataStream(stream,
                                iteration_scheme=ShuffledExampleScheme(
                                    stream.num_examples))

        if not data.tvt:
            valid = TaxiDataset(data.valid_set,
                                data.valid_ds,
                                sources=('trip_id', ))
            valid_trips_ids = valid.get_data(None,
                                             slice(0, valid.num_examples))[0]
            stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)

        if hasattr(self.config, 'max_splits'):
            stream = transformers.TaxiGenerateSplits(
                stream, max_splits=self.config.max_splits)
        elif not data.tvt:
            stream = transformers.add_destination(stream)

        if hasattr(self.config, 'train_max_len'):
            idx = stream.sources.index('latitude')

            def max_len_filter(x):
                return len(x[idx]) <= self.config.train_max_len

            stream = Filter(stream, max_len_filter)

        stream = transformers.TaxiExcludeEmptyTrips(stream)
        stream = transformers.taxi_add_datetime(stream)
        stream = transformers.Select(
            stream, tuple(v for v in req_vars if not v.endswith('_mask')))

        stream = transformers.balanced_batch(
            stream,
            key='latitude',
            batch_size=self.config.batch_size,
            batch_sort_size=self.config.batch_sort_size)
        stream = Padding(stream, mask_sources=['latitude', 'longitude'])
        stream = transformers.Select(stream, req_vars)
        stream = MultiProcessing(stream)

        return stream
예제 #5
0
    def valid(self, req_vars):
        stream = TaxiStream(data.valid_set, data.valid_ds)

        stream = transformers.taxi_add_datetime(stream)
        stream = transformers.Select(
            stream, tuple(v for v in req_vars if not v.endswith('_mask')))

        stream = transformers.balanced_batch(
            stream,
            key='latitude',
            batch_size=self.config.batch_size,
            batch_sort_size=self.config.batch_sort_size)
        stream = Padding(stream, mask_sources=['latitude', 'longitude'])
        stream = transformers.Select(stream, req_vars)
        stream = MultiProcessing(stream)

        return stream
예제 #6
0
    def train(self, req_vars):
        valid = TaxiDataset(self.config.valid_set,
                            'valid.hdf5',
                            sources=('trip_id', ))
        valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]

        dataset = TaxiDataset('train')

        prefix_stream = DataStream(dataset,
                                   iteration_scheme=TaxiTimeCutScheme(
                                       self.config.num_cuts))
        prefix_stream = transformers.TaxiExcludeTrips(prefix_stream,
                                                      valid_trips_ids)
        prefix_stream = transformers.TaxiGenerateSplits(
            prefix_stream, max_splits=self.config.max_splits)
        prefix_stream = transformers.taxi_add_datetime(prefix_stream)
        prefix_stream = transformers.taxi_add_first_last_len(
            prefix_stream, self.config.n_begin_end_pts)
        prefix_stream = Batch(prefix_stream,
                              iteration_scheme=ConstantScheme(
                                  self.config.batch_size))

        candidate_stream = DataStream(dataset,
                                      iteration_scheme=ShuffledExampleScheme(
                                          dataset.num_examples))
        candidate_stream = transformers.TaxiExcludeTrips(
            candidate_stream, valid_trips_ids)
        candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream)
        candidate_stream = transformers.taxi_add_datetime(candidate_stream)
        candidate_stream = transformers.taxi_add_first_last_len(
            candidate_stream, self.config.n_begin_end_pts)
        candidate_stream = Batch(candidate_stream,
                                 iteration_scheme=ConstantScheme(
                                     self.config.train_candidate_size))

        sources = prefix_stream.sources + tuple(
            'candidate_%s' % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)
        stream = transformers.Select(stream, tuple(req_vars))
        stream = MultiProcessing(stream)
        return stream
예제 #7
0
    def valid(self, req_vars):
        valid_dataset = TaxiDataset(self.config.valid_set, 'valid.hdf5')
        train_dataset = TaxiDataset('train')
        valid_trips_ids = valid_dataset.get_data(
            None, slice(0, valid_dataset.num_examples))[
                valid_dataset.sources.index('trip_id')]

        prefix_stream = DataStream(valid_dataset,
                                   iteration_scheme=SequentialExampleScheme(
                                       valid_dataset.num_examples))
        prefix_stream = transformers.taxi_add_datetime(prefix_stream)
        prefix_stream = transformers.taxi_add_first_last_len(
            prefix_stream, self.config.n_begin_end_pts)
        prefix_stream = Batch(prefix_stream,
                              iteration_scheme=ConstantScheme(
                                  self.config.batch_size))

        candidate_stream = DataStream(train_dataset,
                                      iteration_scheme=ShuffledExampleScheme(
                                          train_dataset.num_examples))
        candidate_stream = transformers.TaxiExcludeTrips(
            candidate_stream, valid_trips_ids)
        candidate_stream = transformers.TaxiExcludeEmptyTrips(candidate_stream)
        candidate_stream = transformers.taxi_add_datetime(candidate_stream)
        candidate_stream = transformers.taxi_add_first_last_len(
            candidate_stream, self.config.n_begin_end_pts)
        candidate_stream = Batch(candidate_stream,
                                 iteration_scheme=ConstantScheme(
                                     self.config.valid_candidate_size))

        sources = prefix_stream.sources + tuple(
            'candidate_%s' % k for k in candidate_stream.sources)
        stream = Merge((prefix_stream, candidate_stream), sources)
        stream = transformers.Select(stream, tuple(req_vars))
        stream = MultiProcessing(stream)
        return stream
예제 #8
0
 def test_value_error_on_request(self):
     background = MultiProcessing(self.transformer)
     assert_raises(ValueError, background.get_data, [0, 1])
예제 #9
0
 def test_multiprocessing(self):
     background = MultiProcessing(self.transformer)
     assert_equal(list(background.get_epoch_iterator()),
                  list(zip(range(1, 101))))
예제 #10
0
def test_multiprocessing():
    stream = IterableDataset(range(100)).get_example_stream()
    plus_one = Mapping(stream, lambda x: (x[0] + 1, ))
    background = MultiProcessing(plus_one)
    for a, b in zip(background.get_epoch_iterator(), range(1, 101)):
        assert a == (b, )
예제 #11
0
def test_multiprocessing():
    stream = IterableDataset(range(100)).get_example_stream()
    plus_one = Mapping(stream, lambda x: (x[0] + 1,))
    background = MultiProcessing(plus_one)
    for a, b in zip(background.get_epoch_iterator(), range(1, 101)):
        assert a == (b,)
예제 #12
0
 def test_axis_labels_passed_on_by_default(self):
     self.transformer.axis_labels = {'features': ('batch', 'index')}
     background = MultiProcessing(self.transformer)
     assert_equal(background.axis_labels, self.transformer.axis_labels)
예제 #13
0
 def test_multiprocessing(self):
     background = MultiProcessing(self.transformer)
     assert_equal(list(background.get_epoch_iterator()),
                  list(zip(range(1, 101))))
예제 #14
0
    print(args)

    rsync = Rsync(args.tmpdir)
    rsync.sync(args.data_path)
    args.data_path = os.path.join(args.tmpdir,
                                  os.path.basename(args.data_path))

    train_ds = fuel_utils.get_datastream(
        path=args.data_path,
        which_set=args.train_dataset,
        batch_size=args.batch_size,
        use_ivectors=args.use_ivectors,
        truncate_ivectors=args.truncate_ivectors,
        ivector_dim=args.ivector_dim)
    train_ds = MultiProcessing(train_ds, max_store=200)
    valid_ds = fuel_utils.get_datastream(
        path=args.data_path,
        which_set=args.valid_dataset,
        batch_size=args.batch_size,
        use_ivectors=args.use_ivectors,
        truncate_ivectors=args.truncate_ivectors,
        ivector_dim=args.ivector_dim)
    test_ds = fuel_utils.get_datastream(
        path=args.data_path,
        which_set=args.test_dataset,
        batch_size=args.batch_size,
        use_ivectors=args.use_ivectors,
        truncate_ivectors=args.truncate_ivectors,
        ivector_dim=args.ivector_dim)