def create_openai_model(self, config, input_ids, token_type_ids, position_ids, mc_labels, lm_labels, mc_token_ids): model = OpenAIGPTModel(config) model.eval() hidden_states = model(input_ids, position_ids, token_type_ids) outputs = { "hidden_states": hidden_states, } return outputs
for layer in self_attention_layers ]) if len(output.shape) == 2: output = output.reshape(output.shape[0], -1, output.shape[1]) output = np.swapaxes(output, 0, 1) list_output.append(output) # ====== Construct Cache ====== # temp_cache = {} for i, sent in enumerate(mini_batch): hask_key = hashlib.sha256(sent.encode()).hexdigest() temp_cache[hask_key] = output[i] self.cache.update(temp_cache) idx += mini_batch_size self.count += mini_batch_size output = np.concatenate(list_output, 0) te = time.time() embedding = self.get_multi_head_embedding(output, heads, head_size) return embedding if __name__ == '__main__': model = OpenAIGPTModel('bert-base-uncased') model.prepare('Length') model.construct_encoder()