def extract_feature(self, data_loader): print_freq = 50 self.cnn_model.eval() self.att_model.eval() batch_time = AverageMeter() data_time = AverageMeter() end = time.time() allfeatures = 0 allfeatures_raw = 0 for i, (imgs, flows, _, _) in enumerate(data_loader): imgs = to_torch(imgs) flows = to_torch(flows) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") imgs = imgs.to(device) flows = flows.to(device) with torch.no_grad(): if i == 0: out_feat, out_raw = self.cnn_model(imgs, flows, self.mode) out_feat, out_raw = self.att_model.selfpooling_model(out_feat, out_raw) allfeatures = out_feat allfeatures_raw = out_raw preimgs = imgs preflows = flows elif imgs.size(0) < data_loader.batch_size: flaw_batchsize = imgs.size(0) cat_batchsize = data_loader.batch_size - flaw_batchsize imgs = torch.cat((imgs, preimgs[0:cat_batchsize]), 0) flows = torch.cat((flows, preflows[0:cat_batchsize]), 0) out_feat, out_raw = self.cnn_model(imgs, flows, self.mode) out_feat, out_raw = self.att_model.selfpooling_model(out_feat, out_raw) out_feat = out_feat[0:flaw_batchsize] out_raw = out_feat[0:flaw_batchsize] allfeatures = torch.cat((allfeatures, out_feat), 0) allfeatures_raw = torch.cat((allfeatures_raw, out_raw), 0) else: out_feat, out_raw = self.cnn_model(imgs, flows, self.mode) out_feat, out_raw = self.att_model.selfpooling_model(out_feat, out_raw) allfeatures = torch.cat((allfeatures, out_feat), 0) allfeatures_raw = torch.cat((allfeatures_raw, out_raw), 0) batch_time.update(time.time() - end) end = time.time() if (i + 1) % print_freq == 0: print('Extract Features: [{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t' .format(i + 1, len(data_loader), batch_time.val, batch_time.avg, data_time.val, data_time.avg)) return allfeatures, allfeatures_raw
class Train_classifier: def __init__(self, opt): self.opt = opt torch.manual_seed(opt.seed) print('=========user config==========') pprint(opt._state_dict()) print('============end===============') self.trainloader, self.valloader = load_dataloaders(self.opt) self.use_gpu = opt.use_gpu self.device = torch.device('cuda') self._init_model() self._init_criterion() self._init_optimizer() self.model_dir = os.path.join( self.opt.save_dir, str(self.opt.model_type) + "_" + str(self.opt.model_name)) Path(self.model_dir).mkdir(parents=True, exist_ok=True) if not os.path.exists( os.path.join(self.model_dir, 'reconstructed_images')): os.makedirs(os.path.join(self.model_dir, 'reconstructed_images')) if (self.opt.debug == False): self.experiment = wandb.init(project="cycle_consistent_vae") hyper_params = self.opt._state_dict() self.experiment.config.update(hyper_params) wandb.watch(self.encoder) wandb.watch(self.classifier_model) def _init_model(self): self.encoder, self.classifier_model = load_model(self.opt) print("LEARNING RATE: ", self.opt.base_learning_rate) self.X_1 = torch.FloatTensor(self.opt.batch_size, self.opt.num_channels, self.opt.image_size, self.opt.image_size) self.X_2 = torch.FloatTensor(self.opt.batch_size, self.opt.num_channels, self.opt.image_size, self.opt.image_size) self.X_3 = torch.FloatTensor(self.opt.batch_size, self.opt.num_channels, self.opt.image_size, self.opt.image_size) self.style_latent_space = torch.FloatTensor(self.opt.batch_size, self.opt.style_dim) if self.use_gpu: self.device = torch.device('cuda') self.encoder.cuda() self.classifier_model.cuda() self.X_1 = self.X_1.cuda() self.X_2 = self.X_2.cuda() self.X_3 = self.X_3.cuda() self.style_latent_space = self.style_latent_space.cuda() self.load_encoder() def _init_optimizer(self): """ optimizer and scheduler definition """ self.optimizer = optim.Adam(self.classifier_model.parameters(), lr=self.opt.base_learning_rate, betas=(self.opt.beta1, self.opt.beta2)) # divide the learning rate by a factor of 10 after 80 epochs self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=80, gamma=0.1) def _init_criterion(self): """ loss definitions """ self.cross_entropy_loss = nn.CrossEntropyLoss() self.cross_entropy_loss.cuda() def train(self): sys.stdout = Logger(osp.join(self.model_dir, 'log_train.txt')) print("TRAINING CLASSIFIER...") start_epoch = self.opt.start_epoch best_acc = 0.0 best_epoch = 0 for epoch in range(start_epoch, self.opt.total_epochs): print('') print( 'Epoch #' + str(epoch) + '..........................................................................' ) self.train_one_epoch(epoch) val_acc = self.evaluation(epoch) self.scheduler.step() if (val_acc > best_acc): best_acc = val_acc self.save_model(best_model=True) # break self.save_model(best_model=False) if (self.opt.debug == False): print("UPLOADING FINAL FILES ...") wandb.save(self.model_dir + "/*") # wandb.save(os.path.join(self.opt.save_dir, self.opt.model_name, # str(self.opt.model_type)+ "_classifier.pth")) # wandb.save(os.path.join(self.opt.save_dir, self.opt.model_name, # str(self.opt.model_type)+ "_best_classifier.pth")) def train_one_epoch(self, epoch): # self.encoder.eval() self.classifier_model.train() self.cross_entropy_losses = AverageMeter() self.accuracy = AverageMeter() correct = 0 total = 0 for batch_idx, data in enumerate(self.trainloader): image_batch_1, image_batch_2, labels = data labels = labels.cuda() # labels = torch.FloatTensor(labels).cuda() self.optimizer.zero_grad() self.X_1.copy_(image_batch_1) self.style_mu_1, self.style_logvar_1, self.class_latent_space_1 = self.encoder( Variable(self.X_1)) style_latent_space_1 = reparameterize(training=False, mu=self.style_mu_1, logvar=self.style_logvar_1) if (self.opt.model_type == "specified"): outputs = self.classifier_model(style_latent_space_1) else: outputs = self.classifier_model(self.class_latent_space_1) self.loss = self.cross_entropy_loss(outputs, labels) self.loss.backward() self.optimizer.step() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() self.cross_entropy_losses.update(self.loss.item()) del (image_batch_1) del (image_batch_2) del (outputs) del (labels) # break self.accuracy.update(correct / total) self._print_values(epoch) def _print_values(self, epoch): if (self.opt.debug == False): self.experiment.log( {'Cross Entropy loss': self.cross_entropy_losses.mean}, step=epoch) self.experiment.log({'Train Accuracy': self.accuracy.mean}, step=epoch) print('Cross Entropy loss: ' + str(self.cross_entropy_losses.mean)) print('Train Accuracy: ' + str(self.accuracy.mean)) def evaluation(self, epoch): print("Evaluating Model ...") # self.encoder.eval() self.classifier_model.eval() self.val_accuracy = AverageMeter() correct = 0 total = 0 with torch.no_grad(): for batch_idx, data in enumerate(self.valloader): image_batch_1, image_batch_2, labels = data self.X_1.copy_(image_batch_1) labels = labels.cuda() self.style_mu_1, self.style_logvar_1, self.class_latent_space_1 = self.encoder( Variable(self.X_1)) style_latent_space_1 = reparameterize( training=False, mu=self.style_mu_1, logvar=self.style_logvar_1) if (self.opt.model_type == "specified"): outputs = self.classifier_model(style_latent_space_1) else: outputs = self.classifier_model(self.class_latent_space_1) _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() # break self.val_accuracy.update(correct / total) if (self.opt.debug == False): self.experiment.log( {'Validation Accuracy': self.val_accuracy.mean}, step=epoch) print('Validation Accuracy: ' + str(self.val_accuracy.mean)) print("Correct: ", correct) print("Total: ", total) return self.val_accuracy.mean def visualization(self): self.load_classifier(True) self.classifier_model.eval() val_iter = iter(self.valloader) data = val_iter.next() with torch.no_grad(): image_batch_1, image_batch_2, labels = data self.X_1.copy_(image_batch_1) labels = labels.cuda() self.style_mu_1, self.style_logvar_1, self.class_latent_space_1 = self.encoder( Variable(self.X_1)) style_latent_space_1 = reparameterize(training=False, mu=self.style_mu_1, logvar=self.style_logvar_1) if (self.opt.model_type == "specified"): outputs = self.classifier_model(style_latent_space_1) else: outputs = self.classifier_model(self.class_latent_space_1) _, predicted = outputs.max(1) image_batch = (np.transpose(self.X_1.cpu().numpy(), (0, 2, 3, 1))) labels = labels.detach().cpu().numpy() predicted = predicted.detach().cpu().numpy() shape = [2, 8] fig = plt.figure(1) grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05) size = shape[0] * shape[1] for i in range(size): current_image = image_batch[i] current_image = current_image * 255 current_image = current_image.astype("uint8") current_image = cv2.UMat(current_image).get() if (labels[i] == predicted[i]): cv2.rectangle(current_image, (0, 0), (60, 60), (0, 255, 0), 2) else: cv2.rectangle(current_image, ((0), 0), (60, 60), (255, 0, 0), 2) grid[i].axis('off') grid[i].imshow( current_image) # The AxesGrid object work as a list of axes. print("SAVING IMAGES") path = os.path.join(self.model_dir, 'missclassification.png') plt.savefig(path) plt.clf() def save_model(self, best_model): print("SAVING MODEL ...") if (best_model): torch.save(self.classifier_model.state_dict(), os.path.join(self.model_dir, "best_classifier.pth")) else: torch.save(self.classifier_model.state_dict(), os.path.join(self.model_dir, "classifier.pth")) def load_encoder(self): print("[*] LOADING ENCODER: {}".format( os.path.join(self.opt.save_dir, "vae", "encoder.pth"))) self.encoder.load_state_dict( torch.load(os.path.join(self.opt.save_dir, "vae", "encoder.pth"))) self.encoder.cuda() def load_decoder(self): print("[*] LOADING DECODER: {}".format( os.path.join(self.opt.save_dir, "vae", "decoder.pth"))) self.decoder.load_state_dict( torch.load(os.path.join(self.opt.save_dir, "vae", "decoder.pth"))) self.decoder.cuda() def load_classifier(self, best_model): if (best_model): print("LOADING BEST MODEL") self.classifier_model.load_state_dict( torch.load(os.path.join(self.model_dir, "best_classifier.pth"))) else: self.classifier_model.load_state_dict( torch.load(os.path.join(self.model_dir, "_classifier.pth"))) self.classifier_model.cuda()
def train(self, epoch, data_loader, optimizer1, optimizer2, optimizer3, print_freq=1): self.img_model.train() self.diff_model.train() self.depth_model.train() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() precisions = AverageMeter() end = time.time() for i, inputs in enumerate(data_loader): data_time.update(time.time() - end) inputs, targets = self._parse_data(inputs) loss, prec1 = self._forward(inputs, targets) losses.update(loss.item(), targets.size(0)) precisions.update(prec1, targets.size(0)) optimizer1.zero_grad() optimizer2.zero_grad() optimizer3.zero_grad() loss.backward() optimizer1.step() optimizer2.step() optimizer3.step() batch_time.update(time.time() - end) end = time.time() if (i + 1) % print_freq == 0: print('Epoch: [{}][{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t' 'Loss {:.3f} ({:.3f})\t' 'Prec {:.2%} ({:.2%})\t'.format( epoch, i + 1, len(data_loader), batch_time.val, batch_time.avg, data_time.val, data_time.avg, losses.val, losses.avg, precisions.val, precisions.avg))
def train(self,damage_initial_previous_frame_mask=True,lossfunc='cross_entropy',model_resume=False): ################### self.model.train() running_loss = AverageMeter() #optimizer = optim.SGD([{'params':self.model.feature_extracter.parameters()},{'params':self.model.semantic_embedding.parameters()},{'params':self.model.dynamic_seghead.parameters()}],lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM) optimizer = optim.SGD(self.model.parameters(),lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM,weight_decay=cfg.TRAIN_WEIGHT_DECAY) #scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=cfg.TRAIN_LR_STEPSIZE,gamma=cfg.TRAIN_LR_GAMMA) ################### composed_transforms = transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP), tr.RandomScale(), tr.RandomCrop((cfg.DATA_RANDOMCROP,cfg.DATA_RANDOMCROP)), tr.Resize(cfg.DATA_RESCALE), tr.ToTensor()]) composed_transforms_ytb = transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP), tr.RandomScale([0.5,1,1.25]), tr.RandomCrop((800,800)), tr.Resize(cfg.DATA_RESCALE), tr.ToTensor()]) print('dataset processing...') # train_dataset = DAVIS2017_Train(root=cfg.DATA_ROOT, transform=composed_transforms) train_dataset = DAVIS2017_VOS_Train(root=cfg.DATA_ROOT, transform=composed_transforms) ytb_train_dataset = YTB_VOS_Train(root=cfg.YTB_DATAROOT,transform=composed_transforms_ytb) # trainloader = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE, # sampler = RandomIdentitySampler(train_dataset.sample_list), # shuffle=False,num_workers=cfg.NUM_WORKER,pin_memory=True) trainloader_davis = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE, shuffle=True,num_workers=cfg.NUM_WORKER,pin_memory=True) trainloader_ytb = DataLoader(ytb_train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE, shuffle=True,num_workers=cfg.NUM_WORKER,pin_memory=True) #trainloader=[trainloader_ytb,trainloader_davis] trainloader=[trainloader_ytb,trainloader_davis] print('dataset processing finished.') if lossfunc=='bce': criterion = Added_BCEWithLogitsLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP) elif lossfunc=='cross_entropy': criterion = Added_CrossEntropyLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP) else: print('unsupported loss funciton. Please choose from [cross_entropy,bce]') max_itr = cfg.TRAIN_TOTAL_STEPS step=0 if model_resume: saved_model_=os.path.join(self.save_res_dir,'save_step_60000.pth') saved_model_ = torch.load(saved_model_) self.model=self.load_network(self.model,saved_model_) step=60000 print('resume from step {}'.format(step)) while step<cfg.TRAIN_TOTAL_STEPS: for train_dataloader in trainloader: # sample['meta']={'seq_name':seqname,'frame_num':frame_num,'obj_num':obj_num} for ii, sample in enumerate(train_dataloader): # print(ii) now_lr=self._adjust_lr(optimizer,step,max_itr) ref_imgs = sample['ref_img'] #batch_size * 3 * h * w img1s = sample['img1'] img2s = sample['img2'] ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w label1s = sample['label1'] label2s = sample['label2'] seq_names = sample['meta']['seq_name'] obj_nums = sample['meta']['obj_num'] bs,_,h,w = img2s.size() inputs = torch.cat((ref_imgs,img1s,img2s),0) if damage_initial_previous_frame_mask: try: label1s = damage_masks(label1s) except: label1s = label1s print('damage_error') ########## if self.use_gpu: inputs = inputs.cuda() ref_scribble_labels=ref_scribble_labels.cuda() label1s = label1s.cuda() label2s = label2s.cuda() ########## tmp_dic = self.model(inputs,ref_scribble_labels,label1s,seq_names=seq_names,gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS) label_and_obj_dic={} label_dic={} for i, seq_ in enumerate(seq_names): label_and_obj_dic[seq_]=(label2s[i],obj_nums[i]) for seq_ in tmp_dic.keys(): tmp_pred_logits = tmp_dic[seq_] tmp_pred_logits = nn.functional.interpolate(tmp_pred_logits,size=(h,w),mode = 'bilinear',align_corners=True) tmp_dic[seq_]=tmp_pred_logits label_tmp,obj_num = label_and_obj_dic[seq_] obj_ids = np.arange(1,obj_num+1) obj_ids = torch.from_numpy(obj_ids) obj_ids = obj_ids.int() if torch.cuda.is_available(): obj_ids = obj_ids.cuda() if lossfunc == 'bce': label_tmp = label_tmp.permute(1,2,0) label = (label_tmp.float()==obj_ids.float()) label = label.unsqueeze(-1).permute(3,2,0,1) label_dic[seq_]=label.float() elif lossfunc =='cross_entropy': label_dic[seq_]=label_tmp.long() loss = criterion(tmp_dic,label_dic,step) loss =loss/bs ###################################### if loss.item()>10000: print(tmp_dic) for k,v in tmp_dic.items(): v = v.cpu() v = v.detach().numpy() np.save(k+'.npy',v) l=label_dic[k] l=l.cpu().detach().numpy() np.save('lab'+k+'.npy',l) #continue exit() ########################################## optimizer.zero_grad() loss.backward() optimizer.step() #scheduler.step() running_loss.update(loss.item(),bs) if step%1==0: #print(torch.cuda.memory_allocated()) #print(torch.cuda.max_memory_cached()) #torch.cuda.empty_cache() #torch.cuda.reset_max_memory_allocated() print('step:{},now_lr:{} ,loss:{:.4f}({:.4f})'.format(step,now_lr ,running_loss.val,running_loss.avg)) #print(tmp_dic) #print(seq_names) # print('step:{}'.format(step)) show_ref_img = ref_imgs.cpu().numpy()[0] show_img1 = img1s.cpu().numpy()[0] show_img2 = img2s.cpu().numpy()[0] mean = np.array([[[0.485]], [[0.456]], [[0.406]]]) sigma = np.array([[[0.229]], [[0.224]], [[0.225]]]) show_ref_img = show_ref_img*sigma+mean show_img1 = show_img1*sigma+mean show_img2 = show_img2*sigma+mean show_gt = label2s.cpu()[0] show_gt = show_gt.squeeze(0).numpy() show_gtf = label2colormap(show_gt).transpose((2,0,1)) ########## show_preds = tmp_dic[seq_names[0]].cpu() show_preds=nn.functional.interpolate(show_preds,size=(h,w),mode = 'bilinear',align_corners=True) show_preds = show_preds.squeeze(0) if lossfunc=='bce': show_preds = (torch.sigmoid(show_preds)>0.5) show_preds_s = torch.zeros((h,w)) for i in range(show_preds.size(0)): show_preds_s[show_preds[i]]=i+1 elif lossfunc=='cross_entropy': show_preds_s = torch.argmax(show_preds,dim=0) show_preds_s = show_preds_s.numpy() show_preds_sf = label2colormap(show_preds_s).transpose((2,0,1)) pix_acc = np.sum(show_preds_s==show_gt)/(h*w) tblogger.add_scalar('loss', running_loss.avg, step) tblogger.add_scalar('pix_acc', pix_acc, step) tblogger.add_scalar('now_lr', now_lr, step) tblogger.add_image('Reference image', show_ref_img, step) tblogger.add_image('Previous frame image', show_img1, step) tblogger.add_image('Current frame image', show_img2, step) tblogger.add_image('Groud Truth', show_gtf, step) tblogger.add_image('Predict label', show_preds_sf, step) ###########TODO if step%5000==0 and step!=0: self.save_network(self.model,step) step+=1
def forward(data_loader, model, criterion, epoch=0, training=True, optimizer=None): if args.gpus and len(args.gpus) > 1: model=torch.nn.DataParallel(model, args.gpus) batch_time=AverageMeter() data_time=AverageMeter() losses=AverageMeter() top1=AverageMeter() top5=AverageMeter() end=time.time() for i, (inputs, target) in enumerate(data_loader): # measure data loading time data_time.update(time.time() - end) if args.gpus is not None: target=target.cuda(async=True) input_var=Variable(inputs.type(args.type), volatile=not training) target_var=Variable(target) # compute output output=model(input_var) loss=criterion(output, target_var) if type(output) is list: output=output[0] # measure accuracy and record loss prec1, prec5=accuracy(output.data, target, topk=(1, 5)) losses.update(loss.data[0], inputs.size(0)) top1.update(prec1[0], inputs.size(0)) top5.update(prec5[0], inputs.size(0)) if training: optimizer.update(epoch, epoch * len(data_loader) + i) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end=time.time() if i % args.print_freq == 0: logging.info('{phase} - Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch, i, len(data_loader), phase='TRAINING' if training else 'EVALUATING', batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5)) return losses.avg, top1.avg, top5.avg
def train(self,damage_initial_previous_frame_mask=True,lossfunc='cross_entropy',model_resume=False,eval_total=False,init_prev=False): ################### interactor = interactive_robot.InteractiveScribblesRobot() self.model.train() running_loss = AverageMeter() optimizer = optim.SGD(self.model.inter_seghead.parameters(),lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM,weight_decay=cfg.TRAIN_WEIGHT_DECAY) # optimizer = optim.SGD(self.model.inter_seghead.parameters(),lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM) #scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=cfg.TRAIN_LR_STEPSIZE,gamma=cfg.TRAIN_LR_GAMMA) ################### composed_transforms = transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP), tr.RandomScale(), tr.RandomCrop((cfg.DATA_RANDOMCROP,cfg.DATA_RANDOMCROP),10), tr.Resize(cfg.DATA_RESCALE), tr.ToTensor()]) print('dataset processing...') train_dataset = DAVIS2017_Train(root=cfg.DATA_ROOT, transform=composed_transforms) train_list = train_dataset.seqs # train_dataset = DAVIS2017_VOS_Train(root=cfg.DATA_ROOT, transform=composed_transforms) # trainloader = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE, # shuffle=True,num_workers=cfg.NUM_WORKER,pin_memory=True) print('dataset processing finished.') if lossfunc=='bce': criterion = Added_BCEWithLogitsLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP) elif lossfunc=='cross_entropy': criterion = Added_CrossEntropyLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP) else: print('unsupported loss funciton. Please choose from [cross_entropy,bce]') max_itr = cfg.TRAIN_TOTAL_STEPS step=0 round_=3 epoch_per_round=30 if model_resume: saved_model_=os.path.join(self.save_res_dir,'save_step_75000.pth') saved_model_ = torch.load(saved_model_) self.model=self.load_network(self.model,saved_model_) step=75000 print('resume from step {}'.format(step)) while step<cfg.TRAIN_TOTAL_STEPS: if step>100001: break for r in range(round_): if r==0: print('start new') global_map_tmp_dic={} train_dataset.transform=transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP), tr.RandomScale(), tr.RandomCrop((cfg.DATA_RANDOMCROP,cfg.DATA_RANDOMCROP)), tr.Resize(cfg.DATA_RESCALE), tr.ToTensor()]) train_dataset.init_ref_frame_dic() trainloader = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE, sampler = RandomIdentitySampler(train_dataset.sample_list), shuffle=False,num_workers=cfg.NUM_WORKER,pin_memory=True) print('round:{} start'.format(r)) for epoch in range(epoch_per_round): for ii, sample in enumerate(trainloader): now_lr=self._adjust_lr(optimizer,step,max_itr) ref_imgs = sample['ref_img'] #batch_size * 3 * h * w #img1s = sample['img1'] #img2s = sample['img2'] ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w #label1s = sample['label1'] #label2s = sample['label2'] seq_names = sample['meta']['seq_name'] obj_nums = sample['meta']['obj_num'] #frame_nums = sample['meta']['frame_num'] ref_frame_nums = sample['meta']['ref_frame_num'] ref_frame_gts=sample['ref_frame_gt'] bs,_,h,w = ref_imgs.size() # print(ref_imgs.size()) # if r==0: # ref_scribble_labels=self.rough_ROI(ref_scribble_labels) ########## if self.use_gpu: inputs = ref_imgs.cuda() ref_scribble_labels=ref_scribble_labels.cuda() ref_frame_gts = ref_frame_gts.cuda() #label1s = label1s.cuda() #label2s = label2s.cuda() #print(inputs.size()) ########## with torch.no_grad(): self.model.feature_extracter.eval() self.model.semantic_embedding.eval() ref_frame_embedding = self.model.extract_feature(inputs) if r==0: first_inter=True tmp_dic = self.model.int_seghead(ref_frame_embedding=ref_frame_embedding,ref_scribble_label=ref_scribble_labels, prev_round_label=None,normalize_nearest_neighbor_distances=True,global_map_tmp_dic={}, seq_names=seq_names,gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS, frame_num=ref_frame_nums,first_inter=first_inter) else: first_inter=False prev_round_label=sample['prev_round_label'] # print(prev_round_label.size()) #prev_round_label=prev_round_label_dic[seq_names[0]] prev_round_label=prev_round_label.cuda() tmp_dic = self.model.int_seghead(ref_frame_embedding=ref_frame_embedding,ref_scribble_label=ref_scribble_labels, prev_round_label=prev_round_label,normalize_nearest_neighbor_distances=True,global_map_tmp_dic={}, seq_names=seq_names,gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS, frame_num=ref_frame_nums,first_inter=first_inter) label_and_obj_dic={} label_dic={} for i, seq_ in enumerate(seq_names): label_and_obj_dic[seq_]=(ref_frame_gts[i],obj_nums[i]) for seq_ in tmp_dic.keys(): tmp_pred_logits = tmp_dic[seq_] tmp_pred_logits = nn.functional.interpolate(tmp_pred_logits,size=(h,w),mode = 'bilinear',align_corners=True) tmp_dic[seq_]=tmp_pred_logits label_tmp,obj_num = label_and_obj_dic[seq_] obj_ids = np.arange(0,obj_num+1) obj_ids = torch.from_numpy(obj_ids) obj_ids = obj_ids.int() if torch.cuda.is_available(): obj_ids = obj_ids.cuda() if lossfunc == 'bce': label_tmp = label_tmp.permute(1,2,0) label = (label_tmp.float()==obj_ids.float()) label = label.unsqueeze(-1).permute(3,2,0,1) label_dic[seq_]=label.float() elif lossfunc =='cross_entropy': label_dic[seq_]=label_tmp.long() loss = criterion(tmp_dic,label_dic,step) loss =loss/bs optimizer.zero_grad() loss.backward() optimizer.step() #scheduler.step() running_loss.update(loss.item(),bs) if step%50==0: #print(torch.cuda.memory_allocated()) #print(torch.cuda.max_memory_cached()) torch.cuda.empty_cache() #torch.cuda.reset_max_memory_allocated() print('step:{},now_lr:{} ,loss:{:.4f}({:.4f})'.format(step,now_lr ,running_loss.val,running_loss.avg)) # print('step:{}'.format(step)) show_ref_img = ref_imgs.cpu().numpy()[0] #show_img1 = img1s.cpu().numpy()[0] #show_img2 = img2s.cpu().numpy()[0] mean = np.array([[[0.485]], [[0.456]], [[0.406]]]) sigma = np.array([[[0.229]], [[0.224]], [[0.225]]]) show_ref_img = show_ref_img*sigma+mean #show_img1 = show_img1*sigma+mean #show_img2 = show_img2*sigma+mean #show_gt = label2s.cpu()[0] show_gt = ref_frame_gts.cpu()[0].squeeze(0).numpy() show_gtf = label2colormap(show_gt).transpose((2,0,1)) show_scrbble=ref_scribble_labels.cpu()[0].squeeze(0).numpy() show_scrbble=label2colormap(show_scrbble).transpose((2,0,1)) if r!=0: show_prev_round_label=prev_round_label.cpu()[0].squeeze(0).numpy() show_prev_round_label=label2colormap(show_prev_round_label).transpose((2,0,1)) else: show_prev_round_label = np.zeros_like(show_gt) show_prev_round_label = label2colormap(show_prev_round_label).transpose((2,0,1)) ########## show_preds = tmp_dic[seq_names[0]].cpu() show_preds=nn.functional.interpolate(show_preds,size=(h,w),mode = 'bilinear',align_corners=True) show_preds = show_preds.squeeze(0) if lossfunc=='bce': show_preds = show_preds[1:] show_preds = (torch.sigmoid(show_preds)>0.5) marker = torch.argmax(show_preds,dim=0) show_preds_s = torch.zeros((h,w)) for i in range(show_preds.size(0)): tmp_mask = (marker==i) & (show_preds[i]>0.5) show_preds_s[tmp_mask]=i+1 elif lossfunc=='cross_entropy': show_preds_s = torch.argmax(show_preds,dim=0) show_preds_s = show_preds_s.numpy() show_preds_sf = label2colormap(show_preds_s).transpose((2,0,1)) pix_acc = np.sum(show_preds_s==show_gt)/(h*w) if cfg.TRAIN_TBLOG: tblogger.add_scalar('loss', running_loss.avg, step) tblogger.add_scalar('pix_acc', pix_acc, step) tblogger.add_scalar('now_lr', now_lr, step) tblogger.add_image('Reference image', show_ref_img, step) #tblogger.add_image('Previous frame image', show_img1, step) # tblogger.add_image('Current frame image', show_img2, step) tblogger.add_image('Groud Truth', show_gtf, step) tblogger.add_image('Predict label', show_preds_sf, step) tblogger.add_image('Scribble', show_scrbble, step) tblogger.add_image('prev_round_label', show_prev_round_label, step) ###########TODO if step%20000==0 and step!=0: self.save_network(self.model,step) step+=1 print('trainset evaluating...') print('*'*100) if cfg.TRAIN_INTER_USE_TRUE_RESULT: if r!=round_-1: if r ==0: prev_round_label_dic={} self.model.eval() with torch.no_grad(): round_scribble={} frame_num_dic= {} train_dataset.transform=transforms.Compose([tr.Resize(cfg.DATA_RESCALE),tr.ToTensor()]) # train_dataset.transform=composed_transforms trainloader = DataLoader(train_dataset,batch_size=1, sampler = RandomIdentitySampler(train_dataset.sample_list), shuffle=False,num_workers=cfg.NUM_WORKER,pin_memory=True) for ii, sample in enumerate(trainloader): ref_imgs = sample['ref_img'] #batch_size * 3 * h * w img1s = sample['img1'] img2s = sample['img2'] ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w label1s = sample['label1'] label2s = sample['label2'] seq_names = sample['meta']['seq_name'] obj_nums = sample['meta']['obj_num'] frame_nums = sample['meta']['frame_num'] bs,_,h,w = img2s.size() inputs = torch.cat((ref_imgs,img1s,img2s),0) if r==0: ref_scribble_labels=self.rough_ROI(ref_scribble_labels) print(seq_names[0]) # if damage_initial_previous_frame_mask: # try: # label1s = damage_masks(label1s) # except: # label1s = label1s # print('damage_error') label1s_tocat=torch.Tensor() for i in range(bs): l = label1s[i] l = l.unsqueeze(0) l = mask_damager(l,0.0) l = torch.from_numpy(l) l = l.unsqueeze(0).unsqueeze(0) label1s_tocat = torch.cat((label1s_tocat,l.float()),0) label1s = label1s_tocat if self.use_gpu: inputs = inputs.cuda() ref_scribble_labels=ref_scribble_labels.cuda() label1s = label1s.cuda() tmp_dic, global_map_tmp_dic = self.model(inputs,ref_scribble_labels,label1s,seq_names=seq_names, gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS,global_map_tmp_dic=global_map_tmp_dic, frame_num=frame_nums) pred_label = tmp_dic[seq_names[0]].detach().cpu() pred_label = nn.functional.interpolate(pred_label,size=(h,w),mode = 'bilinear',align_corners=True) pred_label=torch.argmax(pred_label,dim=1) pred_label= pred_label.unsqueeze(0) try: pred_label=damage_masks(pred_label) except: pred_label=pred_label pred_label=pred_label.squeeze(0) round_scribble[seq_names[0]]=interactor.interact(seq_names[0],pred_label.numpy(),label2s.float().squeeze(0).numpy(),obj_nums) frame_num_dic[seq_names[0]]=frame_nums[0] pred_label=pred_label.unsqueeze(0) img_ww=Image.open(os.path.join(cfg.DATA_ROOT,'JPEGImages/480p/',seq_names[0],'00000.jpg')) img_ww=np.array(img_ww) or_h,or_w = img_ww.shape[:2] pred_label = torch.nn.functional.interpolate(pred_label.float(),(or_h,or_w),mode='nearest') prev_round_label_dic[seq_names[0]]=pred_label.squeeze(0) # torch.cuda.empty_cache() train_dataset.update_ref_frame_and_label(round_scribble,frame_num_dic,prev_round_label_dic) print('trainset evaluating finished!') print('*'*100) self.model.train() print('updating ref frame and label') train_dataset.transform=composed_transforms print('updating ref frame and label finished!') else: if r!=round_-1: round_scribble={} if r ==0: prev_round_label_dic={} # eval_global_map_tmp_dic={} frame_num_dic= {} train_dataset.transform=tr.ToTensor() # train_dataset.transform=tr.ToTensor() trainloader = DataLoader(train_dataset,batch_size=1, sampler = RandomIdentitySampler(train_dataset.sample_list), shuffle=False,num_workers=1,pin_memory=True) self.model.eval() with torch.no_grad(): for ii, sample in enumerate(trainloader): ref_imgs = sample['ref_img'] #batch_size * 3 * h * w img1s = sample['img1'] img2s = sample['img2'] ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w label1s = sample['label1'] label2s = sample['label2'] seq_names = sample['meta']['seq_name'] obj_nums = sample['meta']['obj_num'] frame_nums = sample['meta']['frame_num'] bs,_,h,w = img2s.size() # inputs=torch.cat((ref_imgs,img1s,img2s),0) # if r==0: # ref_scribble_labels=self.rough_ROI(ref_scribble_labels) print(seq_names[0]) label2s_ = mask_damager(label2s,0.1) # inputs = inputs.cuda() # ref_scribble_labels=ref_scribble_labels.cuda() # label1s = label1s.cuda() # tmp_dic, global_map_tmp_dic = self.model(inputs,ref_scribble_labels,label1s,seq_names=seq_names, # gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS,global_map_tmp_dic=global_map_tmp_dic, # frame_num=frame_nums) #label2s_show = label2s.squeeze().numpy() #label2s_im = Image.fromarray(label2s_show.astype('uint8')).convert('P') #label2s_im.putpalette(_palette) #label2s_im.save('label.png') #label2s__show = label2s_ #label2s_im_ = Image.fromarray(label2s__show.astype('uint8')).convert('P') #label2s_im_.putpalette(_palette) #label2s_im_.save('damage_label.png') # exit() #print(label2s_) #print(label2s.size()) round_scribble[seq_names[0]]=interactor.interact(seq_names[0],np.expand_dims(label2s_,axis=0),label2s.float().squeeze(0).numpy(),obj_nums) label2s__=torch.from_numpy(label2s_) # img_ww=Image.open(os.path.join(cfg.DATA_ROOT,'JPEGImages/480p/',seq_names[0],'00000.jpg')) # img_ww=np.array(img_ww) # or_h,or_w = img_ww.shape[:2] # label2s__=label2s__.unsqueeze(0).unsqueeze(0) # label2s__ = torch.nn.functional.interpolate(label2s__.float(),(or_h,or_w),mode='nearest') # label2s__=label2s__.squeeze(0) # print(label2s__.size()) frame_num_dic[seq_names[0]]=frame_nums[0] prev_round_label_dic[seq_names[0]]=label2s__ #torch.cuda.empty_cache() print('trainset evaluating finished!') print('*'*100) print('updating ref frame and label') train_dataset.update_ref_frame_and_label(round_scribble,frame_num_dic,prev_round_label_dic) self.model.train() train_dataset.transform=composed_transforms print('updating ref frame and label finished!')
class Solver(object): def __init__(self, opt, net): self.opt = opt self.net = net self.loss = AverageMeter('loss') self.acc = AverageMeter('acc') def fit(self, train_data, test_data, num_query, optimizer, criterion, lr_scheduler): best_rank1 = -np.inf for epoch in range(self.opt.train.num_epochs): self.loss.reset() self.acc.reset() self.net.train() # update learning rate lr = lr_scheduler.update(epoch) for param_group in optimizer.param_groups: param_group['lr'] = lr logging.info('learning rate update to {:.3e}'.format(lr)) tic = time.time() btic = time.time() for i, inputs in enumerate(train_data): data, pids, _ = inputs label = pids.cuda() score, feat = self.net(data) loss = criterion(score, feat, label) optimizer.zero_grad() loss.backward() optimizer.step() self.loss.update(loss.item()) acc = (score.max(1)[1] == label.long()).float().mean().item() self.acc.update(acc) log_interval = self.opt.misc.log_interval if log_interval and not (i + 1) % log_interval: loss_name, loss_value = self.loss.get() metric_name, metric_value = self.acc.get() logging.info( 'Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\t' '%s=%f' % (epoch, i + 1, train_data.batch_size * log_interval / (time.time() - btic), loss_name, loss_value, metric_name, metric_value)) btic = time.time() loss_name, loss_value = self.loss.get() metric_name, metric_value = self.acc.get() throughput = int(train_data.batch_size * len(train_data) / (time.time() - tic)) logging.info( '[Epoch %d] training: %s=%f\t%s=%f' % (epoch, loss_name, loss_value, metric_name, metric_value)) logging.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f' % (epoch, throughput, time.time() - tic)) is_best = False if test_data is not None and self.opt.misc.eval_step and not ( epoch + 1) % self.opt.misc.eval_step: rank1 = self.test_func(test_data, num_query) is_best = rank1 > best_rank1 if is_best: best_rank1 = rank1 state_dict = self.net.module.state_dict() if not (epoch + 1) % self.opt.misc.save_step: save_checkpoint( { 'state_dict': state_dict, 'epoch': epoch + 1, }, is_best=is_best, save_dir=self.opt.misc.save_dir, filename=self.opt.network.name + str(epoch + 1) + '.pth.tar') def test_func(self, test_data, num_query): self.net.eval() feat, person, camera = list(), list(), list() for inputs in test_data: data, pids, camids = inputs with torch.no_grad(): outputs = self.net(data).cpu() feat.append(outputs) person.extend(pids.numpy()) camera.extend(camids.numpy()) feat = torch.cat(feat, 0) qf = feat[:num_query] q_pids = np.asarray(person[:num_query]) q_camids = np.asarray(camera[:num_query]) gf = feat[num_query:] g_pids = np.asarray(person[num_query:]) g_camids = np.asarray(camera[num_query:]) logging.info( "Extracted features for query set, obtained {}-by-{} matrix". format(qf.shape[0], qf.shape[1])) logging.info( "Extracted features for gallery set, obtained {}-by-{} matrix". format(gf.shape[0], gf.shape[1])) logging.info("Computing distance matrix") m, n = qf.shape[0], gf.shape[0] distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() distmat.addmm_(1, -2, qf, gf.t()) distmat = distmat.numpy() logging.info("Computing CMC and mAP") cmc, mAP = self.eval_func(distmat, q_pids, g_pids, q_camids, g_camids) print("Results ----------") print("mAP: {:.1%}".format(mAP)) print("CMC curve") for r in [1, 5, 10]: print("Rank-{:<3}: {:.1%}".format(r, cmc[r - 1])) print("------------------") return cmc[0] @staticmethod def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): """Evaluation with market1501 metric Key: for each query identity, its gallery images from the same camera view are discarded. """ num_q, num_g = distmat.shape if num_g < max_rank: max_rank = num_g print("Note: number of gallery samples is quite small, got {}". format(num_g)) indices = np.argsort(distmat, axis=1) matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) # compute cmc curve for each query all_cmc = [] all_AP = [] num_valid_q = 0. # number of valid query for q_idx in range(num_q): # get query pid and camid q_pid = q_pids[q_idx] q_camid = q_camids[q_idx] # remove gallery samples that have the same pid and camid with query order = indices[q_idx] remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) keep = np.invert(remove) # compute cmc curve # binary vector, positions with value 1 are correct matches orig_cmc = matches[q_idx][keep] if not np.any(orig_cmc): # this condition is true when query identity does not appear in gallery continue cmc = orig_cmc.cumsum() cmc[cmc > 1] = 1 all_cmc.append(cmc[:max_rank]) num_valid_q += 1. # compute average precision # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision num_rel = orig_cmc.sum() tmp_cmc = orig_cmc.cumsum() tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] tmp_cmc = np.asarray(tmp_cmc) * orig_cmc AP = tmp_cmc.sum() / num_rel all_AP.append(AP) assert num_valid_q > 0, "Error: all query identities do not appear in gallery" all_cmc = np.asarray(all_cmc).astype(np.float32) all_cmc = all_cmc.sum(0) / num_valid_q mAP = np.mean(all_AP) return all_cmc, mAP
def evaluate(self, query_loader, gallery_loader, queryinfo, galleryinfo): self.cnn_model.eval() querypid = queryinfo.pid querycamid = queryinfo.camid querytranum = queryinfo.tranum gallerypid = galleryinfo.pid gallerycamid = galleryinfo.camid gallerytranum = galleryinfo.tranum query_features = self.extract_feature(self.cnn_model, query_loader) querylen = len(querypid) gallerylen = len(gallerypid) # online gallery extraction single_distmat = np.zeros((querylen, gallerylen)) gallery_resize = 0 gallery_popindex = 0 gallery_popsize = gallerytranum[gallery_popindex] gallery_resfeatures = 0 gallery_empty = True preimgs = 0 preflows = 0 # time gallery_time = AverageMeter() end = time.time() for i, (imgs, flows, _, _) in enumerate(gallery_loader): imgs = Variable(imgs, volatile=True) flows = Variable(flows, volatile=True) seqnum = imgs.size(0) ############## if i == 0: preimgs = imgs preflows = flows if gallery_empty: out_feat = self.cnn_model(imgs, flows, self.mode) gallery_resfeatures = out_feat.data gallery_empty = False elif imgs.size(0) < gallery_loader.batch_size: flaw_batchsize = imgs.size(0) cat_batchsize = gallery_loader.batch_size - flaw_batchsize imgs = torch.cat((imgs, preimgs[0:cat_batchsize]), 0) flows = torch.cat((flows, preflows[0:cat_batchsize]), 0) out_feat = self.cnn_model(imgs, flows, self.mode) out_feat = out_feat[0:flaw_batchsize] gallery_resfeatures = torch.cat( (gallery_resfeatures, out_feat.data), 0) else: out_feat = self.cnn_model(imgs, flows, self.mode) gallery_resfeatures = torch.cat( (gallery_resfeatures, out_feat.data), 0) gallery_resize = gallery_resize + seqnum while gallery_popsize <= gallery_resize: if (gallery_popindex + 1) % 50 == 0: print('gallery--{:04d}'.format(gallery_popindex)) if gallery_popsize == 1: gallery_popfeatures = gallery_resfeatures else: gallery_popfeatures = gallery_resfeatures[ 0:gallery_popsize, :] if gallery_popsize < gallery_resize: gallery_resfeatures = gallery_resfeatures[ gallery_popsize:gallery_resize, :] else: gallery_resfeatures = 0 gallery_empty = True gallery_resize = gallery_resize - gallery_popsize distmat_qall_g = pairwise_distance_tensor( query_features, gallery_popfeatures) q_start = 0 for qind, qnum in enumerate(querytranum): distmat_qg = distmat_qall_g[q_start:q_start + qnum, :] distmat_qg = distmat_qg.cpu().numpy() percile = np.percentile(distmat_qg, 20) if distmat_qg[distmat_qg < percile] is not None: distmean = np.mean(distmat_qg[distmat_qg < percile]) else: distmean = np.mean(distmat_qg) single_distmat[qind, gallery_popindex] = distmean q_start = q_start + qnum gallery_popindex = gallery_popindex + 1 if gallery_popindex < gallerylen: gallery_popsize = gallerytranum[gallery_popindex] gallery_time.update(time.time() - end) end = time.time() return evaluate_seq(single_distmat, querypid, querycamid, gallerypid, gallerycamid)
def train(self, epoch, data_loader, optimizer1, optimizer2): self.model.train() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() precisions = AverageMeter() precisions1 = AverageMeter() precisions2 = AverageMeter() end = time.time() for i, inputs in enumerate(data_loader): data_time.update(time.time() - end) inputs, targets = self._parse_data(inputs) loss, prec_oim, prec_score, prec_finalscore = self._forward( inputs, targets) losses.update(loss.data[0], targets.size(0)) precisions.update(prec_oim, targets.size(0)) precisions1.update(prec_score, targets.size(0)) precisions2.update(prec_finalscore, targets.size(0)) optimizer1.zero_grad() optimizer2.zero_grad() loss.backward() optimizer1.step() optimizer2.step() batch_time.update(time.time() - end) end = time.time() print_freq = 50 if (i + 1) % print_freq == 0: print('Epoch: [{}][{}/{}]\t' 'Loss {:.3f} ({:.3f})\t' 'prec_oim {:.2%} ({:.2%})\t' 'prec_score {:.2%} ({:.2%})\t' 'prec_finalscore(total) {:.2%} ({:.2%})\t'.format( epoch, i + 1, len(data_loader), losses.val, losses.avg, precisions.val, precisions.avg, precisions1.val, precisions1.avg, precisions2.val, precisions2.avg))
def validate(val_loader, model, criterion): batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to evaluate mode model.eval() '''print("Validate begin") for n, m in self.model.named_modules(): print(m)''' with torch.no_grad(): end = time.time() for i, (input, target) in enumerate(val_loader): input = input.to(args.device) target = target.to(args.device) output = model(input) loss = criterion(output, target) # measure accuracy and record loss prec1, prec5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(float(prec1), input.size(0)) top5.update(float(prec5), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Test: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( i, len(val_loader), batch_time=batch_time, loss=losses, top1=top1, top5=top5)) #return print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format( top1=top1, top5=top5)) return losses.avg, top1.avg, top5.avg
def TrainOneEpoch(train_loader, model, optimizer, criterion, epoch_num, vis_tool, record_value): losses = AverageMeter() train_eval = ConfusionMeter(num_class=opt.out_dim) # calculate the final result train_dice = AverageMeter() train_recall = AverageMeter() best_value = record_value model.train() for batch_ids, (data, target) in enumerate(train_loader): if opt.use_cuda: data, target = data.cuda(), target.cuda() optimizer.zero_grad() output = model(data) # calculate the weight of the batch: weight = GetWeight(opt, target, slr=0, is_t=0) loss = criterion(output, target, weight=weight) loss.backward() optimizer.step() # update the loss value losses.update(loss.item()) # calculate the metrics for evaluation: _, pred = torch.max(output, 1) train_eval.update(pred, target) avg_loss = losses.avg dice_value = train_eval.get_scores('Dice') recall_value = train_eval.get_scores('Recall') train_dice.update(dice_value) train_recall.update(recall_value) # for visualization if batch_ids % opt.train_plotfreq == 0: vis_tool.plot('Train_Loss', loss.item()) vis_tool.plot('Train_Dice', dice_value) vis_tool.plot('Train_Recall', recall_value) print('Train:Batch_Num:{} Loss:{:.3f} Dice:{:.3f} Recall:{:.3f}'. format(batch_ids, loss.item(), dice_value, recall_value)) return avg_loss, train_dice.avg, train_recall.avg, best_value
def TrainOneEpoch(train_loader, model, optimizer, criterion, epoch_num, vis_tool, prefix): losses = AverageMeter() train_eval = ConfusionMeter(num_class=opt.out_dim) # calculate the final result train_dice = AverageMeter() train_recall = AverageMeter() model.train() for batch_ids, (data, target) in enumerate(train_loader): if opt.use_cuda: data, target = data.cuda(), target.cuda() optimizer.zero_grad() output = model(data) # calculate the weight of the batch: weight = GetWeight(opt, target, slr=0.00001, is_t=0) loss = criterion(output, target, weight=weight) loss.backward() optimizer.step() # update the loss value losses.update(loss.item()) # calculate the metrics for evaluation: _, pred = torch.max(output, 1) train_eval.update(pred, target) avg_loss = losses.avg dice_value = train_eval.get_scores('Dice') recall_value = train_eval.get_scores('Recall') train_dice.update(dice_value) train_recall.update(recall_value) # begin to show the results if batch_ids % opt.train_plotfreq == 0: vis_tool.plot('Train_Loss', loss.item()) vis_tool.plot('Train_Dice', dice_value) vis_tool.plot('Train_Recall', recall_value) # begin to plot the prediction result image1 = data.cpu().numpy()[0, 0, ...] image1 = image1 * all_std1 + all_mean1 image1 = np.clip(image1, 150, 350) image1 = (image1 - 150) / 200 image1_mip = np.hstack( [np.max(image1, 0), np.max(image1, 1), np.max(image1, 2)]) # see the pred pred1 = pred.cpu().numpy() pred1 = pred1[0, ...] mip1 = np.hstack( [np.max(pred1, 0), np.max(pred1, 1), np.max(pred1, 2)]) # see the label target1 = target.cpu().numpy() target1 = target1[0, ...] mip2 = np.hstack( [np.max(target1, 0), np.max(target1, 1), np.max(target1, 2)]) mip3 = np.vstack([image1_mip, mip1, mip2]) vis_tool.img('pred_label', np.uint8(255 * mip3)) print('Train:Batch_Num:{} Loss:{:.3f} Dice:{:.3f} Recall:{:.3f}'. format(batch_ids, loss.item(), dice_value, recall_value)) return avg_loss, train_dice.avg, train_recall.avg
def __eval(self, topk): batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top2 = AverageMeter() top5 = AverageMeter() ClassTPs_Top1 = torch.zeros(1, len(self.classes), dtype=torch.uint8).cuda() ClassTPs_Top2 = torch.zeros(1, len(self.classes), dtype=torch.uint8).cuda() ClassTPs_Top5 = torch.zeros(1, len(self.classes), dtype=torch.uint8).cuda() y_preds = [] y_trues = [] # Start data time data_time_start = time.time() #feat = torch.tensor([]) with torch.no_grad(): for i, (images, labels, orig_attrs) in enumerate(self.dataloader): start_time = time.time() if self.use_cuda: images, labels = images.cuda(), labels.cuda() if self.ten_crops: bs, ncrops, c, h, w = images.size() images = images.view(-1, c, h, w) if self.with_attribute: orig_attrs = orig_attrs.cuda() attrs = orig_attrs.detach().clone() attrs[attrs > self.xi] = 1. attrs[attrs <= self.xi] = 0. outputs, _ = self.model(images, orig_attrs) #f = f.view(bs, ncrops, -1).mean(1) #print('Getting features {}'.format(f.shape)) #feat = torch.cat([feat, f.cpu()], dim=0) else: outputs = self.model(images, orig_attrs) if self.ten_crops: outputs = outputs.view(bs, ncrops, -1).mean(1) loss = self.criterion(outputs, labels) y_pred = outputs.argmax(dim=1) y_trues = np.append(y_trues, labels.cpu().numpy(), axis=0) y_preds = np.append(y_preds, y_pred.cpu().numpy(), axis=0) # Compute class accuracy ClassTPs = getclassAccuracy(outputs, labels, len(self.classes), topk) ClassTPs_Top1 += ClassTPs[0] ClassTPs_Top2 += ClassTPs[1] ClassTPs_Top5 += ClassTPs[2] # Measure Top1, Top2 and Top5 accuracy prec1, prec2, prec5 = accuracy(outputs.data, labels.data, topk) losses.update(loss.item(), labels.size(0)) top1.update(prec1.item(), labels.size(0)) top2.update(prec2.item(), labels.size(0)) top5.update(prec5.item(), labels.size(0)) batch_time.update(time.time() - start_time) if (i + 1) % 10 == 0: print('Testing batch: [{}/{}]\t' 'Loss {loss.val:.3f} (avg: {loss.avg:.3f})\t' 'Prec@1 {top1.val:.3f} (avg: {top1.avg:.3f})\t' 'Prec@2 {top2.val:.3f} (avg: {top2.avg:.3f})\t' 'Prec@5 {top5.val:.3f} (avg: {top5.avg:.3f})'.format( i, len(self.dataloader), batch_time=batch_time, loss=losses, top1=top1, top2=top2, top5=top5)) ClassTPDic = { 'Top1': ClassTPs_Top1.cpu().numpy(), 'Top2': ClassTPs_Top2.cpu().numpy(), 'Top5': ClassTPs_Top5.cpu().numpy() } print( 'Elapsed time for {} set evaluation {time:.3f} seconds'.format( set, time=time.time() - data_time_start)) print("") print( metrics.precision_score(y_true=y_trues, y_pred=y_preds, average='micro')) #np.savez('/home/paul/feat.npz', feat.numpy(), np.array(y_trues)) return top1.avg, top2.avg, top5.avg, losses.avg, ClassTPDic
def train_using_metriclearning_with_inception3(model, optimizer, criterion, epoch, train_root, train_pictures, prefix, distance_dict=None, class_to_nearest_class=None): start = time.time() model.train() losses = AverageMeter() is_add_margin = False feature_util = FeatureUtil(G.WIDTH, G.HEIGHT) if train_pictures is None: train_pictures = os.listdir(train_root) log_freq = int(len(train_pictures) / 6) anchor_ls, positive_ls, negative_ls = [], [], [] for i, picture_path in enumerate(train_pictures): cls_idx = picture_path.split('_')[-1][:-4] anchor_input = feature_util.get_proper_input(os.path.join( train_root, picture_path), ls_form=True) anchor_ls.append(anchor_input) hard_sample = random.randint( 1, 2) % 2 == 0 # decide if use random sample or hard sample if hard_sample and distance_dict is not None and class_to_nearest_class is not None: random_int = random.randint(0, 39) random_int = min(random_int, len(distance_dict[cls_idx]) - 1) p_input_pic = distance_dict[cls_idx][random_int][0] if len( distance_dict[cls_idx][random_int]) > 0 else picture_path else: p_pictures = [ x for x in train_pictures if x.split('_')[-1][:-4] == cls_idx and x != picture_path ] random.shuffle(p_pictures) p_input_pic = p_pictures[0] if len( p_pictures) > 0 else picture_path p_input = feature_util.get_proper_input(os.path.join( train_root, p_input_pic), ls_form=True) positive_ls.append(p_input) if hard_sample and distance_dict is not None and class_to_nearest_class is not None: n_pictures = [ x for x in train_pictures if x.split('_')[-1][:-4] == class_to_nearest_class[cls_idx] ] if len(n_pictures) == 0: n_pictures = [ x for x in train_pictures if x.split('_')[-1][:-4] != cls_idx ] else: n_pictures = [ x for x in train_pictures if x.split('_')[-1][:-4] != cls_idx ] random.shuffle(n_pictures) n_input_pic = n_pictures[0] n_input = feature_util.get_proper_input(os.path.join( train_root, n_input_pic), ls_form=True) negative_ls.append(n_input) if ((i + 1) == len(train_pictures)) or ((i + 1) % 128 == 0): anchor_ls = torch.Tensor(anchor_ls) positive_ls = torch.Tensor(positive_ls) negative_ls = torch.Tensor(negative_ls) anchor_ls = anchor_ls.cuda() positive_ls = positive_ls.cuda() negative_ls = negative_ls.cuda() anchor_ls = model(anchor_ls) positive_ls = model(positive_ls) negative_ls = model(negative_ls) loss = criterion(anchor_ls, positive_ls, negative_ls) optimizer.zero_grad() loss.backward() optimizer.step() losses.update(loss.item()) if losses.val < 1e-5: is_add_margin = True """ reset for next training "batch" """ anchor_ls, positive_ls, negative_ls = [], [], [] if (i + 1) % log_freq == 0: print('Epoch: {}[{}/{}]\t' 'Loss {:.6f} ({:.6f})\t'.format(epoch, i + 1, len(train_pictures), losses.val, losses.mean)) time_token = time.time() - start param_group = optimizer.param_groups print('Epoch: [{}]\tEpoch Time {:.1f} s\tLoss {:.6f}\t' 'Lr {:.2e}'.format(epoch, time_token, losses.mean, param_group[0]['lr'])) return is_add_margin
def validate(val_loader, model, criterion, device): batch_time = AverageMeter() losses = AverageMeter() topk = [AverageMeter() for i in range(4)] # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() for i, (input, target) in enumerate(val_loader): if input.dim() > 4: input = input.reshape(input.shape[0] * input.shape[1], input.shape[2], input.shape[3], input.shape[4]) #target = target.float() target = target.reshape(target.shape[0] * target.shape[1], target.shape[2]) input = input.to(device) target = target.to(device) target = target.float() # compute output output = model(input) loss = criterion(output, target) # measure accuracy and record loss output = output.cpu() target = target.cpu() prec = accuracy(output, target, topk=4) for i in range(4): topk[i].update(prec[i], input.size(0)) losses.update(loss.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Test: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\n' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@2 {top2.val:.3f} ({top2.avg:.3f})\t' 'Prec@3 {top3.val:.3f} ({top3.avg:.3f})\t' 'Prec@4 {top4.val:.3f} ({top4.avg:.3f})\t'.format( i, len(val_loader), batch_time=batch_time, loss=losses, top1=topk[0], top2=topk[1], top3=topk[2], top4=topk[3])) print( ' * Prec@1 {top1.avg:.3f} Prec@2 {top2.avg:.3f} Prec@3 {top3.avg:.3f} Prec@4 {top4.avg:.3f}' .format(top1=topk[0], top2=topk[1], top3=topk[2], top4=topk[3])) return topk[0].avg
class ReidSystem(): def __init__(self, cfg, logger, writer): self.cfg, self.logger, self.writer = cfg, logger, writer # Define dataloader self.tng_dataloader, self.val_dataloader, self.num_classes, self.num_query = get_dataloader( cfg) # networks self.model = build_model(cfg, self.num_classes) self.base_type = self.model.base_type # loss function if cfg.SOLVER.LABEL_SMOOTH: self.ce_loss = CrossEntropyLabelSmooth(self.num_classes) else: self.ce_loss = nn.CrossEntropyLoss() self.triplet = TripletLoss(cfg.SOLVER.MARGIN) self.aligned_triplet = TripletLossAlignedReID(margin=cfg.SOLVER.MARGIN) self.of_penalty = OFPenalty(beta=1e-6, penalty_position=['intermediate']) # optimizer and scheduler self.opt = make_optimizer(self.cfg, self.model) self.lr_sched = make_lr_scheduler(self.cfg, self.opt) self._construct() def _construct(self): self.global_step = 0 self.current_epoch = 0 self.batch_nb = 0 self.max_epochs = self.cfg.SOLVER.MAX_EPOCHS self.log_interval = self.cfg.SOLVER.LOG_INTERVAL self.eval_period = self.cfg.SOLVER.EVAL_PERIOD self.use_dp = False self.use_ddp = False def loss_fns(self, outputs, labels): if self.cfg.MODEL.FINE_TUNE: triplet_loss = self.triplet(outputs, labels)[0] return {'global_triplet_loss': triplet_loss} elif self.cfg.SOLVER.TRIPLET_ONLY: triplet_loss = self.triplet(outputs[1], labels)[0] return {'global_triplet_loss': triplet_loss} else: ce_loss = self.ce_loss(outputs[0], labels) triplet_loss = self.triplet(outputs[1], labels)[0] return {'ce_loss': ce_loss, 'global_triplet_loss': triplet_loss} def aligned_loss_fns(self, outputs, labels): """ :param outputs: [cls_score, global_feature, local_feature] :param labels: person IDs :return: """ ce_loss = self.ce_loss(outputs[0], labels) global_triplet_loss, local_triplet_loss = self.aligned_triplet( outputs[1], labels, outputs[2]) #return {'ce_loss': ce_loss, 'globaltriplet': triplet_loss} return { 'ce_loss': ce_loss, 'global_triplet_loss': global_triplet_loss, 'local_triplet_loss': local_triplet_loss } def mgn_loss_fns(self, outputs, labels): triplet_loss = [ self.triplet(output, labels)[0] for output in outputs[1] ] triplet_loss = sum(triplet_loss) / len(triplet_loss) ce_loss = [self.ce_loss(output, labels) for output in outputs[2]] ce_loss = sum(ce_loss) / len(ce_loss) return {'ce_loss': ce_loss, 'global_triplet_loss': triplet_loss} def on_train_begin(self): self.start_epoch = 0 self.best_mAP = -np.inf self.running_loss = AverageMeter() self.running_CE_loss = AverageMeter() self.running_GT_loss = AverageMeter() self.running_LT_loss = AverageMeter() self.running_OF_loss = AverageMeter() log_save_dir = os.path.join(self.cfg.OUTPUT_DIR, self.cfg.DATASETS.TEST_NAMES, self.cfg.MODEL.VERSION) self.model_save_dir = os.path.join(log_save_dir, 'ckpts') if not os.path.exists(self.model_save_dir): os.makedirs(self.model_save_dir) ####### # Load checkpoints cfg = self.cfg if cfg.MODEL.CHECKPOINT is not '': self.load_checkpoint(cfg.MODEL.CHECKPOINT, with_optimizer=not cfg.MODEL.FINE_TUNE) #self.logger.info('continue training') ###### self.gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',') self.use_dp = (len(self.gpus) > 0) and (self.cfg.MODEL.DIST_BACKEND == 'dp') if self.use_dp: self.model = nn.DataParallel(self.model) self.model = self.model.cuda() self.model.train() def on_epoch_begin(self): self.batch_nb = 0 self.current_epoch += 1 self.t0 = time.time() self.running_loss.reset() self.running_CE_loss.reset() self.running_GT_loss.reset() self.running_LT_loss.reset() self.running_OF_loss.reset() self.tng_prefetcher = data_prefetcher(self.tng_dataloader, self.cfg) def training_step(self, batch): inputs, labels, _ = batch outputs = self.model(inputs, labels) if self.base_type in [ BASE_ALIGNED_RESNET50, BASE_ALIGNED_RESNET101, BASE_ALIGNED_RESNEXT101, BASE_ALIGNED_RESNEXT50, BASE_ALIGNED_SE_RESNET101, BASE_ALIGNED_DENSENET169, BASE_ALIGNED_MPNCOV_RESNET50, BASE_ALIGNED_MPNCOV_RESNET101, BASE_ALIGNED_MPNCOV_RESNEXT101 ]: loss_dict = self.aligned_loss_fns(outputs, labels) elif self.base_type in [ BASE_ALIGNED_RESNET50_ABD, BASE_ALIGNED_RESNET101_ABD, BASE_ALIGNED_RESNEXT101_ABD ]: loss_dict = self.aligned_loss_fns(outputs[:3], labels) if self.current_epoch >= self.cfg.MODEL.OF_START_EPOCH: # 从第33个Epoch加of loss_dict['of_loss'] = self.of_penalty(outputs[3]) elif self.base_type in [BASE_RESNET101_ABD, BASE_RESNEXT101_ABD]: loss_dict = self.loss_fns(outputs, labels) if self.current_epoch >= self.cfg.MODEL.OF_START_EPOCH: # 从第33个Epoch加of loss_dict['of_loss'] = self.of_penalty(outputs[2]) elif self.base_type in [MGN_RESNET50, MGN_RESNET101, MGN_RESNEXT101]: loss_dict = self.mgn_loss_fns(outputs, labels) else: loss_dict = self.loss_fns(outputs, labels) total_loss = 0 print_str = f'\r Epoch {self.current_epoch} Iter {self.batch_nb}/{len(self.tng_dataloader)} ' for loss_name, loss_value in loss_dict.items(): total_loss += loss_value print_str += (loss_name + f': {loss_value.item():.3f} ') loss_dict['total_loss'] = total_loss.item() print_str += f'Total loss: {total_loss.item():.3f} ' print(print_str, end=' ') if self.writer and (self.global_step + 1) % self.log_interval == 0: if 'ce_loss' in loss_dict.keys(): self.writer.add_scalar('cross_entropy_loss', loss_dict['ce_loss'], self.global_step) self.writer.add_scalar('global_triplet_loss', loss_dict['global_triplet_loss'], self.global_step) if 'local_triplet_loss' in loss_dict.keys(): self.writer.add_scalar('local_triplet_loss', loss_dict['local_triplet_loss'], self.global_step) self.writer.add_scalar('total_loss', loss_dict['total_loss'], self.global_step) self.running_loss.update(total_loss.item()) if 'ce_loss' in loss_dict.keys(): self.running_CE_loss.update(loss_dict['ce_loss']) self.running_GT_loss.update(loss_dict['global_triplet_loss']) if 'local_triplet_loss' in loss_dict.keys(): self.running_LT_loss.update(loss_dict['local_triplet_loss']) if 'of_loss' in loss_dict.keys(): self.running_OF_loss.update(loss_dict['of_loss']) self.opt.zero_grad() total_loss.backward() self.opt.step() self.global_step += 1 self.batch_nb += 1 def on_epoch_end(self): elapsed = time.time() - self.t0 mins = int(elapsed) // 60 seconds = int(elapsed - mins * 60) print('') self.logger.info( f'Epoch {self.current_epoch} Total loss: {self.running_loss.avg:.3f} CE loss: {self.running_CE_loss.avg:.3f} ' f'GT loss: {self.running_GT_loss.avg:.3f} LT loss: {self.running_LT_loss.avg:.3f} OF loss: {self.running_OF_loss.avg:.3f} ' f'lr: {self.opt.param_groups[0]["lr"]:.2e} During {mins:d}min:{seconds:d}s' ) # update learning rate self.lr_sched.step() def test(self): # convert to eval mode self.model.eval() feats, pids, camids = [], [], [] val_prefetcher = data_prefetcher(self.val_dataloader, self.cfg) batch = val_prefetcher.next() while batch[0] is not None: img, pid, camid = batch with torch.no_grad(): feat = self.model(img) if isinstance(feat, tuple): feats.append(feat[0]) else: feats.append(feat) pids.extend(pid.cpu().numpy()) camids.extend(np.asarray(camid)) batch = val_prefetcher.next() #### feats = torch.cat(feats, dim=0) if self.cfg.TEST.NORM: feats = F.normalize(feats, p=2, dim=1) # query qf = feats[:self.num_query] q_pids = np.asarray(pids[:self.num_query]) q_camids = np.asarray(camids[:self.num_query]) # gallery gf = feats[self.num_query:] g_pids = np.asarray(pids[self.num_query:]) g_camids = np.asarray(camids[self.num_query:]) # TODO: 添加rerank的测评结果 # m, n = qf.shape[0], gf.shape[0] distmat = -torch.mm(qf, gf.t()).cpu().numpy() # distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ # torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() # distmat.addmm_(1, -2, qf, gf.t()) # distmat = distmat.numpy() cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids) self.logger.info(f"Test Results - Epoch: {self.current_epoch}") self.logger.info(f"mAP: {mAP:.1%}") for r in [1, 5, 10]: self.logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}") self.writer.add_scalar('rank1', cmc[0], self.global_step) self.writer.add_scalar('mAP', mAP, self.global_step) metric_dict = {'rank1': cmc[0], 'mAP': mAP} # convert to train mode self.model.train() return metric_dict def train(self): self.on_train_begin() for epoch in range(self.start_epoch, self.max_epochs): self.on_epoch_begin() batch = self.tng_prefetcher.next() while batch[0] is not None: self.training_step(batch) batch = self.tng_prefetcher.next() self.on_epoch_end() if (epoch + 1) % self.eval_period == 0: metric_dict = self.test() if metric_dict['mAP'] > self.best_mAP: is_best = True self.best_mAP = metric_dict['mAP'] else: is_best = False # always save the last checkpoint as the best is_best = True self.save_checkpoint(is_best) torch.cuda.empty_cache() def save_checkpoint(self, is_best): if self.use_dp: state_dict = self.model.module.state_dict() else: state_dict = self.model.state_dict() # TODO: add optimizer state dict and lr scheduler filepath = os.path.join(self.model_save_dir, f'model_epoch{self.current_epoch}.pth') torch.save(state_dict, filepath) # saving optpath = os.path.join(self.model_save_dir, f'optimizer_epoch{self.current_epoch}.pth') opt_dict = {} opt_dict['optimizer'] = self.opt opt_dict['lr_scheduler'] = self.lr_sched opt_dict['epoch'] = self.current_epoch torch.save(opt_dict, optpath) if is_best: best_filepath = os.path.join(self.model_save_dir, 'model_best.pth') shutil.copyfile(filepath, best_filepath) def load_checkpoint(self, checkpoint_path, with_optimizer=True): ## load weights self.logger.info('Loading checkpoints from ' + checkpoint_path) state_dict = torch.load(checkpoint_path) # remove missing keys new_state_dict = state_dict.copy() for k in state_dict: if not k in self.model.state_dict(): new_state_dict.pop(k) self.logger.info(f'Remove key {k} from checkpoint.') state_dict = new_state_dict if self.use_dp: self.model.module.load_state_dict(state_dict) else: self.model.load_state_dict(state_dict) ## load optimizer if with_optimizer: opt_path = checkpoint_path.replace('model_epoch', 'optimizer_epoch') self.logger.info('Loading optimizer from ' + opt_path) opt_dict = torch.load(opt_path) self.opt = opt_dict['optimizer'] self.lr_sched = opt_dict['lr_scheduler'] self.start_epoch = opt_dict['epoch'] self.current_epoch = opt_dict['epoch']
def validate(val_loader, model, criterion): batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to evaluate mode model.eval() if args.dump_dir is not None: QM().disable() DM(args.dump_dir) with torch.no_grad(): end = time.time() for i, (input, target) in enumerate(val_loader): input = input.to(args.device) target = target.to(args.device) if args.dump_dir is not None and i == 5: with DM(args.dump_dir): DM().set_tag('batch%d'%i) # compute output output = model(input) break else: output = model(input) loss = criterion(output, target) # measure accuracy and record loss prec1, prec5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(float(prec1), input.size(0)) top5.update(float(prec5), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Test: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( i, len(val_loader), batch_time=batch_time, loss=losses, top1=top1, top5=top5)) print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' .format(top1=top1, top5=top5)) return losses.avg, top1.avg, top5.avg
def train(self, epoch, data_loader, optimizer1, optimizer2): self.model.train() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() precisions = AverageMeter() precisions1 = AverageMeter() end = time.time() for i, inputs in enumerate(data_loader): data_time.update(time.time() - end) inputs, targets = self._parse_data(inputs) loss, prec_oim, prec_score = self._forward(inputs, targets) losses.update(loss.item(), targets.size(0)) precisions.update(prec_oim, targets.size(0)) precisions1.update(prec_score, targets.size(0)) optimizer1.zero_grad() optimizer2.zero_grad() loss.backward() optimizer1.step() optimizer2.step() batch_time.update(time.time() - end) end = time.time() print_freq = 60 num_step = len(data_loader) # 1146 num_iter = num_step * epoch + i self.writer.add_scalar('train/loss_step', losses.val, num_iter) self.writer.add_scalar('train/loss_avg', losses.avg, num_iter) self.writer.add_scalar('train/prec_pairloss', precisions1.avg, num_iter) self.writer.add_scalar('train/prec_oimloss', precisions.avg, num_iter) if (i + 1) % print_freq == 0: print('Epoch: [{}][{}/{}]\t' 'Loss {:.3f} ({:.3f})\t' 'prec_oim {:.2%} ({:.2%})\t' 'prec_score {:.2%} ({:.2%})\t' .format(epoch, i + 1, len(data_loader), losses.val, losses.avg, precisions.val, precisions.avg, precisions1.val, precisions1.avg))
def val_model(opt, val_loader, model, criterion, vis_tool, name1='1'): # begin to test the dataset model.eval() val_eval = ConfusionMeter(num_class=opt.out_dim) # meters=AverageMeterSet() # calculate the average values val_dice = AverageMeter() val_loss = AverageMeter() val_recall = AverageMeter() for batch_ids, (data, target) in enumerate(val_loader): if opt.use_cuda: data, target = data.cuda(), target.cuda() output = model(data) with torch.no_grad(): loss = criterion(output, target) _, pred = torch.max(output, dim=1) val_loss.update(loss.item()) val_eval.update(pred, target) avg_loss = val_loss.avg dice_value = val_eval.get_scores('Dice') recall_value = val_eval.get_scores('Recall') val_recall.update(recall_value) val_dice.update(dice_value) # begin to play if batch_ids % opt.val_plotfreq == 0: vis_tool.plot('Val_Loss' + name1, loss.item()) vis_tool.plot('Val_Dice' + name1, dice_value) vis_tool.plot('Val_Recall' + name1, recall_value) print( 'Val: Batch_Num:{} Loss:{:.3f} Dice:{:.3f} Recall:{:.3f}' .format(batch_ids, loss.item(), dice_value, recall_value)) return avg_loss, val_dice.avg, val_recall.avg
def __init__(self, opt, net): self.opt = opt self.net = net self.loss = AverageMeter('loss') self.acc = AverageMeter('acc')
def forward(self, data_loader, num_steps=None, training=False, duplicates=1, average_output=False, chunk_batch=1, rec=False): if rec: output_embed = {} meters = { name: AverageMeter() for name in ['step', 'data', 'loss', 'prec1', 'prec5'] } if training and self.grad_clip > 0: meters['grad'] = AverageMeter() batch_first = True if training and isinstance(self.model, nn.DataParallel) or chunk_batch > 1: batch_first = False if average_output: assert duplicates > 1 and batch_first, "duplicates must be > 1 for output averaging" def meter_results(meters): results = {name: meter.avg for name, meter in meters.items()} results['error1'] = 100. - results['prec1'] results['error5'] = 100. - results['prec5'] return results end = time.time() for i, (inputs, target) in (enumerate(data_loader)): if training and duplicates > 1 and self.adapt_grad_norm is not None \ and i % self.adapt_grad_norm == 0: grad_mean = 0 num = inputs.size(1) for j in range(num): grad_mean += float( self._grad_norm(inputs.select(1, j), target)) grad_mean /= num grad_all = float( self._grad_norm( *_flatten_duplicates(inputs, target, batch_first))) self.grad_scale = grad_mean / grad_all logging.info('New loss scale: %s', self.grad_scale) # measure data loading time meters['data'].update(time.time() - end) if duplicates > 1: # multiple versions for each sample (dim 1) inputs, target = _flatten_duplicates( inputs, target, batch_first, expand_target=not average_output) output, loss, grad = self._step(inputs, target, training=training, average_output=average_output, chunk_batch=chunk_batch) if rec: with torch.no_grad(): for i in range(target.shape[0]): tt = target[i] emb = output[i] output_embed[tt.tolist()] = emb if self.pruner is not None: with torch.no_grad(): if training: compression_rate = self.pruner.calc_param_masks( self.model, i % self.print_freq == 0, i + self.epoch * len(data_loader)) if i % self.print_freq == 0: logging.info('Total compression ratio is: ' + str(compression_rate)) self.model = self.pruner.prune_layers(self.model) # measure accuracy and record loss prec1, prec5 = accuracy(output, target, topk=(1, 5)) meters['loss'].update(float(loss), inputs.size(0)) meters['prec1'].update(float(prec1), inputs.size(0)) meters['prec5'].update(float(prec5), inputs.size(0)) if grad is not None: meters['grad'].update(float(grad), inputs.size(0)) # measure elapsed time meters['step'].update(time.time() - end) end = time.time() if i % self.print_freq == 0: report = str( '{phase} - Epoch: [{0}][{1}/{2}]\t' 'Time {meters[step].val:.3f} ({meters[step].avg:.3f})\t' 'Data {meters[data].val:.3f} ({meters[data].avg:.3f})\t' 'Loss {meters[loss].val:.7f} ({meters[loss].avg:.7f})\t' 'Prec@1 {meters[prec1].val:.6f} ({meters[prec1].avg:.6f})\t' 'Prec@5 {meters[prec5].val:.6f} ({meters[prec5].avg:.6f})\t' .format(self.epoch, i, len(data_loader), phase='TRAINING' if training else 'EVALUATING', meters=meters)) if 'grad' in meters.keys(): report += 'Grad {meters[grad].val:.3f} ({meters[grad].avg:.3f})'\ .format(meters=meters) logging.info(report) if num_steps is not None and i >= num_steps or (self.update_only_th and training and i > 2): break if self.pruner is not None: self.pruner.save_eps(epoch=self.epoch + 1) self.pruner.save_masks(epoch=self.epoch + 1) if rec: torch.save(output_embed, 'output_embed_calib') return meter_results(meters)
def train(self, epoch, data_loader): self.model.train() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() start = time.time() for i, inputs in enumerate(data_loader): data_time.update(time.time() - start) # model optimizer self._parse_data(inputs) self._forward() self.optimizer.zero_grad() self._backward() self.optimizer.step() batch_time.update(time.time() - start) losses.update(self.loss.item()) # tensorboard global_step = epoch * len(data_loader) + i self.summary_writer.add_scalar('loss', self.loss.item(), global_step) self.summary_writer.add_scalar( 'lr', self.optimizer.param_groups[0]['lr'], global_step) start = time.time() if (i + 1) % self.opt.print_freq == 0: print('Epoch: [{}][{}/{}]\t' 'Batch Time {:.3f} ({:.3f})\t' 'Data Time {:.3f} ({:.3f})\t' 'Loss {:.3f} ({:.3f})\t'.format(epoch, i + 1, len(data_loader), batch_time.val, batch_time.mean, data_time.val, data_time.mean, losses.val, losses.mean)) param_group = self.optimizer.param_groups print('Epoch: [{}]\tEpoch Time {:.3f} s\tLoss {:.3f}\t' 'Lr {:.2e}'.format(epoch, batch_time.sum, losses.mean, param_group[0]['lr'])) print()
def eval(self, epoch): self.model.eval() losses = AverageMeter() correct = AverageMeter() prec1 = AverageMeter() prec2 = AverageMeter() prec5 = AverageMeter() with torch.no_grad(): for step, (imgs, labels, orig_attrs) in enumerate(self.val_loader): imgs, labels = imgs.cuda(), labels.cuda() if self.with_attribute: orig_attrs = orig_attrs.cuda() attrs = orig_attrs.detach().clone() attrs[attrs > self.xi] = 1. attrs[attrs <= self.xi] = 0. pred_id, pred_attrs = self.model(imgs, orig_attrs) assert pred_attrs.shape[-1] == 134 loss = self.criterion[0](pred_id, labels) loss_attrs = self.criterion[1](pred_attrs.float(), attrs.float()) if epoch > 15: loss += loss_attrs else: pred_id = self.model(imgs, orig_attrs) loss = self.criterion(pred_id, labels) assert pred_id.shape[-1] == self.num_classes losses.update(loss.item(), labels.size(0)) prec = accuracy(pred_id.data, labels.data, topk=(1, 2, 5), is_multilabel=False) prec1.update(prec[0].item(), labels.size(0)) prec2.update(prec[1].item(), labels.size(0)) prec5.update(prec[2].item(), labels.size(0)) y_pred = pred_id.argmax(dim=1) acc = (y_pred == labels).sum().item() / labels.size(0) * 100 correct.update(acc, labels.size(0)/100.) print('Val: [{}] ' 'Loss {:.2f} ({:.2f})\t' 'Acc {:.2f} ({:.2f})\t' 'Prec1 {:.2%} ({:.2%})\t' 'Prec2 {:.2%} ({:.2%})\t' 'Prec5 {:.2%} ({:.2%})\t' .format(epoch, losses.val, losses.avg, correct.val, correct.avg, prec1.val, prec1.avg, prec2.val, prec2.avg, prec5.val, prec5.avg )) return correct.avg, losses.avg
def forward(self, data_loader, num_steps=None, training=False, duplicates=1): meters = { name: AverageMeter() for name in ['step', 'data', 'loss', 'prec1', 'prec5'] } if training and self.grad_clip > 0: meters['grad'] = AverageMeter() def meter_results(meters): results = {name: meter.avg for name, meter in meters.items()} results['error1'] = 100. - results['prec1'] results['error5'] = 100. - results['prec5'] return results end = time.time() if training: self.delay_hist = defaultdict(int) for i, (inputs, target) in enumerate(data_loader): if training: self._schedule_worker(self.epoch * len(data_loader) + i) if training and tb.tboard.res_iterations: tb.tboard.update_step(self.epoch * len(data_loader) + i) # measure data loading time meters['data'].update(time.time() - end) target = target.to(self.device) inputs = inputs.to(self.device, dtype=self.dtype) if duplicates > 1: # multiple versions for each sample (dim 1) target = target.view(-1, 1).expand(-1, inputs.size(1)) inputs = inputs.flatten(0, 1) target = target.flatten(0, 1) output, loss, grad = self._step(inputs, target, training=training) # measure accuracy and record loss prec1, prec5 = accuracy(output.detach(), target, topk=(1, 5)) meters['loss'].update(float(loss), inputs.size(0)) meters['prec1'].update(float(prec1), inputs.size(0)) meters['prec5'].update(float(prec5), inputs.size(0)) if grad is not None: meters['grad'].update(float(grad), inputs.size(0)) # measure elapsed time meters['step'].update(time.time() - end) if training and tb.tboard.res_iterations: tb.tboard.log_results( training_loss_iter=float(loss), training_error1_iter=100 - float(prec1), iterations=self.epoch * len(data_loader) + i) end = time.time() if i % self.print_freq == 0: errors = { 'error1_val': 100 - meters['prec1'].val, 'error5_val': 100 - meters['prec5'].val, 'error1_avg': 100 - meters['prec1'].avg, 'error5_avg': 100 - meters['prec5'].avg } report = str( '{phase} - Epoch: [{0}][{1}/{2}]\t' 'Time {meters[step].val:.3f} ({meters[step].avg:.3f})\t' 'Data {meters[data].val:.3f} ({meters[data].avg:.3f})\t' 'Loss {meters[loss].val:.4f} ({meters[loss].avg:.4f})\t' 'Error@1 {errors[error1_val]:.3f} ({errors[error1_avg]:.3f})\t' 'Error@5 {errors[error5_val]:.3f} ({errors[error5_avg]:.3f})\t' .format(self.epoch, i, len(data_loader), phase='TRAINING' if training else 'EVALUATING', meters=meters, errors=errors)) if 'grad' in meters.keys(): report += 'Grad {meters[grad].val:.3f} ({meters[grad].avg:.3f})' \ .format(meters=meters) logging.info(report) if num_steps is not None and i >= num_steps: break return meter_results(meters)
def train(self, epoch): self.model.train() correct = AverageMeter() losses = AverageMeter() prec1 = AverageMeter() prec2 = AverageMeter() prec5 = AverageMeter() for step, (imgs, labels, orig_attrs) in enumerate(self.train_loader): imgs, labels = imgs.cuda(), labels.cuda() pred_attrs = [] if self.with_attribute: orig_attrs = orig_attrs.cuda() attrs = orig_attrs.detach().clone() attrs[attrs > self.xi] = 1. attrs[attrs <= self.xi] = 0. pred_id, pred_attrs = self.model(imgs, orig_attrs) assert pred_attrs.shape[-1] == 134 else: pred_id = self.model(imgs, orig_attrs) assert pred_id.shape[-1] == self.num_classes if self.with_attribute: loss = self.criterion[0](pred_id, labels) loss_attrs = self.criterion[1](pred_attrs.float(), attrs.float()) if epoch > 15: loss += loss_attrs else: loss = self.criterion(pred_id, labels) #clip_grad_norm_(self.model.parameters(), max_norm=10.0) self.optimizer.zero_grad() loss.backward() self.optimizer.step() prec = accuracy(pred_id.data, labels.data, topk=(1, 2, 5)) losses.update(loss.item(), labels.size(0)) prec1.update(prec[0].item(), labels.size(0)) prec2.update(prec[1].item(), labels.size(0)) prec5.update(prec[2].item(), labels.size(0)) y_pred = pred_id.argmax(dim=1) acc = (y_pred == labels).sum().item() / labels.size(0) * 100 correct.update(acc, labels.size(0)/100.) # tensorboard if self.summary_writer is not None: global_step = epoch * len(self.train_loader) + step self.summary_writer.add_scalar('train_loss', loss.item(), global_step) self.summary_writer.add_scalar('train_acc', 1. * correct.avg, global_step) self.summary_writer.add_scalar('prec1', prec1.avg, global_step) self.summary_writer.add_scalar('prec2', prec2.avg, global_step) self.summary_writer.add_scalar('prec5', prec5.avg, global_step) if (step + 1) % 10 == 0: print('[{}] ' 'Loss {:.3f} ({:.3f})\t' 'Acc {:.2f} ({:.2f})\t' 'Prec1 {:.2%} ({:.2%})\t' 'Prec2 {:.2%} ({:.2%})\t' 'Prec5 {:.2%} ({:.2%})\t' .format(step + 1, losses.val, losses.avg, correct.val, correct.avg, prec1.val, prec2.avg, prec2.val, prec2.avg, prec5.val, prec5.avg )) return correct.avg, losses.avg
def forward(data_loader, model, criterion, epoch=0, training=True, optimizer=None): regularizer = getattr(model, 'regularization', None) if args.device_ids and len(args.device_ids) > 1: model = torch.nn.DataParallel(model, args.device_ids) batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() end = time.time() for i, (inputs, target) in enumerate(data_loader): # measure data loading time data_time.update(time.time() - end) target = target.to(args.device) inputs = inputs.to(args.device, dtype=dtype) # compute output output = model(inputs) loss = criterion(output, target) if regularizer is not None: loss += regularizer(model) if type(output) is list: output = output[0] # measure accuracy and record loss prec1, prec5 = accuracy(output.detach(), target, topk=(1, 5)) losses.update(float(loss), inputs.size(0)) top1.update(float(prec1), inputs.size(0)) top5.update(float(prec5), inputs.size(0)) if training: optimizer.update(epoch, epoch * len(data_loader) + i) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: logging.info('{phase} - Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch, i, len(data_loader), phase='TRAINING' if training else 'EVALUATING', batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5)) return losses.avg, top1.avg, top5.avg
class Trainer(BaseTrainer): def __init__(self, cfg, network, optimizer, loss, lr_scheduler, device, trainloader, testloader, writer): super(Trainer, self).__init__(cfg, network, optimizer, loss, lr_scheduler, device, trainloader, testloader, writer) self.network = self.network.to(device) self.train_loss_metric = AverageMeter(writer=writer, name='Loss/train', length=len(self.trainloader)) self.train_acc_metric = AverageMeter(writer=writer, name='Accuracy/train', length=len(self.trainloader)) self.val_loss_metric = AverageMeter(writer=writer, name='Loss/val', length=len(self.testloader)) self.val_acc_metric = AverageMeter(writer=writer, name='Accuracy/val', length=len(self.testloader)) self.best_val_acc = 0 def load_model(self): saved_name = os.path.join( self.cfg['output_dir'], '{}_{}.pth'.format(self.cfg['model']['base'], self.cfg['dataset']['name'])) state = torch.load(saved_name) self.optimizer.load_state_dict(state['optimizer']) self.network.load_state_dict(state['state_dict']) def save_model(self, epoch): if not os.path.exists(self.cfg['output_dir']): os.makedirs(self.cfg['output_dir']) saved_name = os.path.join( self.cfg['output_dir'], '{}_{}.pth'.format(self.cfg['model']['base'], self.cfg['dataset']['name'])) state = { 'epoch': epoch, 'state_dict': self.network.state_dict(), 'optimizer': self.optimizer.state_dict() } torch.save(state, saved_name) def train_one_epoch(self, epoch): self.network.train() self.train_loss_metric.reset(epoch) self.train_acc_metric.reset(epoch) for i, (img, mask, label) in enumerate(self.trainloader): img, mask, label = img.to(self.device), mask.to( self.device), label.to(self.device) net_mask, net_label = self.network(img) self.optimizer.zero_grad() loss = self.loss(net_mask, net_label, mask, label) loss.backward() self.optimizer.step() # Calculate predictions preds = predict(net_mask, net_label, score_type=self.cfg['test']['score_type']) targets = predict(mask, label, score_type=self.cfg['test']['score_type']) acc = calc_acc(preds, targets) # Update metrics self.train_loss_metric.update(loss.item()) self.train_acc_metric.update(acc) print('Epoch: {}, iter: {}, loss: {}, acc: {}'.format( epoch, epoch * len(self.trainloader) + i, self.train_loss_metric.avg, self.train_acc_metric.avg)) def train(self): for epoch in range(self.cfg['train']['num_epochs']): self.train_one_epoch(epoch) epoch_acc = self.validate(epoch) if epoch_acc > self.best_val_acc: self.best_val_acc = epoch_acc self.save_model(epoch) def validate(self, epoch): self.network.eval() self.val_loss_metric.reset(epoch) self.val_acc_metric.reset(epoch) seed = randint(0, len(self.testloader) - 1) for i, (img, mask, label) in enumerate(self.testloader): img, mask, label = img.to(self.device), mask.to( self.device), label.to(self.device) net_mask, net_label = self.network(img) loss = self.loss(net_mask, net_label, mask, label) # Calculate predictions preds = predict(net_mask, net_label, score_type=self.cfg['test']['score_type']) targets = predict(mask, label, score_type=self.cfg['test']['score_type']) acc = calc_acc(preds, targets) # Update metrics self.val_loss_metric.update(loss.item()) self.val_acc_metric.update(acc) if i == seed: add_images_tb(self.cfg, epoch, img, preds, targets, self.writer) return self.val_acc_metric.avg
def train(epoch, train_loader, model, criterion, optimizers, summary_writer): global center_criterion batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() if not os.path.exists(cfg.TRAIN.SNAPSHOT_DIR): os.makedirs(cfg.TRAIN.SNAPSHOT_DIR) # start training model.train() start = time.time() for ii, datas in enumerate(train_loader): data_time.update(time.time() - start) img, bag_id, cam_id = datas if cfg.CUDA: img = img.cuda() bag_id = bag_id.cuda() triplet_features, softmax_features = model(img) for optimizer in optimizers: optimizer.zero_grad() loss = criterion(softmax_features, triplet_features, bag_id) loss.backward() for param in center_criterion.parameters(): param.grad.data *= (1. / cfg.TRAIN.CENTER_LOSS_WEIGHT) for optimizer in optimizers: optimizer.step() batch_time.update(time.time() - start) losses.update(loss.item()) # tensorboard if summary_writer: global_step = epoch * len(train_loader) + ii summary_writer.add_scalar('loss', loss.item(), global_step) start = time.time() if (ii + 1) % cfg.TRAIN.PRINT_FREQ == 0: logger.info('Epoch: [{}][{}/{}]\t' 'Batch Time {:.3f} ({:.3f})\t' 'Data Time {:.3f} ({:.3f})\t' 'Loss {:.3f} ({:.3f}) \t'.format( epoch + 1, ii + 1, len(train_loader), batch_time.val, batch_time.mean, data_time.val, data_time.mean, losses.val, losses.mean)) adam_param_groups = optimizers[0].param_groups logger.info('Epoch: [{}]\tEpoch Time {:.3f} s\tLoss {:.3f}\t' 'Adam Lr {:.2e} \t '.format(epoch + 1, batch_time.sum, losses.mean, adam_param_groups[0]['lr']))
class ReidSystem(): def __init__(self, cfg, logger, writer): self.cfg, self.logger, self.writer = cfg, logger, writer # Define dataloader self.tng_dataloader, self.val_dataloader, self.num_classes, self.num_query = get_dataloader( cfg) # networks self.model = build_model(cfg, self.num_classes) # loss function self.ce_loss = nn.CrossEntropyLoss() self.triplet = TripletLoss(cfg.SOLVER.MARGIN) # optimizer and scheduler self.opt = make_optimizer(self.cfg, self.model) self.lr_sched = make_lr_scheduler(self.cfg, self.opt) self._construct() def _construct(self): self.global_step = 0 self.current_epoch = 0 self.batch_nb = 0 self.max_epochs = self.cfg.SOLVER.MAX_EPOCHS self.log_interval = self.cfg.SOLVER.LOG_INTERVAL self.eval_period = self.cfg.SOLVER.EVAL_PERIOD self.use_dp = False self.use_ddp = False def loss_fns(self, outputs, labels): ce_loss = self.ce_loss(outputs[0], labels) triplet_loss = self.triplet(outputs[1], labels)[0] return {'ce_loss': ce_loss, 'triplet': triplet_loss} def on_train_begin(self): self.best_mAP = -np.inf self.running_loss = AverageMeter() log_save_dir = os.path.join(self.cfg.OUTPUT_DIR, self.cfg.DATASETS.TEST_NAMES, self.cfg.MODEL.VERSION) self.model_save_dir = os.path.join(log_save_dir, 'ckpts') if not os.path.exists(self.model_save_dir): os.makedirs(self.model_save_dir) self.gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',') self.use_dp = (len(self.gpus) > 0) and (self.cfg.MODEL.DIST_BACKEND == 'dp') if self.use_dp: self.model = nn.DataParallel(self.model) self.model = self.model.cuda() self.model.train() def on_epoch_begin(self): self.batch_nb = 0 self.current_epoch += 1 self.t0 = time.time() self.running_loss.reset() self.tng_prefetcher = data_prefetcher(self.tng_dataloader) def training_step(self, batch): inputs, labels, _ = batch outputs = self.model(inputs, labels) loss_dict = self.loss_fns(outputs, labels) total_loss = 0 print_str = f'\r Epoch {self.current_epoch} Iter {self.batch_nb}/{len(self.tng_dataloader)} ' for loss_name, loss_value in loss_dict.items(): total_loss += loss_value print_str += (loss_name + f': {loss_value.item():.3f} ') loss_dict['total_loss'] = total_loss.item() print_str += f'Total loss: {total_loss.item():.3f} ' print(print_str, end=' ') if (self.global_step + 1) % self.log_interval == 0: self.writer.add_scalar('cross_entropy_loss', loss_dict['ce_loss'], self.global_step) self.writer.add_scalar('triplet_loss', loss_dict['triplet'], self.global_step) self.writer.add_scalar('total_loss', loss_dict['total_loss'], self.global_step) self.running_loss.update(total_loss.item()) self.opt.zero_grad() total_loss.backward() self.opt.step() self.global_step += 1 self.batch_nb += 1 def on_epoch_end(self): elapsed = time.time() - self.t0 mins = int(elapsed) // 60 seconds = int(elapsed - mins * 60) print('') self.logger.info( f'Epoch {self.current_epoch} Total loss: {self.running_loss.avg:.3f} ' f'lr: {self.opt.param_groups[0]["lr"]:.2e} During {mins:d}min:{seconds:d}s' ) # update learning rate self.lr_sched.step() def test(self): # convert to eval mode self.model.eval() feats, pids, camids = [], [], [] val_prefetcher = data_prefetcher(self.val_dataloader) batch = val_prefetcher.next() while batch[0] is not None: img, pid, camid = batch with torch.no_grad(): feat = self.model(img) feats.append(feat) pids.extend(pid.cpu().numpy()) camids.extend(np.asarray(camid)) batch = val_prefetcher.next() feats = torch.cat(feats, dim=0) if self.cfg.TEST.NORM: feats = F.normalize(feats, p=2, dim=1) # query qf = feats[:self.num_query] q_pids = np.asarray(pids[:self.num_query]) q_camids = np.asarray(camids[:self.num_query]) # gallery gf = feats[self.num_query:] g_pids = np.asarray(pids[self.num_query:]) g_camids = np.asarray(camids[self.num_query:]) # m, n = qf.shape[0], gf.shape[0] distmat = torch.mm(qf, gf.t()).cpu().numpy() # distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ # torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() # distmat.addmm_(1, -2, qf, gf.t()) # distmat = distmat.numpy() cmc, mAP = evaluate(-distmat, q_pids, g_pids, q_camids, g_camids) self.logger.info(f"Test Results - Epoch: {self.current_epoch}") self.logger.info(f"mAP: {mAP:.1%}") for r in [1, 5, 10]: self.logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}") self.writer.add_scalar('rank1', cmc[0], self.global_step) self.writer.add_scalar('mAP', mAP, self.global_step) metric_dict = {'rank1': cmc[0], 'mAP': mAP} # convert to train mode self.model.train() return metric_dict def train(self): self.on_train_begin() for epoch in range(self.max_epochs): self.on_epoch_begin() batch = self.tng_prefetcher.next() while batch[0] is not None: self.training_step(batch) batch = self.tng_prefetcher.next() self.on_epoch_end() if (epoch + 1) % self.eval_period == 0: metric_dict = self.test() if metric_dict['mAP'] > self.best_mAP: is_best = True self.best_mAP = metric_dict['mAP'] else: is_best = False self.save_checkpoints(is_best) torch.cuda.empty_cache() def save_checkpoints(self, is_best): if self.use_dp: state_dict = self.model.module.state_dict() else: state_dict = self.model.state_dict() # TODO: add optimizer state dict and lr scheduler filepath = os.path.join(self.model_save_dir, f'model_epoch{self.current_epoch}.pth') torch.save(state_dict, filepath) if is_best: best_filepath = os.path.join(self.model_save_dir, 'model_best.pth') shutil.copyfile(filepath, best_filepath)
def train(train_loader, model, criterion, optimizer, epoch, device): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() topk = [AverageMeter() for i in range(4)] # switch to train mode model.train() end = time.time() for i, (input, target) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) target = target.float() if input.dim() > 4: input = input.reshape(input.shape[0] * input.shape[1], input.shape[2], input.shape[3], input.shape[4]) #target = target.float() target = target.reshape(target.shape[0] * target.shape[1], target.shape[2]) #target = torch.from_numpy(target).float() input = input.to(device) target = target.to(device) # compute output output = model(input) output = output.cpu() target = target.cpu() loss = criterion(output, target) # measure accuracy and record loss prec = accuracy(output, target, topk=4) for k in range(4): topk[k].update(prec[k], input.size(0)) losses.update(loss.item(), input.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\n' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@2 {top2.val:.3f} ({top2.avg:.3f})\t' 'Prec@3 {top3.val:.3f} ({top3.avg:.3f})\t' 'Prec@4 {top4.val:.3f} ({top4.avg:.3f})\t'.format( epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=topk[0], top2=topk[1], top3=topk[2], top4=topk[3]))
def forward(self, data_loader, num_steps=None, training=False, average_output=False, chunk_batch=1): meters = { name: AverageMeter() for name in ['step', 'data', 'loss', 'prec1', 'prec5'] } if training and self.grad_clip > 0: meters['grad'] = AverageMeter() batch_first = True if training and isinstance(self.model, nn.DataParallel) or chunk_batch > 1: batch_first = False def meter_results(meters): results = {name: meter.avg for name, meter in meters.items()} results['error1'] = 100. - results['prec1'] results['error5'] = 100. - results['prec5'] return results end = time.time() for i, (inputs, target) in enumerate(data_loader): duplicates = inputs.dim() > 4 # B x D x C x H x W if training and duplicates and self.adapt_grad_norm is not None \ and i % self.adapt_grad_norm == 0: grad_mean = 0 num = inputs.size(1) for j in range(num): grad_mean += float( self._grad_norm(inputs.select(1, j), target)) grad_mean /= num grad_all = float( self._grad_norm( *_flatten_duplicates(inputs, target, batch_first))) self.grad_scale = grad_mean / grad_all logging.info('New loss scale: %s', self.grad_scale) # measure data loading time meters['data'].update(time.time() - end) if duplicates: # multiple versions for each sample (dim 1) inputs, target = _flatten_duplicates( inputs, target, batch_first, expand_target=not average_output) output, loss, grad = self._step(inputs, target, training=training, average_output=average_output, chunk_batch=chunk_batch) # measure accuracy and record loss prec1, prec5 = accuracy(output, target, topk=(1, 5)) meters['loss'].update(float(loss), inputs.size(0)) meters['prec1'].update(float(prec1), inputs.size(0)) meters['prec5'].update(float(prec5), inputs.size(0)) if grad is not None: meters['grad'].update(float(grad), inputs.size(0)) # measure elapsed time meters['step'].update(time.time() - end) end = time.time() if i % self.print_freq == 0 or i == len(data_loader) - 1: report = str( '{phase} - Epoch: [{0}][{1}/{2}]\t' 'Time {meters[step].val:.3f} ({meters[step].avg:.3f})\t' 'Data {meters[data].val:.3f} ({meters[data].avg:.3f})\t' 'Loss {meters[loss].val:.4f} ({meters[loss].avg:.4f})\t' 'Prec@1 {meters[prec1].val:.3f} ({meters[prec1].avg:.3f})\t' 'Prec@5 {meters[prec5].val:.3f} ({meters[prec5].avg:.3f})\t' .format(self.epoch, i, len(data_loader), phase='TRAINING' if training else 'EVALUATING', meters=meters)) if 'grad' in meters.keys(): report += 'Grad {meters[grad].val:.3f} ({meters[grad].avg:.3f})'\ .format(meters=meters) logging.info(report) self.observe(trainer=self, model=self._model, optimizer=self.optimizer, data=(inputs, target)) self.stream_meters(meters, prefix='train' if training else 'eval') if training: self.write_stream( 'lr', (self.training_steps, self.optimizer.get_lr()[0])) if num_steps is not None and i >= num_steps: break return meter_results(meters)
def train(self, epoch, data_loader, optimizer1, optimizer2): self.model.train() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() precisions = AverageMeter() precisions1 = AverageMeter() end = time.time() for i, inputs in enumerate(data_loader): data_time.update(time.time() - end) inputs, targets = self._parse_data(inputs) loss, prec_oim, prec_score = self._forward(inputs, targets) losses.update(loss.item(), targets.size(0)) precisions.update(prec_oim, targets.size(0)) precisions1.update(prec_score, targets.size(0)) optimizer1.zero_grad() optimizer2.zero_grad() loss.backward() optimizer1.step() optimizer2.step() batch_time.update(time.time() - end) end = time.time() print_freq = 50 if (i + 1) % print_freq == 0: print('Epoch: [{}][{}/{}]\t' 'Loss {:.3f} ({:.3f})\t' 'prec_oim {:.2%} ({:.2%})\t' 'prec_score {:.2%} ({:.2%})\t' .format(epoch, i + 1, len(data_loader), losses.val, losses.avg, precisions.val, precisions.avg, precisions1.val, precisions1.avg))
def evaluate(self, query_loader, gallery_loader, queryinfo, galleryinfo): self.cnn_model.eval() self.att_model.eval() self.classifier_model.eval() querypid = queryinfo.pid querycamid = queryinfo.camid querytranum = queryinfo.tranum gallerypid = galleryinfo.pid gallerycamid = galleryinfo.camid gallerytranum = galleryinfo.tranum pooled_probe, hidden_probe = self.extract_feature(query_loader) querylen = len(querypid) gallerylen = len(gallerypid) # online gallery extraction single_distmat = np.zeros((querylen, gallerylen)) gallery_resize = 0 gallery_popindex = 0 gallery_popsize = gallerytranum[gallery_popindex] gallery_resfeatures = 0 gallery_resraw = 0 gallery_empty = True preimgs = 0 preflows = 0 # time gallery_time = AverageMeter() end = time.time() for i, (imgs, flows, _, _) in enumerate(gallery_loader): imgs = to_torch(imgs) flows = to_torch(flows) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") imgs = imgs.to(device) flows = flows.to(device) with torch.no_grad(): seqnum = imgs.size(0) if i == 0: preimgs = imgs preflows = flows if gallery_empty: out_feat, out_raw = self.cnn_model(imgs, flows, self.mode) gallery_resfeatures = out_feat gallery_resraw = out_raw gallery_empty = False elif imgs.size(0) < gallery_loader.batch_size: flaw_batchsize = imgs.size(0) cat_batchsize = gallery_loader.batch_size - flaw_batchsize imgs = torch.cat((imgs, preimgs[0:cat_batchsize]), 0) flows = torch.cat((flows, preflows[0:cat_batchsize]), 0) out_feat, out_raw = self.cnn_model(imgs, flows, self.mode) out_feat = out_feat[0:flaw_batchsize] out_raw = out_raw[0:flaw_batchsize] gallery_resfeatures = torch.cat((gallery_resfeatures, out_feat), 0) gallery_resraw = torch.cat((gallery_resraw, out_raw), 0) else: out_feat, out_raw = self.cnn_model(imgs, flows, self.mode) gallery_resfeatures = torch.cat((gallery_resfeatures, out_feat), 0) gallery_resraw = torch.cat((gallery_resraw, out_raw), 0) gallery_resize = gallery_resize + seqnum while gallery_popsize <= gallery_resize: if (gallery_popindex + 1) % 50 == 0: print('gallery--{:04d}'.format(gallery_popindex)) gallery_popfeatures = gallery_resfeatures[0:gallery_popsize, :] gallery_popraw = gallery_resraw[0:gallery_popsize, :] if gallery_popsize < gallery_resize: gallery_resfeatures = gallery_resfeatures[gallery_popsize:gallery_resize, :] gallery_resraw = gallery_resraw[gallery_popsize:gallery_resize, :] else: gallery_resfeatures = 0 gallery_resraw = 0 gallery_empty = True gallery_resize = gallery_resize - gallery_popsize pooled_gallery, pooled_raw = self.att_model.selfpooling_model(gallery_popfeatures, gallery_popraw) probesize = pooled_probe.size() gallerysize = pooled_gallery.size() probe_batch = probesize[0] gallery_batch = gallerysize[0] gallery_num = gallerysize[1] pooled_gallery.unsqueeze(0) pooled_gallery = pooled_gallery.expand(probe_batch, gallery_batch, gallery_num) encode_scores = self.classifier_model(pooled_probe, pooled_gallery) encode_size = encode_scores.size() encodemat = encode_scores.view(-1, 2) encodemat = F.softmax(encodemat) encodemat = encodemat.view(encode_size[0], encode_size[1], 2) distmat_qall_g = encodemat[:, :, 0] q_start = 0 for qind, qnum in enumerate(querytranum): distmat_qg = distmat_qall_g[q_start:q_start + qnum, :] distmat_qg = distmat_qg.data.cpu().numpy() percile = np.percentile(distmat_qg, 20) if distmat_qg[distmat_qg <= percile] is not None: distmean = np.mean(distmat_qg[distmat_qg <= percile]) else: distmean = np.mean(distmat_qg) single_distmat[qind, gallery_popindex] = distmean q_start = q_start + qnum gallery_popindex = gallery_popindex + 1 if gallery_popindex < gallerylen: gallery_popsize = gallerytranum[gallery_popindex] gallery_time.update(time.time() - end) end = time.time() return evaluate_seq(single_distmat, querypid, querycamid, gallerypid, gallerycamid)