def __init__( self, task, silence_token, asg_transitions_init, max_replabel, linseg_updates, hide_linseg_messages, ): from wav2letter.criterion import ASGLoss, CriterionScaleMode super().__init__(task) self.tgt_dict = task.target_dictionary self.eos = self.tgt_dict.eos() self.silence = (self.tgt_dict.index(silence_token) if silence_token in self.tgt_dict else None) self.max_replabel = max_replabel num_labels = len(self.tgt_dict) self.asg = ASGLoss(num_labels, scale_mode=CriterionScaleMode.TARGET_SZ_SQRT) self.asg.trans = torch.nn.Parameter(asg_transitions_init * torch.eye(num_labels), requires_grad=True) self.linseg_progress = torch.nn.Parameter(torch.tensor( [0], dtype=torch.int), requires_grad=False) self.linseg_maximum = linseg_updates self.linseg_message_state = "none" if hide_linseg_messages else "start"
class ASGCriterion(FairseqCriterion): @staticmethod def add_args(parser): group = parser.add_argument_group("ASG Loss") group.add_argument( "--asg-transitions-init", help="initial diagonal value of transition matrix", type=float, default=0.0, ) group.add_argument("--max-replabel", help="maximum # of replabels", type=int, default=2) group.add_argument( "--linseg-updates", help="# of training updates to use LinSeg initialization", type=int, default=0, ) group.add_argument( "--hide-linseg-messages", help="hide messages about LinSeg initialization", action="store_true", ) def __init__( self, task, silence_token, asg_transitions_init, max_replabel, linseg_updates, hide_linseg_messages, ): from wav2letter.criterion import ASGLoss, CriterionScaleMode super().__init__(task) self.tgt_dict = task.target_dictionary self.eos = self.tgt_dict.eos() self.silence = (self.tgt_dict.index(silence_token) if silence_token in self.tgt_dict else None) self.max_replabel = max_replabel num_labels = len(self.tgt_dict) self.asg = ASGLoss(num_labels, scale_mode=CriterionScaleMode.TARGET_SZ_SQRT) self.asg.trans = torch.nn.Parameter(asg_transitions_init * torch.eye(num_labels), requires_grad=True) self.linseg_progress = torch.nn.Parameter(torch.tensor( [0], dtype=torch.int), requires_grad=False) self.linseg_maximum = linseg_updates self.linseg_message_state = "none" if hide_linseg_messages else "start" @classmethod def build_criterion(cls, args, task): return cls( task, args.silence_token, args.asg_transitions_init, args.max_replabel, args.linseg_updates, args.hide_linseg_messages, ) def linseg_step(self): if not self.training: return False if self.linseg_progress.item() < self.linseg_maximum: if self.linseg_message_state == "start": print("| using LinSeg to initialize ASG") self.linseg_message_state = "finish" self.linseg_progress.add_(1) return True elif self.linseg_message_state == "finish": print("| finished LinSeg initialization") self.linseg_message_state = "none" return False def replace_eos_with_silence(self, tgt): if tgt[-1] != self.eos: return tgt elif self.silence is None or (len(tgt) > 1 and tgt[-2] == self.silence): return tgt[:-1] else: return tgt[:-1] + [self.silence] def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ net_output = model(**sample["net_input"]) emissions = net_output["encoder_out"].transpose(0, 1).contiguous() B = emissions.size(0) T = emissions.size(1) device = emissions.device target = torch.IntTensor(B, T) target_size = torch.IntTensor(B) using_linseg = self.linseg_step() for b in range(B): initial_target_size = sample["target_lengths"][b].item() if initial_target_size == 0: raise ValueError("target size cannot be zero") tgt = sample["target"][b, :initial_target_size].tolist() tgt = self.replace_eos_with_silence(tgt) tgt = pack_replabels(tgt, self.tgt_dict, self.max_replabel) tgt = tgt[:T] if using_linseg: tgt = [tgt[t * len(tgt) // T] for t in range(T)] target[b][:len(tgt)] = torch.IntTensor(tgt) target_size[b] = len(tgt) loss = self.asg.forward(emissions, target.to(device), target_size.to(device)) if reduce: loss = torch.sum(loss) sample_size = (sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]) logging_output = { "loss": utils.item(loss.data) if reduce else loss.data, "ntokens": sample["ntokens"], "nsentences": sample["target"].size(0), "sample_size": sample_size, } return loss, sample_size, logging_output @staticmethod def aggregate_logging_outputs(logging_outputs): """Aggregate logging outputs from data parallel training.""" loss_sum = sum(log.get("loss", 0) for log in logging_outputs) ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) agg_output = { "loss": loss_sum / nsentences, "ntokens": ntokens, "nsentences": nsentences, "sample_size": sample_size, } return agg_output
parser.add_argument("--cpu", action="store_true", help="Use cpu backend, otherwise use CUDA backend") parser.add_argument( "--double", action="store_true", help="store tensors in double, otherwise in float", ) args = parser.parse_args() device = torch.device("cpu" if args.cpu else "cuda") float_type = torch.double if args.double else torch.float # create ASG loss with scaling the loss to the sqrt of target size # and 6 tokens (6 tokens scores predicted by some network for each frame) asg = ASGLoss(6, scale_mode=CriterionScaleMode.TARGET_SZ_SQRT).to(device) # define the input to the loss (scores for tokens came from # some network for each frame) size is [batch, time, ntokens] input = torch.tensor( [ [ [-0.4340, -0.0254, 0.3667, 0.4180, -0.3805, -0.1707], [0.1060, 0.3631, -0.1122, -0.3825, -0.0031, -0.3801], [0.0443, -0.3795, 0.3194, -0.3130, 0.0094, 0.1560], [0.1252, 0.2877, 0.1997, -0.4554, 0.2774, -0.2526], [-0.4001, -0.2402, 0.1295, 0.0172, 0.1805, -0.3299], ], [ [0.3298, -0.2259, -0.0959, 0.4909, 0.2996, -0.2543], [-0.2863, 0.3239, -0.3988, 0.0732, -0.2107, -0.4739], [-0.0906, 0.0480, -0.1301, 0.3975, -0.3317, -0.1967],