def train_mscoco(): # ===========================模型的配置和加载====================================== # config for data argument cfg = config.Config() cfg.use_horizontal_flips = True cfg.use_vertical_flips = True cfg.rot_90 = True cfg.num_rois = 32 #resnet前四卷积部分的权值 cfg.base_net_weights = nn.get_weight_path() #保存模型的权重值 cfg.model_path = './model/mscoco_frcnn.hdf5' #all_images, class_mapping = get_data() #加载训练的图片 train_imgs, class_mapping = get_data('train') cfg.class_mapping = class_mapping print('Num classes (including bg) = {}'.format(len(class_mapping))) #保存所有的配置文件 with open(cfg.config_save_file, 'wb') as config_f: pickle.dump(cfg, config_f) print( 'Config has been written to {}, and can be loaded when testing to ensure correct results' .format(cfg.config_save_file)) #图片随机洗牌 random.shuffle(train_imgs) print('Num train samples {}'.format(len(train_imgs))) data_gen_train = data_generators.get_anchor_gt(train_imgs, class_mapping, cfg, nn.get_img_output_length, K.image_dim_ordering(), mode='train') # ============================================================================== # ===============================模型的定义====================================== #keras内核为tensorflow input_shape_img = (None, None, 3) img_input = Input(shape=input_shape_img) roi_input = Input(shape=(None, 4)) # define the base resnet50 network shared_layers = nn.nn_base(img_input, trainable=False) # define the RPN, built on the base layers num_anchors = len(cfg.anchor_box_scales) * len(cfg.anchor_box_ratios) rpn = nn.rpn(shared_layers, num_anchors) classifier = nn.classifier(shared_layers, roi_input, cfg.num_rois, nb_classes=len(class_mapping), trainable=True) #model(input=,output=) model_rpn = Model(img_input, rpn[:2]) model_classifier = Model([img_input, roi_input], classifier) # this is a model that holds both the RPN and the classifier, used to load/save weights for the models model_all = Model([img_input, roi_input], rpn[:2] + classifier) # ============================================================================== # ===========================基本模型加载ImageNet权值============================= try: print('loading base model weights from {}'.format( cfg.base_net_weights)) model_rpn.load_weights(cfg.base_net_weights, by_name=True) model_classifier.load_weights(cfg.base_net_weights, by_name=True) except Exception as e: print('基本模型加载ImageNet权值: ', e) print('Could not load pretrained model weights on ImageNet.') # ============================================================================== # ===============================模型优化======================================== #在调用model.compile()之前初始化一个优化器对象,然后传入该函数 optimizer = Adam(lr=1e-5) optimizer_classifier = Adam(lr=1e-5) model_rpn.compile(optimizer=optimizer, loss=[ losses_fn.rpn_loss_cls(num_anchors), losses_fn.rpn_loss_regr(num_anchors) ]) model_classifier.compile( optimizer=optimizer_classifier, loss=[ losses_fn.class_loss_cls, losses_fn.class_loss_regr(len(class_mapping) - 1) ], metrics={'dense_class_{}'.format(len(class_mapping)): 'accuracy'}) model_all.compile(optimizer='sgd', loss='mae') # ============================================================================== # ================================训练、输出设置================================== epoch_length = len(train_imgs) num_epochs = int(cfg.num_epochs) iter_num = 0 losses = np.zeros((epoch_length, 5)) rpn_accuracy_rpn_monitor = [] rpn_accuracy_for_epoch = [] start_time = time.time() best_loss = np.Inf logger = Logger(os.path.join('.', 'log.txt')) # ============================================================================== print('Starting training') for epoch_num in range(num_epochs): progbar = generic_utils.Progbar(epoch_length) logger.write('Epoch {}/{}'.format(epoch_num + 1, num_epochs)) while True: try: if len(rpn_accuracy_rpn_monitor ) == epoch_length and cfg.verbose: mean_overlapping_bboxes = float( sum(rpn_accuracy_rpn_monitor)) / len( rpn_accuracy_rpn_monitor) rpn_accuracy_rpn_monitor = [] print( 'Average number of overlapping bounding boxes from RPN = {} for {} previous iterations' .format(mean_overlapping_bboxes, epoch_length)) if mean_overlapping_bboxes == 0: print( 'RPN is not producing bounding boxes that overlap' ' the ground truth boxes. Check RPN settings or keep training.' ) #图片,标准的cls、rgr,盒子数据 X, Y, img_data = next(data_gen_train) #训练rpn loss_rpn = model_rpn.train_on_batch(X, Y) #边训练rpn得到的区域送入roi #x_class, x_regr, base_layers P_rpn = model_rpn.predict_on_batch(X) result = roi_helpers.rpn_to_roi(P_rpn[0], P_rpn[1], cfg, K.image_dim_ordering(), use_regr=True, overlap_thresh=0.7, max_boxes=300) # note: calc_iou converts from (x1,y1,x2,y2) to (x,y,w,h) format #区域、cls、rgr、iou X2, Y1, Y2, IouS = roi_helpers.calc_iou( result, img_data, cfg, class_mapping) if X2 is None: rpn_accuracy_rpn_monitor.append(0) rpn_accuracy_for_epoch.append(0) continue neg_samples = np.where(Y1[0, :, -1] == 1) pos_samples = np.where(Y1[0, :, -1] == 0) if len(neg_samples) > 0: neg_samples = neg_samples[0] else: neg_samples = [] if len(pos_samples) > 0: pos_samples = pos_samples[0] else: pos_samples = [] rpn_accuracy_rpn_monitor.append(len(pos_samples)) rpn_accuracy_for_epoch.append((len(pos_samples))) if cfg.num_rois > 1: if len(pos_samples) < cfg.num_rois // 2: selected_pos_samples = pos_samples.tolist() else: selected_pos_samples = np.random.choice( pos_samples, cfg.num_rois // 2, replace=False).tolist() try: selected_neg_samples = np.random.choice( neg_samples, cfg.num_rois - len(selected_pos_samples), replace=False).tolist() except: selected_neg_samples = np.random.choice( neg_samples, cfg.num_rois - len(selected_pos_samples), replace=True).tolist() sel_samples = selected_pos_samples + selected_neg_samples else: # in the extreme case where num_rois = 1, we pick a random pos or neg sample selected_pos_samples = pos_samples.tolist() selected_neg_samples = neg_samples.tolist() if np.random.randint(0, 2): sel_samples = random.choice(neg_samples) else: sel_samples = random.choice(pos_samples) #训练classifier loss_class = model_classifier.train_on_batch( [X, X2[:, sel_samples, :]], [Y1[:, sel_samples, :], Y2[:, sel_samples, :]]) losses[iter_num, 0] = loss_rpn[1] losses[iter_num, 1] = loss_rpn[2] losses[iter_num, 2] = loss_class[1] losses[iter_num, 3] = loss_class[2] losses[iter_num, 4] = loss_class[3] iter_num += 1 progbar.update( iter_num, [('rpn_cls', np.mean(losses[:iter_num, 0])), ('rpn_regr', np.mean(losses[:iter_num, 1])), ('detector_cls', np.mean(losses[:iter_num, 2])), ('detector_regr', np.mean(losses[:iter_num, 3]))]) if iter_num == epoch_length: loss_rpn_cls = np.mean(losses[:, 0]) loss_rpn_regr = np.mean(losses[:, 1]) loss_class_cls = np.mean(losses[:, 2]) loss_class_regr = np.mean(losses[:, 3]) class_acc = np.mean(losses[:, 4]) mean_overlapping_bboxes = float(sum( rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch) rpn_accuracy_for_epoch = [] if cfg.verbose: logger.write( 'Mean number of bounding boxes from RPN overlapping ground truth boxes: {}' .format(mean_overlapping_bboxes)) logger.write( 'Classifier accuracy for bounding boxes from RPN: {}' .format(class_acc)) logger.write( 'Loss RPN classifier: {}'.format(loss_rpn_cls)) logger.write( 'Loss RPN regression: {}'.format(loss_rpn_regr)) logger.write('Loss Detector classifier: {}'.format( loss_class_cls)) logger.write('Loss Detector regression: {}'.format( loss_class_regr)) logger.write('Elapsed time: {}'.format(time.time() - start_time)) curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr iter_num = 0 start_time = time.time() if curr_loss < best_loss: if cfg.verbose: logger.write( 'Total loss decreased from {} to {}, saving weights' .format(best_loss, curr_loss)) best_loss = curr_loss model_all.save_weights(cfg.model_path) break except Exception as e: print('Exception: {}'.format(e)) # save model model_all.save_weights(cfg.model_path) continue print('Training complete, exiting.')
# 생각, 계산 최소화. [(유효한 길이 합) * VOCAB_SIZE] 주의할 점은 테스트뽑을때 배치별로 붙으므로 한줄을 읽으려면 stride를 BATCH_SIZE로 줘야한다. target = nn.utils.rnn.pack_padded_sequence( target, pred_length, batch_first=True) # 마찬가지로 패딩을 모두 없애고 한줄로 핀다. [유효한 길이 합] loss = criterion(preds.data, target.data) loss = loss + args.ATTENTION_COEF * ( (1.0 - coefs.sum(dim=1))**2).mean() # regularization train_loss = train_loss + loss loss.backward() decoder_optimizer.step() encoder_optimizer.step() # GAN과 마찬가지로 역순 train_losses.append(train_loss) if idx % 100 == 0: logger.write("batch : [{}/{}]\n".format( idx, len(train_loader))) logger.write("train_loss : {}\n".format( train_loss / (len(train_loader) * args.NUM_CAPTIONS))) logger.write("-" * 10 + "validation" + "-" * 10 + "\n") encoder.eval() decoder.eval() with torch.no_grad(): for idx, (img, caption_5, caption_lengths_5, _) in enumerate(validation_loader): origin_img = img for i in range(args.NUM_CAPTIONS): img = origin_img.to(device) caption = caption_5[:, i, :].to(device) caption_lengths = caption_lengths_5[:, :, i].to(device)