def predict(self,
             data: Union[str, List[str]],
             batch_size: int = None,
             **kwargs):
     if not data:
         return []
     flat = self.input_is_flat(data)
     if flat:
         data = [data]
     samples = self.build_samples(data)
     dataloader = self.build_dataloader(samples,
                                        device=self.device,
                                        **merge_dict(self.config,
                                                     batch_size=batch_size,
                                                     overwrite=True))
     outputs = []
     orders = []
     for idx, batch in enumerate(dataloader):
         out, mask = self.feed_batch(batch)
         prediction = self.decode_output(out, mask, batch, span_probs=None)
         # prediction = [x[0] for x in prediction]
         outputs.extend(prediction)
         orders.extend(batch[IDX])
     outputs = reorder(outputs, orders)
     if flat:
         return outputs[0]
     return outputs
Beispiel #2
0
 def predict(self, data: Any, batch_size=None, batch_max_tokens=None, output_format='conllx', **kwargs):
     if not data:
         return []
     use_pos = self.use_pos
     flat = self.input_is_flat(data, use_pos)
     if flat:
         data = [data]
     samples = self.build_samples(data, use_pos)
     if not batch_max_tokens:
         batch_max_tokens = self.config.batch_max_tokens
     if not batch_size:
         batch_size = self.config.batch_size
     dataloader = self.build_dataloader(samples,
                                        device=self.devices[0], shuffle=False,
                                        **merge_dict(self.config,
                                                     batch_size=batch_size,
                                                     batch_max_tokens=batch_max_tokens,
                                                     overwrite=True,
                                                     **kwargs))
     predictions, build_data, data, order = self.before_outputs(data)
     for batch in dataloader:
         arc_scores, rel_scores, mask, puncts = self.feed_batch(batch)
         self.collect_outputs(arc_scores, rel_scores, mask, batch, predictions, order, data, use_pos,
                              build_data)
     outputs = self.post_outputs(predictions, data, order, use_pos, build_data)
     if flat:
         return outputs[0]
     return outputs
Beispiel #3
0
 def build_model(self, training=True, **kwargs) -> torch.nn.Module:
     transformer = self.config.encoder.module()
     model = GraphAbstractMeaningRepresentationModel(
         self.vocabs,
         **merge_dict(self.config, overwrite=True, encoder=transformer),
         tokenizer=self.config.encoder.transform())
     return model
Beispiel #4
0
 def __call__(self, data, batch_size=None, **kwargs):
     return super().__call__(
         data,
         **merge_dict(self.config,
                      overwrite=True,
                      batch_size=batch_size
                      or self.config.get('batch_size', None),
                      **kwargs))
Beispiel #5
0
 def load(self, save_dir: str, devices=None, **kwargs):
     save_dir = get_resource(save_dir)
     # flash('Loading config and vocabs [blink][yellow]...[/yellow][/blink]')
     if devices is None and self.model:
         devices = self.devices
     self.load_config(save_dir, **kwargs)
     self.load_vocabs(save_dir)
     flash('Building model [blink][yellow]...[/yellow][/blink]')
     self.model = self.build_model(**merge_dict(self.config,
                                                training=False,
                                                **kwargs,
                                                overwrite=True,
                                                inplace=True))
     flash('')
     self.load_weights(save_dir, **kwargs)
     self.to(devices)
     self.model.eval()
Beispiel #6
0
    def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs, criterion,
                              optimizer,
                              metric,
                              save_dir,
                              logger,
                              patience,
                              **kwargs):
        max_e, max_metric = 0, -1

        criterion = self.build_criterion()
        timer = CountdownTimer(epochs)
        ratio_width = len(f'{len(trn)}/{len(trn)}')
        scheduler = self.build_scheduler(**merge_dict(self.config, optimizer=optimizer, overwrite=True))
        if not patience:
            patience = epochs
        for epoch in range(1, epochs + 1):
            logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
            self.fit_dataloader(trn, criterion, optimizer, metric, logger, ratio_width=ratio_width)
            loss, dev_metric = self.evaluate_dataloader(dev, criterion, logger)
            if scheduler:
                if isinstance(scheduler, ReduceLROnPlateau):
                    scheduler.step(dev_metric.score)
                else:
                    scheduler.step(epoch)
            report_patience = f'Patience: {epoch - max_e}/{patience}'
            # save the model if it is the best so far
            if dev_metric > max_metric:
                self.save_weights(save_dir)
                max_e, max_metric = epoch, dev_metric
                report_patience = '[red]Saved[/red] '
            stop = epoch - max_e >= patience
            if stop:
                timer.stop()
            timer.log(f'{report_patience} lr: {optimizer.param_groups[0]["lr"]:.4f}',
                      ratio_percentage=False, newline=True, ratio=False)
            if stop:
                break
        timer.stop()
        if max_e != epoch:
            self.load_weights(save_dir)
        logger.info(f"Max score of dev is {max_metric.score:.2%} at epoch {max_e}")
        logger.info(f"{timer.elapsed_human} elapsed, average time of each epoch is {timer.elapsed_average_human}")
Beispiel #7
0
 def predict(self,
             data: Union[str, List[str]],
             batch_size: int = None,
             **kwargs):
     if not data:
         return []
     flat = isinstance(data, str) or isinstance(data, tuple)
     if flat:
         data = [data]
     samples = []
     for idx, d in enumerate(data):
         sample = {IDX: idx}
         if self.config.text_b_key:
             sample[self.config.text_a_key] = d[0]
             sample[self.config.text_b_key] = d[1]
         else:
             sample[self.config.text_a_key] = d
         samples.append(sample)
     dataloader = self.build_dataloader(samples,
                                        sorting=False,
                                        **merge_dict(self.config,
                                                     batch_size=batch_size,
                                                     shuffle=False,
                                                     device=self.device,
                                                     overwrite=True))
     labels = [None] * len(data)
     vocab = self.vocabs.label
     for batch in dataloader:
         logits = self.feed_batch(batch)
         pred = logits.argmax(-1)
         pred = pred.tolist()
         for idx, tag in zip(batch[IDX], pred):
             labels[idx] = vocab.idx_to_token[tag]
     if flat:
         return labels[0]
     return labels
Beispiel #8
0
 def predict(self,
             data: Union[str, List[str]],
             batch_size: int = None,
             **kwargs):
     if not data:
         return []
     flat = self.input_is_flat(data)
     if flat:
         data = [data]
     samples = self.build_samples(data)
     dataloader = self.build_dataloader(samples,
                                        device=self.device,
                                        **merge_dict(self.config,
                                                     overwrite=True,
                                                     batch_size=batch_size))
     pp = PostProcessor(self.vocabs['rel'])
     results = list(parse_data_(self.model, pp, dataloader))
     for i, each in enumerate(results):
         amr_graph = AMRGraph(each)
         self.sense_restore.restore_graph(amr_graph)
         results[i] = amr_graph
     if flat:
         return results[0]
     return results
Beispiel #9
0
    def evaluate(self,
                 tst_data,
                 save_dir=None,
                 logger: logging.Logger = None,
                 batch_size=None,
                 output=False,
                 **kwargs):
        if not self.model:
            raise RuntimeError('Call fit or load before evaluate.')
        if isinstance(tst_data, str):
            tst_data = get_resource(tst_data)
            filename = os.path.basename(tst_data)
        else:
            filename = None
        if output is True:
            output = self.generate_prediction_filename(
                tst_data if isinstance(tst_data, str) else 'test.txt',
                save_dir)
        if logger is None:
            _logger_name = basename_no_ext(filename) if filename else None
            logger = self.build_logger(_logger_name, save_dir)
        if not batch_size:
            batch_size = self.config.get('batch_size', 32)
        data = self.build_dataloader(**merge_dict(self.config,
                                                  data=tst_data,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  device=self.devices[0],
                                                  logger=logger,
                                                  overwrite=True))
        dataset = data
        while dataset and hasattr(dataset, 'dataset'):
            dataset = dataset.dataset
        num_samples = len(dataset) if dataset else None
        if output and isinstance(dataset, TransformDataset):

            def add_idx(samples):
                for idx, sample in enumerate(samples):
                    if sample:
                        sample[IDX] = idx

            add_idx(dataset.data)
            if dataset.cache:
                add_idx(dataset.cache)

        criterion = self.build_criterion(**self.config)
        metric = self.build_metric(**self.config)
        start = time.time()
        outputs = self.evaluate_dataloader(data,
                                           criterion=criterion,
                                           filename=filename,
                                           output=output,
                                           input=tst_data,
                                           save_dir=save_dir,
                                           test=True,
                                           num_samples=num_samples,
                                           **merge_dict(self.config,
                                                        batch_size=batch_size,
                                                        metric=metric,
                                                        logger=logger,
                                                        **kwargs))
        elapsed = time.time() - start
        if logger:
            if num_samples:
                logger.info(
                    f'speed: {num_samples / elapsed:.0f} samples/second')
            else:
                logger.info(f'speed: {len(data) / elapsed:.0f} batches/second')
        return metric, outputs
Beispiel #10
0
 def fit(self,
         trn_data,
         dev_data,
         save_dir,
         batch_size,
         epochs,
         devices=None,
         logger=None,
         seed=None,
         finetune=False,
         eval_trn=True,
         _device_placeholder=False,
         **kwargs):
     # Common initialization steps
     config = self._capture_config(locals())
     if not logger:
         logger = self.build_logger('train', save_dir)
     if not seed:
         self.config.seed = 233 if isdebugging() else int(time.time())
     set_seed(self.config.seed)
     logger.info(self._savable_config.to_json(sort=True))
     if isinstance(devices, list) or devices is None or isinstance(
             devices, float):
         flash('[yellow]Querying CUDA devices [blink]...[/blink][/yellow]')
         devices = -1 if isdebugging() else cuda_devices(devices)
         flash('')
     # flash(f'Available GPUs: {devices}')
     if isinstance(devices, list):
         first_device = (devices[0] if devices else -1)
     elif isinstance(devices, dict):
         first_device = next(iter(devices.values()))
     elif isinstance(devices, int):
         first_device = devices
     else:
         first_device = -1
     if _device_placeholder and first_device >= 0:
         _dummy_placeholder = self._create_dummy_placeholder_on(
             first_device)
     if finetune:
         if isinstance(finetune, str):
             self.load(finetune, devices=devices)
         else:
             self.load(save_dir, devices=devices)
         logger.info(
             f'Finetune model loaded with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}'
             f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.'
         )
     self.on_config_ready(**self.config)
     trn = self.build_dataloader(**merge_dict(config,
                                              data=trn_data,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              training=True,
                                              device=first_device,
                                              logger=logger,
                                              vocabs=self.vocabs,
                                              overwrite=True))
     dev = self.build_dataloader(
         **merge_dict(config,
                      data=dev_data,
                      batch_size=batch_size,
                      shuffle=False,
                      training=None,
                      device=first_device,
                      logger=logger,
                      vocabs=self.vocabs,
                      overwrite=True)) if dev_data else None
     if not finetune:
         flash('[yellow]Building model [blink]...[/blink][/yellow]')
         self.model = self.build_model(**merge_dict(config, training=True))
         flash('')
         logger.info(
             f'Model built with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}'
             f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.'
         )
         assert self.model, 'build_model is not properly implemented.'
     _description = repr(self.model)
     if len(_description.split('\n')) < 10:
         logger.info(_description)
     self.save_config(save_dir)
     self.save_vocabs(save_dir)
     self.to(devices, logger)
     if _device_placeholder and first_device >= 0:
         del _dummy_placeholder
     criterion = self.build_criterion(**merge_dict(config, trn=trn))
     optimizer = self.build_optimizer(
         **merge_dict(config, trn=trn, criterion=criterion))
     metric = self.build_metric(**self.config)
     if hasattr(trn.dataset, '__len__') and dev and hasattr(
             dev.dataset, '__len__'):
         logger.info(
             f'{len(trn.dataset)}/{len(dev.dataset)} samples in trn/dev set.'
         )
         trn_size = len(trn) // self.config.get('gradient_accumulation', 1)
         ratio_width = len(f'{trn_size}/{trn_size}')
     else:
         ratio_width = None
     return self.execute_training_loop(**merge_dict(config,
                                                    trn=trn,
                                                    dev=dev,
                                                    epochs=epochs,
                                                    criterion=criterion,
                                                    optimizer=optimizer,
                                                    metric=metric,
                                                    logger=logger,
                                                    save_dir=save_dir,
                                                    devices=devices,
                                                    ratio_width=ratio_width,
                                                    trn_data=trn_data,
                                                    dev_data=dev_data,
                                                    eval_trn=eval_trn,
                                                    overwrite=True))