def __init__(self, data_path, download=False, **kwargs): # is not exist and download true, download dataset or stop if not os.path.exists(data_path): if download: print('dataset path {} is not existed, start download dataset'. format(data_path)) self.download_dataset(data_path, self.dataset_url) else: return 'dataset path {} is not existed' # load sample root1, folders, _ = os_walk(data_path) samples = [] for folder in folders: pid = int(folder) root2, _, files = os_walk(osp.join(root1, folder)) files = [ file for file in files if '.png' in file or '.jpg' in file ] for file in files: cam_id = int(file.split('_')[0]) img_path = osp.join(root2, file) samples.append([img_path, pid, cam_id]) # init train = samples query = None gallery = None super(WildTrackCrop, self).__init__(train, query, gallery)
def _load_querygallery_images_path(self, path, idstartfrom=0): txt_path = join(path, 'list.txt') samples = [] if os.path.exists(txt_path): txt_file = open(txt_path, 'r') lines = txt_file.readlines() for line in lines: img_path, pid, cid = line.split(',') pid = int(pid) cid = int(cid) samples.append( [img_path, pid + idstartfrom, cid + idstartfrom]) txt_file.close() else: txt_file = open(txt_path, 'w') _, _, files = os_walk(path) for file in files: if '.jpg' in file: pid = int(file.split('_')[0]) cid = int(file.split('_')[1][5:7]) txt_file.writelines('{path},{pid},{cid}\n'.format( path=join(path, file), pid=pid, cid=cid)) samples.append([ join(path, file), pid + idstartfrom, cid + idstartfrom ]) txt_file.close() return samples
def save_model(self, save_epoch): """ save model parameters (only state_dict) in self.results_dir/model_{epoch}.pth save model (architecture and state_dict) in self.results_dir/final_model.pth.tar, may be used as a teacher """ model_path = os.path.join(self.results_dir, 'model_{}.pth'.format(save_epoch)) torch.save(self.model.state_dict(), model_path) root, _, files = os_walk(self.results_dir) pth_files = [file for file in files if '.pth' in file and file != 'final_model.pth.tar'] if len(pth_files) > 1: pth_epochs = sorted([int(pth_file.replace('.pth', '').split('_')[1]) for pth_file in pth_files], reverse=False) model_path = os.path.join(root, 'model_{}.pth'.format(pth_epochs[0])) os.remove(model_path) torch.save(self.model, os.path.join(self.results_dir, 'final_model.pth.tar'))
def resume_latest_model(self): ''' resume from the latest model in path self.results_dir ''' root, _, files = os_walk(self.results_dir) pth_files = [file for file in files if '.pth' in file and file != 'final_model.pth.tar'] if len(pth_files) != 0: pth_epochs = [int(pth_file.replace('.pth', '').split('_')[1]) for pth_file in pth_files] max_epoch = max(pth_epochs) model_path = os.path.join(root, 'model_{}.pth'.format(max_epoch)) self.model.load_state_dict(torch.load(model_path), strict=True) self.logging(time_now(), 'restore from {}'.format(model_path)) return max_epoch else: return None
def _load_images_path(self, path, idstartfrom=0): txt_path = join(path, 'list.txt') samples = [] if os.path.exists(txt_path): txt_file = open(txt_path, 'r') lines = txt_file.readlines() for line in lines: img_path, pid, cid = line.split(',') pid = int(pid) cid = int(cid) samples.append( [img_path, pid + idstartfrom, cid + idstartfrom]) txt_file.close() else: txt_file = open(txt_path, 'w') _, folders, _ = os_walk(path) samples = [] for folder in folders: _, _, files = os_walk(os.path.join(path, folder)) for file in files: if '.jpg' in file or '.png' in file: pid = int(folder) try: cid = int(file[5:7]) except: continue txt_file.writelines('{path},{pid},{cid}\n'.format( path=join(join(path, folder), file), pid=pid, cid=cid)) samples.append([ os.path.join(os.path.join(path, folder), file), pid + idstartfrom, cid + idstartfrom ]) txt_file.close() return samples