예제 #1
0
def train_CRN_Net(args):
    cfg = Params(os.path.join("../config", args.config))
    # set up params
    cfg.ctx = [int(i) for i in args.gpu.split(',')]
    if len(cfg.ctx) > 1:
        Exception('Only for 1-GPU mode')

    cfg_attr = vars(cfg)
    cfg_attr.update(vars(args))

    # set up data loader
    trainIter = dataIter(cfg)

    # set up model
    if cfg.backbone == 'CRN_Res101':
        model = CRN(mtype=101, num_classes=1)
    if cfg.backbone == 'CRN_Res50':
        model = CRN(mtype=50, num_classes=1)

    if args.resume:
        mdl_dir = args.model_dir
        pre_pfx = args.pretrained_prefix
        pre_epc = args.pretrained_epoch
        net_pfx = cfg.network
        saved_state_dict = torch.load(
            os.path.join(mdl_dir, pre_pfx, net_pfx + '_' + pre_epc + '.pth'))
        model.load_state_dict(saved_state_dict)

    if cfg.use_global_stats == True:
        model.eval()  #

    if len(cfg.ctx) > 0:
        model.cuda(cfg.ctx[0])

    # set up optimizer
    if cfg.optimizer == 'SGD':
        optimizer = optim.SGD(model.parameters(),
                              lr=cfg.learning_rate,
                              momentum=cfg.momentum,
                              weight_decay=cfg.wd)  #
    elif cfg.optimizer == 'Adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=cfg.learning_rate,
                               weight_decay=cfg.wd)
    else:
        Exception('SGD or Adam')

    optimizer.zero_grad()

    # set up model path
    model_path = os.path.join(args.model_dir, cfg.prefix)
    if not os.path.isdir(model_path):
        os.mkdir(model_path)
    model_full_path = os.path.join(model_path,
                                   datetime.now().strftime('%Y_%m_%d_%H_%M'))
    if not os.path.isdir(model_full_path):
        os.mkdir(model_full_path)

    # set up log
    util.save_log(cfg.prefix, model_full_path)
    logging.info(
        '---------------------------TIME-------------------------------')
    logging.info('-------------------{}------------------------'.format(
        datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
    for k, v in sorted(cfg_attr.items(), key=lambda x: x[0]):
        logging.info("%s : %s", k, v)

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # training phase
    for epoch in range(cfg.begin_epoch, cfg.end_epoch):
        trainIter.reset()
        totalLoss = 0.0
        tic = time.time()
        for iter in tqdm(range(trainIter.iter_cnt)):

            img, msk_16, msk_32, gt0, gt1, gt2, gt3, gt4, gt5 = trainIter.next(
            )

            out = model([img, msk_32])
            loss0 = Loss_calc(out[0], gt0)
            loss1 = Loss_calc(out[1], gt1)
            loss2 = Loss_calc(out[2], gt2)
            loss3 = Loss_calc(out[3], gt3)
            loss4 = Loss_calc(out[4], gt4)
            loss5 = Loss_calc(out[5], gt5)

            loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5

            loss.backward()
            if iter % cfg.updateIter == 0:
                optimizer.step()
                optimizer.zero_grad()
            totalLoss += loss.data.cpu().numpy() / trainIter.iter_cnt

        logger.info('Epoch[%d] Train-Loss=%.5f', epoch, totalLoss)
        toc = time.time()
        logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))

        if epoch % cfg.frequence == 0:
            print 'taking snapshot: ' + os.path.join(cfg.network + '_' +
                                                     str(epoch) + '.pth')
            torch.save(
                model.cpu().state_dict(),
                os.path.join(model_full_path,
                             cfg.network + '-' + str(epoch) + '.pth'))
            model.cuda(cfg.ctx[0])
예제 #2
0
def train_Tube_Net(args):
	cfg = Params(os.path.join("../config",args.config))
	# set up params
	cfg.ctx = [int(i) for i in args.gpu.split(',')]
	if len(cfg.ctx)>1:
		Exception('Only for 1-GPU mode')


	cfg_attr = vars(cfg)
	cfg_attr.update(vars(args))

	# set up model path
	model_path = os.path.join(args.model_dir, cfg.prefix)
	if not os.path.isdir(model_path):
		os.mkdir(model_path)
	model_full_path = os.path.join(
		model_path, datetime.now().strftime('%Y_%m_%d_%H_%M'))
	if not os.path.isdir(model_full_path):
		os.mkdir(model_full_path)


	# set up log
	util.save_log(cfg.prefix, model_full_path)
	logging.info(
		'---------------------------TIME-------------------------------')
	logging.info('-------------------{}------------------------'.format(
		datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
	for k, v in sorted(cfg_attr.items(), key=lambda x: x[0]):
		logging.info("%s : %s", k, v)

	logger = logging.getLogger()
	logger.setLevel(logging.INFO)






	name_list, f1_img, f1_gt, img_list, label_list = read_vos_test_list(cfg)
	mIoU = 0.0
	MAE = 0.0
	seq_num = len(name_list)


	for seq_i in range(seq_num):

		seq = name_list[seq_i]
		f1_i = f1_img[seq_i]
		f1_g = f1_gt[seq_i]
		imgs = img_list[seq_i]
		labels =  label_list[seq_i]


		# set up model
		if cfg.backbone == 'CRN_Res101':
			model = CRN(mtype=101,num_classes=1)
		if cfg.backbone == 'CRN_Res50':
			model = CRN(mtype=50,num_classes=1)

		if args.resume==1:		
			print('load model:'+args.trained_path)
			logging.info('load model:'+args.trained_path)
			saved_state_dict = torch.load(args.trained_path)
			model.load_state_dict(saved_state_dict)

		if cfg.use_global_stats == True:
			model.eval() # use_global_stats = True 

		if len(cfg.ctx)>0:
			model.cuda(cfg.ctx[0])

		# set up optimizer
		if cfg.optimizer == 'SGD':
			optimizer = optim.SGD(model.parameters(),lr = cfg.learning_rate, 
													momentum = cfg.momentum, weight_decay = cfg.wd)#
		elif cfg.optimizer == 'Adam':
			optimizer = optim.Adam(model.parameters(),lr = cfg.learning_rate, weight_decay = cfg.wd)
		else:
			Exception('SGD or Adam')

		optimizer.zero_grad()



		totalLoss = 0.0
		tic = time.time()

		logger.info('Finetuning on the set- %s', seq)

		for epoch in tqdm(range(cfg.finetune_epoch)):
			img,msk_16,msk_32,gt0,gt1,gt2,gt3,gt4,gt5= get_training_batch(cfg,f1_i,f1_g)
			
			out = model([img,msk_32])
			loss0 = Loss_calc(out[0],gt0)
			loss1 = Loss_calc(out[1],gt1)
			loss2 = Loss_calc(out[2],gt2)
			loss3 = Loss_calc(out[3],gt3)
			loss4 = Loss_calc(out[4],gt4)
			loss5 = Loss_calc(out[5],gt5)
			loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5
			loss.backward()
			optimizer.step()
			optimizer.zero_grad()
			totalLoss += loss.data.cpu().numpy()/cfg.finetune_epoch

		model.eval()
		torch.save(model.cpu().state_dict(),os.path.join(model_full_path,seq+'_iter'+str(cfg.finetune_epoch)+'.pth'))
		model.cuda(cfg.ctx[0])

		toc = time.time()
		logger.info('Train-Loss=%.5f, Time cost=%.3f', totalLoss,(toc - tic))

		iouu = 0.0
		mae = 0.0


	
		img_ = imgs[0]
		label_ = labels[0]

		img_test,lbl_test,mask_test=get_testing_batch(cfg,img_,label_)

		mask__ = lbl_test[0,0].cpu().data.numpy()

		for test_ in range(len(imgs)):
			model.eval()
			torch.cuda.empty_cache()
			img_ = imgs[test_]
			label_ = labels[test_]

			img_test,lbl_test,mask_test=get_testing_batch(cfg,img_,label_)

			kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(8,8))  
			mask__ = cv2.dilate(mask__,kernel) 

			mask_temp = mask__>0
			msk_atn0 = cv2.resize(mask_temp.astype(np.float),(256,256) , interpolation = cv2.INTER_NEAREST)
			msk_atn0 = Variable(torch.from_numpy(msk_atn0).float().view(1,1,256,256)).cuda(cfg.ctx[0])	
			msk_atn1 = cv2.resize(mask_temp.astype(np.float),(128,128) , interpolation = cv2.INTER_NEAREST)
			msk_atn1 = Variable(torch.from_numpy(msk_atn1).float().view(1,1,128,128)).cuda(cfg.ctx[0])	
			msk_atn2 = cv2.resize(mask_temp.astype(np.float),(64,64) , interpolation = cv2.INTER_NEAREST)
			msk_atn2 = Variable(torch.from_numpy(msk_atn2).float().view(1,1,64,64)).cuda(cfg.ctx[0])	
			msk_atn3 = cv2.resize(mask_temp.astype(np.float),(32,32) , interpolation = cv2.INTER_NEAREST)
			msk_atn3 = Variable(torch.from_numpy(msk_atn3).float().view(1,1,32,32)).cuda(cfg.ctx[0])	
			msk_atn4 = cv2.resize(mask_temp.astype(np.float),(16,16) , interpolation = cv2.INTER_NEAREST)
			msk_atn4 = Variable(torch.from_numpy(msk_atn4).float().view(1,1,16,16)).cuda(cfg.ctx[0])	

			mask_ = cv2.resize(mask__.astype(np.float),(16,16) )
			mask_t = Variable(torch.from_numpy(mask_).float().view(1,16,16))

			out = model([img_test,mask_t.cuda(cfg.ctx[0])])
			mask_out = (out[0][0,0].cpu().data.numpy()>0.5).astype(np.float)

			kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(64, 64))  
			mask__ = cv2.dilate(mask__,kernel) 
 
		 	mask__ = mask_out*mask__

			lbl_test = (lbl_test).cpu().data.numpy()

			#cv2.imshow("mask__",mask__)  
			#cv2.waitKey (50) 

			iouu += compute_iou_for_binary_segmentation(mask__>0.5,lbl_test[:,:]>0.5)
			mae += (np.abs(mask__-lbl_test[:,:])).sum()
		mIoU += iouu/len(imgs)
		MAE += mae/len(imgs)/(cfg.frame_num*cfg.img_size/8*cfg.img_size/8)

		logger.info('Testing: mIoU-%.5f,  MAE-%.5f', iouu/len(imgs), mae/len(imgs)/(cfg.frame_num*cfg.img_size/8*cfg.img_size/8))
		logger.info('---------------------------------')
	
	logger.info('####################################')
	logger.info('Total Testing: mIoU-%.5f,  MAE-%.5f', mIoU/seq_num, MAE/seq_num)
	logger.info('####################################')