def train_BMN(data_loader, model, optimizer, epoch, bm_mask): model.train() epoch_pemreg_loss = 0 epoch_pemclr_loss = 0 epoch_tem_loss = 0 epoch_loss = 0 for n_iter, (input_data, label_confidence, label_start, label_end) in enumerate(data_loader): input_data = input_data.cuda() label_start = label_start.cuda() label_end = label_end.cuda() label_confidence = label_confidence.cuda() confidence_map, start, end = model(input_data) loss = bmn_loss_func(confidence_map, start, end, label_confidence, label_start, label_end, bm_mask.cuda()) optimizer.zero_grad() loss[0].backward() optimizer.step() epoch_pemreg_loss += loss[2].cpu().detach().numpy() epoch_pemclr_loss += loss[3].cpu().detach().numpy() epoch_tem_loss += loss[1].cpu().detach().numpy() epoch_loss += loss[0].cpu().detach().numpy() print( "BMN training loss(epoch %d): tem_loss: %.03f, pem class_loss: %.03f, pem reg_loss: %.03f, total_loss: %.03f" % (epoch, epoch_tem_loss / (n_iter + 1), epoch_pemclr_loss / (n_iter + 1), epoch_pemreg_loss / (n_iter + 1), epoch_loss / (n_iter + 1)))
def test_BMN(data_loader, model, epoch, bm_mask): model.eval() best_loss = 1e10 epoch_pemreg_loss = 0 epoch_pemclr_loss = 0 epoch_tem_loss = 0 epoch_loss = 0 for n_iter, (input_data, label_confidence, label_start, label_end) in enumerate(data_loader): input_data = input_data.cuda() label_start = label_start.cuda() label_end = label_end.cuda() label_confidence = label_confidence.cuda() confidence_map, start, end = model(input_data) loss = bmn_loss_func(confidence_map, start, end, label_confidence, label_start, label_end, bm_mask.cuda()) epoch_pemreg_loss += loss[2].cpu().detach().numpy() epoch_pemclr_loss += loss[3].cpu().detach().numpy() epoch_tem_loss += loss[1].cpu().detach().numpy() epoch_loss += loss[0].cpu().detach().numpy() print( "BMN training loss(epoch %d): tem_loss: %.03f, pem class_loss: %.03f, pem reg_loss: %.03f, total_loss: %.03f" % (epoch, epoch_tem_loss / (n_iter + 1), epoch_pemclr_loss / (n_iter + 1), epoch_pemreg_loss / (n_iter + 1), epoch_loss / (n_iter + 1))) state = {'epoch': epoch + 1, 'state_dict': model.state_dict()} torch.save(state, opt["checkpoint_path"] + "/BMN_checkpoint.pth.tar") if epoch_loss < best_loss: best_loss = epoch_loss torch.save(state, opt["checkpoint_path"] + "/BMN_best.pth.tar")
def train_epoch(self, data_loader, bm_mask, epoch, writer): cfg = self.cfg self.model.train() self.optimizer.zero_grad() loss_names = [ 'Loss', 'TemLoss', 'PemLoss Regression', 'PemLoss Classification' ] epoch_losses = [0] * 4 period_losses = [0] * 4 last_period_size = len(data_loader) % cfg.TRAIN.STEP_PERIOD last_period_start = cfg.TRAIN.STEP_PERIOD * (len(data_loader) // cfg.TRAIN.STEP_PERIOD) for n_iter, (env_features, agent_features, agent_masks, label_confidence, label_start, label_end) in enumerate(tqdm(data_loader)): env_features = env_features.cuda() if cfg.USE_ENV else None agent_features = agent_features.cuda() if cfg.USE_AGENT else None agent_masks = agent_masks.cuda() if cfg.USE_AGENT else None label_start = label_start.cuda() label_end = label_end.cuda() label_confidence = label_confidence.cuda() confidence_map, start, end = self.model(env_features, agent_features, agent_masks) losses = bmn_loss_func(confidence_map, start, end, label_confidence, label_start, label_end, bm_mask) period_size = cfg.TRAIN.STEP_PERIOD if n_iter < last_period_start else last_period_size total_loss = losses[0] / period_size total_loss.backward() losses = [ l.cpu().detach().numpy() / cfg.TRAIN.STEP_PERIOD for l in losses ] period_losses = [l + pl for l, pl in zip(losses, period_losses)] if (n_iter + 1) % cfg.TRAIN.STEP_PERIOD != 0 and n_iter != ( len(data_loader) - 1): continue self.optimizer.step() self.optimizer.zero_grad() epoch_losses = [ el + pl for el, pl in zip(epoch_losses, period_losses) ] write_step = epoch * len(data_loader) + n_iter for i, loss_name in enumerate(loss_names): writer.add_scalar(loss_name, period_losses[i], write_step) period_losses = [0] * 4 print( "BMN training loss(epoch %d): tem_loss: %.03f, pem reg_loss: %.03f, pem cls_loss: %.03f, total_loss: %.03f" % (epoch, epoch_losses[1] / (n_iter + 1), epoch_losses[2] / (n_iter + 1), epoch_losses[3] / (n_iter + 1), epoch_losses[0] / (n_iter + 1)))