def launch(task_type: TaskType, profiler: TextProfiler, dataset: Dataset[Union[LabeledTextInstance, ScoredTextInstance]]): if task_type == TaskType.CLASSIFICATION: if not isinstance(dataset.instances[0], LabeledTextInstance): raise ValueError("Inconsistent type between Instance and TaskType") examples, labels = _setup_classification_dataset(dataset) models = { "text_classifier": LITModelForTextClassifier(profiler, labels) } lit_datasets = { "classification_dataset": TextClassificationLITDataset(examples, labels) } lit_demo = dev_server.Server(models, lit_datasets, **server_flags.get_flags()) lit_demo.serve() elif task_type == TaskType.REGRESSION: if not isinstance(dataset.instances[0], ScoredTextInstance): raise ValueError("Inconsistent type between Instance and TaskType") examples = _setup_regression_dataset(dataset) models = {"text_regressor": LITModelForTextRegressor(profiler)} lit_datasets = { "regression_dataset": TextRegressionLITDataset(examples) } lit_demo = dev_server.Server(models, lit_datasets, **server_flags.get_flags()) lit_demo.serve() else: raise ValueError( f"Unsupported task({task_type}) for launching LIT server")
def main(_): demo_layout = lit_dtypes.LitComponentLayout( components={ 'Main': [ 'data-table-module', 'datapoint-editor-module', 'lit-slice-module', 'color-module', ], 'Predictions': ['classification-module', 'scalar-module'], 'Explanations': ['classification-module', 'salience-map-module'], }, description='Basic layout for image demo', ) datasets = {'imagenette': imagenette.ImagenetteDataset()} models = {'mobilenet': mobilenet.MobileNet()} interpreters = { 'Grad': image_gradient_maps.VanillaGradients(), 'Integrated Gradients': image_gradient_maps.IntegratedGradients(), 'Blur IG': image_gradient_maps.BlurIG(), 'Guided IG': image_gradient_maps.GuidedIG(), 'XRAI': image_gradient_maps.XRAI(), 'XRAI GIG': image_gradient_maps.XRAIGIG(), } lit_demo = dev_server.Server(models, datasets, interpreters=interpreters, generators={}, layouts={'demo_layout': demo_layout}, **server_flags.get_flags()) return lit_demo.serve()
def main(): model = SentimentClassifierModel() models = {"sst": model} datasets = {"sst": SSTData(labels=model.labels)} lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags()) lit_demo.serve()
def main(_): models = {"gector": GectorBertModel('bert_0_gector.th')} datasets = {"test_data": GeceProdigyData('test_sample.jsonl')} # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags()) lit_demo.serve()
def run_server(load_path: str): """Run a LIT server with the trained coreference model.""" # Normally path is a directory; if it's an archive file, download and # extract to the transformers cache. if load_path.endswith(".tar.gz"): load_path = transformers.file_utils.cached_path( load_path, extract_compressed_file=True) # Load model from disk. full_model = model.FrozenEncoderCoref.from_saved( load_path, encoder_cls=encoders.BertEncoderWithOffsets, classifier_cls=edge_predictor.SingleEdgePredictor) # Set up the LIT server. models = {"model": full_model} datasets = {"winogender": winogender.WinogenderDataset()} if FLAGS.ontonotes_edgeprobe_path: datasets["ontonotes_dev"] = ontonotes.OntonotesCorefDataset( os.path.join(FLAGS.ontonotes_edgeprobe_path, "development.json")) # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server(models, datasets, layouts=CUSTOM_LAYOUTS, **server_flags.get_flags()) return lit_demo.serve()
def main(_): # Normally path is a directory; if it's an archive file, download and # extract to the transformers cache. model_path = FLAGS.model_path if model_path.endswith(".tar.gz"): model_path = transformers.file_utils.cached_path( model_path, extract_compressed_file=True) models = { "nli": glue_models.MNLIModel(model_path, inference_batch_size=16) } datasets = { "xnli": classification.XNLIData("validation", FLAGS.languages), "mnli_dev": glue.MNLIData("validation_matched"), "mnli_dev_mm": glue.MNLIData("validation_mismatched"), } # Truncate datasets if --max_examples is set. for name in datasets: logging.info("Dataset: '%s' with %d examples", name, len(datasets[name])) datasets[name] = datasets[name].slice[:FLAGS.max_examples] logging.info(" truncated to %d examples", len(datasets[name])) # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags()) lit_demo.serve()
def main(_): models = {} datasets = {} if "sst2" in FLAGS.tasks: models["sst2"] = glue_models.SST2Model( os.path.join(FLAGS.models_path, "sst2")) datasets["sst_dev"] = glue.SST2Data("validation") logging.info("Loaded models and data for SST-2 task.") if "stsb" in FLAGS.tasks: models["stsb"] = glue_models.STSBModel( os.path.join(FLAGS.models_path, "stsb")) datasets["stsb_dev"] = glue.STSBData("validation") logging.info("Loaded models and data for STS-B task.") if "mnli" in FLAGS.tasks: models["mnli"] = glue_models.MNLIModel( os.path.join(FLAGS.models_path, "mnli")) datasets["mnli_dev"] = glue.MNLIData("validation_matched") datasets["mnli_dev_mm"] = glue.MNLIData("validation_mismatched") logging.info("Loaded models and data for MultiNLI task.") # Truncate datasets if --max_examples is set. for name in datasets: logging.info("Dataset: '%s' with %d examples", name, len(datasets[name])) datasets[name] = datasets[name].slice[:FLAGS.max_examples] logging.info(" truncated to %d examples", len(datasets[name])) # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags()) lit_demo.serve()
def __init__(self, *args, height=1000, render=False, proxy_url=None, **kw): """Start LIT server and optionally render the UI immediately. Args: *args: Positional arguments for the LitApp. height: Height to display the LIT UI in pixels. Defaults to 1000. render: Whether to render the UI when this object is constructed. Defaults to False. proxy_url: Optional proxy URL, if using in a notebook with a server proxy. Defaults to None. **kw: Keyword arguments for the LitApp. """ app_flags = server_flags.get_flags() app_flags['server_type'] = 'notebook' app_flags['host'] = 'localhost' app_flags['port'] = None app_flags.update(kw) lit_demo = dev_server.Server(*args, **app_flags) self._server = typing.cast(wsgi_serving.NotebookWsgiServer, lit_demo.serve()) self._height = height self._proxy_url = proxy_url if render: self.render()
def main(_): # Load the model we defined above. models = {"sst": SimpleSentimentModel(FLAGS.model_path)} # Load SST-2 validation set from TFDS. datasets = {"sst_dev": glue.SST2Data("validation")} # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags()) lit_demo.serve()
def main() -> None: parser = argparse.ArgumentParser( description="Analyzes an NMT model using LIT") parser.add_argument("experiment", help="Experiment name") parser.add_argument("--memory-growth", default=False, action="store_true", help="Enable memory growth") parser.add_argument( "--eager-execution", default=False, action="store_true", help="Enable TensorFlow eager execution.", ) parser.add_argument("--checkpoint", type=str, help="Analyze checkpoint") parser.add_argument("--last", default=False, action="store_true", help="Analyze last checkpoint") parser.add_argument("--best", default=False, action="store_true", help="Analyze best evaluated checkpoint") parser.add_argument("--avg", default=False, action="store_true", help="Analyze averaged checkpoint") args = parser.parse_args() get_git_revision_hash() set_tf_log_level() if args.eager_execution: tf.config.run_functions_eagerly(True) if args.memory_growth: enable_memory_growth() checkpoints: Set[str] = set() if args.avg: checkpoints.add("avg") if args.checkpoint is not None: checkpoints.add(args.checkpoint) if args.last: checkpoints.add("last") if args.best: checkpoints.add("best") if len(checkpoints) == 0: checkpoints.add("last") server_args, server_kw = create_lit_args(args.experiment, checkpoints) lit_server = dev_server.Server(*server_args, **server_kw, **server_flags.get_flags()) lit_server.serve()
def main(_): # Load the model we defined above. models = { "NCBI BERT Finetuned": NerModel(FLAGS.model_path, labels_file=FLAGS.labels) } datasets = {"I2b2 2014": I2b2Dataset(data_dir=FLAGS.test_data_dir)} # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags()) lit_demo.serve()
def main(_): # Quick-start mode. if FLAGS.quickstart: FLAGS.models = QUICK_START_MODELS # smaller, faster models if FLAGS.max_examples is None or FLAGS.max_examples > 1000: FLAGS.max_examples = 1000 # truncate larger eval sets logging.info( "Quick-start mode; overriding --models and --max_examples.") models = {} datasets = {} tasks_to_load = set() for model_string in FLAGS.models: # Only split on the first two ':', because path may be a URL # containing 'https://' name, task, path = model_string.split(":", 2) logging.info("Loading model '%s' for task '%s' from '%s'", name, task, path) # Normally path is a directory; if it's an archive file, download and # extract to the transformers cache. if path.endswith(".tar.gz"): path = transformers.file_utils.cached_path( path, extract_compressed_file=True) # Load the model from disk. models[name] = MODELS_BY_TASK[task](path) tasks_to_load.add(task) ## # Load datasets for each task that we have a model for if "sst2" in tasks_to_load: logging.info("Loading data for SST-2 task.") datasets["sst_dev"] = glue.SST2Data("validation") if "stsb" in tasks_to_load: logging.info("Loading data for STS-B task.") datasets["stsb_dev"] = glue.STSBData("validation") if "mnli" in tasks_to_load: logging.info("Loading data for MultiNLI task.") datasets["mnli_dev"] = glue.MNLIData("validation_matched") datasets["mnli_dev_mm"] = glue.MNLIData("validation_mismatched") # Truncate datasets if --max_examples is set. for name in datasets: logging.info("Dataset: '%s' with %d examples", name, len(datasets[name])) datasets[name] = datasets[name].slice[:FLAGS.max_examples] logging.info(" truncated to %d examples", len(datasets[name])) # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags()) return lit_demo.serve()
def main(_): model_path = FLAGS.model_path or tempfile.mkdtemp() logging.info("Working directory: %s", model_path) run_finetuning(model_path) # Load our trained model. models = {"sst": glue_models.SST2Model(model_path)} datasets = {"sst_dev": glue.SST2Data("validation")} # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags()) lit_demo.serve()
def main(_): model_path = FLAGS.model_path models = {'species classifier': penguin_model.PenguinModel(model_path)} datasets = {'penguins': penguin_data.PenguinDataset()} generators = { 'Minimal Targeted Counterfactuals': minimal_targeted_counterfactuals.TabularMTC() } lit_demo = dev_server.Server(models, datasets, generators=generators, **server_flags.get_flags()) return lit_demo.serve()
def main(_): ## # Load models, according to the --models flag. models = {} for model_name_or_path in FLAGS.models: # Ignore path prefix, if using /path/to/<model_name> to load from a # specific directory rather than the default shortcut. model_name = os.path.basename(model_name_or_path) if model_name.startswith("bert-"): models[model_name] = pretrained_lms.BertMLM( model_name_or_path, top_k=FLAGS.top_k) elif model_name.startswith("gpt2") or model_name in ["distilgpt2"]: models[model_name] = pretrained_lms.GPT2LanguageModel( model_name_or_path, top_k=FLAGS.top_k) else: raise ValueError( f"Unsupported model name '{model_name}' from path '{model_name_or_path}'" ) datasets = { # Single sentences from movie reviews (SST dev set). "sst_dev": glue.SST2Data("validation").remap({"sentence": "text"}), # Longer passages from movie reviews (IMDB dataset, test split). "imdb_train": classification.IMDBData("test"), # Empty dataset, if you just want to type sentences into the UI. "blank": lm.PlaintextSents(""), } # Guard this with a flag, because TFDS will download and process 1.67 GB # of data if you haven't loaded `lm1b` before. if FLAGS.load_bwb: # A few sentences from the Billion Word Benchmark (Chelba et al. 2013). datasets["bwb"] = lm.BillionWordBenchmark( "train", max_examples=FLAGS.max_examples) for name in datasets: datasets[name] = datasets[name].slice[:FLAGS.max_examples] logging.info("Dataset: '%s' with %d examples", name, len(datasets[name])) generators = {"word_replacer": word_replacer.WordReplacer()} lit_demo = dev_server.Server( models, datasets, generators=generators, layouts=CUSTOM_LAYOUTS, **server_flags.get_flags()) return lit_demo.serve()
def main(_): # Normally path is a directory; if it's an archive file, download and # extract to the transformers cache. model_path = FLAGS.model_path if model_path.endswith(".tar.gz"): model_path = transformers.file_utils.cached_path( model_path, extract_compressed_file=True) # Load the model we defined above. models = {"sst": SimpleSentimentModel(model_path)} # Load SST-2 validation set from TFDS. datasets = {"sst_dev": glue.SST2Data("validation")} # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags()) lit_demo.serve()
def main(_): model_path = FLAGS.model_path logging.info("Working directory: %s", model_path) # Load our trained model. models = {"toxicity": glue_models.ToxicityModel(model_path)} datasets = {"toxicity_test": classification.ToxicityData("test")} # Truncate datasets if --max_examples is set. for name in datasets: logging.info("Dataset: '%s' with %d examples", name, len(datasets[name])) datasets[name] = datasets[name].slice[:FLAGS.max_examples] logging.info(" truncated to %d examples", len(datasets[name])) # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags()) return lit_demo.serve()
def main(_): models = { "nli": glue_models.MNLIModel(FLAGS.model_path, inference_batch_size=16) } datasets = { "xnli": classification.XNLIData("validation", FLAGS.languages), "mnli_dev": glue.MNLIData("validation_matched"), "mnli_dev_mm": glue.MNLIData("validation_mismatched"), } # Truncate datasets if --max_examples is set. for name in datasets: logging.info("Dataset: '%s' with %d examples", name, len(datasets[name])) datasets[name] = datasets[name].slice[:FLAGS.max_examples] logging.info(" truncated to %d examples", len(datasets[name])) # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags()) lit_demo.serve()
def main(_): # Load the model we defined above. models = {"sst": SimpleSentimentModel(FLAGS.model_path)} # Load SST-2 validation set from TFDS. #datasets = {"sst_dev": glue.SST2Data("validation")} ## RATING #path = "/Users/lucialarraona/Desktop/Codigo TFG ordenado/01. Data/aws_data_raw/dataset_en_test.json" #datasets = {"amz_rev": AMAZONes_5cat(path)} ## CATEGORY AMAZON #path = "/Users/lucialarraona/Desktop/Codigo TFG ordenado/01. Data/aws_data_final/test/amazonEN_test.csv" #datasets = {"amz_rev": AMAZONes_31cat(path)} ## CATEGORY CHATBOT path = "/Users/lucialarraona/Desktop/Codigo TFG ordenado/01. Data/chatbot_empresa/dato_chatbot_procesado/chatbot_testEN.csv" datasets = {"amz_rev": CHATBOTes_30cat(path)} # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags()) lit_demo.serve()
def __init__(self, *args, height=1000, render=False, proxy_url=None, layouts: Optional[dtypes.LitComponentLayouts] = None, **kw): """Start LIT server and optionally render the UI immediately. Args: *args: Positional arguments for the LitApp. height: Height to display the LIT UI in pixels. Defaults to 1000. render: Whether to render the UI when this object is constructed. Defaults to False. proxy_url: Optional proxy URL, if using in a notebook with a server proxy. Defaults to None. layouts: Optional custom UI layouts. TODO(lit-dev): support simple module lists here as well. **kw: Keyword arguments for the LitApp. """ app_flags = server_flags.get_flags() app_flags['server_type'] = 'notebook' app_flags['host'] = 'localhost' app_flags['port'] = None app_flags['warm_start'] = 1 layouts = dict(layouts or {}) if 'notebook' not in layouts: layouts['notebook'] = LIT_NOTEBOOK_LAYOUT # This will be 'notebook' unless custom layouts are also given in Python. app_flags['default_layout'] = list(layouts.keys())[0] app_flags.update(kw) lit_demo = dev_server.Server(*args, layouts=layouts, **app_flags) self._server = cast(wsgi_serving.NotebookWsgiServer, lit_demo.serve()) self._height = height self._proxy_url = proxy_url if render: self.render()
def main(_): ## # Load models. You can specify several here, if you want to compare different # models side-by-side, and can also include models of different types that use # different datasets. base_models = {} for model_name_or_path in FLAGS.models: # Ignore path prefix, if using /path/to/<model_name> to load from a # specific directory rather than the default shortcut. model_name = os.path.basename(model_name_or_path) if model_name_or_path.startswith("SavedModel"): saved_model_path = model_name_or_path.split(":", 1)[1] base_models[model_name] = t5.T5SavedModel(saved_model_path) else: # TODO(lit-dev): attention is temporarily disabled, because O(n^2) between # tokens in a long document can get very, very large. Re-enable once we # can send this to the frontend more efficiently. base_models[model_name] = t5.T5HFModel( model_name=model_name_or_path, num_to_generate=FLAGS.num_to_generate, token_top_k=FLAGS.token_top_k, output_attention=False) ## # Load eval sets and model wrappers for each task. # Model wrappers share the same in-memory T5 model, but add task-specific pre- # and post-processing code. models = {} datasets = {} if "summarization" in FLAGS.tasks: for k, m in base_models.items(): models[k + "_summarization"] = t5.SummarizationWrapper(m) datasets["CNNDM"] = summarization.CNNDMData( split="validation", max_examples=FLAGS.max_examples) if "mt" in FLAGS.tasks: for k, m in base_models.items(): models[k + "_translation"] = t5.TranslationWrapper(m) datasets["wmt14_enfr"] = mt.WMT14Data(version="fr-en", reverse=True) datasets["wmt14_ende"] = mt.WMT14Data(version="de-en", reverse=True) # Truncate datasets if --max_examples is set. for name in datasets: logging.info("Dataset: '%s' with %d examples", name, len(datasets[name])) datasets[name] = datasets[name].slice[:FLAGS.max_examples] logging.info(" truncated to %d examples", len(datasets[name])) ## # We can also add custom components. Generators are used to create new # examples by perturbing or modifying existing ones. generators = { # Word-substitution, like "great" -> "terrible" "word_replacer": word_replacer.WordReplacer(), } if FLAGS.use_indexer: indexer = build_indexer(models) # Wrap the indexer into a Generator component that we can query. generators[ "similarity_searcher"] = similarity_searcher.SimilaritySearcher( indexer=indexer) ## # Actually start the LIT server, using the models, datasets, and other # components constructed above. lit_demo = dev_server.Server(models, datasets, generators=generators, **server_flags.get_flags()) return lit_demo.serve()
def main(_): ## # Load models. You can specify several here, if you want to compare different # models side-by-side, and can also include models of different types that use # different datasets. models = {} for model_name_or_path in FLAGS.models: # Ignore path prefix, if using /path/to/<model_name> to load from a # specific directory rather than the default shortcut. model_name = os.path.basename(model_name_or_path) # TODO(lit-dev): attention is temporarily disabled, because O(n^2) between # tokens in a long document can get very, very large. Re-enable once we can # send this to the frontend more efficiently. models[model_name] = t5.T5GenerationModel( model_name=model_name_or_path, input_prefix="summarize: ", top_k=FLAGS.top_k, output_attention=False) ## # Load datasets. Typically you"ll have the constructor actually read the # examples and do any pre-processing, so that they"re in memory and ready to # send to the frontend when you open the web UI. datasets = { "CNNDM": summarization.CNNDMData( split="validation", max_examples=FLAGS.max_examples), } for name, ds in datasets.items(): logging.info("Dataset: '%s' with %d examples", name, len(ds)) ## # We can also add custom components. Generators are used to create new # examples by perturbing or modifying existing ones. generators = { # Word-substitution, like "great" -> "terrible" "word_replacer": word_replacer.WordReplacer(), } if FLAGS.use_indexer: assert FLAGS.data_dir, "--data_dir must be set to use the indexer." # Datasets for indexer - this one loads the training corpus instead of val. index_datasets = { "CNNDM": summarization.CNNDMData( split="train", max_examples=FLAGS.max_index_examples), } # Set up the Indexer, building index if necessary (this may be slow). indexer = index.Indexer( datasets=index_datasets, models=models, data_dir=FLAGS.data_dir, initialize_new_indices=FLAGS.initialize_index) # Wrap the indexer into a Generator component that we can query. generators["similarity_searcher"] = similarity_searcher.SimilaritySearcher( indexer=indexer) ## # Actually start the LIT server, using the models, datasets, and other # components constructed above. lit_demo = dev_server.Server( models, datasets, generators=generators, **server_flags.get_flags()) lit_demo.serve()