예제 #1
0
def ModelTrain(train_data_loader, model, criterion, optimizer, loss_bin,
               config, epoch):
    if (config['base']['algorithm'] == 'DB'
            or config['base']['algorithm'] == 'SAST'):
        running_metric_text = runningScore(2)
    else:
        running_metric_text = runningScore(2)
        running_metric_kernel = runningScore(2)
    for batch_idx, data in enumerate(train_data_loader):
        if (data is None):
            continue
        pre_batch, gt_batch = model(data)
        loss, metrics = criterion(pre_batch, gt_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        cv2.imwrite('pre.jpg',
                    pre_batch['f_score'][0, 0].cpu().detach().numpy() * 255)
        for key in loss_bin.keys():
            if (key in metrics.keys()):
                loss_bin[key].loss_add(metrics[key].item())
            else:
                loss_bin[key].loss_add(loss.item())
        if (config['base']['algorithm'] == 'DB'):
            iou, acc = cal_DB(pre_batch['binary'], gt_batch['gt'],
                              gt_batch['mask'], running_metric_text)
        elif (config['base']['algorithm'] == 'SAST'):
            iou, acc = cal_DB(pre_batch['f_score'], gt_batch['input_score'],
                              gt_batch['input_mask'], running_metric_text)
        else:
            iou, acc = cal_PAN_PSE(pre_batch['pre_kernel'],
                                   gt_batch['gt_kernel'],
                                   pre_batch['pre_text'], gt_batch['gt_text'],
                                   gt_batch['train_mask'], running_metric_text,
                                   running_metric_kernel)

        if (batch_idx % config['base']['show_step'] == 0):
            log = '({}/{}/{}/{}) | ' \
                .format(epoch, config['base']['n_epoch'], batch_idx, len(train_data_loader))
            bin_keys = list(loss_bin.keys())

            for i in range(len(bin_keys)):
                log += bin_keys[i] + ':{:.4f}'.format(
                    loss_bin[bin_keys[i]].loss_mean()) + ' | '

            log += 'ACC:{:.4f}'.format(acc) + ' | '
            log += 'IOU:{:.4f}'.format(iou) + ' | '
            log += 'lr:{:.8f}'.format(optimizer.param_groups[0]['lr'])
            print(log)
    loss_write = []
    for key in list(loss_bin.keys()):
        loss_write.append(loss_bin[key].loss_mean())
    loss_write.extend([acc, iou])
    return loss_write
예제 #2
0
def ModelTrain(train_data_loader, t_model, t_criterion, model, criterion,
               optimizer, loss_bin, args, config, epoch):
    if (config['base']['algorithm'] == 'DB'
            or config['base']['algorithm'] == 'SAST'):
        running_metric_text = runningScore(2)
    else:
        running_metric_text = runningScore(2)
        running_metric_kernel = runningScore(2)
    for batch_idx, data in enumerate(train_data_loader):
        if (data is None):
            continue
        pre_batch, gt_batch = model(data)

        if (t_model is not None):
            with torch.no_grad():
                t_pre_batch, _ = t_model(data)
            distil_loss = t_criterion(pre_batch, t_pre_batch)

        loss, metrics = criterion(pre_batch, gt_batch)

        if (t_model is not None):
            loss = args.t_ratio * loss + (1 - args.t_ratio) * distil_loss
            metrics['loss_distil'] = distil_loss

        optimizer.zero_grad()
        loss.backward()
        if (args.sr_lr is not None):
            updateBN(model, args)
        optimizer.step()

        for key in loss_bin.keys():
            if (key in metrics.keys()):
                loss_bin[key].loss_add(metrics[key].item())
            else:
                loss_bin[key].loss_add(loss.item())
        if (config['base']['algorithm'] == 'DB'):
            iou, acc = cal_DB(pre_batch['binary'], gt_batch['gt'],
                              gt_batch['mask'], running_metric_text)
        elif (config['base']['algorithm'] == 'SAST'):
            iou, acc = cal_DB(pre_batch['f_score'], gt_batch['input_score'],
                              gt_batch['input_mask'], running_metric_text)
        else:
            iou, acc = cal_PAN_PSE(pre_batch['pre_kernel'],
                                   gt_batch['gt_kernel'],
                                   pre_batch['pre_text'], gt_batch['gt_text'],
                                   gt_batch['train_mask'], running_metric_text,
                                   running_metric_kernel)

        if (batch_idx % config['base']['show_step'] == 0):
            log = '({}/{}/{}/{}) | ' \
                .format(epoch, config['base']['n_epoch'], batch_idx, len(train_data_loader))
            bin_keys = list(loss_bin.keys())

            for i in range(len(bin_keys)):
                log += bin_keys[i] + ':{:.4f}'.format(
                    loss_bin[bin_keys[i]].loss_mean()) + ' | '

            log += 'ACC:{:.4f}'.format(acc) + ' | '
            log += 'IOU:{:.4f}'.format(iou) + ' | '
            log += 'lr:{:.8f}'.format(optimizer.param_groups[0]['lr'])
            print(log)
    loss_write = []
    for key in list(loss_bin.keys()):
        loss_write.append(loss_bin[key].loss_mean())
    loss_write.extend([acc, iou])
    return loss_write