def alternate_train(args, ctx, pretrained, epoch, rpn_epoch, rpn_lr, rpn_lr_step, rcnn_epoch, rcnn_lr, rcnn_lr_step): # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) # basic config begin_epoch = 8 config.TRAIN.BG_THRESH_LO = 0.0 # model path model_path = args.prefix ''' logging.info('########## TRAIN RPN WITH IMAGENET INIT') train_rpn(args.network, args.dataset, args.image_set, args.root_path, args.dataset_path, args.frequent, args.kvstore, args.work_load_list, args.no_flip, args.no_shuffle, args.resume, ctx, pretrained, epoch, model_path+'/rpn1', begin_epoch, rpn_epoch, train_shared=False, lr=rpn_lr, lr_step=rpn_lr_step) logging.info('########## GENERATE RPN DETECTION') image_sets = [iset for iset in args.image_set.split('+')] for image_set in image_sets: test_rpn(args.network, args.dataset, image_set, args.root_path, args.dataset_path, ctx[0], model_path+'/rpn1', rpn_epoch, vis=False, shuffle=False, thresh=0) ''' ''' logging.info('########## TRAIN RCNN WITH IMAGENET INIT AND RPN DETECTION') train_maskrcnn(args.network, args.dataset, args.image_set, args.root_path, args.dataset_path, args.frequent, args.kvstore, args.work_load_list, args.no_flip, args.no_shuffle, args.resume, ctx, pretrained, epoch, model_path+'/rcnn1', begin_epoch, rcnn_epoch, train_shared=False, lr=rcnn_lr, lr_step=rcnn_lr_step, proposal='rpn', maskrcnn_stage='rcnn1') ''' logging.info('########## TRAIN RPN WITH RCNN INIT') train_rpn(args.network, args.dataset, args.image_set, args.root_path, args.dataset_path, args.frequent, args.kvstore, args.work_load_list, args.no_flip, args.no_shuffle, args.resume, ctx, model_path+'/rcnn1', rcnn_epoch, model_path+'/rpn2', begin_epoch, rpn_epoch, train_shared=True, lr=rpn_lr, lr_step=rpn_lr_step) logging.info('########## GENERATE RPN DETECTION') image_sets = [iset for iset in args.image_set.split('+')] for image_set in image_sets: test_rpn(args.network, args.dataset, image_set, args.root_path, args.dataset_path, ctx[0], model_path+'/rpn2', rpn_epoch, vis=False, shuffle=False, thresh=0) logger.info('########## COMBINE RPN2 WITH RCNN1') combine_model(model_path+'/rpn2', rpn_epoch, model_path+'/rcnn1', rcnn_epoch, model_path+'/rcnn2', 0) logger.info('########## TRAIN RCNN WITH RPN INIT AND DETECTION') train_maskrcnn(args.network, args.dataset, args.image_set, args.root_path, args.dataset_path, args.frequent, args.kvstore, args.work_load_list, args.no_flip, args.no_shuffle, args.resume, ctx, model_path+'/rcnn2', 0, model_path+'/rcnn2', begin_epoch, rcnn_epoch, train_shared=True, lr=rcnn_lr, lr_step=rcnn_lr_step, proposal='rpn', maskrcnn_stage='rcnn2') logger.info('########## COMBINE RPN2 WITH RCNN2') combine_model(model_path+'/rpn2', rpn_epoch, model_path+'/rcnn2', rcnn_epoch, model_path+'/final', 0)
def alternate_train(args, ctx, pretrained, epoch, rpn_epoch, rpn_lr, rpn_lr_step, rcnn_epoch, rcnn_lr, rcnn_lr_step): # set up logger # logging.basicConfig(filename="mask_rcnn_alternate_train_%d.log" % int(time.time())) logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) # basic config begin_epoch = 0 config.TRAIN.BG_THRESH_LO = 0.0 # model path model_path = args.prefix logger.info('########## TRAIN RCNN WITH RPN INIT AND DETECTION') train_maskrcnn(args.network, args.dataset, args.image_set, args.root_path, args.dataset_path, args.frequent, args.kvstore, args.work_load_list, args.no_flip, args.no_shuffle, args.resume, ctx, model_path + '/rcnn2', 0, model_path + '/rcnn2', begin_epoch, rcnn_epoch, train_shared=True, lr=rcnn_lr, lr_step=rcnn_lr_step, proposal='rpn', maskrcnn_stage='rcnn2') logger.info('########## COMBINE RPN2 WITH RCNN2') combine_model(model_path + '/rpn2', rpn_epoch, model_path + '/rcnn2', rcnn_epoch, model_path + '/final', 0)