Beispiel #1
0
    def _run_test_forward(dtype, device, average_frames, reduction):
        x, y, xs, ys, expected = CTCLossTest._create_test_data(
            dtype, device, average_frames, reduction
        )
        # Test function
        loss = torch_baidu_ctc.ctc_loss(
            x, y, xs, ys, average_frames=average_frames, reduction=reduction
        )
        np.testing.assert_array_almost_equal(loss.cpu(), expected.cpu())

        # Test module
        loss = torch_baidu_ctc.CTCLoss(
            average_frames=average_frames, reduction=reduction
        )(x, y, xs, ys)
        np.testing.assert_array_almost_equal(loss.cpu(), expected.cpu())
Beispiel #2
0
    def forward(self, output, target, **kwargs):
        # type: (torch.Tensor, List[List[int]]) -> (FloatScalar, List[int])
        """
        Args:
            output: Size seqLength x outputDim, contains
                the output from the network as well as a list of size
                seqLength containing batch sizes of the sequence
            target: Contains the size of each output
                sequence from the network. Size batchSize
        """
        acts, act_lens = transform_output(output)

        assert act_lens[0] == acts.size(0), "Maximum length does not match"
        assert len(target) == acts.size(1), "Batch size does not match"

        valid_indices, err_indices = get_valids_and_errors(act_lens, target)
        if err_indices:
            if kwargs.get("batch_ids", None) is not None:
                assert isinstance(kwargs["batch_ids"], (list, tuple))
                err_indices = [kwargs["batch_ids"][i] for i in err_indices]
            _logger.warning(
                "The following samples in the batch were ignored for the loss "
                "computation: {}",
                err_indices,
            )

        if not valid_indices:
            _logger.warning("All samples in the batch were ignored!")
            return None

        # TODO(jpuigcerver): We need to change this because CTCPrepare.apply
        # will set requires_grad of *all* outputs to True if *any* of the
        # inputs requires_grad is True.
        acts, labels, act_lens, label_lens = CTCPrepare.apply(
            acts, target, act_lens, valid_indices if err_indices else None
        )

        # TODO(jpuigcerver): Remove the detach() once the previous TODO is
        # fixed.
        return ctc_loss(
            acts=acts,
            labels=labels.detach(),
            acts_lens=act_lens.detach(),
            labels_lens=label_lens.detach(),
            reduction=self._reduction,
            average_frames=self._average_frames,
        )
def process_boxes(images,
                  im_data,
                  iou_pred,
                  roi_pred,
                  angle_pred,
                  score_maps,
                  gt_idxs,
                  gtso,
                  lbso,
                  features,
                  net,
                  ctc_loss,
                  opts,
                  debug=False):

    ctc_loss_count = 0
    loss = torch.from_numpy(np.asarray([0])).type(torch.FloatTensor).cuda()

    for bid in range(iou_pred.size(0)):

        gts = gtso[bid]
        lbs = lbso[bid]

        gt_proc = 0
        gt_good = 0

        gts_count = {}

        iou_pred_np = iou_pred[bid].data.cpu().numpy()
        iou_map = score_maps[bid]
        to_walk = iou_pred_np.squeeze(0) * iou_map * (iou_pred_np.squeeze(0) >
                                                      0.5)

        roi_p_bid = roi_pred[bid].data.cpu().numpy()
        gt_idx = gt_idxs[bid]

        if debug:
            img = images[bid]
            img += 1
            img *= 128
            img = np.asarray(img, dtype=np.uint8)

        xy_text = np.argwhere(to_walk > 0)
        random.shuffle(xy_text)
        xy_text = xy_text[0:min(xy_text.shape[0], 100)]

        for i in range(0, xy_text.shape[0]):
            if opts.geo_type == 1:
                break
            pos = xy_text[i, :]

            gt_id = gt_idx[pos[0], pos[1]]

            if not gt_id in gts_count:
                gts_count[gt_id] = 0

            if gts_count[gt_id] > 2:
                continue

            gt = gts[gt_id]
            gt_txt = lbs[gt_id]
            if gt_txt.startswith('##'):
                continue

            angle_sin = angle_pred[bid, 0, pos[0], pos[1]]
            angle_cos = angle_pred[bid, 1, pos[0], pos[1]]

            angle = math.atan2(angle_sin, angle_cos)

            angle_gt = (math.atan2(
                (gt[2][1] - gt[1][1]), gt[2][0] - gt[1][0]) + math.atan2(
                    (gt[3][1] - gt[0][1]), gt[3][0] - gt[0][0])) / 2

            if math.fabs(angle_gt - angle) > math.pi / 16:
                continue

            offset = roi_p_bid[:, pos[0], pos[1]]
            posp = pos + 0.25
            pos_g = np.array([(posp[1] - offset[0] * math.sin(angle)) * 4,
                              (posp[0] - offset[0] * math.cos(angle)) * 4])
            pos_g2 = np.array([(posp[1] + offset[1] * math.sin(angle)) * 4,
                               (posp[0] + offset[1] * math.cos(angle)) * 4])

            pos_r = np.array([(posp[1] - offset[2] * math.cos(angle)) * 4,
                              (posp[0] - offset[2] * math.sin(angle)) * 4])
            pos_r2 = np.array([(posp[1] + offset[3] * math.cos(angle)) * 4,
                               (posp[0] + offset[3] * math.sin(angle)) * 4])

            center = (pos_g + pos_g2 + pos_r + pos_r2) / 2 - [
                4 * pos[1], 4 * pos[0]
            ]
            #center = (pos_g + pos_g2 + pos_r + pos_r2) / 4
            dw = pos_r - pos_r2
            dh = pos_g - pos_g2

            w = math.sqrt(dw[0] * dw[0] + dw[1] * dw[1])
            h = math.sqrt(dh[0] * dh[0] + dh[1] * dh[1])

            dhgt = gt[1] - gt[0]

            h_gt = math.sqrt(dhgt[0] * dhgt[0] + dhgt[1] * dhgt[1])
            if h_gt < 10:
                continue

            rect = ((center[0], center[1]), (w, h), angle * 180 / math.pi)
            pts = cv2.boxPoints(rect)

            pred_bbox = cv2.boundingRect(pts)
            pred_bbox = [
                pred_bbox[0], pred_bbox[1], pred_bbox[2], pred_bbox[3]
            ]
            pred_bbox[2] += pred_bbox[0]
            pred_bbox[3] += pred_bbox[1]

            if gt[:,
                  0].max() > im_data.size(3) or gt[:,
                                                   1].max() > im_data.size(3):
                continue

            gt_bbox = [
                gt[:, 0].min(), gt[:, 1].min(), gt[:, 0].max(), gt[:, 1].max()
            ]
            inter = intersect(pred_bbox, gt_bbox)

            uni = union(pred_bbox, gt_bbox)
            ratio = area(inter) / float(area(uni))

            if ratio < 0.90:
                continue

            hratio = min(h, h_gt) / max(h, h_gt)
            if hratio < 0.5:
                continue

            input_W = im_data.size(3)
            input_H = im_data.size(2)
            target_h = norm_height

            scale = target_h / h
            target_gw = (int(w * scale) + target_h)
            target_gw = max(8, int(round(target_gw / 4)) * 4)

            #show pooled image in image layer

            scalex = (w + h) / input_W
            scaley = h / input_H

            th11 = scalex * math.cos(angle)
            th12 = -math.sin(angle) * scaley
            th13 = (2 * center[0] - input_W - 1) / (
                input_W - 1
            )  #* torch.cos(angle_var) - (2 * yc - input_H - 1) / (input_H - 1) * torch.sin(angle_var)

            th21 = math.sin(angle) * scalex
            th22 = scaley * math.cos(angle)
            th23 = (2 * center[1] - input_H - 1) / (
                input_H - 1
            )  #* torch.cos(angle_var) + (2 * xc - input_W - 1) / (input_W - 1) * torch.sin(angle_var)

            t = np.asarray([th11, th12, th13, th21, th22, th23],
                           dtype=np.float)
            t = torch.from_numpy(t).type(torch.FloatTensor).cuda()

            #t = torch.stack((th11, th12, th13, th21, th22, th23), dim=1)
            theta = t.view(-1, 2, 3)

            grid = F.affine_grid(
                theta, torch.Size((1, 3, int(target_h), int(target_gw))))

            x = F.grid_sample(im_data[bid].unsqueeze(0), grid)

            if debug:
                x_c = x.data.cpu().numpy()[0]
                x_data_draw = x_c.swapaxes(0, 2)
                x_data_draw = x_data_draw.swapaxes(0, 1)

                x_data_draw += 1
                x_data_draw *= 128
                x_data_draw = np.asarray(x_data_draw, dtype=np.uint8)
                x_data_draw = x_data_draw[:, :, ::-1]

                cv2.circle(img, (int(center[0]), int(center[1])), 5,
                           (0, 255, 0))
                cv2.imshow('im_data', x_data_draw)

                draw_box_points(img, pts)
                draw_box_points(img, gt, color=(0, 0, 255))

                cv2.imshow('img', img)
                cv2.waitKey(100)

            gt_labels = []
            gt_labels.append(codec_rev[' '])
            for k in range(len(gt_txt)):
                if gt_txt[k] in codec_rev:
                    gt_labels.append(codec_rev[gt_txt[k]])
                else:
                    print('Unknown char: {0}'.format(gt_txt[k]))
                    gt_labels.append(3)

            if 'ARABIC' in ud.name(gt_txt[0]):
                gt_labels = gt_labels[::-1]
            gt_labels.append(codec_rev[' '])

            features = net.forward_features(x)
            labels_pred = net.forward_ocr(features)

            label_length = []
            label_length.append(len(gt_labels))
            probs_sizes = autograd.Variable(
                torch.IntTensor([(labels_pred.permute(2, 0, 1).size()[0])] *
                                (labels_pred.permute(2, 0, 1).size()[1])))
            label_sizes = autograd.Variable(
                torch.IntTensor(
                    torch.from_numpy(np.array(label_length)).int()))
            labels = autograd.Variable(
                torch.IntTensor(torch.from_numpy(np.array(gt_labels)).int()))

            loss = loss + ctc_loss(labels_pred.permute(2, 0, 1), labels,
                                   probs_sizes, label_sizes).cuda()
            ctc_loss_count += 1

            if debug:
                ctc_f = labels_pred.data.cpu().numpy()
                ctc_f = ctc_f.swapaxes(1, 2)

                labels = ctc_f.argmax(2)
                det_text, conf, dec_s, splits = print_seq_ext(
                    labels[0, :], codec)

                print('{0} \t {1}'.format(det_text, gt_txt))

            gts_count[gt_id] += 1

            if ctc_loss_count > 64 or debug:
                break

        for gt_id in range(0, len(gts)):

            gt = gts[gt_id]
            gt_txt = lbs[gt_id]

            gt_txt_low = gt_txt.lower()
            if gt_txt.startswith('##'):
                continue

            if gt[:,
                  0].max() > im_data.size(3) or gt[:,
                                                   1].max() > im_data.size(3):
                continue

            if gt.min() < 0:
                continue

            center = (gt[0, :] + gt[1, :] + gt[2, :] + gt[3, :]) / 4
            dw = gt[2, :] - gt[1, :]
            dh = gt[1, :] - gt[0, :]

            w = math.sqrt(dw[0] * dw[0] + dw[1] * dw[1])
            h = math.sqrt(dh[0] * dh[0] + dh[1] * dh[1]) + random.randint(
                -2, 2)

            if h < 8:
                #print('too small h!')
                continue

            angle_gt = (math.atan2(
                (gt[2][1] - gt[1][1]), gt[2][0] - gt[1][0]) + math.atan2(
                    (gt[3][1] - gt[0][1]), gt[3][0] - gt[0][0])) / 2

            input_W = im_data.size(3)
            input_H = im_data.size(2)
            target_h = norm_height

            scale = target_h / h
            target_gw = int(w * scale) + random.randint(0, int(target_h))
            target_gw = max(8, int(round(target_gw / 4)) * 4)

            xc = center[0]
            yc = center[1]
            w2 = w
            h2 = h

            #show pooled image in image layer

            scalex = (w2 + random.randint(0, int(h2))) / input_W
            scaley = h2 / input_H

            th11 = scalex * math.cos(angle_gt)
            th12 = -math.sin(angle_gt) * scaley
            th13 = (2 * xc - input_W - 1) / (
                input_W - 1
            )  #* torch.cos(angle_var) - (2 * yc - input_H - 1) / (input_H - 1) * torch.sin(angle_var)

            th21 = math.sin(angle_gt) * scalex
            th22 = scaley * math.cos(angle_gt)
            th23 = (2 * yc - input_H - 1) / (
                input_H - 1
            )  #* torch.cos(angle_var) + (2 * xc - input_W - 1) / (input_W - 1) * torch.sin(angle_var)

            t = np.asarray([th11, th12, th13, th21, th22, th23],
                           dtype=np.float)
            t = torch.from_numpy(t).type(torch.FloatTensor)
            t = t.cuda()
            theta = t.view(-1, 2, 3)

            grid = F.affine_grid(
                theta, torch.Size((1, 3, int(target_h), int(target_gw))))
            x = F.grid_sample(im_data[bid].unsqueeze(0), grid)

            #score_sampled = F.grid_sample(iou_pred[bid].unsqueeze(0), grid)

            gt_labels = []
            gt_labels.append(codec_rev[' '])
            for k in range(len(gt_txt)):
                if gt_txt[k] in codec_rev:
                    gt_labels.append(codec_rev[gt_txt[k]])
                else:
                    print('Unknown char: {0}'.format(gt_txt[k]))
                    gt_labels.append(3)
            gt_labels.append(codec_rev[' '])

            if 'ARABIC' in ud.name(gt_txt[0]):
                gt_labels = gt_labels[::-1]

            features = net.forward_features(x)
            labels_pred = net.forward_ocr(features)

            label_length = []
            label_length.append(len(gt_labels))
            probs_sizes = torch.IntTensor(
                [(labels_pred.permute(2, 0, 1).size()[0])] *
                (labels_pred.permute(2, 0, 1).size()[1]))
            label_sizes = torch.IntTensor(
                torch.from_numpy(np.array(label_length)).int())
            labels = torch.IntTensor(
                torch.from_numpy(np.array(gt_labels)).int())

            loss = loss + ctc_loss(labels_pred.permute(2, 0, 1), labels,
                                   probs_sizes, label_sizes).cuda()
            ctc_loss_count += 1

            if debug:
                x_d = x.data.cpu().numpy()[0]
                x_data_draw = x_d.swapaxes(0, 2)
                x_data_draw = x_data_draw.swapaxes(0, 1)

                x_data_draw += 1
                x_data_draw *= 128
                x_data_draw = np.asarray(x_data_draw, dtype=np.uint8)
                x_data_draw = x_data_draw[:, :, ::-1]
                cv2.imshow('im_data_gt', x_data_draw)
                cv2.waitKey(100)

            gt_proc += 1
            if True:
                ctc_f = labels_pred.data.cpu().numpy()
                ctc_f = ctc_f.swapaxes(1, 2)

                labels = ctc_f.argmax(2)
                det_text, conf, dec_s, splits = print_seq_ext(
                    labels[0, :], codec)
                if debug:
                    print('{0} \t {1}'.format(det_text, gt_txt))
                if det_text.lower() == gt_txt.lower():
                    gt_good += 1

            if ctc_loss_count > 128 or debug:
                break

    if ctc_loss_count > 0:
        loss /= ctc_loss_count

    return loss, gt_good, gt_proc
def main(opts):

    model_name = 'OCT-E2E-MLT'
    net = OctMLT(attention=True)
    print("Using {0}".format(model_name))

    learning_rate = opts.base_lr
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=opts.base_lr,
                                 weight_decay=weight_decay)
    step_start = 0
    if os.path.exists(opts.model):
        print('loading model from %s' % args.model)
        step_start, learning_rate = net_utils.load_net(args.model, net)

    if opts.cuda:
        net.cuda()

    net.train()

    data_generator = data_gen.get_batch(num_workers=opts.num_readers,
                                        input_size=opts.input_size,
                                        batch_size=opts.batch_size,
                                        train_list=opts.train_list,
                                        geo_type=opts.geo_type)

    dg_ocr = ocr_gen.get_batch(num_workers=2,
                               batch_size=opts.ocr_batch_size,
                               train_list=opts.ocr_feed_list,
                               in_train=True,
                               norm_height=norm_height,
                               rgb=True)

    train_loss = 0
    bbox_loss, seg_loss, angle_loss = 0., 0., 0.
    cnt = 0
    ctc_loss = CTCLoss()

    ctc_loss_val = 0
    box_loss_val = 0
    good_all = 0
    gt_all = 0

    best_step = step_start
    best_loss = 1000000
    best_model = net.state_dict()
    best_optimizer = optimizer.state_dict()
    best_learning_rate = learning_rate
    max_patience = 3000
    early_stop = False

    for step in range(step_start, opts.max_iters):

        # batch
        images, image_fns, score_maps, geo_maps, training_masks, gtso, lbso, gt_idxs = next(
            data_generator)
        im_data = net_utils.np_to_variable(images, is_cuda=opts.cuda).permute(
            0, 3, 1, 2)
        start = timeit.timeit()
        try:
            seg_pred, roi_pred, angle_pred, features = net(im_data)
        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            continue
        end = timeit.timeit()

        # backward

        smaps_var = net_utils.np_to_variable(score_maps, is_cuda=opts.cuda)
        training_mask_var = net_utils.np_to_variable(training_masks,
                                                     is_cuda=opts.cuda)
        angle_gt = net_utils.np_to_variable(geo_maps[:, :, :, 4],
                                            is_cuda=opts.cuda)
        geo_gt = net_utils.np_to_variable(geo_maps[:, :, :, [0, 1, 2, 3]],
                                          is_cuda=opts.cuda)

        try:
            loss = net.loss(seg_pred, smaps_var, training_mask_var, angle_pred,
                            angle_gt, roi_pred, geo_gt)
        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            continue

        bbox_loss += net.box_loss_value.data.cpu().numpy()
        seg_loss += net.segm_loss_value.data.cpu().numpy()
        angle_loss += net.angle_loss_value.data.cpu().numpy()

        train_loss += loss.data.cpu().numpy()
        optimizer.zero_grad()

        try:

            if step > 10000:  #this is just extra augumentation step ... in early stage just slows down training
                ctcl, gt_b_good, gt_b_all = process_boxes(images,
                                                          im_data,
                                                          seg_pred[0],
                                                          roi_pred[0],
                                                          angle_pred[0],
                                                          score_maps,
                                                          gt_idxs,
                                                          gtso,
                                                          lbso,
                                                          features,
                                                          net,
                                                          ctc_loss,
                                                          opts,
                                                          debug=opts.debug)
                ctc_loss_val += ctcl.data.cpu().numpy()[0]
                loss = loss + ctcl
                gt_all += gt_b_all
                good_all += gt_b_good

            imageso, labels, label_length = next(dg_ocr)
            im_data_ocr = net_utils.np_to_variable(imageso,
                                                   is_cuda=opts.cuda).permute(
                                                       0, 3, 1, 2)
            features = net.forward_features(im_data_ocr)
            labels_pred = net.forward_ocr(features)

            probs_sizes = torch.IntTensor(
                [(labels_pred.permute(2, 0, 1).size()[0])] *
                (labels_pred.permute(2, 0, 1).size()[1]))
            label_sizes = torch.IntTensor(
                torch.from_numpy(np.array(label_length)).int())
            labels = torch.IntTensor(torch.from_numpy(np.array(labels)).int())
            loss_ocr = ctc_loss(labels_pred.permute(2, 0,
                                                    1), labels, probs_sizes,
                                label_sizes) / im_data_ocr.size(0) * 0.5

            loss_ocr.backward()
            loss.backward()

            optimizer.step()
        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            pass
        cnt += 1
        if step % disp_interval == 0:

            if opts.debug:

                segm = seg_pred[0].data.cpu()[0].numpy()
                segm = segm.squeeze(0)
                cv2.imshow('segm_map', segm)

                segm_res = cv2.resize(score_maps[0],
                                      (images.shape[2], images.shape[1]))
                mask = np.argwhere(segm_res > 0)

                x_data = im_data.data.cpu().numpy()[0]
                x_data = x_data.swapaxes(0, 2)
                x_data = x_data.swapaxes(0, 1)

                x_data += 1
                x_data *= 128
                x_data = np.asarray(x_data, dtype=np.uint8)
                x_data = x_data[:, :, ::-1]

                im_show = x_data
                try:
                    im_show[mask[:, 0], mask[:, 1], 1] = 255
                    im_show[mask[:, 0], mask[:, 1], 0] = 0
                    im_show[mask[:, 0], mask[:, 1], 2] = 0
                except:
                    pass

                cv2.imshow('img0', im_show)
                cv2.imshow('score_maps', score_maps[0] * 255)
                cv2.imshow('train_mask', training_masks[0] * 255)
                cv2.waitKey(10)

            train_loss /= cnt
            bbox_loss /= cnt
            seg_loss /= cnt
            angle_loss /= cnt
            ctc_loss_val /= cnt
            box_loss_val /= cnt

            if train_loss < best_loss:
                best_step = step
                best_model = net.state_dict()
                best_loss = train_loss
                best_learning_rate = learning_rate
                best_optimizer = optimizer.state_dict()
            if best_step - step > max_patience:
                print("Early stopped criteria achieved.")
                save_name = os.path.join(
                    opts.save_path,
                    'BEST_{}_{}.h5'.format(model_name, best_step))
                state = {
                    'step': best_step,
                    'learning_rate': best_learning_rate,
                    'state_dict': best_model,
                    'optimizer': best_optimizer
                }
                torch.save(state, save_name)
                print('save model: {}'.format(save_name))
                opts.max_iters = step
                early_stop = True
            try:
                print(
                    'epoch %d[%d], loss: %.3f, bbox_loss: %.3f, seg_loss: %.3f, ang_loss: %.3f, ctc_loss: %.3f, rec: %.5f in %.3f'
                    % (step / batch_per_epoch, step, train_loss, bbox_loss,
                       seg_loss, angle_loss, ctc_loss_val,
                       good_all / max(1, gt_all), end - start))
                print('max_memory_allocated {}'.format(
                    torch.cuda.max_memory_allocated()))
            except:
                import sys, traceback
                traceback.print_exc(file=sys.stdout)
                pass

            train_loss = 0
            bbox_loss, seg_loss, angle_loss = 0., 0., 0.
            cnt = 0
            ctc_loss_val = 0
            good_all = 0
            gt_all = 0
            box_loss_val = 0

        #if step % valid_interval == 0:
        #  validate(opts.valid_list, net)
        if step > step_start and (step % batch_per_epoch == 0):
            save_name = os.path.join(opts.save_path,
                                     '{}_{}.h5'.format(model_name, step))
            state = {
                'step': step,
                'learning_rate': learning_rate,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict(),
                'max_memory_allocated': torch.cuda.max_memory_allocated()
            }
            torch.save(state, save_name)
            print('save model: {}\tmax memory: {}'.format(
                save_name, torch.cuda.max_memory_allocated()))
    if not early_stop:
        save_name = os.path.join(opts.save_path, '{}.h5'.format(model_name))
        state = {
            'step': step,
            'learning_rate': learning_rate,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        torch.save(state, save_name)
        print('save model: {}'.format(save_name))
    def forward(self, output, target, **kwargs):
        # type: (torch.Tensor, List[List[int]]) -> (FloatScalar, List[int])
        """
        Args:
            output: Size seqLength x outputDim, contains
                the output from the network as well as a list of size
                seqLength containing batch sizes of the sequence
            target: Contains the size of each output
                sequence from the network. Size batchSize
        """
        acts, act_lens = transform_output(output)

        assert act_lens[0] == acts.size(0), "Maximum length does not match"
        assert len(target) == acts.size(1), "Batch size does not match"

        valid_indices, err_indices = get_valids_and_errors(act_lens, target)
        if err_indices:
            if kwargs.get("batch_ids", None) is not None:
                assert isinstance(kwargs["batch_ids"], (list, tuple))
                err_indices = [kwargs["batch_ids"][i] for i in err_indices]
            _logger.warning(
                "The following samples in the batch were ignored for the loss "
                "computation: {}",
                err_indices,
            )

        if not valid_indices:
            _logger.warning("All samples in the batch were ignored!")
            return None

        # TODO(jpuigcerver): We need to change this because CTCPrepare.apply
        # will set requires_grad of *all* outputs to True if *any* of the
        # inputs requires_grad is True.
        acts, labels, act_lens, label_lens = CTCPrepare.apply(
            acts, target, act_lens, valid_indices if err_indices else None
        )
        labels = labels.detach()
        act_lens = act_lens.detach()
        label_lens = label_lens.detach()

        if self._add_logsoftmax:
            acts = torch.nn.functional.log_softmax(acts, dim=-1)

        if self._implementation == CTCLossImpl.PYTORCH:
            torch.backends.cudnn.enabled = False
            losses = torch.nn.functional.ctc_loss(
                log_probs=acts,
                targets=labels.to(acts.device),
                input_lengths=act_lens,
                target_lengths=label_lens,
                blank=self._blank,
                reduction="none",
            )
            torch.backends.cudnn.enabled = True

            if self._average_frames:
                losses = losses / act_lens.to(losses)

            if self._reduction == "none":
                return losses
            elif self._reduction == "mean":
                return losses.mean()
            elif self._reduction == "sum":
                return losses.sum()
            else:
                raise ValueError(
                    "Reduction {!r} not supported!".format(self._reduction)
                )
        elif self._implementation == CTCLossImpl.BAIDU:
            return torch_baidu_ctc.ctc_loss(
                acts=acts,
                labels=labels,
                acts_lens=act_lens,
                labels_lens=label_lens,
                reduction=self._reduction,
                average_frames=self._average_frames,
            )
        else:
            raise ValueError(
                "Unknown CTC implementation: {!r}".format(self._implementation)
            )
def train(
    model,
    epochs=150,
    batch_size=16,
    train_index_path="./train.index",
    dev_index_path="./dev.index",
    labels_path="./labels.json",
    learning_rate=0.1,
    momentum=0.8,
    max_grad_norm=0.2,
    weight_decay=0,
):
    train_dataset = data.MASRDataset(train_index_path, labels_path)
    batchs = (len(train_dataset) + batch_size - 1) // batch_size
    dev_dataset = data.MASRDataset(dev_index_path, labels_path)
    train_dataloader = data.MASRDataLoader(
        train_dataset, batch_size=batch_size, num_workers=8
    )
    train_dataloader_shuffle = data.MASRDataLoader(
        train_dataset, batch_size=batch_size, num_workers=8, shuffle=True
    )
    dev_dataloader = data.MASRDataLoader(
        dev_dataset, batch_size=batch_size, num_workers=8
    )
    parameters = model.parameters()
    # parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
    optimizer = torch.optim.SGD(
        parameters,
        lr=learning_rate,
        momentum=momentum,
        nesterov=True,
        weight_decay=weight_decay,
    )
    # ctcloss = CTCLoss(size_average=True)
    # ctcloss = nn.CTCLoss(reduction='mean')
    # lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.985)
    writer = tensorboard.SummaryWriter()
    gstep = 0
    for epoch in range(epochs):
        epoch_loss = 0
        if epoch > 0:
            train_dataloader = train_dataloader_shuffle
        # lr_sched.step()
        lr = get_lr(optimizer)
        writer.add_scalar("lr/epoch", lr, epoch)
        for i, (x, y, x_lens, y_lens) in enumerate(train_dataloader):
            x = x.to(device)
            out, out_lens = model(x, x_lens)
            out = out.transpose(0, 1).transpose(0, 2)
            # loss = ctcloss(out, y, out_lens, y_lens)
            loss = ctc_loss(out, y, out_lens, y_lens, reduction="mean")
            # loss = ctcloss(nn.functional.log_softmax(out), y, out_lens, y_lens)
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            epoch_loss += loss.item()
            writer.add_scalar("loss/step", loss.item(), gstep)
            gstep += 1
            print(
                "[{}/{}][{}/{}]\tLoss = {}".format(
                    epoch + 1, epochs, i, int(batchs), loss.item()
                )
            )
        epoch_loss = epoch_loss / batchs
        cer = eval(model, dev_dataloader)
        writer.add_scalar("loss/epoch", epoch_loss, epoch)
        writer.add_scalar("cer/epoch", cer, epoch)
        print("Epoch {}: Loss= {}, CER = {}".format(epoch, epoch_loss, cer))
        torch.save(model, "{}/model_{}.pth".format(save_path,epoch))
Beispiel #7
0
        4,
        4,
        2,
        3,
    ],
    dtype=torch.int,
)
# Activations lengths
xs = torch.tensor([10, 6, 9], dtype=torch.int)
# Target lengths
ys = torch.tensor([5, 3, 4], dtype=torch.int)

# By default, the costs (negative log-likelihood) of all samples are summed.
# This is equivalent to:
#   ctc_loss(x, y, xs, ys, average_frames=False, reduction="sum")
loss1 = ctc_loss(x, y, xs, ys)

# You can also average the cost of each sample among the number of frames.
# The averaged costs are then summed.
loss2 = ctc_loss(x, y, xs, ys, average_frames=True)

# Instead of summing the costs of each sample, you can perform
# other `reductions`: "none", "sum", or "mean"
#
# Return an array with the loss of each individual sample
losses = ctc_loss(x, y, xs, ys, reduction="none")
#
# Compute the mean of the individual losses
loss3 = ctc_loss(x, y, xs, ys, reduction="mean")
#
# First, normalize loss by number of frames, later average losses
Beispiel #8
0
    sampler.shuffle(epoch)

    model.train()

    err = AverageMeter('loss')
    grd = AverageMeter('gradient')

    progress = tqdm(train)
    for xs, ys, xn, yn in progress:

        optimizer.zero_grad()

        xs, xn = model(xs.cuda(non_blocking=True), xn)
        xs = log_softmax(xs, dim=-1)

        loss = ctc_loss(xs, ys, xn, yn, average_frames=False, reduction="mean")
        loss.backward()

        grad_norm = nn.utils.clip_grad_norm_(model.parameters(), 100)

        optimizer.step()
        scheduler.step()

        err.update(loss.item())
        grd.update(grad_norm)

        lr = scheduler.get_lr()[0]

        progress.set_description('epoch %d %.6f %s %s' %
                                 (epoch + 1, lr, err, grd))
Beispiel #9
0
 def f_(x_):
     loss = torch_baidu_ctc.ctc_loss(
         x_, y, xs, ys, average_frames=average_frames, reduction=reduction
     )
     return torch.sum(loss / 2.0)
def main(opts):

    model_name = 'OctGatedMLT'
    net = OctMLT(attention=True)
    acc = []

    if opts.cuda:
        net.cuda()

    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=base_lr,
                                 weight_decay=weight_decay)
    step_start = 0
    if os.path.exists(opts.model):
        print('loading model from %s' % args.model)
        step_start, learning_rate = net_utils.load_net(
            args.model,
            net,
            optimizer,
            load_ocr=opts.load_ocr,
            load_detection=opts.load_detection,
            load_shared=opts.load_shared,
            load_optimizer=opts.load_optimizer,
            reset_step=opts.load_reset_step)
    else:
        learning_rate = base_lr

    step_start = 0

    net.train()

    if opts.freeze_shared:
        net_utils.freeze_shared(net)

    if opts.freeze_ocr:
        net_utils.freeze_ocr(net)

    if opts.freeze_detection:
        net_utils.freeze_detection(net)

    #acc_test = test(net, codec, opts, list_file=opts.valid_list, norm_height=opts.norm_height)
    #acc.append([0, acc_test])
    ctc_loss = CTCLoss()

    data_generator = ocr_gen.get_batch(num_workers=opts.num_readers,
                                       batch_size=opts.batch_size,
                                       train_list=opts.train_list,
                                       in_train=True,
                                       norm_height=opts.norm_height,
                                       rgb=True)

    train_loss = 0
    cnt = 0

    for step in range(step_start, 300000):
        # batch
        images, labels, label_length = next(data_generator)
        im_data = net_utils.np_to_variable(images, is_cuda=opts.cuda).permute(
            0, 3, 1, 2)
        features = net.forward_features(im_data)
        labels_pred = net.forward_ocr(features)

        # backward
        '''
    acts: Tensor of (seqLength x batch x outputDim) containing output from network
        labels: 1 dimensional Tensor containing all the targets of the batch in one sequence
        act_lens: Tensor of size (batch) containing size of each output sequence from the network
        act_lens: Tensor of (batch) containing label length of each example
    '''

        probs_sizes = torch.IntTensor(
            [(labels_pred.permute(2, 0, 1).size()[0])] *
            (labels_pred.permute(2, 0, 1).size()[1]))
        label_sizes = torch.IntTensor(
            torch.from_numpy(np.array(label_length)).int())
        labels = torch.IntTensor(torch.from_numpy(np.array(labels)).int())
        loss = ctc_loss(labels_pred.permute(2, 0, 1), labels, probs_sizes,
                        label_sizes) / im_data.size(0)  # change 1.9.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if not np.isinf(loss.data.cpu().numpy()):
            train_loss += loss.data.cpu().numpy()[0] if isinstance(
                loss.data.cpu().numpy(), list) else loss.data.cpu().numpy(
                )  #net.bbox_loss.data.cpu().numpy()[0]
            cnt += 1

        if opts.debug:
            dbg = labels_pred.data.cpu().numpy()
            ctc_f = dbg.swapaxes(1, 2)
            labels = ctc_f.argmax(2)
            det_text, conf, dec_s = print_seq_ext(labels[0, :], codec)

            print('{0} \t'.format(det_text))

        if step % disp_interval == 0:

            train_loss /= cnt
            print('epoch %d[%d], loss: %.3f, lr: %.5f ' %
                  (step / batch_per_epoch, step, train_loss, learning_rate))

            train_loss = 0
            cnt = 0

        if step > step_start and (step % batch_per_epoch == 0):
            save_name = os.path.join(opts.save_path,
                                     '{}_{}.h5'.format(model_name, step))
            state = {
                'step': step,
                'learning_rate': learning_rate,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(state, save_name)
            print('save model: {}'.format(save_name))

            #acc_test, ted = test(net, codec, opts,  list_file=opts.valid_list, norm_height=opts.norm_height)
            #acc.append([0, acc_test, ted])
            np.savez('train_acc_{0}'.format(model_name), acc=acc)