def fit(self, xs, epochs=1, batch_size=None, max_steps=10**6): """Fits to sequences given as [N x length] token array.""" if batch_size is None: batch_size = self._batch_size if hasattr(xs, 'as_numpy_iterator'): # TF Dataset ds = xs.repeat(epochs) num_train_steps = max_steps elif hasattr(xs, 'element_spec'): # Dataset iterator. if epochs != 1: raise ValueError('Epochs must == 1 when using iterator input.') ds = xs num_train_steps = max_steps else: # Raw sequences which we turn into a dataset. ds = data.dataset_from_tensors(xs) ds = ds.shuffle(buffer_size=1024).repeat().batch(batch_size) num_train_steps = math.ceil((len(xs) * epochs) / float(batch_size)) if max_steps: num_train_steps = min(num_train_steps, max_steps) if not num_train_steps: raise ValueError('Must set max_steps to nonzero value.') metrics = [] start = time.time() max_steps = max_steps or 10**6 for _, batch in zip(range(num_train_steps), ds): metrics.append(self.fit_batch(batch)) finish = time.time() average = evaluation.combine_metrics(metrics) average['runtime'] = finish - start average['rate'] = len(metrics) / (finish - start) if self._store_metrics: average = tree.map_structure(onp.array, average) self._epoch_train.append(average) return dict(last=evaluation.combine_metrics([metrics[-1]]), average=average)
def fit(self, xs, ys=None, weights=None, epochs=1, batch_size=None, shuffle=True, max_steps=None, verbose=False): """Fits to sequences given as [N x length] token array.""" # TODO(ddohan): Use other kwargs. del shuffle del weights del verbose del ys if batch_size is None: batch_size = self._batch_size if hasattr(xs, 'as_numpy_iterator'): # TF Dataset ds = xs.repeat(epochs) num_train_steps = max_steps elif hasattr(xs, 'element_spec'): # Dataset iterator. if epochs != 1: raise ValueError('Epochs must == 1 when using iterator input.') ds = xs num_train_steps = max_steps else: # Raw sequences which we turn into a dataset. ds = data.dataset_from_tensors(xs) ds = ds.shuffle(buffer_size=1024).repeat().batch(batch_size) num_train_steps = math.ceil((len(xs) * epochs) / float(batch_size)) if max_steps: num_train_steps = min(num_train_steps, max_steps) if not num_train_steps: raise ValueError('Must set max_steps to nonzero value.') for _, batch in zip(range(num_train_steps), ds): batch = batch._numpy() # pylint: disable=protected-access metrics = self.fit_batch(batch)