def train_BMN(data_loader, model, optimizer, epoch, bm_mask):
    model.train()
    epoch_pemreg_loss = 0
    epoch_pemclr_loss = 0
    epoch_tem_loss = 0
    epoch_loss = 0
    for n_iter, (input_data, label_confidence, label_start,
                 label_end) in enumerate(data_loader):
        input_data = input_data.cuda()
        label_start = label_start.cuda()
        label_end = label_end.cuda()
        label_confidence = label_confidence.cuda()
        confidence_map, start, end = model(input_data)
        loss = bmn_loss_func(confidence_map, start, end, label_confidence,
                             label_start, label_end, bm_mask.cuda())
        optimizer.zero_grad()
        loss[0].backward()
        optimizer.step()

        epoch_pemreg_loss += loss[2].cpu().detach().numpy()
        epoch_pemclr_loss += loss[3].cpu().detach().numpy()
        epoch_tem_loss += loss[1].cpu().detach().numpy()
        epoch_loss += loss[0].cpu().detach().numpy()

    print(
        "BMN training loss(epoch %d): tem_loss: %.03f, pem class_loss: %.03f, pem reg_loss: %.03f, total_loss: %.03f"
        % (epoch, epoch_tem_loss / (n_iter + 1), epoch_pemclr_loss /
           (n_iter + 1), epoch_pemreg_loss / (n_iter + 1), epoch_loss /
           (n_iter + 1)))
def test_BMN(data_loader, model, epoch, bm_mask):
    model.eval()
    best_loss = 1e10
    epoch_pemreg_loss = 0
    epoch_pemclr_loss = 0
    epoch_tem_loss = 0
    epoch_loss = 0
    for n_iter, (input_data, label_confidence, label_start,
                 label_end) in enumerate(data_loader):
        input_data = input_data.cuda()
        label_start = label_start.cuda()
        label_end = label_end.cuda()
        label_confidence = label_confidence.cuda()

        confidence_map, start, end = model(input_data)
        loss = bmn_loss_func(confidence_map, start, end, label_confidence,
                             label_start, label_end, bm_mask.cuda())

        epoch_pemreg_loss += loss[2].cpu().detach().numpy()
        epoch_pemclr_loss += loss[3].cpu().detach().numpy()
        epoch_tem_loss += loss[1].cpu().detach().numpy()
        epoch_loss += loss[0].cpu().detach().numpy()

    print(
        "BMN training loss(epoch %d): tem_loss: %.03f, pem class_loss: %.03f, pem reg_loss: %.03f, total_loss: %.03f"
        % (epoch, epoch_tem_loss / (n_iter + 1), epoch_pemclr_loss /
           (n_iter + 1), epoch_pemreg_loss / (n_iter + 1), epoch_loss /
           (n_iter + 1)))

    state = {'epoch': epoch + 1, 'state_dict': model.state_dict()}
    torch.save(state, opt["checkpoint_path"] + "/BMN_checkpoint.pth.tar")
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(state, opt["checkpoint_path"] + "/BMN_best.pth.tar")
Esempio n. 3
0
    def train_epoch(self, data_loader, bm_mask, epoch, writer):
        cfg = self.cfg
        self.model.train()
        self.optimizer.zero_grad()
        loss_names = [
            'Loss', 'TemLoss', 'PemLoss Regression', 'PemLoss Classification'
        ]
        epoch_losses = [0] * 4
        period_losses = [0] * 4
        last_period_size = len(data_loader) % cfg.TRAIN.STEP_PERIOD
        last_period_start = cfg.TRAIN.STEP_PERIOD * (len(data_loader) //
                                                     cfg.TRAIN.STEP_PERIOD)

        for n_iter, (env_features, agent_features, agent_masks,
                     label_confidence, label_start,
                     label_end) in enumerate(tqdm(data_loader)):
            env_features = env_features.cuda() if cfg.USE_ENV else None
            agent_features = agent_features.cuda() if cfg.USE_AGENT else None
            agent_masks = agent_masks.cuda() if cfg.USE_AGENT else None

            label_start = label_start.cuda()
            label_end = label_end.cuda()
            label_confidence = label_confidence.cuda()

            confidence_map, start, end = self.model(env_features,
                                                    agent_features,
                                                    agent_masks)

            losses = bmn_loss_func(confidence_map, start, end,
                                   label_confidence, label_start, label_end,
                                   bm_mask)
            period_size = cfg.TRAIN.STEP_PERIOD if n_iter < last_period_start else last_period_size
            total_loss = losses[0] / period_size
            total_loss.backward()

            losses = [
                l.cpu().detach().numpy() / cfg.TRAIN.STEP_PERIOD
                for l in losses
            ]
            period_losses = [l + pl for l, pl in zip(losses, period_losses)]

            if (n_iter + 1) % cfg.TRAIN.STEP_PERIOD != 0 and n_iter != (
                    len(data_loader) - 1):
                continue

            self.optimizer.step()
            self.optimizer.zero_grad()

            epoch_losses = [
                el + pl for el, pl in zip(epoch_losses, period_losses)
            ]

            write_step = epoch * len(data_loader) + n_iter
            for i, loss_name in enumerate(loss_names):
                writer.add_scalar(loss_name, period_losses[i], write_step)
            period_losses = [0] * 4

        print(
            "BMN training loss(epoch %d): tem_loss: %.03f, pem reg_loss: %.03f, pem cls_loss: %.03f, total_loss: %.03f"
            % (epoch, epoch_losses[1] / (n_iter + 1), epoch_losses[2] /
               (n_iter + 1), epoch_losses[3] / (n_iter + 1), epoch_losses[0] /
               (n_iter + 1)))