def main(): model = AFB_URR(device, update_bank=True, load_imagenet_params=False) model = model.to(device) model.eval() if args.resume: if os.path.isfile(args.resume): checkpoint = torch.load(args.resume) end_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['model'], strict=False) train_loss = checkpoint['loss'] seed = checkpoint['seed'] print( myutils.gct(), f'Loaded checkpoint {args.resume}. (end_epoch: {end_epoch}, train_loss: {train_loss}, seed: {seed})' ) else: print(myutils.gct(), f'No checkpoint found at {args.resume}') raise IOError if args.level == 1: model_name = 'AFB-URR_DAVIS_17val' dataset = DAVIS_Test_DS(args.dataset, '2017/val.txt') elif args.level == 2: model_name = 'AFB-URR_YoutubeVOS' dataset = YouTube_Test_DS(args.dataset) elif args.level == 3: model_name = 'AFB-URR_LongVideo' dataset = LongVideo_Test_DS(args.dataset, 'val.txt') else: raise ValueError(f'{args.level} is unknown.') if args.prefix: model_name += f'_{args.prefix}' dataloader = utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) print(myutils.gct(), f'Model name: {model_name}') if args.level == 1: eval_DAVIS(model, model_name, dataloader) elif args.level == 2: eval_YouTube(model, model_name, dataloader) elif args.level == 3: eval_LongVideo(model, model_name, dataloader)
def __init__(self, root, output_size, dataset_file='./assets/pretrain.txt', clip_n=3, max_obj_n=11): self.root = root self.clip_n = clip_n self.output_size = output_size self.max_obj_n = max_obj_n self.img_list = list() self.mask_list = list() dataset_list = list() with open(os.path.join(dataset_file), 'r') as lines: for line in lines: dataset_name = line.strip() img_dir = os.path.join(root, 'JPEGImages', dataset_name) mask_dir = os.path.join(root, 'Annotations', dataset_name) img_list = sorted(glob(os.path.join(img_dir, '*.jpg'))) + sorted(glob(os.path.join(img_dir, '*.png'))) mask_list = sorted(glob(os.path.join(mask_dir, '*.png'))) if len(img_list) > 0: if len(img_list) == len(mask_list): dataset_list.append(dataset_name) self.img_list += img_list self.mask_list += mask_list print(f'\t{dataset_name}: {len(img_list)} imgs.') else: print(f'\tPreTrain dataset {dataset_name} has {len(img_list)} imgs and {len(mask_list)} annots. Not match! Skip.') else: print(f'\tPreTrain dataset {dataset_name} doesn\'t exist. Skip.') print(myutils.gct(), f'{len(self.img_list)} imgs are used for PreTrain. They are from {dataset_list}.') self.random_horizontal_flip = mytrans.RandomHorizontalFlip(0.3) self.color_jitter = TF.ColorJitter(0.1, 0.1, 0.1, 0.03) self.random_affine = mytrans.RandomAffine(degrees=20, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=10) self.random_resize_crop = mytrans.RandomResizedCrop(output_size, (0.8, 1)) self.to_tensor = TF.ToTensor() self.to_onehot = mytrans.ToOnehot(max_obj_n, shuffle=True)
def eval_DAVIS(model, model_name, dataloader): fps = myutils.FrameSecondMeter() for seq_idx, V in enumerate(dataloader): frames, masks, obj_n, info = V seq_name = info['name'][0] obj_n = obj_n.item() seg_dir = os.path.join('./output', model_name, seq_name) if not os.path.exists(seg_dir): os.makedirs(seg_dir) if args.viz: overlay_dir = os.path.join('./overlay', model_name, seq_name) if not os.path.exists(overlay_dir): os.makedirs(overlay_dir) frames, masks = frames[0].to(device), masks[0].to(device) frame_n = info['num_frames'][0].item() pred_mask = masks[0:1] pred = torch.argmax(pred_mask[0], dim=0).cpu().numpy().astype(np.uint8) seg_path = os.path.join(seg_dir, '00000.png') myutils.save_seg_mask(pred, seg_path, palette) if args.viz: overlay_path = os.path.join(overlay_dir, '00000.png') myutils.save_overlay(frames[0], pred, overlay_path, palette) fb = FeatureBank(obj_n, args.budget, device, update_rate=args.update_rate, thres_close=args.merge_thres) k4_list, v4_list = model.memorize(frames[0:1], pred_mask) fb.init_bank(k4_list, v4_list) for t in tqdm(range(1, frame_n), desc=f'{seq_idx} {seq_name}'): score, _ = model.segment(frames[t:t + 1], fb) pred_mask = F.softmax(score, dim=1) pred = torch.argmax(pred_mask[0], dim=0).cpu().numpy().astype(np.uint8) seg_path = os.path.join(seg_dir, f'{t:05d}.png') myutils.save_seg_mask(pred, seg_path, palette) if t < frame_n - 1: k4_list, v4_list = model.memorize(frames[t:t + 1], pred_mask) fb.update(k4_list, v4_list, t) if args.viz: overlay_path = os.path.join(overlay_dir, f'{t:05d}.png') myutils.save_overlay(frames[t], pred, overlay_path, palette) fps.add_frame_n(frame_n) fps.end() print(myutils.gct(), 'fps:', fps.fps)
dataloader = utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) print(myutils.gct(), f'Model name: {model_name}') if args.level == 1: eval_DAVIS(model, model_name, dataloader) elif args.level == 2: eval_YouTube(model, model_name, dataloader) elif args.level == 3: eval_LongVideo(model, model_name, dataloader) if __name__ == '__main__': args = get_args() print(myutils.gct(), 'Args =', args) if args.gpu >= 0 and torch.cuda.is_available(): device = torch.device('cuda', args.gpu) else: raise ValueError('CUDA is required. --gpu must be >= 0.') palette = Image.open( os.path.join('./assets/mask_palette.png')).getpalette() main() print(myutils.gct(), 'Evaluation done.')
def main(): # torch.autograd.set_detect_anomaly(True) if args.level == 0: dataset = PreTrain_DS(args.dataset, output_size=400, clip_n=args.clip_n, max_obj_n=args.obj_n) desc = 'Pre Train' elif args.level == 1: dataset = DAVIS_Train_DS(args.dataset, output_size=400, clip_n=args.clip_n, max_obj_n=args.obj_n) desc = 'Train DAVIS17' elif args.level == 2: dataset = YouTube_Train_DS(args.dataset, output_size=400, clip_n=args.clip_n, max_obj_n=args.obj_n) desc = 'Train YV18' else: raise ValueError(f'{args.level} is unknown.') dataloader = data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2, pin_memory=True) print(myutils.gct(), f'Load level {args.level} dataset: {len(dataset)} training cases.') model = AFB_URR(device, update_bank=False, load_imagenet_params=True) model = model.to(device) model.train() model.apply(myutils.set_bn_eval) # turn-off BN params = model.parameters() optimizer = torch.optim.AdamW(filter(lambda x: x.requires_grad, params), args.lr) start_epoch = 0 best_loss = 100000000 if args.resume: if os.path.isfile(args.resume): checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['model'], strict=False) seed = checkpoint['seed'] if not args.new: start_epoch = checkpoint['epoch'] + 1 optimizer.load_state_dict(checkpoint['optimizer']) best_loss = checkpoint['loss'] print( myutils.gct(), f'Loaded checkpoint {args.resume} (epoch: {start_epoch-1}, best loss: {best_loss})' ) else: if args.seed < 0: seed = int(time.time()) else: seed = args.seed print( myutils.gct(), f'Loaded checkpoint {args.resume}. Train from the beginning.' ) else: print(myutils.gct(), f'No checkpoint found at {args.resume}') raise IOError else: if args.seed < 0: seed = int(time.time()) else: seed = args.seed print(myutils.gct(), 'Random seed:', seed) torch.manual_seed(seed) np.random.seed(seed) criterion = torch.nn.CrossEntropyLoss().to(device) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.scheduler_step, gamma=0.5, last_epoch=start_epoch - 1) for epoch in range(start_epoch, args.total_epochs): lr = scheduler.get_last_lr()[0] print('') print(myutils.gct(), f'Epoch: {epoch} lr: {lr}') loss = train_model(model, dataloader, criterion, optimizer, desc) if args.log: checkpoint = { 'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'loss': loss, 'seed': seed, } checkpoint_path = f'{model_path}/final.pth' torch.save(checkpoint, checkpoint_path) if best_loss > loss: best_loss = loss checkpoint_path = f'{model_path}/epoch_{epoch:03d}_loss_{loss:.03f}.pth' torch.save(checkpoint, checkpoint_path) checkpoint_path = f'{model_path}/best.pth' torch.save(checkpoint, checkpoint_path) print('Best model updated.') scheduler.step()
checkpoint_path = f'{model_path}/epoch_{epoch:03d}_loss_{loss:.03f}.pth' torch.save(checkpoint, checkpoint_path) checkpoint_path = f'{model_path}/best.pth' torch.save(checkpoint, checkpoint_path) print('Best model updated.') scheduler.step() if __name__ == '__main__': args = get_args() print(myutils.gct(), f'Args = {args}') if args.gpu >= 0 and torch.cuda.is_available(): device = torch.device('cuda', args.gpu) else: raise ValueError('CUDA is required. --gpu must be >= 0.') if args.log: if not os.path.exists('logs'): os.makedirs('logs') prefix = f'level{args.level}' log_dir = 'logs/{}'.format(time.strftime(prefix + '_%Y%m%d-%H%M%S')) log_path = os.path.join(log_dir, 'log') model_path = os.path.join(log_dir, 'model') if not os.path.exists(log_path):