def _run_test_forward(dtype, device, average_frames, reduction): x, y, xs, ys, expected = CTCLossTest._create_test_data( dtype, device, average_frames, reduction ) # Test function loss = torch_baidu_ctc.ctc_loss( x, y, xs, ys, average_frames=average_frames, reduction=reduction ) np.testing.assert_array_almost_equal(loss.cpu(), expected.cpu()) # Test module loss = torch_baidu_ctc.CTCLoss( average_frames=average_frames, reduction=reduction )(x, y, xs, ys) np.testing.assert_array_almost_equal(loss.cpu(), expected.cpu())
def forward(self, output, target, **kwargs): # type: (torch.Tensor, List[List[int]]) -> (FloatScalar, List[int]) """ Args: output: Size seqLength x outputDim, contains the output from the network as well as a list of size seqLength containing batch sizes of the sequence target: Contains the size of each output sequence from the network. Size batchSize """ acts, act_lens = transform_output(output) assert act_lens[0] == acts.size(0), "Maximum length does not match" assert len(target) == acts.size(1), "Batch size does not match" valid_indices, err_indices = get_valids_and_errors(act_lens, target) if err_indices: if kwargs.get("batch_ids", None) is not None: assert isinstance(kwargs["batch_ids"], (list, tuple)) err_indices = [kwargs["batch_ids"][i] for i in err_indices] _logger.warning( "The following samples in the batch were ignored for the loss " "computation: {}", err_indices, ) if not valid_indices: _logger.warning("All samples in the batch were ignored!") return None # TODO(jpuigcerver): We need to change this because CTCPrepare.apply # will set requires_grad of *all* outputs to True if *any* of the # inputs requires_grad is True. acts, labels, act_lens, label_lens = CTCPrepare.apply( acts, target, act_lens, valid_indices if err_indices else None ) # TODO(jpuigcerver): Remove the detach() once the previous TODO is # fixed. return ctc_loss( acts=acts, labels=labels.detach(), acts_lens=act_lens.detach(), labels_lens=label_lens.detach(), reduction=self._reduction, average_frames=self._average_frames, )
def process_boxes(images, im_data, iou_pred, roi_pred, angle_pred, score_maps, gt_idxs, gtso, lbso, features, net, ctc_loss, opts, debug=False): ctc_loss_count = 0 loss = torch.from_numpy(np.asarray([0])).type(torch.FloatTensor).cuda() for bid in range(iou_pred.size(0)): gts = gtso[bid] lbs = lbso[bid] gt_proc = 0 gt_good = 0 gts_count = {} iou_pred_np = iou_pred[bid].data.cpu().numpy() iou_map = score_maps[bid] to_walk = iou_pred_np.squeeze(0) * iou_map * (iou_pred_np.squeeze(0) > 0.5) roi_p_bid = roi_pred[bid].data.cpu().numpy() gt_idx = gt_idxs[bid] if debug: img = images[bid] img += 1 img *= 128 img = np.asarray(img, dtype=np.uint8) xy_text = np.argwhere(to_walk > 0) random.shuffle(xy_text) xy_text = xy_text[0:min(xy_text.shape[0], 100)] for i in range(0, xy_text.shape[0]): if opts.geo_type == 1: break pos = xy_text[i, :] gt_id = gt_idx[pos[0], pos[1]] if not gt_id in gts_count: gts_count[gt_id] = 0 if gts_count[gt_id] > 2: continue gt = gts[gt_id] gt_txt = lbs[gt_id] if gt_txt.startswith('##'): continue angle_sin = angle_pred[bid, 0, pos[0], pos[1]] angle_cos = angle_pred[bid, 1, pos[0], pos[1]] angle = math.atan2(angle_sin, angle_cos) angle_gt = (math.atan2( (gt[2][1] - gt[1][1]), gt[2][0] - gt[1][0]) + math.atan2( (gt[3][1] - gt[0][1]), gt[3][0] - gt[0][0])) / 2 if math.fabs(angle_gt - angle) > math.pi / 16: continue offset = roi_p_bid[:, pos[0], pos[1]] posp = pos + 0.25 pos_g = np.array([(posp[1] - offset[0] * math.sin(angle)) * 4, (posp[0] - offset[0] * math.cos(angle)) * 4]) pos_g2 = np.array([(posp[1] + offset[1] * math.sin(angle)) * 4, (posp[0] + offset[1] * math.cos(angle)) * 4]) pos_r = np.array([(posp[1] - offset[2] * math.cos(angle)) * 4, (posp[0] - offset[2] * math.sin(angle)) * 4]) pos_r2 = np.array([(posp[1] + offset[3] * math.cos(angle)) * 4, (posp[0] + offset[3] * math.sin(angle)) * 4]) center = (pos_g + pos_g2 + pos_r + pos_r2) / 2 - [ 4 * pos[1], 4 * pos[0] ] #center = (pos_g + pos_g2 + pos_r + pos_r2) / 4 dw = pos_r - pos_r2 dh = pos_g - pos_g2 w = math.sqrt(dw[0] * dw[0] + dw[1] * dw[1]) h = math.sqrt(dh[0] * dh[0] + dh[1] * dh[1]) dhgt = gt[1] - gt[0] h_gt = math.sqrt(dhgt[0] * dhgt[0] + dhgt[1] * dhgt[1]) if h_gt < 10: continue rect = ((center[0], center[1]), (w, h), angle * 180 / math.pi) pts = cv2.boxPoints(rect) pred_bbox = cv2.boundingRect(pts) pred_bbox = [ pred_bbox[0], pred_bbox[1], pred_bbox[2], pred_bbox[3] ] pred_bbox[2] += pred_bbox[0] pred_bbox[3] += pred_bbox[1] if gt[:, 0].max() > im_data.size(3) or gt[:, 1].max() > im_data.size(3): continue gt_bbox = [ gt[:, 0].min(), gt[:, 1].min(), gt[:, 0].max(), gt[:, 1].max() ] inter = intersect(pred_bbox, gt_bbox) uni = union(pred_bbox, gt_bbox) ratio = area(inter) / float(area(uni)) if ratio < 0.90: continue hratio = min(h, h_gt) / max(h, h_gt) if hratio < 0.5: continue input_W = im_data.size(3) input_H = im_data.size(2) target_h = norm_height scale = target_h / h target_gw = (int(w * scale) + target_h) target_gw = max(8, int(round(target_gw / 4)) * 4) #show pooled image in image layer scalex = (w + h) / input_W scaley = h / input_H th11 = scalex * math.cos(angle) th12 = -math.sin(angle) * scaley th13 = (2 * center[0] - input_W - 1) / ( input_W - 1 ) #* torch.cos(angle_var) - (2 * yc - input_H - 1) / (input_H - 1) * torch.sin(angle_var) th21 = math.sin(angle) * scalex th22 = scaley * math.cos(angle) th23 = (2 * center[1] - input_H - 1) / ( input_H - 1 ) #* torch.cos(angle_var) + (2 * xc - input_W - 1) / (input_W - 1) * torch.sin(angle_var) t = np.asarray([th11, th12, th13, th21, th22, th23], dtype=np.float) t = torch.from_numpy(t).type(torch.FloatTensor).cuda() #t = torch.stack((th11, th12, th13, th21, th22, th23), dim=1) theta = t.view(-1, 2, 3) grid = F.affine_grid( theta, torch.Size((1, 3, int(target_h), int(target_gw)))) x = F.grid_sample(im_data[bid].unsqueeze(0), grid) if debug: x_c = x.data.cpu().numpy()[0] x_data_draw = x_c.swapaxes(0, 2) x_data_draw = x_data_draw.swapaxes(0, 1) x_data_draw += 1 x_data_draw *= 128 x_data_draw = np.asarray(x_data_draw, dtype=np.uint8) x_data_draw = x_data_draw[:, :, ::-1] cv2.circle(img, (int(center[0]), int(center[1])), 5, (0, 255, 0)) cv2.imshow('im_data', x_data_draw) draw_box_points(img, pts) draw_box_points(img, gt, color=(0, 0, 255)) cv2.imshow('img', img) cv2.waitKey(100) gt_labels = [] gt_labels.append(codec_rev[' ']) for k in range(len(gt_txt)): if gt_txt[k] in codec_rev: gt_labels.append(codec_rev[gt_txt[k]]) else: print('Unknown char: {0}'.format(gt_txt[k])) gt_labels.append(3) if 'ARABIC' in ud.name(gt_txt[0]): gt_labels = gt_labels[::-1] gt_labels.append(codec_rev[' ']) features = net.forward_features(x) labels_pred = net.forward_ocr(features) label_length = [] label_length.append(len(gt_labels)) probs_sizes = autograd.Variable( torch.IntTensor([(labels_pred.permute(2, 0, 1).size()[0])] * (labels_pred.permute(2, 0, 1).size()[1]))) label_sizes = autograd.Variable( torch.IntTensor( torch.from_numpy(np.array(label_length)).int())) labels = autograd.Variable( torch.IntTensor(torch.from_numpy(np.array(gt_labels)).int())) loss = loss + ctc_loss(labels_pred.permute(2, 0, 1), labels, probs_sizes, label_sizes).cuda() ctc_loss_count += 1 if debug: ctc_f = labels_pred.data.cpu().numpy() ctc_f = ctc_f.swapaxes(1, 2) labels = ctc_f.argmax(2) det_text, conf, dec_s, splits = print_seq_ext( labels[0, :], codec) print('{0} \t {1}'.format(det_text, gt_txt)) gts_count[gt_id] += 1 if ctc_loss_count > 64 or debug: break for gt_id in range(0, len(gts)): gt = gts[gt_id] gt_txt = lbs[gt_id] gt_txt_low = gt_txt.lower() if gt_txt.startswith('##'): continue if gt[:, 0].max() > im_data.size(3) or gt[:, 1].max() > im_data.size(3): continue if gt.min() < 0: continue center = (gt[0, :] + gt[1, :] + gt[2, :] + gt[3, :]) / 4 dw = gt[2, :] - gt[1, :] dh = gt[1, :] - gt[0, :] w = math.sqrt(dw[0] * dw[0] + dw[1] * dw[1]) h = math.sqrt(dh[0] * dh[0] + dh[1] * dh[1]) + random.randint( -2, 2) if h < 8: #print('too small h!') continue angle_gt = (math.atan2( (gt[2][1] - gt[1][1]), gt[2][0] - gt[1][0]) + math.atan2( (gt[3][1] - gt[0][1]), gt[3][0] - gt[0][0])) / 2 input_W = im_data.size(3) input_H = im_data.size(2) target_h = norm_height scale = target_h / h target_gw = int(w * scale) + random.randint(0, int(target_h)) target_gw = max(8, int(round(target_gw / 4)) * 4) xc = center[0] yc = center[1] w2 = w h2 = h #show pooled image in image layer scalex = (w2 + random.randint(0, int(h2))) / input_W scaley = h2 / input_H th11 = scalex * math.cos(angle_gt) th12 = -math.sin(angle_gt) * scaley th13 = (2 * xc - input_W - 1) / ( input_W - 1 ) #* torch.cos(angle_var) - (2 * yc - input_H - 1) / (input_H - 1) * torch.sin(angle_var) th21 = math.sin(angle_gt) * scalex th22 = scaley * math.cos(angle_gt) th23 = (2 * yc - input_H - 1) / ( input_H - 1 ) #* torch.cos(angle_var) + (2 * xc - input_W - 1) / (input_W - 1) * torch.sin(angle_var) t = np.asarray([th11, th12, th13, th21, th22, th23], dtype=np.float) t = torch.from_numpy(t).type(torch.FloatTensor) t = t.cuda() theta = t.view(-1, 2, 3) grid = F.affine_grid( theta, torch.Size((1, 3, int(target_h), int(target_gw)))) x = F.grid_sample(im_data[bid].unsqueeze(0), grid) #score_sampled = F.grid_sample(iou_pred[bid].unsqueeze(0), grid) gt_labels = [] gt_labels.append(codec_rev[' ']) for k in range(len(gt_txt)): if gt_txt[k] in codec_rev: gt_labels.append(codec_rev[gt_txt[k]]) else: print('Unknown char: {0}'.format(gt_txt[k])) gt_labels.append(3) gt_labels.append(codec_rev[' ']) if 'ARABIC' in ud.name(gt_txt[0]): gt_labels = gt_labels[::-1] features = net.forward_features(x) labels_pred = net.forward_ocr(features) label_length = [] label_length.append(len(gt_labels)) probs_sizes = torch.IntTensor( [(labels_pred.permute(2, 0, 1).size()[0])] * (labels_pred.permute(2, 0, 1).size()[1])) label_sizes = torch.IntTensor( torch.from_numpy(np.array(label_length)).int()) labels = torch.IntTensor( torch.from_numpy(np.array(gt_labels)).int()) loss = loss + ctc_loss(labels_pred.permute(2, 0, 1), labels, probs_sizes, label_sizes).cuda() ctc_loss_count += 1 if debug: x_d = x.data.cpu().numpy()[0] x_data_draw = x_d.swapaxes(0, 2) x_data_draw = x_data_draw.swapaxes(0, 1) x_data_draw += 1 x_data_draw *= 128 x_data_draw = np.asarray(x_data_draw, dtype=np.uint8) x_data_draw = x_data_draw[:, :, ::-1] cv2.imshow('im_data_gt', x_data_draw) cv2.waitKey(100) gt_proc += 1 if True: ctc_f = labels_pred.data.cpu().numpy() ctc_f = ctc_f.swapaxes(1, 2) labels = ctc_f.argmax(2) det_text, conf, dec_s, splits = print_seq_ext( labels[0, :], codec) if debug: print('{0} \t {1}'.format(det_text, gt_txt)) if det_text.lower() == gt_txt.lower(): gt_good += 1 if ctc_loss_count > 128 or debug: break if ctc_loss_count > 0: loss /= ctc_loss_count return loss, gt_good, gt_proc
def main(opts): model_name = 'OCT-E2E-MLT' net = OctMLT(attention=True) print("Using {0}".format(model_name)) learning_rate = opts.base_lr optimizer = torch.optim.Adam(net.parameters(), lr=opts.base_lr, weight_decay=weight_decay) step_start = 0 if os.path.exists(opts.model): print('loading model from %s' % args.model) step_start, learning_rate = net_utils.load_net(args.model, net) if opts.cuda: net.cuda() net.train() data_generator = data_gen.get_batch(num_workers=opts.num_readers, input_size=opts.input_size, batch_size=opts.batch_size, train_list=opts.train_list, geo_type=opts.geo_type) dg_ocr = ocr_gen.get_batch(num_workers=2, batch_size=opts.ocr_batch_size, train_list=opts.ocr_feed_list, in_train=True, norm_height=norm_height, rgb=True) train_loss = 0 bbox_loss, seg_loss, angle_loss = 0., 0., 0. cnt = 0 ctc_loss = CTCLoss() ctc_loss_val = 0 box_loss_val = 0 good_all = 0 gt_all = 0 best_step = step_start best_loss = 1000000 best_model = net.state_dict() best_optimizer = optimizer.state_dict() best_learning_rate = learning_rate max_patience = 3000 early_stop = False for step in range(step_start, opts.max_iters): # batch images, image_fns, score_maps, geo_maps, training_masks, gtso, lbso, gt_idxs = next( data_generator) im_data = net_utils.np_to_variable(images, is_cuda=opts.cuda).permute( 0, 3, 1, 2) start = timeit.timeit() try: seg_pred, roi_pred, angle_pred, features = net(im_data) except: import sys, traceback traceback.print_exc(file=sys.stdout) continue end = timeit.timeit() # backward smaps_var = net_utils.np_to_variable(score_maps, is_cuda=opts.cuda) training_mask_var = net_utils.np_to_variable(training_masks, is_cuda=opts.cuda) angle_gt = net_utils.np_to_variable(geo_maps[:, :, :, 4], is_cuda=opts.cuda) geo_gt = net_utils.np_to_variable(geo_maps[:, :, :, [0, 1, 2, 3]], is_cuda=opts.cuda) try: loss = net.loss(seg_pred, smaps_var, training_mask_var, angle_pred, angle_gt, roi_pred, geo_gt) except: import sys, traceback traceback.print_exc(file=sys.stdout) continue bbox_loss += net.box_loss_value.data.cpu().numpy() seg_loss += net.segm_loss_value.data.cpu().numpy() angle_loss += net.angle_loss_value.data.cpu().numpy() train_loss += loss.data.cpu().numpy() optimizer.zero_grad() try: if step > 10000: #this is just extra augumentation step ... in early stage just slows down training ctcl, gt_b_good, gt_b_all = process_boxes(images, im_data, seg_pred[0], roi_pred[0], angle_pred[0], score_maps, gt_idxs, gtso, lbso, features, net, ctc_loss, opts, debug=opts.debug) ctc_loss_val += ctcl.data.cpu().numpy()[0] loss = loss + ctcl gt_all += gt_b_all good_all += gt_b_good imageso, labels, label_length = next(dg_ocr) im_data_ocr = net_utils.np_to_variable(imageso, is_cuda=opts.cuda).permute( 0, 3, 1, 2) features = net.forward_features(im_data_ocr) labels_pred = net.forward_ocr(features) probs_sizes = torch.IntTensor( [(labels_pred.permute(2, 0, 1).size()[0])] * (labels_pred.permute(2, 0, 1).size()[1])) label_sizes = torch.IntTensor( torch.from_numpy(np.array(label_length)).int()) labels = torch.IntTensor(torch.from_numpy(np.array(labels)).int()) loss_ocr = ctc_loss(labels_pred.permute(2, 0, 1), labels, probs_sizes, label_sizes) / im_data_ocr.size(0) * 0.5 loss_ocr.backward() loss.backward() optimizer.step() except: import sys, traceback traceback.print_exc(file=sys.stdout) pass cnt += 1 if step % disp_interval == 0: if opts.debug: segm = seg_pred[0].data.cpu()[0].numpy() segm = segm.squeeze(0) cv2.imshow('segm_map', segm) segm_res = cv2.resize(score_maps[0], (images.shape[2], images.shape[1])) mask = np.argwhere(segm_res > 0) x_data = im_data.data.cpu().numpy()[0] x_data = x_data.swapaxes(0, 2) x_data = x_data.swapaxes(0, 1) x_data += 1 x_data *= 128 x_data = np.asarray(x_data, dtype=np.uint8) x_data = x_data[:, :, ::-1] im_show = x_data try: im_show[mask[:, 0], mask[:, 1], 1] = 255 im_show[mask[:, 0], mask[:, 1], 0] = 0 im_show[mask[:, 0], mask[:, 1], 2] = 0 except: pass cv2.imshow('img0', im_show) cv2.imshow('score_maps', score_maps[0] * 255) cv2.imshow('train_mask', training_masks[0] * 255) cv2.waitKey(10) train_loss /= cnt bbox_loss /= cnt seg_loss /= cnt angle_loss /= cnt ctc_loss_val /= cnt box_loss_val /= cnt if train_loss < best_loss: best_step = step best_model = net.state_dict() best_loss = train_loss best_learning_rate = learning_rate best_optimizer = optimizer.state_dict() if best_step - step > max_patience: print("Early stopped criteria achieved.") save_name = os.path.join( opts.save_path, 'BEST_{}_{}.h5'.format(model_name, best_step)) state = { 'step': best_step, 'learning_rate': best_learning_rate, 'state_dict': best_model, 'optimizer': best_optimizer } torch.save(state, save_name) print('save model: {}'.format(save_name)) opts.max_iters = step early_stop = True try: print( 'epoch %d[%d], loss: %.3f, bbox_loss: %.3f, seg_loss: %.3f, ang_loss: %.3f, ctc_loss: %.3f, rec: %.5f in %.3f' % (step / batch_per_epoch, step, train_loss, bbox_loss, seg_loss, angle_loss, ctc_loss_val, good_all / max(1, gt_all), end - start)) print('max_memory_allocated {}'.format( torch.cuda.max_memory_allocated())) except: import sys, traceback traceback.print_exc(file=sys.stdout) pass train_loss = 0 bbox_loss, seg_loss, angle_loss = 0., 0., 0. cnt = 0 ctc_loss_val = 0 good_all = 0 gt_all = 0 box_loss_val = 0 #if step % valid_interval == 0: # validate(opts.valid_list, net) if step > step_start and (step % batch_per_epoch == 0): save_name = os.path.join(opts.save_path, '{}_{}.h5'.format(model_name, step)) state = { 'step': step, 'learning_rate': learning_rate, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict(), 'max_memory_allocated': torch.cuda.max_memory_allocated() } torch.save(state, save_name) print('save model: {}\tmax memory: {}'.format( save_name, torch.cuda.max_memory_allocated())) if not early_stop: save_name = os.path.join(opts.save_path, '{}.h5'.format(model_name)) state = { 'step': step, 'learning_rate': learning_rate, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict() } torch.save(state, save_name) print('save model: {}'.format(save_name))
def forward(self, output, target, **kwargs): # type: (torch.Tensor, List[List[int]]) -> (FloatScalar, List[int]) """ Args: output: Size seqLength x outputDim, contains the output from the network as well as a list of size seqLength containing batch sizes of the sequence target: Contains the size of each output sequence from the network. Size batchSize """ acts, act_lens = transform_output(output) assert act_lens[0] == acts.size(0), "Maximum length does not match" assert len(target) == acts.size(1), "Batch size does not match" valid_indices, err_indices = get_valids_and_errors(act_lens, target) if err_indices: if kwargs.get("batch_ids", None) is not None: assert isinstance(kwargs["batch_ids"], (list, tuple)) err_indices = [kwargs["batch_ids"][i] for i in err_indices] _logger.warning( "The following samples in the batch were ignored for the loss " "computation: {}", err_indices, ) if not valid_indices: _logger.warning("All samples in the batch were ignored!") return None # TODO(jpuigcerver): We need to change this because CTCPrepare.apply # will set requires_grad of *all* outputs to True if *any* of the # inputs requires_grad is True. acts, labels, act_lens, label_lens = CTCPrepare.apply( acts, target, act_lens, valid_indices if err_indices else None ) labels = labels.detach() act_lens = act_lens.detach() label_lens = label_lens.detach() if self._add_logsoftmax: acts = torch.nn.functional.log_softmax(acts, dim=-1) if self._implementation == CTCLossImpl.PYTORCH: torch.backends.cudnn.enabled = False losses = torch.nn.functional.ctc_loss( log_probs=acts, targets=labels.to(acts.device), input_lengths=act_lens, target_lengths=label_lens, blank=self._blank, reduction="none", ) torch.backends.cudnn.enabled = True if self._average_frames: losses = losses / act_lens.to(losses) if self._reduction == "none": return losses elif self._reduction == "mean": return losses.mean() elif self._reduction == "sum": return losses.sum() else: raise ValueError( "Reduction {!r} not supported!".format(self._reduction) ) elif self._implementation == CTCLossImpl.BAIDU: return torch_baidu_ctc.ctc_loss( acts=acts, labels=labels, acts_lens=act_lens, labels_lens=label_lens, reduction=self._reduction, average_frames=self._average_frames, ) else: raise ValueError( "Unknown CTC implementation: {!r}".format(self._implementation) )
def train( model, epochs=150, batch_size=16, train_index_path="./train.index", dev_index_path="./dev.index", labels_path="./labels.json", learning_rate=0.1, momentum=0.8, max_grad_norm=0.2, weight_decay=0, ): train_dataset = data.MASRDataset(train_index_path, labels_path) batchs = (len(train_dataset) + batch_size - 1) // batch_size dev_dataset = data.MASRDataset(dev_index_path, labels_path) train_dataloader = data.MASRDataLoader( train_dataset, batch_size=batch_size, num_workers=8 ) train_dataloader_shuffle = data.MASRDataLoader( train_dataset, batch_size=batch_size, num_workers=8, shuffle=True ) dev_dataloader = data.MASRDataLoader( dev_dataset, batch_size=batch_size, num_workers=8 ) parameters = model.parameters() # parameters = list(filter(lambda p: p.requires_grad, model.parameters())) optimizer = torch.optim.SGD( parameters, lr=learning_rate, momentum=momentum, nesterov=True, weight_decay=weight_decay, ) # ctcloss = CTCLoss(size_average=True) # ctcloss = nn.CTCLoss(reduction='mean') # lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.985) writer = tensorboard.SummaryWriter() gstep = 0 for epoch in range(epochs): epoch_loss = 0 if epoch > 0: train_dataloader = train_dataloader_shuffle # lr_sched.step() lr = get_lr(optimizer) writer.add_scalar("lr/epoch", lr, epoch) for i, (x, y, x_lens, y_lens) in enumerate(train_dataloader): x = x.to(device) out, out_lens = model(x, x_lens) out = out.transpose(0, 1).transpose(0, 2) # loss = ctcloss(out, y, out_lens, y_lens) loss = ctc_loss(out, y, out_lens, y_lens, reduction="mean") # loss = ctcloss(nn.functional.log_softmax(out), y, out_lens, y_lens) optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() epoch_loss += loss.item() writer.add_scalar("loss/step", loss.item(), gstep) gstep += 1 print( "[{}/{}][{}/{}]\tLoss = {}".format( epoch + 1, epochs, i, int(batchs), loss.item() ) ) epoch_loss = epoch_loss / batchs cer = eval(model, dev_dataloader) writer.add_scalar("loss/epoch", epoch_loss, epoch) writer.add_scalar("cer/epoch", cer, epoch) print("Epoch {}: Loss= {}, CER = {}".format(epoch, epoch_loss, cer)) torch.save(model, "{}/model_{}.pth".format(save_path,epoch))
4, 4, 2, 3, ], dtype=torch.int, ) # Activations lengths xs = torch.tensor([10, 6, 9], dtype=torch.int) # Target lengths ys = torch.tensor([5, 3, 4], dtype=torch.int) # By default, the costs (negative log-likelihood) of all samples are summed. # This is equivalent to: # ctc_loss(x, y, xs, ys, average_frames=False, reduction="sum") loss1 = ctc_loss(x, y, xs, ys) # You can also average the cost of each sample among the number of frames. # The averaged costs are then summed. loss2 = ctc_loss(x, y, xs, ys, average_frames=True) # Instead of summing the costs of each sample, you can perform # other `reductions`: "none", "sum", or "mean" # # Return an array with the loss of each individual sample losses = ctc_loss(x, y, xs, ys, reduction="none") # # Compute the mean of the individual losses loss3 = ctc_loss(x, y, xs, ys, reduction="mean") # # First, normalize loss by number of frames, later average losses
sampler.shuffle(epoch) model.train() err = AverageMeter('loss') grd = AverageMeter('gradient') progress = tqdm(train) for xs, ys, xn, yn in progress: optimizer.zero_grad() xs, xn = model(xs.cuda(non_blocking=True), xn) xs = log_softmax(xs, dim=-1) loss = ctc_loss(xs, ys, xn, yn, average_frames=False, reduction="mean") loss.backward() grad_norm = nn.utils.clip_grad_norm_(model.parameters(), 100) optimizer.step() scheduler.step() err.update(loss.item()) grd.update(grad_norm) lr = scheduler.get_lr()[0] progress.set_description('epoch %d %.6f %s %s' % (epoch + 1, lr, err, grd))
def f_(x_): loss = torch_baidu_ctc.ctc_loss( x_, y, xs, ys, average_frames=average_frames, reduction=reduction ) return torch.sum(loss / 2.0)
def main(opts): model_name = 'OctGatedMLT' net = OctMLT(attention=True) acc = [] if opts.cuda: net.cuda() optimizer = torch.optim.Adam(net.parameters(), lr=base_lr, weight_decay=weight_decay) step_start = 0 if os.path.exists(opts.model): print('loading model from %s' % args.model) step_start, learning_rate = net_utils.load_net( args.model, net, optimizer, load_ocr=opts.load_ocr, load_detection=opts.load_detection, load_shared=opts.load_shared, load_optimizer=opts.load_optimizer, reset_step=opts.load_reset_step) else: learning_rate = base_lr step_start = 0 net.train() if opts.freeze_shared: net_utils.freeze_shared(net) if opts.freeze_ocr: net_utils.freeze_ocr(net) if opts.freeze_detection: net_utils.freeze_detection(net) #acc_test = test(net, codec, opts, list_file=opts.valid_list, norm_height=opts.norm_height) #acc.append([0, acc_test]) ctc_loss = CTCLoss() data_generator = ocr_gen.get_batch(num_workers=opts.num_readers, batch_size=opts.batch_size, train_list=opts.train_list, in_train=True, norm_height=opts.norm_height, rgb=True) train_loss = 0 cnt = 0 for step in range(step_start, 300000): # batch images, labels, label_length = next(data_generator) im_data = net_utils.np_to_variable(images, is_cuda=opts.cuda).permute( 0, 3, 1, 2) features = net.forward_features(im_data) labels_pred = net.forward_ocr(features) # backward ''' acts: Tensor of (seqLength x batch x outputDim) containing output from network labels: 1 dimensional Tensor containing all the targets of the batch in one sequence act_lens: Tensor of size (batch) containing size of each output sequence from the network act_lens: Tensor of (batch) containing label length of each example ''' probs_sizes = torch.IntTensor( [(labels_pred.permute(2, 0, 1).size()[0])] * (labels_pred.permute(2, 0, 1).size()[1])) label_sizes = torch.IntTensor( torch.from_numpy(np.array(label_length)).int()) labels = torch.IntTensor(torch.from_numpy(np.array(labels)).int()) loss = ctc_loss(labels_pred.permute(2, 0, 1), labels, probs_sizes, label_sizes) / im_data.size(0) # change 1.9. optimizer.zero_grad() loss.backward() optimizer.step() if not np.isinf(loss.data.cpu().numpy()): train_loss += loss.data.cpu().numpy()[0] if isinstance( loss.data.cpu().numpy(), list) else loss.data.cpu().numpy( ) #net.bbox_loss.data.cpu().numpy()[0] cnt += 1 if opts.debug: dbg = labels_pred.data.cpu().numpy() ctc_f = dbg.swapaxes(1, 2) labels = ctc_f.argmax(2) det_text, conf, dec_s = print_seq_ext(labels[0, :], codec) print('{0} \t'.format(det_text)) if step % disp_interval == 0: train_loss /= cnt print('epoch %d[%d], loss: %.3f, lr: %.5f ' % (step / batch_per_epoch, step, train_loss, learning_rate)) train_loss = 0 cnt = 0 if step > step_start and (step % batch_per_epoch == 0): save_name = os.path.join(opts.save_path, '{}_{}.h5'.format(model_name, step)) state = { 'step': step, 'learning_rate': learning_rate, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict() } torch.save(state, save_name) print('save model: {}'.format(save_name)) #acc_test, ted = test(net, codec, opts, list_file=opts.valid_list, norm_height=opts.norm_height) #acc.append([0, acc_test, ted]) np.savez('train_acc_{0}'.format(model_name), acc=acc)