Пример #1
0
    def save_checkpoint(self, checkpoint_path, distributed=False):
        for j in range(len(self.offsets) - 1):
            nrows = self.offsets[j + 1] - self.offsets[j]
            name = self.feature_names[j]

            chunks = math.ceil(nrows / _embedding_checkpoint_batch)
            for i in range(chunks):
                filename = get_variable_path(checkpoint_path, name, i)
                end = min((i + 1) * _embedding_checkpoint_batch, nrows)

                indices = tf.range(start=self.offsets[j] +
                                   i * _embedding_checkpoint_batch,
                                   limit=self.offsets[j] + end,
                                   dtype=tf.int32)

                arr = tf.gather(params=self.embedding_table,
                                indices=indices,
                                axis=0)
                arr = arr.numpy()

                if distributed:
                    arr = hvd.allgather_object(arr)
                    arr = np.concatenate(arr, axis=1)

                    if hvd.rank() != 0:
                        continue

                np.save(arr=arr, file=filename)
Пример #2
0
    def save_checkpoint(self, checkpoint_path, distributed=False):
        chunks = math.ceil(self.embedding_table.shape[0] /
                           _embedding_checkpoint_batch)
        for i in range(chunks):
            filename = get_variable_path(checkpoint_path, self.feature_name, i)
            end = min((i + 1) * _embedding_checkpoint_batch,
                      self.embedding_table.shape[0])

            indices = tf.range(start=i * _embedding_checkpoint_batch,
                               limit=end,
                               dtype=tf.int32)

            arr = tf.gather(params=self.embedding_table,
                            indices=indices,
                            axis=0)
            arr = arr.numpy()

            if distributed:
                arr = hvd.allgather_object(arr)
                arr = np.concatenate(arr, axis=1)

                if hvd.rank() != 0:
                    continue

            np.save(arr=arr, file=filename)
Пример #3
0
 def _save_mlp_checkpoint(checkpoint_path, layers, prefix):
     for i, layer in enumerate(layers):
         for varname in ['kernel', 'bias']:
             filename = get_variable_path(
                 checkpoint_path, name=f'{prefix}/layer_{i}/{varname}')
             print(f'saving: {varname} to {filename}')
             variable = layer.__dict__[varname]
             np.save(arr=variable.numpy(), file=filename)
Пример #4
0
    def _restore_mlp_checkpoint(checkpoint_path, layers, prefix):
        for i, layer in enumerate(layers):
            for varname in ['kernel', 'bias']:
                filename = get_variable_path(
                    checkpoint_path, name=f'{prefix}/layer_{i}/{varname}')
                print(f'loading: {varname} from {filename}')
                variable = layer.__dict__[varname]

                numpy_var = np.load(file=filename)
                variable.assign(numpy_var)
Пример #5
0
    def restore_checkpoint(self, checkpoint_path, distributed=False):
        chunks = math.ceil(self.embedding_table.shape[0] /
                           _embedding_checkpoint_batch)
        for i in range(chunks):
            filename = get_variable_path(checkpoint_path, self.feature_name, i)
            start = i * _embedding_checkpoint_batch
            numpy_arr = np.load(file=filename)

            if distributed:
                numpy_arr = np.split(
                    numpy_arr, axis=1,
                    indices_or_sections=hvd.size())[hvd.rank()]

            indices = tf.range(start=start,
                               limit=start + numpy_arr.shape[0],
                               dtype=tf.int32)
            update = tf.IndexedSlices(values=numpy_arr,
                                      indices=indices,
                                      dense_shape=self.embedding_table.shape)
            self.embedding_table.scatter_update(sparse_delta=update)