Esempio n. 1
0
def alternate_train(args, ctx, pretrained, epoch,
                    rpn_epoch, rpn_lr, rpn_lr_step,
                    rcnn_epoch, rcnn_lr, rcnn_lr_step):
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    # basic config
    begin_epoch = 8
    config.TRAIN.BG_THRESH_LO = 0.0

    # model path
    model_path = args.prefix
    '''
    logging.info('########## TRAIN RPN WITH IMAGENET INIT')
    train_rpn(args.network, args.dataset, args.image_set, args.root_path, args.dataset_path,
              args.frequent, args.kvstore, args.work_load_list, args.no_flip, args.no_shuffle, args.resume,
              ctx, pretrained, epoch, model_path+'/rpn1', begin_epoch, rpn_epoch,
              train_shared=False, lr=rpn_lr, lr_step=rpn_lr_step)
    
    logging.info('########## GENERATE RPN DETECTION')
    image_sets = [iset for iset in args.image_set.split('+')]
    for image_set in image_sets:
        test_rpn(args.network, args.dataset, image_set, args.root_path, args.dataset_path,
                 ctx[0], model_path+'/rpn1', rpn_epoch,
                 vis=False, shuffle=False, thresh=0)
    '''
    '''
    logging.info('########## TRAIN RCNN WITH IMAGENET INIT AND RPN DETECTION')
    train_maskrcnn(args.network, args.dataset, args.image_set, args.root_path, args.dataset_path,
               args.frequent, args.kvstore, args.work_load_list, args.no_flip, args.no_shuffle, args.resume,
               ctx, pretrained, epoch, model_path+'/rcnn1', begin_epoch, rcnn_epoch,
               train_shared=False, lr=rcnn_lr, lr_step=rcnn_lr_step, proposal='rpn', maskrcnn_stage='rcnn1')
    '''
    logging.info('########## TRAIN RPN WITH RCNN INIT')
    train_rpn(args.network, args.dataset, args.image_set, args.root_path, args.dataset_path,
              args.frequent, args.kvstore, args.work_load_list, args.no_flip, args.no_shuffle, args.resume,
              ctx, model_path+'/rcnn1', rcnn_epoch, model_path+'/rpn2', begin_epoch, rpn_epoch,
              train_shared=True, lr=rpn_lr, lr_step=rpn_lr_step)

    logging.info('########## GENERATE RPN DETECTION')
    image_sets = [iset for iset in args.image_set.split('+')]
    for image_set in image_sets:
        test_rpn(args.network, args.dataset, image_set, args.root_path, args.dataset_path,
                 ctx[0], model_path+'/rpn2', rpn_epoch,
                 vis=False, shuffle=False, thresh=0)

    logger.info('########## COMBINE RPN2 WITH RCNN1')
    combine_model(model_path+'/rpn2', rpn_epoch, model_path+'/rcnn1', rcnn_epoch, model_path+'/rcnn2', 0)

    logger.info('########## TRAIN RCNN WITH RPN INIT AND DETECTION')
    train_maskrcnn(args.network, args.dataset, args.image_set, args.root_path, args.dataset_path,
               args.frequent, args.kvstore, args.work_load_list, args.no_flip, args.no_shuffle, args.resume,
               ctx, model_path+'/rcnn2', 0, model_path+'/rcnn2', begin_epoch, rcnn_epoch,
               train_shared=True, lr=rcnn_lr, lr_step=rcnn_lr_step, proposal='rpn', maskrcnn_stage='rcnn2')

    logger.info('########## COMBINE RPN2 WITH RCNN2')
    combine_model(model_path+'/rpn2', rpn_epoch, model_path+'/rcnn2', rcnn_epoch, model_path+'/final', 0)
def alternate_train(args, ctx, pretrained, epoch,
                    rpn_epoch, rpn_lr, rpn_lr_step,
                    rcnn_epoch, rcnn_lr, rcnn_lr_step):
    # set up logger
    # logging.basicConfig(filename="mask_rcnn_alternate_train_%d.log" % int(time.time()))
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    # basic config
    begin_epoch = 0
    config.TRAIN.BG_THRESH_LO = 0.0

    # model path
    model_path = args.prefix

    logger.info('########## TRAIN RCNN WITH RPN INIT AND DETECTION')
    train_maskrcnn(args.network, args.dataset, args.image_set, args.root_path, args.dataset_path,
                   args.frequent, args.kvstore, args.work_load_list, args.no_flip, args.no_shuffle, args.resume,
                   ctx, model_path + '/rcnn2', 0, model_path + '/rcnn2', begin_epoch, rcnn_epoch,
                   train_shared=True, lr=rcnn_lr, lr_step=rcnn_lr_step, proposal='rpn', maskrcnn_stage='rcnn2')

    logger.info('########## COMBINE RPN2 WITH RCNN2')
    combine_model(model_path + '/rpn2', rpn_epoch, model_path + '/rcnn2', rcnn_epoch, model_path + '/final', 0)