def training_step_hybrid(sp, model, batch, mask_id, pad_id, vocab_start_idx, vocab_end_idx, use_cuda): imgs, _lengths, _ = batch # TODO: implement LSTM for hybrid model and pass lengths to model call imgs_k, imgs_q = imgs[:, 0, :], imgs[:, 1, :] imgs_q, mlm_targets = mask_mlm(imgs_q, pad_id, mask_id, vocab_start_idx, vocab_end_idx) if use_cuda: imgs_k = imgs_k.cuda(non_blocking=True) imgs_q = imgs_q.cuda(non_blocking=True) mlm_targets = mlm_targets.cuda(non_blocking=True) predicted_masked_tokens, moco_logits, moco_targets = model(imgs_k, imgs_q) moco_loss = F.cross_entropy(moco_logits, moco_targets) moco_acc1, moco_acc5 = accuracy(moco_logits, moco_targets, topk=(1, 5)) mlm_loss = F.cross_entropy(predicted_masked_tokens.flatten(end_dim=1), mlm_targets.flatten(), ignore_index=pad_id) mlm_acc1, mlm_acc5 = accuracy(predicted_masked_tokens[mlm_targets != pad_id], mlm_targets[mlm_targets != pad_id], topk=(1, 5)) loss = 4 * moco_loss + mlm_loss logs = { "pretrain/moco/loss": moco_loss.item(), "pretrain/moco/acc@1": moco_acc1[0].item(), "pretrain/moco/acc@5": moco_acc5[0].item(), "pretrain/moco/queue_ptr": model.module.queue_ptr.item(), "pretrain/mlm/loss": mlm_loss.item(), "pretrain/mlm/acc@1": mlm_acc1[0].item(), "pretrain/mlm/acc@5": mlm_acc5[0].item(), "pretrain/hybrid_loss": loss, } return {"loss": loss, "log": logs}
def training_step(model, batch, use_cuda=False): imgs, lengths, _ = batch if use_cuda: imgs = imgs.cuda(non_blocking=True) imgs_k, imgs_q = imgs[:, 0, :], imgs[:, 1, :] lengths_k, lengths_q = lengths[:, 0], lengths[:, 1] output, target = model(imgs_q, imgs_k, lengths_k, lengths_q) loss = F.cross_entropy(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 5)) logs = { "pretrain/loss": loss.item(), "pretrain/acc@1": acc1[0].item(), "pretrain/acc@5": acc5[0].item(), "pretrain/queue_ptr": model.module.queue_ptr.item(), } return {"loss": loss, "log": logs}
def training_step_mlm(sp, model, batch, mask_id: int, pad_id: int, vocab_start_idx: int, vocab_end_idx: int, use_cuda=True): seq, lengths, _ = batch # B x L if use_cuda: seq = seq.cuda() B, L = seq.shape seq_masked, targets = mask_mlm(seq, pad_id, mask_id, vocab_start_idx, vocab_end_idx) # logger.debug(f"Example transform:\t{sp.DecodeIds(seq_masked[0].cpu().numpy().tolist())}") output = model(seq_masked, lengths) # B x L x Vocab assert targets.shape == (B, L), f"{targets.shape} versus {B}x{L}" assert output.shape == (B, L, output.shape[-1]), output.shape loss = F.cross_entropy(output.flatten(end_dim=1), targets.flatten(), ignore_index=pad_id) acc1, acc5 = accuracy(output[targets != pad_id], targets[targets != pad_id], topk=(1, 5)) return { "loss": loss, "log": {"pretrain/loss": loss.item(), "pretrain/acc@1": acc1[0].item(), "pretrain/acc@5": acc5[0].item()}, }