def ModelTrain(train_data_loader, model, criterion, optimizer, loss_bin, config, epoch): if (config['base']['algorithm'] == 'DB' or config['base']['algorithm'] == 'SAST'): running_metric_text = runningScore(2) else: running_metric_text = runningScore(2) running_metric_kernel = runningScore(2) for batch_idx, data in enumerate(train_data_loader): if (data is None): continue pre_batch, gt_batch = model(data) loss, metrics = criterion(pre_batch, gt_batch) optimizer.zero_grad() loss.backward() optimizer.step() cv2.imwrite('pre.jpg', pre_batch['f_score'][0, 0].cpu().detach().numpy() * 255) for key in loss_bin.keys(): if (key in metrics.keys()): loss_bin[key].loss_add(metrics[key].item()) else: loss_bin[key].loss_add(loss.item()) if (config['base']['algorithm'] == 'DB'): iou, acc = cal_DB(pre_batch['binary'], gt_batch['gt'], gt_batch['mask'], running_metric_text) elif (config['base']['algorithm'] == 'SAST'): iou, acc = cal_DB(pre_batch['f_score'], gt_batch['input_score'], gt_batch['input_mask'], running_metric_text) else: iou, acc = cal_PAN_PSE(pre_batch['pre_kernel'], gt_batch['gt_kernel'], pre_batch['pre_text'], gt_batch['gt_text'], gt_batch['train_mask'], running_metric_text, running_metric_kernel) if (batch_idx % config['base']['show_step'] == 0): log = '({}/{}/{}/{}) | ' \ .format(epoch, config['base']['n_epoch'], batch_idx, len(train_data_loader)) bin_keys = list(loss_bin.keys()) for i in range(len(bin_keys)): log += bin_keys[i] + ':{:.4f}'.format( loss_bin[bin_keys[i]].loss_mean()) + ' | ' log += 'ACC:{:.4f}'.format(acc) + ' | ' log += 'IOU:{:.4f}'.format(iou) + ' | ' log += 'lr:{:.8f}'.format(optimizer.param_groups[0]['lr']) print(log) loss_write = [] for key in list(loss_bin.keys()): loss_write.append(loss_bin[key].loss_mean()) loss_write.extend([acc, iou]) return loss_write
def ModelTrain(train_data_loader, t_model, t_criterion, model, criterion, optimizer, loss_bin, args, config, epoch): if (config['base']['algorithm'] == 'DB' or config['base']['algorithm'] == 'SAST'): running_metric_text = runningScore(2) else: running_metric_text = runningScore(2) running_metric_kernel = runningScore(2) for batch_idx, data in enumerate(train_data_loader): if (data is None): continue pre_batch, gt_batch = model(data) if (t_model is not None): with torch.no_grad(): t_pre_batch, _ = t_model(data) distil_loss = t_criterion(pre_batch, t_pre_batch) loss, metrics = criterion(pre_batch, gt_batch) if (t_model is not None): loss = args.t_ratio * loss + (1 - args.t_ratio) * distil_loss metrics['loss_distil'] = distil_loss optimizer.zero_grad() loss.backward() if (args.sr_lr is not None): updateBN(model, args) optimizer.step() for key in loss_bin.keys(): if (key in metrics.keys()): loss_bin[key].loss_add(metrics[key].item()) else: loss_bin[key].loss_add(loss.item()) if (config['base']['algorithm'] == 'DB'): iou, acc = cal_DB(pre_batch['binary'], gt_batch['gt'], gt_batch['mask'], running_metric_text) elif (config['base']['algorithm'] == 'SAST'): iou, acc = cal_DB(pre_batch['f_score'], gt_batch['input_score'], gt_batch['input_mask'], running_metric_text) else: iou, acc = cal_PAN_PSE(pre_batch['pre_kernel'], gt_batch['gt_kernel'], pre_batch['pre_text'], gt_batch['gt_text'], gt_batch['train_mask'], running_metric_text, running_metric_kernel) if (batch_idx % config['base']['show_step'] == 0): log = '({}/{}/{}/{}) | ' \ .format(epoch, config['base']['n_epoch'], batch_idx, len(train_data_loader)) bin_keys = list(loss_bin.keys()) for i in range(len(bin_keys)): log += bin_keys[i] + ':{:.4f}'.format( loss_bin[bin_keys[i]].loss_mean()) + ' | ' log += 'ACC:{:.4f}'.format(acc) + ' | ' log += 'IOU:{:.4f}'.format(iou) + ' | ' log += 'lr:{:.8f}'.format(optimizer.param_groups[0]['lr']) print(log) loss_write = [] for key in list(loss_bin.keys()): loss_write.append(loss_bin[key].loss_mean()) loss_write.extend([acc, iou]) return loss_write