コード例 #1
0
ファイル: sampling.py プロジェクト: ml-lab/wayback
    def run(self, session, primers, length, temperature, hp=None):
        batch_size = len(primers)
        # process in segments to avoid tensorflow eating all the memory
        max_segment_length = min(10000, hp.segment_length)

        print "conditioning..."
        segment_length = min(max_segment_length,
                             max(len(primer[0]) for primer in primers))
        # ensure segment_length is a multiple of chunk_size
        segment_length -= segment_length % hp.chunk_size

        state = NS(model=self.model.initial_state(batch_size))
        for segment in util.segments(primers,
                                     segment_length,
                                     overlap=hp.chunk_size):
            x, = util.examples_as_arrays(segment)
            feed_dict = {self.tensors.x: x.T}
            feed_dict.update(self.model.feed_dict(state.model))
            values = NS.FlatCall(
                ft.partial(session.run, feed_dict=feed_dict),
                self.tensors.cond.Extract("final_state.model final_xchunk"))
            state.model = values.final_state.model
            sys.stderr.write(".")
        sys.stderr.write("\n")

        cond_values = values

        # make sure length is a multiple of chunk_size
        chunky_length = length + hp.chunk_size - length % hp.chunk_size

        print "sampling..."
        length_left = chunky_length
        xhats = []
        state = NS(model=cond_values.final_state.model,
                   initial_xchunk=cond_values.final_xchunk)
        while length_left > 0:
            segment_length = min(max_segment_length, length_left)
            length_left -= segment_length

            feed_dict = {
                self.tensors.initial_xchunk: state.initial_xchunk,
                self.tensors.length: segment_length,
                self.tensors.temperature: temperature
            }
            feed_dict.update(self.model.feed_dict(state.model))
            sample_values = NS.FlatCall(
                ft.partial(session.run, feed_dict=feed_dict),
                self.tensors.sample.Extract(
                    "final_state.model xhat final_xhatchunk"))
            state.model = sample_values.final_state.model
            state.initial_xchunk = sample_values.final_xhatchunk

            xhats.append(sample_values.xhat)
            sys.stderr.write(".")
        sys.stderr.write("\n")

        xhat = np.concatenate(xhats, axis=0)
        # truncate from chunky_length to the desired sample length
        xhat = xhat[:length]
        return xhat.T
コード例 #2
0
ファイル: util_test.py プロジェクト: zhuhyc/wayback
 def test_segments_regression1(self):
     examples = [[np.arange(9)]]
     segment_length = 4
     overlap = 2
     for segment in util.segments(examples, segment_length,
                                  overlap=overlap):
         self.assertEqual(4, len(segment[0][0]))
コード例 #3
0
ファイル: training.py プロジェクト: zhuhyc/wayback
    def run(self, session, examples, max_step_count=None, hooks=None, hp=None):
        tensors = self.tensors.Extract(
            "loss error summaries global_step training_op learning_rate final_state.model"
        )
        state = NS(global_step=tf.train.global_step(session,
                                                    self.tensors.global_step),
                   model=self.model.initial_state(hp.batch_size))
        while True:
            for batch in util.batches(examples, hp.batch_size):
                for segment in util.segments(batch,
                                             self.segment_length,
                                             overlap=LEFTOVER):
                    if max_step_count is not None and state.global_step >= max_step_count:
                        return

                    hooks.Get("step.before", util.noop)(state)
                    x, = util.examples_as_arrays(segment)
                    feed_dict = {self.tensors.x: x.T}
                    feed_dict.update(self.model.feed_dict(state.model))
                    values = tfutil.run(session, tensors, feed_dict=feed_dict)
                    state.model = values.final_state.model
                    state.global_step = values.global_step
                    hooks.Get("step.after", util.noop)(state, values)

                    print("step #%d loss %f error %f learning rate %e" %
                          (values.global_step, values.loss, values.error,
                           values.learning_rate))

                    if np.isnan(values.loss):
                        raise ValueError("loss has become NaN")
コード例 #4
0
ファイル: sample.py プロジェクト: zhuhyc/wayback
def preprocess_primers(examples, hp):
    # maybe augment number of examples to ensure batch norm will work
    min_batch_size = 16
    if len(examples) < min_batch_size:
        k = min_batch_size // len(examples)
        examples.extend(
            example.with_offset(offset) for example in examples for offset in
            util.random_choice(len(example), size=[k], replace=False))

    # maybe augment number of time steps to ensure util.segments doesn't discard
    # anything at the ends of the examples. this is done by left-padding the
    # shorter examples with repetitions.
    max_len = max(map(len, examples))
    examples = [
        example.map(lambda feature: np.pad(
            feature, [(max_len - len(feature), 0)], mode="wrap"))
        for wav, in examples
    ]

    # time is tight; condition on 3 seconds of the wav files only
    examples_segments = list(util.segments(examples,
                                           3 * hp.sampling_frequency))
    if len(examples_segments) > 2:
        # don't use the first and last segments to avoid silence
        examples_segments = examples_segments[1:-1]
    examples = examples_segments[util.random_choice(len(examples_segments))]

    return examples
コード例 #5
0
ファイル: util_test.py プロジェクト: zhuhyc/wayback
    def test_segments1(self):
        length = 100
        xs = np.random.randint(2, 100, size=[length])
        examples = [[xs]]
        segment_length = 10

        ys = []
        for segment in util.segments(examples, segment_length):
            ys.extend(segment[0][0])
        self.assertEqual(list(xs), list(ys))

        k = np.random.randint(1, segment_length - 1)
        for i, segment in enumerate(
                util.segments(examples, segment_length, overlap=k)):
            if i != 0:
                self.assertTrue(np.array_equal(overlap, segment[0][0][:k]))
            overlap = segment[0][0][-k:]
コード例 #6
0
    def run(self,
            session,
            examples,
            max_step_count=None,
            hp=None,
            aggregates=None):
        aggregates = NS(aggregates or {})
        for key in "loss error".split():
            if key not in aggregates:
                aggregates[key] = util.MeanAggregate()

        tensors = self.tensors.Extract(*[key for key in aggregates.Keys()])
        tensors.Update(self.tensors.Extract("final_state.model"))

        state = NS(step=0, model=self.model.initial_state(hp.batch_size))

        try:
            for batch in util.batches(examples, hp.batch_size):
                for segment in util.segments(batch,
                                             hp.segment_length,
                                             overlap=hp.chunk_size):
                    if max_step_count is not None and state.step >= max_step_count:
                        raise StopIteration()

                    x, = util.examples_as_arrays(segment)
                    feed_dict = {self.tensors.x: x.T}
                    feed_dict.update(self.model.feed_dict(state.model))
                    values = NS.FlatCall(
                        ft.partial(session.run, feed_dict=feed_dict), tensors)

                    for key in aggregates.Keys():
                        aggregates[key].add(values[key])

                    sys.stderr.write(".")
                    state.model = values.final_state.model
                    state.step += 1
        except StopIteration:
            pass

        sys.stderr.write("\n")

        values = NS(
            (key, aggregate.value) for key, aggregate in aggregates.Items())

        values.summaries = [
            tf.Summary.Value(tag="%s_valid" % key, simple_value=values[key])
            for key in "loss error".split()
        ]
        print "### evaluation loss %6.5f error %6.5f" % (values.loss,
                                                         values.error)

        if np.isnan(values.loss):
            raise ValueError("loss has become NaN")

        return values
コード例 #7
0
    def run(self, session, primers, length, temperature, hp=None):
        batch_size = len(primers)
        # process in segments to avoid tensorflow eating all the memory
        max_segment_length = min(10000, hp.segment_length)

        print "conditioning..."
        segment_length = min(max_segment_length,
                             max(len(primer[0]) for primer in primers))

        state = NS(model=self.model.initial_state(batch_size))
        for segment in util.segments(primers, segment_length,
                                     overlap=LEFTOVER):
            x, = util.examples_as_arrays(segment)
            feed_dict = {self.tensors.x: x.T}
            feed_dict.update(self.model.feed_dict(state.model))
            values = tfutil.run(session,
                                tensors=self.tensors.cond.Extract(
                                    "final_state.model final_xelt"),
                                feed_dict=feed_dict)
            state.model = values.final_state.model
            sys.stderr.write(".")
        sys.stderr.write("\n")

        cond_values = values

        print "sampling..."
        length_left = length + LEFTOVER
        xhats = []
        state = NS(model=cond_values.final_state.model,
                   initial_xelt=cond_values.final_xelt)
        while length_left > 0:
            segment_length = min(max_segment_length, length_left)
            length_left -= segment_length

            feed_dict = {
                self.tensors.initial_xelt: state.initial_xelt,
                self.tensors.length: segment_length,
                self.tensors.temperature: temperature
            }
            feed_dict.update(self.model.feed_dict(state.model))
            sample_values = tfutil.run(
                session,
                tensors=self.tensors.sample.Extract(
                    "final_state.model xhat final_xhatelt"),
                feed_dict=feed_dict),
            state.model = sample_values.final_state.model
            state.initial_xelt = sample_values.final_xhatelt

            xhats.append(sample_values.xhat)
            sys.stderr.write(".")
        sys.stderr.write("\n")

        xhat = np.concatenate(xhats, axis=0)
        return xhat.T
コード例 #8
0
ファイル: util_test.py プロジェクト: zhuhyc/wayback
    def test_segmented_batches(self):
        length = np.random.randint(2, 100)
        segment_length = np.random.randint(1, length)
        example_count = np.random.randint(2, 100)
        batch_size = np.random.randint(1, example_count)
        feature_shapes = [
            np.random.randint(1, 10, size=np.random.randint(1, 4))
            for _ in range(np.random.randint(1, 4))
        ]
        examples = [[
            np.random.rand(length, *shape) for shape in feature_shapes
        ] for _ in range(example_count)]

        for batch in util.batches(examples, batch_size, augment=False):
            for segment in util.segments(examples, segment_length):
                self.assertEqual(batch_size, len(batch))
                for features in segment:
                    self.assertEqual(len(feature_shapes), len(features))
                    for feature, feature_shape in util.equizip(
                            features, feature_shapes):
                        self.assertLessEqual(len(feature), segment_length)
                        self.assertEqual(tuple(feature_shape),
                                         feature.shape[1:])
コード例 #9
0
    def run(self, session, examples, max_step_count=None, hooks=None, hp=None):
        state = NS(global_step=tf.train.global_step(session,
                                                    self.tensors.global_step),
                   model=self.model.initial_state(hp.batch_size))
        while True:
            for batch in util.batches(examples, hp.batch_size):
                for segment in util.segments(
                        batch,
                        # the last chunk is not processed, so grab
                        # one more to ensure we backpropagate
                        # through at least one full model cycle.
                        # TODO(cotim): rename segment_length to
                        # backprop_length?
                        hp.segment_length + hp.chunk_size,
                        overlap=hp.chunk_size):
                    if max_step_count is not None and state.global_step >= max_step_count:
                        return

                    hooks.Get("step.before", util.noop)(state)
                    x, = util.examples_as_arrays(segment)
                    feed_dict = {self.tensors.x: x.T}
                    feed_dict.update(self.model.feed_dict(state.model))
                    values = NS.FlatCall(
                        ft.partial(session.run, feed_dict=feed_dict),
                        self.tensors.Extract(
                            "loss error summaries global_step training_op learning_rate final_state.model"
                        ))
                    state.model = values.final_state.model
                    state.global_step = values.global_step
                    hooks.Get("step.after", util.noop)(state, values)

                    print("step #%d loss %f error %f learning rate %e" %
                          (values.global_step, values.loss, values.error,
                           values.learning_rate))

                    if np.isnan(values.loss):
                        raise ValueError("loss has become NaN")
コード例 #10
0
ファイル: util_test.py プロジェクト: zhuhyc/wayback
    def test_segments_truncate(self):
        class ComparableNdarray(np.ndarray):
            """A Numpy ndarray that doesn't break equality.

      Numpy ndarray violates the __eq__ contract, which breaks deep
      comparisons. Work around it by wrapping the arrays.
      """
            def __eq__(self, other):
                return np.array_equal(self, other)

        def comparablearray(*args, **kwargs):
            array = np.array(*args, **kwargs)
            return ComparableNdarray(array.shape,
                                     buffer=array,
                                     dtype=array.dtype)

        def to_examples(segmented_examples):
            segments_by_example = [[[
                comparablearray(list(s.strip()), dtype="|S1")
            ] for s in e.split("|")] for e in segmented_examples]
            examples_by_segment = list(
                map(list, util.equizip(*segments_by_example)))
            return examples_by_segment

        examples = to_examples(["abcdefg", "mno", "vwxyz"])[0]

        self.assertEqual(
            to_examples(["ab|cd|ef|g ", "mn|o |  |  ", "vw|xy|z |  "]),
            list(util.segments(examples, 2, truncate=False)))
        self.assertEqual(to_examples(["ab", "mn", "vw"]),
                         list(util.segments(examples, 2, truncate=True)))
        self.assertEqual(
            to_examples(["abc|def|g  ", "mno|   |   ", "vwx|yz |   "]),
            list(util.segments(examples, 3, truncate=False)))
        self.assertEqual(to_examples(["abc", "mno", "vwx"]),
                         list(util.segments(examples, 3, truncate=True)))
        self.assertEqual(to_examples(["abcd|efg ", "mno |    ", "vwxy|z   "]),
                         list(util.segments(examples, 4, truncate=False)))
        self.assertEqual([], list(util.segments(examples, 4, truncate=True)))

        overlap = 1
        self.assertEqual(
            to_examples([
                "ab|bc|cd|de|ef|fg|g ", "mn|no|o |  |  |  |  ",
                "vw|wx|xy|yz|z |  |  "
            ]), list(util.segments(examples, 2, overlap, truncate=False)))
        self.assertEqual(
            to_examples(["ab|bc", "mn|no", "vw|wx"]),
            list(util.segments(examples, 2, overlap, truncate=True)))
        self.assertEqual(
            to_examples(
                ["abc|cde|efg|g  ", "mno|o  |   |   ", "vwx|xyz|z  |   "]),
            list(util.segments(examples, 3, overlap, truncate=False)))
        self.assertEqual(
            to_examples(["abc", "mno", "vwx"]),
            list(util.segments(examples, 3, overlap, truncate=True)))
        self.assertEqual(
            to_examples(["abcd|defg|g   ", "mno |    |    ",
                         "vwxy|yz  |    "]),
            list(util.segments(examples, 4, overlap, truncate=False)))
        self.assertEqual([],
                         list(
                             util.segments(examples, 4, overlap,
                                           truncate=True)))