def eval_define(self,weights_path): n_classes = self.n_classes net = ShelfNet(n_classes=n_classes) net.load_state_dict(torch.load(weights_path)) net.cuda() net.eval() self.evaluator = MscEval(net, dataloader=None, scales=[1.0],flip=False)
def evaluate(respth='./res', dspth='/data2/.encoding/data/cityscapes', checkpoint=None): ## logger logger = logging.getLogger() ## model logger.info('\n') logger.info('====' * 20) logger.info('evaluating the model ...\n') logger.info('setup and restore model') n_classes = 19 net = ShelfNet(n_classes=n_classes) if checkpoint is None: save_pth = osp.join(respth, 'model_final.pth') else: save_pth = checkpoint net.load_state_dict(torch.load(save_pth)) net.cuda() net.eval() ## dataset batchsize = 5 n_workers = 2 dsval = CityScapes(dspth, mode='val') dl = DataLoader(dsval, batch_size=batchsize, shuffle=False, num_workers=n_workers, drop_last=False) ## evaluator logger.info('compute the mIOU') evaluator = MscEval(net, dl, scales=[1.0], flip=False) ## eval mIOU = evaluator.evaluate() logger.info('mIOU is: {:.6f}'.format(mIOU))
def train(): args = parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group( backend = 'nccl', init_method = 'tcp://127.0.0.1:33241', world_size = torch.cuda.device_count(), rank=args.local_rank ) setup_logger(respth) ## dataset n_classes = 19#19 n_img_per_gpu = 5 n_workers = 10#4 cropsize = [1024, 1024] ds = CityScapes('./data/cityscapes', cropsize=cropsize, mode='train') sampler = torch.utils.data.distributed.DistributedSampler(ds) dl = DataLoader(ds, batch_size = n_img_per_gpu, shuffle = False, sampler = sampler, num_workers = n_workers, pin_memory = True, drop_last = True) ## model ignore_idx = 255 device = torch.device("cuda") net = ShelfNet(n_classes=n_classes) net.load_state_dict(torch.load('./res/model_final_idd.pth')) net.to(device) # net.load_state_dict(checkpoint['model'].module.state_dict()) # net.cuda() net.train() # net.cuda() # net.train() # net = nn.parallel.DistributedDataParallel(net, # device_ids = [args.local_rank, ], # output_device = args.local_rank, # find_unused_parameters=True # ) score_thres = 0.7 n_min = n_img_per_gpu*cropsize[0]*cropsize[1]//16 LossP = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) Loss2 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) Loss3 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) ## optimizer momentum = 0.9 weight_decay = 5e-4 lr_start = 1e-2 max_iter = 80000 power = 0.9 warmup_steps = 1000 warmup_start_lr = 1e-5 optim = Optimizer( model = net, lr0 = lr_start, momentum = momentum, wd = weight_decay, warmup_steps = warmup_steps, warmup_start_lr = warmup_start_lr, max_iter = max_iter, power = power) ## train loop msg_iter = 50 loss_avg = [] st = glob_st = time.time() diter = iter(dl) epoch = 0 for it in range(max_iter): try: im, lb = next(diter) if not im.size()[0]==n_img_per_gpu: raise StopIteration except StopIteration: epoch += 1 sampler.set_epoch(epoch) diter = iter(dl) im, lb = next(diter) im = im.cuda() lb = lb.cuda() H, W = im.size()[2:] lb = torch.squeeze(lb, 1) optim.zero_grad() out, out16, out32 = net(im) lossp = LossP(out, lb) loss2 = Loss2(out16, lb) loss3 = Loss3(out32, lb) loss = lossp + loss2 + loss3 loss.backward() optim.step() loss_avg.append(loss.item()) ## print training log message if (it+1)%msg_iter==0: loss_avg = sum(loss_avg) / len(loss_avg) lr = optim.lr ed = time.time() t_intv, glob_t_intv = ed - st, ed - glob_st eta = int((max_iter - it) * (glob_t_intv / it)) eta = str(datetime.timedelta(seconds=eta)) msg = ', '.join([ 'it: {it}/{max_it}', 'lr: {lr:4f}', 'loss: {loss:.4f}', 'eta: {eta}', 'time: {time:.4f}', ]).format( it = it+1, max_it = max_iter, lr = lr, loss = loss_avg, time = t_intv, eta = eta ) logger.info(msg) loss_avg = [] st = ed if it % (1000) == 0:#1000 ## dump the final model save_pth = osp.join(respth, 'shelfnet_model_it_%d.pth'%it) #net.cpu() #state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict() #if dist.get_rank() == 0: torch.save(state, save_pth) torch.save(net.state_dict(),save_pth) if it % (1000) == 0 and it > 0:#1000 evaluate(checkpoint=save_pth) ## dump the final model save_pth = osp.join(respth, 'model_final.pth') net.cpu() state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict() if dist.get_rank()==0: torch.save(state, save_pth) logger.info('training done, model saved to: {}'.format(save_pth))