loss_rpn_cls += rpn_loss_cls.item() loss_rpn_box += rpn_loss_box.item() loss_rcnn_cls += RCNN_loss_cls.item() loss_rcnn_box += RCNN_loss_bbox.item() fg_cnt += torch.sum(rois_label.data != 0) bg_cnt += torch.sum(rois_label.data == 0) loss = rpn_loss_cls.mean() + rpn_loss_box.mean( ) + RCNN_loss_cls.mean() + RCNN_loss_bbox.mean() loss_temp += loss.item() # backward optimizer.zero_grad() loss.backward() if args.net == "vgg16": clip_gradient(fasterRCNN, 10.) optimizer.step() steps += 1 end = time.time() loss_rpn_cls /= steps loss_rpn_box /= steps loss_rcnn_cls /= steps loss_rcnn_box /= steps loss_temp /= steps print("[session %d][epoch %2d] loss: %.4f, lr: %.2e" \ % (args.session, epoch, loss_temp, lr)) print("\t\t\tfg/bg=(%d/%d), time cost: %f" % (fg_cnt, bg_cnt, end - start))
def train(): args = parse_args() print('Called with args:') print(args) np.random.seed(4) torch.manual_seed(2017) torch.cuda.manual_seed(1086) output_dir = args.save_dir if not os.path.exists(output_dir): os.makedirs(output_dir) train_dataset = TDetDataset([args.dataset + '_train'], training=True, multi_scale=args.multiscale, rotation=args.rotation, pd=args.pd, warping=args.warping, prop_method=args.prop_method, prop_min_scale=args.prop_min_scale, prop_topk=args.prop_topk) val_dataset = TDetDataset([args.dataset + '_val'], training=False) tval_dataset = TDetDataset(['coco_voc_val'], training=False) lr = args.lr res_path = 'data/pretrained_model/resnet101_caffe.pth' vgg_path = 'data/pretrained_model/vgg16_caffe.pth' if args.net == 'UBR_VGG': UBR = UBR_VGG(vgg_path, not args.fc, not args.not_freeze, args.no_dropout) elif args.net == 'UBR_RES': UBR = UBR_RES(res_path, 1, not args.fc) elif args.net == 'UBR_RES_FC2': UBR = UBR_RES_FC2(res_path, 1) elif args.net == 'UBR_RES_FC3': UBR = UBR_RES_FC3(res_path, 1) else: print("network is not defined") pdb.set_trace() UBR.create_architecture() params = [] for key, value in dict(UBR.named_parameters()).items(): if value.requires_grad: if 'bias' in key: params += [{ 'params': [value], 'lr': lr * 2, 'weight_decay': 0 }] else: params += [{ 'params': [value], 'lr': lr, 'weight_decay': 0 if args.no_wd else 0.0005 }] optimizer = torch.optim.SGD(params, momentum=0.9) patience = 0 last_optima = 999 if args.resume: load_name = os.path.join( output_dir, '{}_{}_{}.pth'.format(args.net, args.checksession, args.checkepoch)) print("loading checkpoint %s" % (load_name)) checkpoint = torch.load(load_name) assert args.net == checkpoint['net'] args.start_epoch = checkpoint['epoch'] UBR.load_state_dict(checkpoint['model']) if not args.no_optim: if 'patience' in checkpoint: patience = checkpoint['patience'] if 'last_optima' in checkpoint: last_optima = checkpoint['last_optima'] optimizer.load_state_dict(checkpoint['optimizer']) lr = optimizer.param_groups[0]['lr'] print("loaded checkpoint %s" % (load_name)) log_file_name = os.path.join( output_dir, 'log_{}_{}.txt'.format(args.net, args.session)) if args.resume: log_file = open(log_file_name, 'a') else: log_file = open(log_file_name, 'w') log_file.write(str(args)) log_file.write('\n') UBR.cuda() if args.loss == 'smoothl1': criterion = UBR_SmoothL1Loss(args.iou_th) elif args.loss == 'iou': criterion = UBR_IoULoss(args.iou_th) if not args.use_prop: random_box_generator = NaturalUniformBoxGenerator( args.iou_th, pos_th=args.alpha, scale_min=1 - args.beta, scale_max=1 + args.beta) for epoch in range(args.start_epoch, args.max_epochs + 1): # setting to train mode UBR.train() loss_temp = 0 effective_iteration = 0 start = time.time() mean_boxes_per_iter = 0 rand_perm = np.random.permutation(len(train_dataset)) for step in range(1, len(train_dataset) + 1): index = rand_perm[step - 1] im_data, gt_boxes, box_labels, proposals, prop_scores, image_level_label, im_scale, raw_img, im_id, _ = train_dataset[ index] data_height = im_data.size(1) data_width = im_data.size(2) im_data = Variable(im_data.unsqueeze(0).cuda()) num_gt_box = gt_boxes.size(0) UBR.zero_grad() # generate random box from given gt box # the shape of rois is (n, 5), the first column is not used # so, rois[:, 1:5] is [xmin, ymin, xmax, ymax] num_per_base = 50 if num_gt_box > 4: num_per_base = 200 // num_gt_box if args.use_prop: proposals = sample_pos_prop(proposals, gt_boxes, args.iou_th) if proposals is None: # log_file.write('@@@@ no box @@@@\n') # print('@@@@@ no box @@@@@') continue rois = torch.zeros((proposals.size(0), 5)) rois[:, 1:] = proposals else: rois = torch.zeros((num_per_base * num_gt_box, 5)) cnt = 0 for i in range(num_gt_box): here = random_box_generator.get_rand_boxes( gt_boxes[i, :], num_per_base, data_height, data_width) if here is None: continue rois[cnt:cnt + here.size(0), :] = here cnt += here.size(0) if cnt == 0: log_file.write('@@@@ no box @@@@\n') print('@@@@@ no box @@@@@') continue rois = rois[:cnt, :] plt.imshow(raw_img) plt.show() continue mean_boxes_per_iter += rois.size(0) rois = Variable(rois.cuda()) gt_boxes = Variable(gt_boxes.cuda()) bbox_pred, shared_feat = UBR(im_data, rois) #refined_boxes = inverse_transform(rois[:, 1:].data, bbox_pred.data) #plt.imshow(raw_img) #draw_box(rois[:, 1:].data / im_scale) #draw_box(refined_boxes / im_scale, 'yellow') #draw_box(gt_boxes.data / im_scale, 'black') #plt.show() loss, num_selected_rois, num_rois, refined_rois = criterion( rois[:, 1:5], bbox_pred, gt_boxes) if loss is None: loss_temp = 1000000 loss = Variable(torch.zeros(1).cuda()) print('zero mached') loss = loss.mean() loss_temp += loss.data[0] # backward optimizer.zero_grad() loss.backward() clip_gradient([UBR], 10.0) optimizer.step() effective_iteration += 1 if step % args.disp_interval == 0: end = time.time() loss_temp /= effective_iteration mean_boxes_per_iter /= effective_iteration print( "[net %s][session %d][epoch %2d][iter %4d] loss: %.4f, lr: %.2e, time: %.1f, boxes: %.1f" % (args.net, args.session, epoch, step, loss_temp, lr, end - start, mean_boxes_per_iter)) log_file.write( "[net %s][session %d][epoch %2d][iter %4d] loss: %.4f, lr: %.2e, time: %.1f, boxes: %.1f\n" % (args.net, args.session, epoch, step, loss_temp, lr, end - start, mean_boxes_per_iter)) loss_temp = 0 effective_iteration = 0 mean_boxes_per_iter = 0 start = time.time() if math.isnan(loss_temp): print('@@@@@@@@@@@@@@nan@@@@@@@@@@@@@') log_file.write('@@@@@@@nan@@@@@@@@\n') return val_loss = validate(UBR, None if args.use_prop else random_box_generator, criterion, val_dataset, args) tval_loss = validate(UBR, None if args.use_prop else random_box_generator, criterion, tval_dataset, args) print('[net %s][session %d][epoch %2d] validation loss: %.4f' % (args.net, args.session, epoch, val_loss)) log_file.write( '[net %s][session %d][epoch %2d] validation loss: %.4f\n' % (args.net, args.session, epoch, val_loss)) print( '[net %s][session %d][epoch %2d] transfer validation loss: %.4f' % (args.net, args.session, epoch, tval_loss)) log_file.write( '[net %s][session %d][epoch %2d] transfer validation loss: %.4f\n' % (args.net, args.session, epoch, tval_loss)) log_file.flush() if args.auto_decay: if last_optima - val_loss < 0.001: patience += 1 if last_optima > val_loss: last_optima = val_loss if patience >= args.decay_patience: adjust_learning_rate(optimizer, args.lr_decay_gamma) lr *= args.lr_decay_gamma patience = 0 else: if epoch % args.lr_decay_step == 0: adjust_learning_rate(optimizer, args.lr_decay_gamma) lr *= args.lr_decay_gamma if epoch % args.save_interval == 0 or lr < 0.000005: save_name = os.path.join( output_dir, '{}_{}_{}.pth'.format(args.net, args.session, epoch)) checkpoint = dict() checkpoint['net'] = args.net checkpoint['session'] = args.session checkpoint['epoch'] = epoch + 1 checkpoint['model'] = UBR.state_dict() checkpoint['optimizer'] = optimizer.state_dict() checkpoint['patience'] = patience checkpoint['last_optima'] = last_optima save_checkpoint(checkpoint, save_name) print('save model: {}'.format(save_name)) if lr < 0.000005: break log_file.close()
model.zero_grad() rois, cls_prob, bbox_pred, \ rpn_loss_cls, rpn_loss_box, \ RCNN_loss_cls, RCNN_loss_bbox, \ rois_label = model(im_data, im_info, gt_boxes, num_boxes) loss = rpn_loss_cls.mean() + rpn_loss_box.mean() \ + RCNN_loss_cls.mean() + RCNN_loss_bbox.mean() # loss_temp += loss.data[0] loss_temp += loss.item() # backward optimizer.zero_grad() loss.backward() if args.net == "vgg16": clip_gradient(model, 10.) optimizer.step() if step % args.disp_interval == 0: end = time.time() if step > 0: loss_temp /= args.disp_interval if args.mGPUs: loss_rpn_cls = rpn_loss_cls.mean().item() loss_rpn_box = rpn_loss_box.mean().item() loss_rcnn_cls = RCNN_loss_cls.mean().item() loss_rcnn_box = RCNN_loss_bbox.mean().item() fg_cnt = torch.sum(rois_label.data.ne(0)) bg_cnt = rois_label.data.numel() - fg_cnt else:
RCNN_loss_cls, RCNN_loss_bbox, \ sbc_loss_cls = outputs loss = args.rpn_loss_cls_weight * rpn_loss_cls.mean() + \ args.rpn_loss_box_weight * rpn_loss_box.mean() + \ args.rcnn_loss_cls_weight * RCNN_loss_cls.mean() + \ args.rcnn_loss_bbox_weight * RCNN_loss_bbox.mean() + \ args.sbc_loss_cls_weight * sbc_loss_cls.mean() loss_temp += loss.data[0] # backward optimizer.zero_grad() loss.backward() if args.net == "vgg16": clip_gradient(cnn_model, 10.) optimizer.step() if step % args.disp_interval == 0: end = time.time() if step > 0: loss_temp /= args.disp_interval if args.mGPUs: loss_rpn_cls = rpn_loss_cls.mean().data[0] loss_rpn_box = rpn_loss_box.mean().data[0] loss_rcnn_cls = RCNN_loss_cls.mean().data[0] loss_rcnn_box = RCNN_loss_bbox.mean().data[0] loss_sbc_cls = sbc_loss_cls.mean().data[0] else: loss_rpn_cls = rpn_loss_cls.data[0]
sorted_indices].data sorted_previous_rois[im_id][:, 1].clamp_(min=0, max=data_width - 1) sorted_previous_rois[im_id][:, 2].clamp_(min=0, max=data_height - 1) sorted_previous_rois[im_id][:, 3].clamp_(min=0, max=data_width - 1) sorted_previous_rois[im_id][:, 4].clamp_(min=0, max=data_height - 1) loss = loss.mean() loss_temp += loss.data[0] # backward optimizer.zero_grad() loss.backward() if args.net == "vgg16": clip_gradient(UBR, 10.) optimizer.step() if step % args.disp_interval == 0: end = time.time() if step > 0: loss_temp /= args.disp_interval print( "[session %d][epoch %2d][iter %4d] loss: %.4f, lr: %.2e, time: %f" % (args.session, epoch, step, loss_temp, lr, end - start)) loss_temp = 0 start = time.time() save_name = os.path.join( output_dir, 'ubr_{}_{}_{}.pth'.format(args.session, epoch, step))
heirRCNN.zero_grad() rois, cls_prob, bbox_pred, \ rpn_loss_cls, rpn_loss_box, \ RCNN_loss_cls, RCNN_loss_bbox, \ rois_label = heirRCNN(im_data, im_info, gt_boxes, num_boxes) loss = rpn_loss_cls.mean() + rpn_loss_box.mean() \ + RCNN_loss_cls.mean() + RCNN_loss_bbox.mean() loss_temp += loss.item() # backward optimizer.zero_grad() loss.backward() if args.net == "vgg16": clip_gradient(heirRCNN, 10.) optimizer.step() if step % args.disp_interval == 0: end = time.time() if step > 0: loss_temp /= (args.disp_interval + 1) if args.mGPUs: loss_rpn_cls = rpn_loss_cls.mean().item() loss_rpn_box = rpn_loss_box.mean().item() loss_rcnn_cls = RCNN_loss_cls.mean().item() loss_rcnn_box = RCNN_loss_bbox.mean().item() fg_cnt = torch.sum(rois_label.data.ne(0)) bg_cnt = rois_label.data.numel() - fg_cnt else:
def train(): args = parse_args() print('Called with args:') print(args) np.random.seed(4) torch.manual_seed(2017) torch.cuda.manual_seed(1086) output_dir = args.save_dir if not os.path.exists(output_dir): os.makedirs(output_dir) source_train_dataset = TDetDataset(['coco60_train'], training=False) target_train_dataset = TDetDataset(['voc07_trainval'], training=False) val_dataset = TDetDataset(['coco60_val'], training=False) tval_dataset = TDetDataset(['coco_voc_val'], training=False) lr = args.lr if args.net == 'UBR_TANH0': source_model = UBR_TANH(0, None, not args.fc, not args.not_freeze, args.no_dropout) target_model = UBR_TANH(0, None, not args.fc, not args.not_freeze, args.no_dropout) elif args.net == 'UBR_TANH1': source_model = UBR_TANH(1, None, not args.fc, not args.not_freeze, args.no_dropout) target_model = UBR_TANH(1, None, not args.fc, not args.not_freeze, args.no_dropout) elif args.net == 'UBR_TANH2': source_model = UBR_TANH(2, None, not args.fc, not args.not_freeze, args.no_dropout) target_model = UBR_TANH(2, None, not args.fc, not args.not_freeze, args.no_dropout) else: print("network is not defined") pdb.set_trace() D = BoxDiscriminator(args.dim) source_model.create_architecture() target_model.create_architecture() paramsG = [] for key, value in dict(target_model.named_parameters()).items(): if value.requires_grad: if 'bias' in key: paramsG += [{'params': [value], 'lr': lr * 2, 'weight_decay': 0}] else: paramsG += [{'params': [value], 'lr': lr, 'weight_decay': 0.0005}] paramsD = [] for key, value in dict(D.named_parameters()).items(): if value.requires_grad: if 'bias' in key: paramsD += [{'params': [value], 'lr': lr * 2, 'weight_decay': 0}] else: paramsD += [{'params': [value], 'lr': lr, 'weight_decay': 0.0005}] optimizerG = torch.optim.SGD(paramsG, momentum=0.9) optimizerD = torch.optim.SGD(paramsD, momentum=0.9) load_name = args.pretrained_model print("loading checkpoint %s" % (load_name)) checkpoint = torch.load(load_name) assert checkpoint['net'] == args.net source_model.load_state_dict(checkpoint['model']) target_model.load_state_dict(checkpoint['model']) print("loaded checkpoint %s" % (load_name)) log_file_name = os.path.join(output_dir, 'log_{}_{}.txt'.format(args.net, args.session)) log_file = open(log_file_name, 'w') log_file.write(str(args)) log_file.write('\n') source_model.cuda() target_model.cuda() D.cuda() # setting to train mode target_model.train() source_model.eval() D.train() lossG_temp = 0 lossD_real_temp = 0 lossD_fake_temp = 0 lossD_temp = 0 effective_iteration = 0 start = time.time() if args.loss == 'smoothl1': criterion = UBR_SmoothL1Loss(args.iou_th) elif args.loss == 'iou': criterion = UBR_IoULoss(args.iou_th) random_box_generator = NaturalUniformBoxGenerator(args.iou_th) for step in range(1, args.max_iter + 1): src_idx = np.random.choice(len(source_train_dataset)) tar_idx = np.random.choice(len(target_train_dataset)) src_im_data, src_gt_boxes, _, _, src_im_scale, src_raw_img, src_im_id, _ = source_train_dataset[src_idx] tar_im_data, tar_gt_boxes, _, _, tar_im_scale, tar_raw_img, tar_im_id, _ = target_train_dataset[tar_idx] # generate random box from given gt box # the shape of rois is (n, 5), the first column is not used # so, rois[:, 1:5] is [xmin, ymin, xmax, ymax] num_src_gt = src_gt_boxes.size(0) num_per_base = 60 // num_src_gt src_rois = torch.zeros((num_per_base * num_src_gt, 5)) cnt = 0 for i in range(num_src_gt): here = random_box_generator.get_rand_boxes(src_gt_boxes[i, :], num_per_base, src_im_data.size(1), src_im_data.size(2)) if here is None: continue src_rois[cnt:cnt + here.size(0), :] = here cnt += here.size(0) if cnt == 0: log_file.write('@@@@ no box @@@@\n') print('@@@@@ no box @@@@@') continue src_rois = src_rois[:cnt, :] src_rois = Variable(src_rois.cuda()) num_tar_gt = tar_gt_boxes.size(0) num_per_base = 60 // num_tar_gt tar_rois = torch.zeros((num_per_base * num_tar_gt, 5)) cnt = 0 for i in range(num_tar_gt): here = random_box_generator.get_rand_boxes(tar_gt_boxes[i, :], num_per_base, tar_im_data.size(1), tar_im_data.size(2)) if here is None: continue tar_rois[cnt:cnt + here.size(0), :] = here cnt += here.size(0) if cnt == 0: log_file.write('@@@@ no box @@@@\n') print('@@@@@ no box @@@@@') continue tar_rois = tar_rois[:cnt, :] tar_rois = Variable(tar_rois.cuda()) ############################################################################################## # train D with real optimizerD.zero_grad() src_im_data = Variable(src_im_data.unsqueeze(0).cuda()) src_feat = source_model.get_tanh_feat(src_im_data, src_rois) if args.tanh: src_feat = F.tanh(src_feat) output_real = D(src_feat.detach()) label_real = Variable(torch.ones(output_real.size()).cuda()) loss_real = F.binary_cross_entropy_with_logits(output_real, label_real) loss_real.backward() # train D with fake tar_im_data = Variable(tar_im_data.unsqueeze(0).cuda()) tar_feat = target_model.get_tanh_feat(tar_im_data, tar_rois) if args.tanh: tar_feat = F.tanh(tar_feat) output_fake = D(tar_feat.detach()) label_fake = Variable(torch.zeros(output_fake.size()).cuda()) loss_fake = F.binary_cross_entropy_with_logits(output_fake, label_fake) loss_fake.backward() lossD_real_temp += loss_real.data[0] lossD_fake_temp += loss_fake.data[0] lossD = loss_real + loss_fake clip_gradient([D], 10.0) optimizerD.step() ############################################################################################# # train G optimizerG.zero_grad() output = D(tar_feat) label_real = Variable(torch.ones(output.size()).cuda()) lossG = F.binary_cross_entropy_with_logits(output, label_real) lossG.backward() clip_gradient([target_model], 10.0) if step > 3000: optimizerG.step() ############################################################################################## effective_iteration += 1 lossG_temp += lossG.data[0] lossD_temp += lossD.data[0] if step % args.disp_interval == 0: end = time.time() lossG_temp /= effective_iteration lossD_temp /= effective_iteration lossD_fake_temp /= effective_iteration lossD_real_temp /= effective_iteration print("[net %s][session %d][iter %4d] lossG: %.4f, lossD: %.4f, lr: %.2e, time: %.1f" % (args.net, args.session, step, lossG_temp, lossD_temp, lr, end - start)) log_file.write("[net %s][session %d][iter %4d] lossG: %.4f, lossD: %.4f, lr: %.2e, time: %.1f\n" % (args.net, args.session, step, lossG_temp, lossD_temp, lr, end - start)) #print('%f %f' % (lossD_real_temp, lossD_fake_temp)) effective_iteration = 0 lossG_temp = 0 lossD_temp = 0 lossD_real_temp = 0 lossD_fake_temp = 0 start = time.time() if step % args.val_interval == 0: val_loss = validate(target_model, random_box_generator, criterion, val_dataset) tval_loss = validate(target_model, random_box_generator, criterion, tval_dataset) print('[net %s][session %d][step %2d] validation loss: %.4f' % (args.net, args.session, step, val_loss)) log_file.write('[net %s][session %d][step %2d] validation loss: %.4f\n' % (args.net, args.session, step, val_loss)) print('[net %s][session %d][step %2d] transfer validation loss: %.4f' % (args.net, args.session, step, tval_loss)) log_file.write('[net %s][session %d][step %2d] transfer validation loss: %.4f\n' % (args.net, args.session, step, tval_loss)) log_file.flush() log_file.close()
outputs = cnn_model(frame_1_box, frame_2, frame_2_box, num_box_1) # (bbox_pred, loss_bbox) cnn_model.zero_grad() bbox_pred, loss_bbox = outputs loss = loss_bbox.mean() loss_temp += loss.data[0] # backward optimizer.zero_grad() loss.backward() if args.clip_grad is not None: clip_gradient(cnn_model, args.clip_grad) optimizer.step() if step % args.disp_interval == 0: end = time.time() if step > 0: loss_temp /= args.disp_interval if args.mGPUs: bbox_loss = loss_bbox.mean().data[0] else: bbox_loss = loss_bbox.data[0] print( "---------------- [session %d][epoch %2d][iter %4d/%4d] ------------------" % (args.session, epoch, step, iters_per_epoch))
# take the first feature in this tracklet as the initial hidden state feature_1 = tracklet_feature[0] feature_2 = tracklet_feature[1] cls_score, cls_prob = sbc_model(feature_1, feature_2) loss = loss + sbc_model.get_loss( cls_score, tracklet_label.long(), smooth=args.smooth_loss) # cls_score = cls_score.detach() loss = loss / args.batch_size # update the parameter optimizer.zero_grad() loss.backward() if args.clip_grad is not None: #nn.utils.clip_grad_norm(st_rnn.parameters(), args.clip_grad) clip_gradient(sbc_model, args.clip_grad) optimizer.step() if sbc_model.conv_cls.weight.grad.data is not None: conv_cls_grad = sbc_model.conv_cls.weight.grad.data else: conv_cls_grad = sbc_model.conv_cls.weight.data.new([0]) if (conv_cls_grad != conv_cls_grad).sum() > 0: raise RuntimeError( '\n there is nan in the grad of one layer\n') # print('\n there is nan in the grad of one layer\n') # pdb.set_trace() if cls_prob.max().data[0] == 1: print( 'Find probability of 1, maybe there are some thing wrong!')
# bbox_pred.shape=(b,128,4)预测的每个roi的的回归值(针对他们的类别target的回归值) 如果是test:shape=(b,2000,21*4) # rpn_class_cls rpn网络的分类loss(9*w*h个anchor里面正样本和负样本(共k个)的分类loss(不含ignore的分类loss)) # rpn_loss_box=(2.36) 是一个值,代表一个batch里面各个图片上的回归损失求的平均 # RCNN_loss_cls 是128个roi的分类loss # RCNN_loss_bbox 是128个roi的回归loss ##rois_labels.shape=(b*128) 保存了送入rcnn网络的每张图片128个roi的类别标签 loss = rpn_loss_cls.mean() + rpn_loss_box.mean() \ + RCNN_loss_cls.mean() + RCNN_loss_bbox.mean() loss_temp += loss.item() # backward optimizer.zero_grad() loss.backward() if args.net == "vgg16": clip_gradient(fasterRCNN, 10.)#修剪网络各个权重的梯度 optimizer.step() if step % args.disp_interval == 0: end = time.time() if step > 0: loss_temp /= (args.disp_interval + 1) if args.mGPUs: loss_rpn_cls = rpn_loss_cls.mean().item() loss_rpn_box = rpn_loss_box.mean().item() loss_rcnn_cls = RCNN_loss_cls.mean().item() loss_rcnn_box = RCNN_loss_bbox.mean().item() fg_cnt = torch.sum(rois_label.data.ne(0)) bg_cnt = rois_label.data.numel() - fg_cnt else: