def update_metrics(self, batch: Dict[str, Any], output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], prediction: Dict[str, Any], metric: Union[MetricDict, Metric]): return CRFConstituencyParser.update_metrics(self, metric, batch, prediction)
def compute_loss( self, batch: Dict[str, Any], output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], criterion ) -> Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]: out, mask = output['output'], output['mask'] loss, span_probs = CRFConstituencyParser.compute_loss( self, out, batch['chart_id'], mask, crf_decoder=criterion) output['span_probs'] = span_probs return loss
def feed_batch(self, h: torch.FloatTensor, batch: Dict[str, torch.Tensor], mask: torch.BoolTensor, decoder: torch.nn.Module): return { 'output': decoder(h), 'mask': CRFConstituencyParser.compute_mask( self, batch, offset=1 if 'constituency' in batch or batch['token'][0][-1] == EOS else -1) }
def decode_output(self, output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], mask: torch.BoolTensor, batch: Dict[str, Any], decoder: torch.nn.Module, **kwargs) -> Union[Dict[str, Any], Any]: out, mask = output['output'], output['mask'] tokens = [] for sent in batch['token']: if sent[0] == BOS: sent = sent[1:] if sent[-1] == EOS: sent = sent[:-1] tokens.append(sent) return CRFConstituencyParser.decode_output(self, out, mask, batch, output.get( 'span_probs', None), decoder=decoder, tokens=tokens)
def build_samples(self, inputs): return CRFConstituencyParser.build_samples(self, inputs)
def input_is_flat(self, data) -> bool: return CRFConstituencyParser.input_is_flat(self, data)
def build_metric(self, **kwargs): return CRFConstituencyParser.build_metric(self)
def build_samples(self, inputs, cls_is_bos=False, sep_is_eos=False): return CRFConstituencyParser.build_samples(self, inputs)