def main(_run, _config, _log): logdir = f'{_run.observers[0].dir}/' print(logdir) category = [ 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' ] if _run.observers: os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True) for source_file, _ in _run.experiment_info['sources']: os.makedirs(os.path.dirname( f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') data_name = _config['dataset'] max_label = 20 if data_name == 'VOC' else 80 set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) print(_config['ckpt_dir']) tbwriter = SummaryWriter(osp.join(_config['ckpt_dir'])) training_tags = { 'loss': "ATraining/total_loss", "query_loss": "ATraining/query_loss", 'aligned_loss': "ATraining/aligned_loss", 'base_loss': "ATraining/base_loss", } infer_tags = { 'mean_iou': "MeanIoU/mean_iou", "mean_iou_binary": "MeanIoU/mean_iou_binary", } _log.info('###### Create model ######') if _config['model']['part']: model = FewshotSegPartResnet( pretrained_path=_config['path']['init_path'], cfg=_config) _log.info('Model: FewshotSegPartResnet') else: model = FewshotSegResnet(pretrained_path=_config['path']['init_path'], cfg=_config) _log.info('Model: FewshotSegResnet') model = nn.DataParallel(model.cuda(), device_ids=[ _config['gpu_id'], ]) model.train() _log.info('###### Load data ######') data_name = _config['dataset'] make_data = voc_fewshot labels = CLASS_LABELS[data_name][_config['label_sets']] transforms = Compose([Resize(size=_config['input_size']), RandomMirror()]) dataset = make_data(base_dir=_config['path'][data_name]['data_dir'], split=_config['path'][data_name]['data_split'], transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=_config['n_steps'] * _config['batch_size'], n_ways=_config['task']['n_ways'], n_shots=_config['task']['n_shots'], n_queries=_config['task']['n_queries'], n_unlabel=_config['task']['n_unlabels'], cfg=_config) trainloader = DataLoader(dataset, batch_size=_config['batch_size'], shuffle=True, num_workers=_config['num_workers'], pin_memory=True, drop_last=True) _log.info('###### Set optimizer ######') if _config['fix']: print('Optimizer: fix') optimizer = torch.optim.SGD( params=[ { "params": model.module.encoder.layer3.parameters(), "lr": _config['optim']['lr'], "weight_decay": _config['optim']['weight_decay'] }, { "params": model.module.encoder.layer4.parameters(), "lr": _config['optim']['lr'], "weight_decay": _config['optim']['weight_decay'] }, ], momentum=_config['optim']['momentum'], ) else: print('Optimizer: Not fix') optimizer = torch.optim.SGD(model.parameters(), **_config['optim']) scheduler = MultiStepLR(optimizer, milestones=_config['lr_milestones'], gamma=0.1) criterion = nn.CrossEntropyLoss(ignore_index=_config['ignore_label']) log_loss = {'loss': 0, 'align_loss': 0, 'base_loss': 0} _log.info('###### Training ######') highest_iou = 0 metrics = {} for i_iter, sample_batched in enumerate(trainloader): if _config['fix']: model.module.encoder.conv1.eval() model.module.encoder.bn1.eval() model.module.encoder.layer1.eval() model.module.encoder.layer2.eval() if _config['eval']: if i_iter == 0: break # Prepare input support_images = [[shot.cuda() for shot in way] for way in sample_batched['support_images']] support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way] for way in sample_batched['support_mask']] support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way] for way in sample_batched['support_mask']] query_images = [ query_image.cuda() for query_image in sample_batched['query_images'] ] query_labels = torch.cat([ query_label.long().cuda() for query_label in sample_batched['query_labels'] ], dim=0) #1*417*417 base_loss = torch.zeros(1).to(torch.device('cuda')) # Forward and Backward optimizer.zero_grad() query_pred, _, align_loss = model(support_images, support_fg_mask, support_bg_mask, query_images) query_loss = criterion(query_pred, query_labels) #1*3*417*417, 1*417*417 loss = query_loss + align_loss * _config[ 'align_loss_scaler'] + base_loss * _config['base_loss_scaler'] loss.backward() optimizer.step() scheduler.step() # Log loss query_loss = query_loss.detach().data.cpu().numpy() align_loss = align_loss.detach().data.cpu().numpy() base_loss = base_loss.detach().data.cpu().numpy() log_loss['loss'] += query_loss log_loss['align_loss'] += align_loss log_loss['base_loss'] += base_loss # print loss and take snapshots if (i_iter + 1) % _config['print_interval'] == 0: loss = log_loss['loss'] / (i_iter + 1) align_loss = log_loss['align_loss'] / (i_iter + 1) base_loss = log_loss['base_loss'] / (i_iter + 1) print( f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss}, base_loss: {base_loss}' ) _log.info( f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss}, base_loss: {base_loss}' ) metrics['loss'] = loss metrics['query_loss'] = query_loss metrics['aligned_loss'] = align_loss metrics['base_loss'] = base_loss # for k, v in metrics.items(): # tbwriter.add_scalar(training_tags[k], v, i_iter) if (i_iter + 1) % _config['evaluate_interval'] == 0: _log.info('###### Evaluation begins ######') print(_config['ckpt_dir']) model.eval() labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][ _config['label_sets']] transforms = [Resize(size=_config['input_size'])] transforms = Compose(transforms) metric = Metric(max_label=max_label, n_runs=_config['n_runs']) with torch.no_grad(): for run in range(1): _log.info(f'### Run {run + 1} ###') set_seed(_config['seed'] + run) _log.info(f'### Load data ###') dataset = make_data( base_dir=_config['path'][data_name]['data_dir'], split=_config['path'][data_name]['data_split'], transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=_config['infer_max_iters'], n_ways=_config['task']['n_ways'], n_shots=_config['task']['n_shots'], n_queries=_config['task']['n_queries'], n_unlabel=_config['task']['n_unlabels'], cfg=_config) testloader = DataLoader(dataset, batch_size=_config['batch_size'], shuffle=False, num_workers=_config['num_workers'], pin_memory=True, drop_last=False) _log.info(f"Total # of Data: {len(dataset)}") for sample_batched in tqdm.tqdm(testloader): label_ids = list(sample_batched['class_ids']) support_images = [[ shot.cuda() for shot in way ] for way in sample_batched['support_images']] suffix = 'mask' support_fg_mask = [[ shot[f'fg_{suffix}'].float().cuda() for shot in way ] for way in sample_batched['support_mask']] support_bg_mask = [[ shot[f'bg_{suffix}'].float().cuda() for shot in way ] for way in sample_batched['support_mask']] query_images = [ query_image.cuda() for query_image in sample_batched['query_images'] ] query_labels = torch.cat([ query_label.cuda() for query_label in sample_batched['query_labels'] ], dim=0) query_pred, _, _ = model(support_images, support_fg_mask, support_bg_mask, query_images) curr_iou = metric.record(query_pred.argmax(dim=1)[0], query_labels[0], labels=label_ids, n_run=run) classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary( n_run=run) _run.log_scalar('classIoU', classIoU.tolist()) _run.log_scalar('meanIoU', meanIoU.tolist()) _run.log_scalar('classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist()) _log.info(f'classIoU: {classIoU}') _log.info(f'meanIoU: {meanIoU}') _log.info(f'classIoU_binary: {classIoU_binary}') _log.info(f'meanIoU_binary: {meanIoU_binary}') print( f'meanIoU: {meanIoU}, meanIoU_binary: {meanIoU_binary}' ) metrics = {} metrics['mean_iou'] = meanIoU metrics['mean_iou_binary'] = meanIoU_binary for k, v in metrics.items(): tbwriter.add_scalar(infer_tags[k], v, i_iter) if meanIoU > highest_iou: print( f'The highest iou is in iter: {i_iter} : {meanIoU}, save: {_config["ckpt_dir"]}/best.pth' ) highest_iou = meanIoU torch.save( model.state_dict(), os.path.join(f'{_config["ckpt_dir"]}/best.pth')) else: print( f'The highest iou is in iter: {i_iter} : {meanIoU}' ) torch.save(model.state_dict(), os.path.join(f'{_config["ckpt_dir"]}/{i_iter + 1}.pth')) model.train() print(_config['ckpt_dir']) _log.info(' --------- Testing begins ---------') labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][ _config['label_sets']] transforms = [Resize(size=_config['input_size'])] transforms = Compose(transforms) ckpt = os.path.join(f'{_config["ckpt_dir"]}/best.pth') print(f'{_config["ckpt_dir"]}/best.pth') model.load_state_dict(torch.load(ckpt, map_location='cpu')) model.eval() metric = Metric(max_label=max_label, n_runs=5) with torch.no_grad(): for run in range(5): n_iter = 0 _log.info(f'### Run {run + 1} ###') set_seed(_config['seed'] + run) _log.info(f'### Load data ###') dataset = make_data( base_dir=_config['path'][data_name]['data_dir'], split=_config['path'][data_name]['data_split'], transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=_config['infer_max_iters'], n_ways=_config['task']['n_ways'], n_shots=_config['task']['n_shots'], n_queries=_config['task']['n_queries'], n_unlabel=_config['task']['n_unlabels'], cfg=_config) testloader = DataLoader(dataset, batch_size=_config['batch_size'], shuffle=False, num_workers=_config['num_workers'], pin_memory=True, drop_last=False) _log.info(f"Total # of Data: {len(dataset)}") for sample_batched in tqdm.tqdm(testloader): label_ids = list(sample_batched['class_ids']) support_images = [[shot.cuda() for shot in way] for way in sample_batched['support_images']] suffix = 'mask' support_fg_mask = [[ shot[f'fg_{suffix}'].float().cuda() for shot in way ] for way in sample_batched['support_mask']] support_bg_mask = [[ shot[f'bg_{suffix}'].float().cuda() for shot in way ] for way in sample_batched['support_mask']] query_images = [ query_image.cuda() for query_image in sample_batched['query_images'] ] query_labels = torch.cat([ query_label.cuda() for query_label in sample_batched['query_labels'] ], dim=0) query_pred, _, _ = model(support_images, support_fg_mask, support_bg_mask, query_images) curr_iou = metric.record(query_pred.argmax(dim=1)[0], query_labels[0], labels=label_ids, n_run=run) n_iter += 1 classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run) _run.log_scalar('classIoU', classIoU.tolist()) _run.log_scalar('meanIoU', meanIoU.tolist()) _run.log_scalar('classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist()) _log.info(f'classIoU: {classIoU}') _log.info(f'meanIoU: {meanIoU}') _log.info(f'classIoU_binary: {classIoU_binary}') _log.info(f'meanIoU_binary: {meanIoU_binary}') classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU( labels=sorted(labels)) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary( ) _run.log_scalar('meanIoU', meanIoU.tolist()) _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist()) _run.log_scalar('final_classIoU', classIoU.tolist()) _run.log_scalar('final_classIoU_std', classIoU_std.tolist()) _run.log_scalar('final_meanIoU', meanIoU.tolist()) _run.log_scalar('final_meanIoU_std', meanIoU_std.tolist()) _run.log_scalar('final_classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('final_classIoU_std_binary', classIoU_std_binary.tolist()) _run.log_scalar('final_meanIoU_binary', meanIoU_binary.tolist()) _run.log_scalar('final_meanIoU_std_binary', meanIoU_std_binary.tolist()) _log.info('----- Final Result -----') _log.info(f'classIoU mean: {classIoU}') _log.info(f'classIoU std: {classIoU_std}') _log.info(f'meanIoU mean: {meanIoU}') _log.info(f'meanIoU std: {meanIoU_std}') _log.info(f'classIoU_binary mean: {classIoU_binary}') _log.info(f'classIoU_binary std: {classIoU_std_binary}') _log.info(f'meanIoU_binary mean: {meanIoU_binary}') _log.info(f'meanIoU_binary std: {meanIoU_std_binary}') _log.info("## ------------------------------------------ ##") _log.info(f'###### Setting: {_run.observers[0].dir} ######') _log.info( "Running {num_run} runs, meanIoU:{miou:.4f}, meanIoU_binary:{mbiou:.4f} " "meanIoU_std:{miou_std:.4f}, meanIoU_binary_std:{mbiou_std:.4f}". format(num_run=5, miou=meanIoU, mbiou=meanIoU_binary, miou_std=meanIoU_std, mbiou_std=meanIoU_std_binary)) _log.info(f"Current setting is {_run.observers[0].dir}") print( "Running {num_run} runs, meanIoU:{miou:.4f}, meanIoU_binary:{mbiou:.4f} " "meanIoU_std:{miou_std:.4f}, meanIoU_binary_std:{mbiou_std:.4f}". format(num_run=5, miou=meanIoU, mbiou=meanIoU_binary, miou_std=meanIoU_std, mbiou_std=meanIoU_std_binary)) print(f"Current setting is {_run.observers[0].dir}") print(_config['ckpt_dir']) print(logdir)
def main(_run, _config, _log): for source_file, _ in _run.experiment_info['sources']: os.makedirs( os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) _log.info('###### Create model ######') model = Encoder(pretrained_path=_config['path']['init_path']) model = nn.DataParallel(model.cuda(), device_ids=[ _config['gpu_id'], ]) if not _config['notrain']: model.load_state_dict( torch.load(_config['snapshot'], map_location='cpu')) model.eval() _log.info('###### Prepare data ######') data_name = _config['dataset'] if data_name == 'VOC': make_data = voc_fewshot max_label = 20 elif data_name == 'COCO': make_data = coco_fewshot max_label = 80 else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][ _config['label_sets']] transforms = [Resize(size=_config['input_size'])] if _config['scribble_dilation'] > 0: transforms.append(DilateScribble(size=_config['scribble_dilation'])) transforms = Compose(transforms) _log.info('###### Testing begins ######') metric = Metric(max_label=max_label, n_runs=_config['n_runs']) with torch.no_grad(): for run in range(_config['n_runs']): _log.info(f'### Run {run + 1} ###') set_seed(_config['seed'] + run) _log.info(f'### Load data ###') dataset = make_data( base_dir=_config['path'][data_name]['data_dir'], split=_config['path'][data_name]['data_split'], transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, label_sets=_config['label_sets'], max_iters=_config['n_steps'] * _config['batch_size'], n_ways=_config['task']['n_ways'], n_shots=_config['task']['n_shots'], n_queries=_config['task']['n_queries']) if _config['dataset'] == 'COCO': coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds() testloader = DataLoader(dataset, batch_size=_config['batch_size'], shuffle=False, num_workers=1, pin_memory=True, drop_last=False) _log.info(f"Total # of Data: {len(dataset)}") for sample_batched in tqdm.tqdm(testloader): if _config['dataset'] == 'COCO': label_ids = [ coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids'] ] else: label_ids = list(sample_batched['class_ids']) support_images = [[shot.cuda() for shot in way] for way in sample_batched['support_images']] support_images = torch.cat( [torch.cat(way, dim=0) for way in support_images], dim=0) suffix = 'scribble' if _config['scribble'] else 'mask' if _config['bbox']: support_fg_mask = [] support_bg_mask = [] for i, way in enumerate(sample_batched['support_mask']): fg_masks = [] bg_masks = [] for j, shot in enumerate(way): fg_mask, bg_mask = get_bbox( shot['fg_mask'], sample_batched['support_inst'][i][j]) fg_masks.append(fg_mask.float().cuda()) bg_masks.append(bg_mask.float().cuda()) support_fg_mask.append(fg_masks) support_bg_mask.append(bg_masks) else: support_fg_mask = [[ shot[f'fg_mask'].float().cuda() for shot in way ] for way in sample_batched['support_mask']] support_fg_mask = torch.cat( [torch.cat(way, dim=0) for way in support_fg_mask], dim=0) query_images = [ query_image.cuda() for query_image in sample_batched['query_images'] ] query_images = torch.cat(query_images, dim=0) query_labels = torch.cat([ query_label.long().cuda() for query_label in sample_batched['query_labels'] ], dim=0) query_pred = model(support_images, query_images, support_fg_mask) query_pred = F.interpolate(query_pred, size=query_images.shape[-2:], mode='bilinear') metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()), np.array(query_labels[0].cpu()), labels=label_ids, n_run=run) classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run) _run.log_scalar('classIoU', classIoU.tolist()) _run.log_scalar('meanIoU', meanIoU.tolist()) _run.log_scalar('classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist()) _log.info(f'classIoU: {classIoU}') _log.info(f'meanIoU: {meanIoU}') _log.info(f'classIoU_binary: {classIoU_binary}') _log.info(f'meanIoU_binary: {meanIoU_binary}') classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU( labels=sorted(labels)) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary( ) _log.info('----- Final Result -----') _run.log_scalar('final_classIoU', classIoU.tolist()) _run.log_scalar('final_classIoU_std', classIoU_std.tolist()) _run.log_scalar('final_meanIoU', meanIoU.tolist()) _run.log_scalar('final_meanIoU_std', meanIoU_std.tolist()) _run.log_scalar('final_classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('final_classIoU_std_binary', classIoU_std_binary.tolist()) _run.log_scalar('final_meanIoU_binary', meanIoU_binary.tolist()) _run.log_scalar('final_meanIoU_std_binary', meanIoU_std_binary.tolist()) _log.info(f'classIoU mean: {classIoU}') _log.info(f'classIoU std: {classIoU_std}') _log.info(f'meanIoU mean: {meanIoU}') _log.info(f'meanIoU std: {meanIoU_std}') _log.info(f'classIoU_binary mean: {classIoU_binary}') _log.info(f'classIoU_binary std: {classIoU_std_binary}') _log.info(f'meanIoU_binary mean: {meanIoU_binary}') _log.info(f'meanIoU_binary std: {meanIoU_std_binary}')
def main(cfg, gpus): # Network Builders torch.cuda.set_device(gpus[0]) print('###### Create model ######') net_objectness = ModelBuilder.build_objectness( arch=cfg.MODEL.arch_objectness, weights=cfg.MODEL.weights_enc_query, fix_encoder=cfg.TRAIN.fix_encoder) net_decoder = ModelBuilder.build_decoder( arch=cfg.MODEL.arch_decoder.lower(), input_dim=cfg.MODEL.decoder_dim, fc_dim=cfg.MODEL.fc_dim, ppm_dim=cfg.MODEL.ppm_dim, num_class=2, weights=cfg.MODEL.weights_decoder, dropout_rate=cfg.MODEL.dropout_rate, use_dropout=cfg.MODEL.use_dropout) crit = nn.NLLLoss(ignore_index=255) print('###### Load data ######') data_name = cfg.DATASET.name if data_name == 'VOC': from dataloaders.customized_objectness import voc_fewshot make_data = voc_fewshot max_label = 20 elif data_name == 'COCO': from dataloaders.customized_objectness import coco_fewshot make_data = coco_fewshot max_label = 80 else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx] labels_val = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][ cfg.TASK.fold_idx] exclude_labels = labels_val value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] train_transform = [ transforms.ToNumpy(), transforms.RandScale([0.9, 1.1]), transforms.RandRotate([-10, 10], padding=mean, ignore_label=0), transforms.RandomGaussianBlur(), transforms.RandomHorizontalFlip(), transforms.Crop([cfg.DATASET.input_size[0], cfg.DATASET.input_size[1]], crop_type='rand', padding=mean, ignore_label=0) ] train_transform = Compose(train_transform) val_transform = Compose([ transforms.ToNumpy(), transforms.Resize_pad(size=cfg.DATASET.input_size[0]) ]) dataset = make_data(base_dir=cfg.DATASET.data_dir, split=cfg.DATASET.data_split, transforms=train_transform, to_tensor=transforms.ToTensorNormalize_noresize(), labels=labels, max_iters=cfg.TRAIN.n_iters * cfg.TRAIN.n_batch, n_ways=cfg.TASK.n_ways, n_shots=cfg.TASK.n_shots, n_queries=cfg.TASK.n_queries, permute=cfg.TRAIN.permute_labels, exclude_labels=exclude_labels) trainloader = DataLoader(dataset, batch_size=cfg.TRAIN.n_batch, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) #segmentation_module = nn.DataParallel(segmentation_module, device_ids=gpus) net_objectness.cuda() net_decoder.cuda() # Set up optimizers nets = (net_objectness, net_decoder, crit) optimizers = create_optimizers(nets, cfg) batch_time = AverageMeter() data_time = AverageMeter() ave_total_loss = AverageMeter() ave_acc = AverageMeter() history = {'train': {'iter': [], 'loss': [], 'acc': []}} net_objectness.train(not cfg.TRAIN.fix_bn) net_decoder.train(not cfg.TRAIN.fix_bn) best_iou = 0 # main loop tic = time.time() print('###### Training ######') for i_iter, sample_batched in enumerate(trainloader): # Prepare input feed_dict = data_preprocess(sample_batched, cfg) data_time.update(time.time() - tic) net_objectness.zero_grad() net_decoder.zero_grad() # adjust learning rate adjust_learning_rate(optimizers, i_iter, cfg) # forward pass feat = net_objectness(feed_dict['img_data'], return_feature_maps=True) pred = net_decoder(feat) loss = crit(pred, feed_dict['seg_label']) acc = pixel_acc(pred, feed_dict['seg_label']) loss = loss.mean() acc = acc.mean() # Backward loss.backward() for optimizer in optimizers: if optimizer: optimizer.step() # measure elapsed time batch_time.update(time.time() - tic) tic = time.time() # update average loss and acc ave_total_loss.update(loss.data.item()) ave_acc.update(acc.data.item() * 100) # calculate accuracy, and display if i_iter % cfg.TRAIN.disp_iter == 0: print('Iter: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, ' 'lr_encoder: {:.6f}, lr_decoder: {:.6f}, ' 'Accuracy: {:4.2f}, Loss: {:.6f}'.format( i_iter, i_iter, cfg.TRAIN.n_iters, batch_time.average(), data_time.average(), cfg.TRAIN.running_lr_encoder, cfg.TRAIN.running_lr_decoder, ave_acc.average(), ave_total_loss.average())) history['train']['iter'].append(i_iter) history['train']['loss'].append(loss.data.item()) history['train']['acc'].append(acc.data.item()) if (i_iter + 1) % cfg.TRAIN.save_freq == 0: checkpoint(nets, history, cfg, i_iter + 1) if (i_iter + 1) % cfg.TRAIN.eval_freq == 0: metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs) with torch.no_grad(): print('----Evaluation----') net_objectness.eval() net_decoder.eval() net_decoder.use_softmax = True for run in range(cfg.VAL.n_runs): print(f'### Run {run + 1} ###') set_seed(cfg.VAL.seed + run) print(f'### Load validation data ###') dataset_val = make_data( base_dir=cfg.DATASET.data_dir, split='val', transforms=val_transform, to_tensor=transforms.ToTensorNormalize_noresize(), labels=labels_val, max_iters=cfg.VAL.n_iters * cfg.VAL.n_batch, n_ways=cfg.TASK.n_ways, n_shots=cfg.TASK.n_shots, n_queries=cfg.TASK.n_queries, permute=cfg.VAL.permute_labels, exclude_labels=[]) if data_name == 'COCO': coco_cls_ids = dataset_val.datasets[ 0].dataset.coco.getCatIds() testloader = DataLoader(dataset_val, batch_size=cfg.VAL.n_batch, shuffle=False, num_workers=1, pin_memory=True, drop_last=False) print(f"Total # of validation Data: {len(dataset)}") #for sample_batched in tqdm.tqdm(testloader): for sample_batched in testloader: feed_dict = data_preprocess(sample_batched, cfg, is_val=True) if data_name == 'COCO': label_ids = [ coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids'] ] else: label_ids = list(sample_batched['class_ids']) feat = net_objectness(feed_dict['img_data'], return_feature_maps=True) query_pred = net_decoder( feat, segSize=cfg.DATASET.input_size) metric.record( np.array(query_pred.argmax(dim=1)[0].cpu()), np.array(feed_dict['seg_label'][0].cpu()), labels=label_ids, n_run=run) classIoU, meanIoU = metric.get_mIoU( labels=sorted(labels_val), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary( n_run=run) classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU( labels=sorted(labels_val)) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary( ) print('----- Evaluation Result -----') print(f'best meanIoU_binary: {best_iou}') print(f'meanIoU mean: {meanIoU}') print(f'meanIoU std: {meanIoU_std}') print(f'meanIoU_binary mean: {meanIoU_binary}') print(f'meanIoU_binary std: {meanIoU_std_binary}') if meanIoU_binary > best_iou: best_iou = meanIoU_binary checkpoint(nets, history, cfg, 'best') net_objectness.train(not cfg.TRAIN.fix_bn) net_decoder.train(not cfg.TRAIN.fix_bn) net_decoder.use_softmax = False print('Training Done!')
def main(cfg, gpus): torch.cuda.set_device(gpus[0]) # Network Builders net_objectness = ModelBuilder.build_objectness( arch=cfg.MODEL.arch_objectness, weights=cfg.MODEL.weights_enc_query, fix_encoder=cfg.TRAIN.fix_encoder) net_decoder = ModelBuilder.build_decoder( arch=cfg.MODEL.arch_decoder.lower(), input_dim=cfg.MODEL.decoder_dim, fc_dim=cfg.MODEL.fc_dim, ppm_dim=cfg.MODEL.ppm_dim, num_class=2, weights=cfg.MODEL.weights_decoder, dropout_rate=cfg.MODEL.dropout_rate, use_dropout=cfg.MODEL.use_dropout, use_softmax=True) crit = nn.NLLLoss(ignore_index=255) net_objectness.cuda() net_objectness.eval() net_decoder.cuda() net_decoder.eval() print('###### Prepare data ######') data_name = cfg.DATASET.name if data_name == 'VOC': if cfg.VAL.test_with_classes: from dataloaders.customized import voc_fewshot else: from dataloaders.customized_objectness import voc_fewshot make_data = voc_fewshot max_label = 20 elif data_name == 'COCO': if cfg.VAL.test_with_classes: from dataloaders.customized import coco_fewshot else: from dataloaders.customized_objectness import coco_fewshot make_data = coco_fewshot max_label = 80 split = cfg.DATASET.data_split + '2014' annFile = f'{cfg.DATASET.data_dir}/annotations/instances_{split}.json' cocoapi = COCO(annFile) else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][ cfg.TASK.fold_idx] #labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx] #transforms = [Resize_test(size=cfg.DATASET.input_size)] val_transforms = [ transforms.ToNumpy(), transforms.Resize_pad(size=cfg.DATASET.input_size[0]) ] value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] '''val_transforms = [ transforms.ToNumpy(), #transforms.RandScale([0.9, 1.1]), #transforms.RandRotate([-10, 10], padding=mean, ignore_label=0), #transforms.RandomGaussianBlur(), #transforms.RandomHorizontalFlip(), transforms.Crop([cfg.DATASET.input_size[0], cfg.DATASET.input_size[1]], crop_type='rand', padding=mean, ignore_label=0)]''' val_transforms = Compose(val_transforms) print('###### Testing begins ######') metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs) with torch.no_grad(): for run in range(cfg.VAL.n_runs): print(f'### Run {run + 1} ###') set_seed(cfg.VAL.seed + run) print(f'### Load data ###') dataset = make_data( base_dir=cfg.DATASET.data_dir, split=cfg.DATASET.data_split, transforms=val_transforms, to_tensor=transforms.ToTensorNormalize_noresize(), labels=labels, max_iters=cfg.VAL.n_iters * cfg.VAL.n_batch, n_ways=cfg.TASK.n_ways, n_shots=cfg.TASK.n_shots, n_queries=cfg.TASK.n_queries, permute=cfg.VAL.permute_labels, ) if data_name == 'COCO': coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds() testloader = DataLoader(dataset, batch_size=cfg.VAL.n_batch, shuffle=False, num_workers=1, pin_memory=True, drop_last=False) print(f"Total # of Data: {len(dataset)}") count = 0 for sample_batched in tqdm.tqdm(testloader): feed_dict = data_preprocess(sample_batched, cfg) if data_name == 'COCO': label_ids = [ coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids'] ] else: label_ids = list(sample_batched['class_ids']) feat = net_objectness(feed_dict['img_data'], return_feature_maps=True) query_pred = net_decoder(feat, segSize=(473, 473)) metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()), np.array(feed_dict['seg_label'][0].cpu()), labels=label_ids, n_run=run) if cfg.VAL.visualize: #print(as_numpy(feed_dict['seg_label'][0].cpu()).shape) #print(as_numpy(np.array(query_pred.argmax(dim=1)[0].cpu())).shape) #print(feed_dict['img_data'].cpu().shape) query_name = sample_batched['query_ids'][0][0] support_name = sample_batched['support_ids'][0][0][0] if data_name == 'VOC': img = imread( os.path.join(cfg.DATASET.data_dir, 'JPEGImages', query_name + '.jpg')) else: query_name = int(query_name) img_meta = cocoapi.loadImgs(query_name)[0] img = imread( os.path.join(cfg.DATASET.data_dir, split, img_meta['file_name'])) #img = imresize(img, cfg.DATASET.input_size) visualize_result( (img, as_numpy(feed_dict['seg_label'][0].cpu()), '%05d' % (count)), as_numpy(np.array(query_pred.argmax(dim=1)[0].cpu())), os.path.join(cfg.DIR, 'result')) count += 1 classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run) '''_run.log_scalar('classIoU', classIoU.tolist()) _run.log_scalar('meanIoU', meanIoU.tolist()) _run.log_scalar('classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist()) _log.info(f'classIoU: {classIoU}') _log.info(f'meanIoU: {meanIoU}') _log.info(f'classIoU_binary: {classIoU_binary}') _log.info(f'meanIoU_binary: {meanIoU_binary}')''' classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU( labels=sorted(labels)) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary( ) print('----- Final Result -----') print('final_classIoU', classIoU.tolist()) print('final_classIoU_std', classIoU_std.tolist()) print('final_meanIoU', meanIoU.tolist()) print('final_meanIoU_std', meanIoU_std.tolist()) print('final_classIoU_binary', classIoU_binary.tolist()) print('final_classIoU_std_binary', classIoU_std_binary.tolist()) print('final_meanIoU_binary', meanIoU_binary.tolist()) print('final_meanIoU_std_binary', meanIoU_std_binary.tolist()) print(f'classIoU mean: {classIoU}') print(f'classIoU std: {classIoU_std}') print(f'meanIoU mean: {meanIoU}') print(f'meanIoU std: {meanIoU_std}') print(f'classIoU_binary mean: {classIoU_binary}') print(f'classIoU_binary std: {classIoU_std_binary}') print(f'meanIoU_binary mean: {meanIoU_binary}') print(f'meanIoU_binary std: {meanIoU_std_binary}')
def main(cfg, gpus): torch.cuda.set_device(gpus[0]) # Network Builders net_enc_query = ModelBuilder.build_encoder( arch=cfg.MODEL.arch_encoder.lower(), fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_enc_query, fix_encoder=cfg.TRAIN.fix_encoder) net_enc_memory = ModelBuilder.build_encoder_memory_separate( arch=cfg.MODEL.arch_memory_encoder.lower(), fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_enc_memory, num_class=cfg.TASK.n_ways+1, RGB_mask_combine_val=cfg.DATASET.RGB_mask_combine_val, segm_downsampling_rate=cfg.DATASET.segm_downsampling_rate) net_att_query = ModelBuilder.build_attention( arch=cfg.MODEL.arch_attention, input_dim=cfg.MODEL.encoder_dim, fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_att_query) net_att_memory = ModelBuilder.build_attention( arch=cfg.MODEL.arch_attention, input_dim=cfg.MODEL.fc_dim, fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_att_memory) net_projection = ModelBuilder.build_projection( arch=cfg.MODEL.arch_projection, input_dim=cfg.MODEL.encoder_dim, fc_dim=cfg.MODEL.projection_dim, weights=cfg.MODEL.weights_projection) net_decoder = ModelBuilder.build_decoder( arch=cfg.MODEL.arch_decoder.lower(), input_dim=cfg.MODEL.decoder_dim, fc_dim=cfg.MODEL.decoder_fc_dim, ppm_dim=cfg.MODEL.ppm_dim, num_class=cfg.TASK.n_ways+1, weights=cfg.MODEL.weights_decoder, dropout_rate=cfg.MODEL.dropout_rate, use_dropout=cfg.MODEL.use_dropout, use_softmax=True) if cfg.MODEL.weights_objectness and cfg.MODEL.weights_objectness_decoder: '''net_objectness = ModelBuilder.build_objectness( arch='resnet50_deeplab', weights=cfg.MODEL.weights_objectness, fix_encoder=True) net_objectness_decoder = ModelBuilder.build_decoder( arch='aspp_few_shot', input_dim=2048, fc_dim=256, ppm_dim=256, num_class=2, weights=cfg.MODEL.weights_objectness_decoder, dropout_rate=0.5, use_dropout=True)''' net_objectness = ModelBuilder.build_objectness( arch='hrnetv2', weights=cfg.MODEL.weights_objectness, fix_encoder=True) net_objectness_decoder = ModelBuilder.build_decoder( arch='c1_nodropout', input_dim=720, fc_dim=720, ppm_dim=256, num_class=2, weights=cfg.MODEL.weights_objectness_decoder, use_dropout=False) for param in net_objectness.parameters(): param.requires_grad = False for param in net_objectness_decoder.parameters(): param.requires_grad = False else: net_objectness = None net_objectness_decoder = None crit = nn.NLLLoss(ignore_index=255) segmentation_module = SegmentationAttentionSeparateModule(net_enc_query, net_enc_memory, net_att_query, net_att_memory, net_decoder, net_projection, net_objectness, net_objectness_decoder, crit, zero_memory=cfg.MODEL.zero_memory, zero_qval=cfg.MODEL.zero_qval, normalize_key=cfg.MODEL.normalize_key, p_scalar=cfg.MODEL.p_scalar, memory_feature_aggregation=cfg.MODEL.memory_feature_aggregation, memory_noLabel=cfg.MODEL.memory_noLabel, debug=cfg.is_debug or cfg.eval_att_voting, mask_feat_downsample_rate=cfg.MODEL.mask_feat_downsample_rate, att_mat_downsample_rate=cfg.MODEL.att_mat_downsample_rate, objectness_feat_downsample_rate=cfg.MODEL.objectness_feat_downsample_rate, segm_downsampling_rate=cfg.DATASET.segm_downsampling_rate, mask_foreground=cfg.MODEL.mask_foreground, global_pool_read=cfg.MODEL.global_pool_read, average_memory_voting=cfg.MODEL.average_memory_voting, average_memory_voting_nonorm=cfg.MODEL.average_memory_voting_nonorm, mask_memory_RGB=cfg.MODEL.mask_memory_RGB, linear_classifier_support=cfg.MODEL.linear_classifier_support, decay_lamb=cfg.MODEL.decay_lamb, linear_classifier_support_only=cfg.MODEL.linear_classifier_support_only, qread_only=cfg.MODEL.qread_only, feature_as_key=cfg.MODEL.feature_as_key, objectness_multiply=cfg.MODEL.objectness_multiply) segmentation_module = nn.DataParallel(segmentation_module, device_ids=gpus) segmentation_module.cuda() segmentation_module.eval() print('###### Prepare data ######') data_name = cfg.DATASET.name if data_name == 'VOC': from dataloaders.customized import voc_fewshot make_data = voc_fewshot max_label = 20 elif data_name == 'COCO': from dataloaders.customized import coco_fewshot make_data = coco_fewshot max_label = 80 split = cfg.DATASET.data_split + '2014' annFile = f'{cfg.DATASET.data_dir}/annotations/instances_{split}.json' cocoapi = COCO(annFile) else: raise ValueError('Wrong config for dataset!') #labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][cfg.TASK.fold_idx] labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx] transforms = [Resize_test(size=cfg.DATASET.input_size)] transforms = Compose(transforms) print('###### Testing begins ######') metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs) with torch.no_grad(): for run in range(cfg.VAL.n_runs): print(f'### Run {run + 1} ###') set_seed(cfg.VAL.seed + run) print(f'### Load data ###') dataset = make_data( base_dir=cfg.DATASET.data_dir, split=cfg.DATASET.data_split, transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=cfg.VAL.n_iters * cfg.VAL.n_batch, n_ways=cfg.TASK.n_ways, n_shots=cfg.TASK.n_shots, n_queries=cfg.TASK.n_queries, permute=cfg.VAL.permute_labels, exclude_labels=[] ) if data_name == 'COCO': coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds() testloader = DataLoader(dataset, batch_size=cfg.VAL.n_batch, shuffle=False, num_workers=1, pin_memory=True, drop_last=False) print(f"Total # of Data: {len(dataset)}") count = 0 if cfg.multi_scale_test: scales = [224, 328, 424] else: scales = [328] for sample_batched in tqdm.tqdm(testloader): feed_dict = data_preprocess(sample_batched, cfg) if data_name == 'COCO': label_ids = [coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids']] else: label_ids = list(sample_batched['class_ids']) for q, scale in enumerate(scales): if len(scales) > 1: feed_dict['img_data'] = nn.functional.interpolate(feed_dict['img_data'].cuda(), size=(scale, scale), mode='bilinear') if cfg.eval_att_voting or cfg.is_debug: query_pred, qread, qval, qk_b, mk_b, mv_b, p, feature_enc, feature_memory = segmentation_module(feed_dict, segSize=(feed_dict['seg_label_noresize'].shape[1], feed_dict['seg_label_noresize'].shape[2])) if cfg.eval_att_voting: height, width = qread.shape[-2], qread.shape[-1] assert p.shape[0] == height*width img_refs_mask_resize = nn.functional.interpolate(feed_dict['img_refs_mask'][0].cuda(), size=(height, width), mode='nearest') img_refs_mask_resize_flat = img_refs_mask_resize[:,0,:,:].view(img_refs_mask_resize.shape[0], -1) mask_voting_flat = torch.mm(img_refs_mask_resize_flat, p) mask_voting = mask_voting_flat.view(mask_voting_flat.shape[0], height, width) mask_voting = torch.unsqueeze(mask_voting, 0) query_pred = nn.functional.interpolate(mask_voting[:,0:-1], size=cfg.DATASET.input_size, mode='bilinear', align_corners=False) if cfg.is_debug: np.save('debug/img_refs_mask-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), img_refs_mask_resize.detach().cpu().float().numpy()) np.save('debug/query_pred-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), query_pred.detach().cpu().float().numpy()) if cfg.is_debug: np.save('debug/qread-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), qread.detach().cpu().float().numpy()) np.save('debug/qval-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), qval.detach().cpu().float().numpy()) #np.save('debug/qk_b-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), qk_b.detach().cpu().float().numpy()) #np.save('debug/mk_b-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), mk_b.detach().cpu().float().numpy()) #np.save('debug/mv_b-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), mv_b.detach().cpu().float().numpy()) #np.save('debug/p-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), p.detach().cpu().float().numpy()) #np.save('debug/feature_enc-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), feature_enc[-1].detach().cpu().float().numpy()) #np.save('debug/feature_memory-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), feature_memory[-1].detach().cpu().float().numpy()) else: #query_pred = segmentation_module(feed_dict, segSize=cfg.DATASET.input_size) query_pred = segmentation_module(feed_dict, segSize=(feed_dict['seg_label_noresize'].shape[1], feed_dict['seg_label_noresize'].shape[2])) if q == 0: query_pred_final = query_pred/len(scales) else: query_pred_final += query_pred/len(scales) query_pred = query_pred_final metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()), np.array(feed_dict['seg_label_noresize'][0].cpu()), labels=label_ids, n_run=run) if cfg.VAL.visualize: #print(as_numpy(feed_dict['seg_label'][0].cpu()).shape) #print(as_numpy(np.array(query_pred.argmax(dim=1)[0].cpu())).shape) #print(feed_dict['img_data'].cpu().shape) query_name = sample_batched['query_ids'][0][0] support_name = sample_batched['support_ids'][0][0][0] if data_name == 'VOC': img = imread(os.path.join(cfg.DATASET.data_dir, 'JPEGImages', query_name+'.jpg')) else: query_name = int(query_name) img_meta = cocoapi.loadImgs(query_name)[0] img = imread(os.path.join(cfg.DATASET.data_dir, split, img_meta['file_name'])) #img = imresize(img, cfg.DATASET.input_size) visualize_result( (img, as_numpy(feed_dict['seg_label_noresize'][0].cpu()), '%05d'%(count)), as_numpy(np.array(query_pred.argmax(dim=1)[0].cpu())), os.path.join(cfg.DIR, 'result') ) count += 1 classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run) '''_run.log_scalar('classIoU', classIoU.tolist()) _run.log_scalar('meanIoU', meanIoU.tolist()) _run.log_scalar('classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist()) _log.info(f'classIoU: {classIoU}') _log.info(f'meanIoU: {meanIoU}') _log.info(f'classIoU_binary: {classIoU_binary}') _log.info(f'meanIoU_binary: {meanIoU_binary}')''' classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(labels=sorted(labels)) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary() print('----- Final Result -----') print('final_classIoU', classIoU.tolist()) print('final_classIoU_std', classIoU_std.tolist()) print('final_meanIoU', meanIoU.tolist()) print('final_meanIoU_std', meanIoU_std.tolist()) print('final_classIoU_binary', classIoU_binary.tolist()) print('final_classIoU_std_binary', classIoU_std_binary.tolist()) print('final_meanIoU_binary', meanIoU_binary.tolist()) print('final_meanIoU_std_binary', meanIoU_std_binary.tolist()) print(f'classIoU mean: {classIoU}') print(f'classIoU std: {classIoU_std}') print(f'meanIoU mean: {meanIoU}') print(f'meanIoU std: {meanIoU_std}') print(f'classIoU_binary mean: {classIoU_binary}') print(f'classIoU_binary std: {classIoU_std_binary}') print(f'meanIoU_binary mean: {meanIoU_binary}') print(f'meanIoU_binary std: {meanIoU_std_binary}')
def main(cfg, gpus): # Network Builders torch.cuda.set_device(gpus[0]) print('###### Create model ######') net_enc_query = ModelBuilder.build_encoder( arch=cfg.MODEL.arch_encoder.lower(), fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_enc_query, fix_encoder=cfg.TRAIN.fix_encoder) net_enc_memory = ModelBuilder.build_encoder_memory_separate( arch=cfg.MODEL.arch_memory_encoder.lower(), fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_enc_memory, num_class=cfg.TASK.n_ways + 1, RGB_mask_combine_val=cfg.DATASET.RGB_mask_combine_val, segm_downsampling_rate=cfg.DATASET.segm_downsampling_rate) net_att_query = ModelBuilder.build_attention( arch=cfg.MODEL.arch_attention, input_dim=cfg.MODEL.encoder_dim, fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_att_query) net_att_memory = ModelBuilder.build_attention( arch=cfg.MODEL.arch_attention, input_dim=cfg.MODEL.fc_dim, fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_att_memory) net_projection = ModelBuilder.build_projection( arch=cfg.MODEL.arch_projection, input_dim=cfg.MODEL.encoder_dim, fc_dim=cfg.MODEL.projection_dim, weights=cfg.MODEL.weights_projection) net_decoder = ModelBuilder.build_decoder( arch=cfg.MODEL.arch_decoder.lower(), input_dim=cfg.MODEL.decoder_dim, fc_dim=cfg.MODEL.decoder_fc_dim, ppm_dim=cfg.MODEL.ppm_dim, num_class=cfg.TASK.n_ways + 1, weights=cfg.MODEL.weights_decoder, dropout_rate=cfg.MODEL.dropout_rate, use_dropout=cfg.MODEL.use_dropout) if cfg.MODEL.weights_objectness and cfg.MODEL.weights_objectness_decoder: '''net_objectness = ModelBuilder.build_objectness( arch='resnet50_deeplab', weights=cfg.MODEL.weights_objectness, fix_encoder=True) net_objectness_decoder = ModelBuilder.build_decoder( arch='aspp_few_shot', input_dim=2048, fc_dim=256, ppm_dim=256, num_class=2, weights=cfg.MODEL.weights_objectness_decoder, dropout_rate=0.5, use_dropout=True)''' net_objectness = ModelBuilder.build_objectness( arch='hrnetv2', weights=cfg.MODEL.weights_objectness, fix_encoder=True) net_objectness_decoder = ModelBuilder.build_decoder( arch='c1_nodropout', input_dim=720, fc_dim=720, ppm_dim=256, num_class=2, weights=cfg.MODEL.weights_objectness_decoder, use_dropout=False) for param in net_objectness.parameters(): param.requires_grad = False for param in net_objectness_decoder.parameters(): param.requires_grad = False else: net_objectness = None net_objectness_decoder = None crit = nn.NLLLoss(ignore_index=255) segmentation_module = SegmentationAttentionSeparateModule( net_enc_query, net_enc_memory, net_att_query, net_att_memory, net_decoder, net_projection, net_objectness, net_objectness_decoder, crit, zero_memory=cfg.MODEL.zero_memory, random_memory_bias=cfg.MODEL.random_memory_bias, random_memory_nobias=cfg.MODEL.random_memory_nobias, random_scale=cfg.MODEL.random_scale, zero_qval=cfg.MODEL.zero_qval, normalize_key=cfg.MODEL.normalize_key, p_scalar=cfg.MODEL.p_scalar, memory_feature_aggregation=cfg.MODEL.memory_feature_aggregation, memory_noLabel=cfg.MODEL.memory_noLabel, mask_feat_downsample_rate=cfg.MODEL.mask_feat_downsample_rate, att_mat_downsample_rate=cfg.MODEL.att_mat_downsample_rate, objectness_feat_downsample_rate=cfg.MODEL. objectness_feat_downsample_rate, segm_downsampling_rate=cfg.DATASET.segm_downsampling_rate, mask_foreground=cfg.MODEL.mask_foreground, global_pool_read=cfg.MODEL.global_pool_read, average_memory_voting=cfg.MODEL.average_memory_voting, average_memory_voting_nonorm=cfg.MODEL.average_memory_voting_nonorm, mask_memory_RGB=cfg.MODEL.mask_memory_RGB, linear_classifier_support=cfg.MODEL.linear_classifier_support, decay_lamb=cfg.MODEL.decay_lamb, linear_classifier_support_only=cfg.MODEL. linear_classifier_support_only, qread_only=cfg.MODEL.qread_only, feature_as_key=cfg.MODEL.feature_as_key, objectness_multiply=cfg.MODEL.objectness_multiply) print('###### Load data ######') data_name = cfg.DATASET.name if data_name == 'VOC': from dataloaders.customized_objectness_debug import voc_fewshot make_data = voc_fewshot max_label = 20 elif data_name == 'COCO': from dataloaders.customized_objectness_debug import coco_fewshot make_data = coco_fewshot max_label = 80 else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx] labels_val = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][ cfg.TASK.fold_idx] if cfg.DATASET.exclude_labels: exclude_labels = labels_val else: exclude_labels = [] transforms = Compose([Resize(size=cfg.DATASET.input_size), RandomMirror()]) dataset = make_data(base_dir=cfg.DATASET.data_dir, split=cfg.DATASET.data_split, transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=cfg.TRAIN.n_iters * cfg.TRAIN.n_batch, n_ways=cfg.TASK.n_ways, n_shots=cfg.TASK.n_shots, n_queries=cfg.TASK.n_queries, permute=cfg.TRAIN.permute_labels, exclude_labels=exclude_labels, use_ignore=cfg.use_ignore) trainloader = DataLoader(dataset, batch_size=cfg.TRAIN.n_batch, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) #segmentation_module = nn.DataParallel(segmentation_module, device_ids=gpus) segmentation_module.cuda() # Set up optimizers nets = (net_enc_query, net_enc_memory, net_att_query, net_att_memory, net_decoder, net_projection, crit) optimizers = create_optimizers(nets, cfg) batch_time = AverageMeter() data_time = AverageMeter() ave_total_loss = AverageMeter() ave_acc = AverageMeter() history = {'train': {'iter': [], 'loss': [], 'acc': []}} segmentation_module.train(not cfg.TRAIN.fix_bn) if net_objectness and net_objectness_decoder: net_objectness.eval() net_objectness_decoder.eval() best_iou = 0 # main loop tic = time.time() print('###### Training ######') for i_iter, sample_batched in enumerate(trainloader): # Prepare input feed_dict = data_preprocess(sample_batched, cfg) data_time.update(time.time() - tic) segmentation_module.zero_grad() # adjust learning rate adjust_learning_rate(optimizers, i_iter, cfg) # forward pass #print(batch_data) loss, acc = segmentation_module(feed_dict) loss = loss.mean() acc = acc.mean() # Backward loss.backward() for optimizer in optimizers: if optimizer: optimizer.step() # measure elapsed time batch_time.update(time.time() - tic) tic = time.time() # update average loss and acc ave_total_loss.update(loss.data.item()) ave_acc.update(acc.data.item() * 100) # calculate accuracy, and display if i_iter % cfg.TRAIN.disp_iter == 0: print('Iter: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, ' 'lr_encoder: {:.6f}, lr_decoder: {:.6f}, ' 'Accuracy: {:4.2f}, Loss: {:.6f}'.format( i_iter, i_iter, cfg.TRAIN.n_iters, batch_time.average(), data_time.average(), cfg.TRAIN.running_lr_encoder, cfg.TRAIN.running_lr_decoder, ave_acc.average(), ave_total_loss.average())) history['train']['iter'].append(i_iter) history['train']['loss'].append(loss.data.item()) history['train']['acc'].append(acc.data.item()) if (i_iter + 1) % cfg.TRAIN.save_freq == 0: checkpoint(nets, history, cfg, i_iter + 1) if (i_iter + 1) % cfg.TRAIN.eval_freq == 0: metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs) with torch.no_grad(): print('----Evaluation----') segmentation_module.eval() net_decoder.use_softmax = True for run in range(cfg.VAL.n_runs): print(f'### Run {run + 1} ###') set_seed(cfg.VAL.seed + run) print(f'### Load validation data ###') dataset_val = make_data(base_dir=cfg.DATASET.data_dir, split=cfg.DATASET.data_split, transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels_val, max_iters=cfg.VAL.n_iters * cfg.VAL.n_batch, n_ways=cfg.TASK.n_ways, n_shots=cfg.TASK.n_shots, n_queries=cfg.TASK.n_queries, permute=cfg.VAL.permute_labels, exclude_labels=[]) if data_name == 'COCO': coco_cls_ids = dataset_val.datasets[ 0].dataset.coco.getCatIds() testloader = DataLoader(dataset_val, batch_size=cfg.VAL.n_batch, shuffle=False, num_workers=1, pin_memory=True, drop_last=False) print(f"Total # of validation Data: {len(dataset)}") #for sample_batched in tqdm.tqdm(testloader): for sample_batched in testloader: feed_dict = data_preprocess(sample_batched, cfg, is_val=True) if data_name == 'COCO': label_ids = [ coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids'] ] else: label_ids = list(sample_batched['class_ids']) query_pred = segmentation_module( feed_dict, segSize=cfg.DATASET.input_size) metric.record( np.array(query_pred.argmax(dim=1)[0].cpu()), np.array(feed_dict['seg_label'][0].cpu()), labels=label_ids, n_run=run) classIoU, meanIoU = metric.get_mIoU( labels=sorted(labels_val), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary( n_run=run) classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU( labels=sorted(labels_val)) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary( ) print('----- Evaluation Result -----') print(f'best meanIoU mean: {best_iou}') print(f'meanIoU mean: {meanIoU}') print(f'meanIoU std: {meanIoU_std}') print(f'meanIoU_binary mean: {meanIoU_binary}') print(f'meanIoU_binary std: {meanIoU_std_binary}') checkpoint(nets, history, cfg, 'latest') if meanIoU > best_iou: best_iou = meanIoU checkpoint(nets, history, cfg, 'best') segmentation_module.train(not cfg.TRAIN.fix_bn) if net_objectness and net_objectness_decoder: net_objectness.eval() net_objectness_decoder.eval() net_decoder.use_softmax = False print('Training Done!')
def main(_run, _config, _log): os.makedirs(f'{_run.observers[0].dir}/features', exist_ok=True) for source_file, _ in _run.experiment_info['sources']: os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) _log.info('###### Create model ######') model = FewShotSeg(pretrained_path=_config['path']['init_path'], cfg=_config['model'], task=_config['task']) model = nn.DataParallel(model.cuda(), device_ids=[_config['gpu_id'],]) if not _config['notrain']: model.load_state_dict(torch.load(_config['snapshot'], map_location='cpu')) model.eval() _log.info('###### Prepare data ######') data_name = _config['dataset'] if data_name == 'VOC': make_data = voc_fewshot max_label = 20 elif data_name == 'COCO': make_data = coco_fewshot max_label = 80 else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][_config['label_sets']] transforms = [Resize(size=_config['input_size'])] transforms = Compose(transforms) _log.info('###### Testing begins ######') metric = Metric(max_label=max_label, n_runs=_config['n_runs']) with torch.no_grad(): for run in range(_config['n_runs']): _log.info(f'### Run {run + 1} ###') set_seed(_config['seed'] + run) features_dfs = [] _log.info(f'### Load data ###') dataset = make_data( base_dir=_config['path'][data_name]['data_dir'], split=_config['path'][data_name]['data_split'], transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=_config['n_steps'] * _config['batch_size'], n_ways=_config['task']['n_ways'], n_shots=_config['task']['n_shots'], n_queries=_config['task']['n_queries'] ) if _config['dataset'] == 'COCO': coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds() testloader = DataLoader(dataset, batch_size=_config['batch_size'], shuffle=False, num_workers=1, pin_memory=True, drop_last=False) _log.info(f"Total # of Data: {len(dataset)}") for sample_batched in tqdm.tqdm(testloader): if _config['dataset'] == 'COCO': label_ids = [coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids']] else: label_ids = list(sample_batched['class_ids']) support_ids = [[sample_batched['support_ids'][way*_config['task']['n_shots'] + shot][0] for shot in range(_config['task']['n_shots'])] for way in range(_config['task']['n_ways'])] support_images = [[shot.cuda() for shot in way] for way in sample_batched['support_images']] support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way] for way in sample_batched['support_mask']] support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way] for way in sample_batched['support_mask']] query_images = [query_image.cuda() for query_image in sample_batched['query_images']] query_labels = torch.cat( [query_label.cuda()for query_label in sample_batched['query_labels']], dim=0) query_pred, _, supp_fts = model(support_images, support_fg_mask, support_bg_mask, query_images) metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()), np.array(query_labels[0].cpu()), labels=label_ids, n_run=run) # Save features row for i, label_id in enumerate(label_ids): lbl_df = pd.DataFrame(torch.cat(supp_fts[i]).cpu().numpy()) lbl_df['label'] = label_id.item() lbl_df['id'] = pd.Series(support_ids[i]) features_dfs.append(lbl_df) classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run) _run.log_scalar('classIoU', classIoU.tolist()) _run.log_scalar('meanIoU', meanIoU.tolist()) _run.log_scalar('classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist()) _log.info(f'classIoU: {classIoU}') _log.info(f'meanIoU: {meanIoU}') _log.info(f'classIoU_binary: {classIoU_binary}') _log.info(f'meanIoU_binary: {meanIoU_binary}') _log.info('### Exporting features CSV') features_df = pd.concat(features_dfs) features_df = features_df.drop_duplicates(subset=['id']) cols = list(features_df) cols = [cols[-1], cols[-2]] + cols[:-2] features_df = features_df[cols] features_df.to_csv(f'{_run.observers[0].dir}/features/features_run_{run+1}.csv', index=False) classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(labels=sorted(labels)) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary() _log.info('###### Saving features visualization ######') all_fts = pd.concat([pd.read_csv(f'{_run.observers[0].dir}/features/features_run_{run+1}.csv') for run in range(_config['n_runs'])]) all_fts = all_fts.drop_duplicates(subset=['id']) _log.info('### Obtaining Umap visualization ###') plot_umap(all_fts, f'{_run.observers[0].dir}/features/Umap_fts.png') _log.info('### Obtaining TSNE visualization ###') plot_tsne(all_fts, f'{_run.observers[0].dir}/features/TSNE_fts.png') _log.info('----- Final Result -----') _run.log_scalar('final_classIoU', classIoU.tolist()) _run.log_scalar('final_classIoU_std', classIoU_std.tolist()) _run.log_scalar('final_meanIoU', meanIoU.tolist()) _run.log_scalar('final_meanIoU_std', meanIoU_std.tolist()) _run.log_scalar('final_classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('final_classIoU_std_binary', classIoU_std_binary.tolist()) _run.log_scalar('final_meanIoU_binary', meanIoU_binary.tolist()) _run.log_scalar('final_meanIoU_std_binary', meanIoU_std_binary.tolist()) _log.info(f'classIoU mean: {classIoU}') _log.info(f'classIoU std: {classIoU_std}') _log.info(f'meanIoU mean: {meanIoU}') _log.info(f'meanIoU std: {meanIoU_std}') _log.info(f'classIoU_binary mean: {classIoU_binary}') _log.info(f'classIoU_binary std: {classIoU_std_binary}') _log.info(f'meanIoU_binary mean: {meanIoU_binary}') _log.info(f'meanIoU_binary std: {meanIoU_std_binary}')
def main(_run, _config, _log): if _run.observers: os.makedirs(f'{_run.observers[0].dir}/interm_preds', exist_ok=True) for source_file, _ in _run.experiment_info['sources']: os.makedirs(os.path.dirname( f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) _log.info(f'###### Reload model {_config["reload_model_path"]} ######') model = FewShotSeg(pretrained_path=_config['reload_model_path'], cfg=_config['model']) model = model.cuda() model.eval() _log.info('###### Load data ######') ### Training set data_name = _config['dataset'] if data_name == 'SABS_Superpix': baseset_name = 'SABS' max_label = 13 elif data_name == 'C0_Superpix': raise NotImplementedError baseset_name = 'C0' max_label = 3 elif data_name == 'CHAOST2_Superpix': baseset_name = 'CHAOST2' max_label = 4 else: raise ValueError(f'Dataset: {data_name} not found') test_labels = DATASET_INFO[baseset_name]['LABEL_GROUP'][ 'pa_all'] - DATASET_INFO[baseset_name]['LABEL_GROUP'][ _config["label_sets"]] ### Transforms for data augmentation te_transforms = None assert _config[ 'scan_per_load'] < 0 # by default we load the entire dataset directly _log.info( f'###### Labels excluded in training : {[lb for lb in _config["exclude_cls_list"]]} ######' ) _log.info( f'###### Unseen labels evaluated in testing: {[lb for lb in test_labels]} ######' ) if baseset_name == 'SABS': # for CT we need to know statistics of tr_parent = SuperpixelDataset( # base dataset which_dataset=baseset_name, base_dir=_config['path'][data_name]['data_dir'], idx_split=_config['eval_fold'], mode='train', min_fg=str( _config["min_fg_data"]), # dummy entry for superpixel dataset transforms=None, nsup=_config['task']['n_shots'], scan_per_load=_config['scan_per_load'], exclude_list=_config["exclude_cls_list"], superpix_scale=_config["superpix_scale"], fix_length=_config["max_iters_per_load"] if (data_name == 'C0_Superpix') or (data_name == 'CHAOST2_Superpix') else None) norm_func = tr_parent.norm_func else: norm_func = get_normalize_op(modality='MR', fids=None) te_dataset, te_parent = med_fewshot_val( dataset_name=baseset_name, base_dir=_config['path'][baseset_name]['data_dir'], idx_split=_config['eval_fold'], scan_per_load=_config['scan_per_load'], act_labels=test_labels, npart=_config['task']['npart'], nsup=_config['task']['n_shots'], extern_normalize_func=norm_func) ### dataloaders testloader = DataLoader(te_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=False, drop_last=False) _log.info('###### Set validation nodes ######') mar_val_metric_node = Metric( max_label=max_label, n_scans=len(te_dataset.dataset.pid_curr_load) - _config['task']['n_shots']) _log.info('###### Starting validation ######') model.eval() mar_val_metric_node.reset() with torch.no_grad(): save_pred_buffer = {} # indexed by class for curr_lb in test_labels: te_dataset.set_curr_cls(curr_lb) support_batched = te_parent.get_support( curr_class=curr_lb, class_idx=[curr_lb], scan_idx=_config["support_idx"], npart=_config['task']['npart']) # way(1 for now) x part x shot x 3 x H x W] # support_images = [[shot.cuda() for shot in way] for way in support_batched['support_images'] ] # way x part x [shot x C x H x W] suffix = 'mask' support_fg_mask = [[ shot[f'fg_{suffix}'].float().cuda() for shot in way ] for way in support_batched['support_mask']] support_bg_mask = [[ shot[f'bg_{suffix}'].float().cuda() for shot in way ] for way in support_batched['support_mask']] curr_scan_count = -1 # counting for current scan _lb_buffer = {} # indexed by scan last_qpart = 0 # used as indicator for adding result to buffer for sample_batched in testloader: _scan_id = sample_batched["scan_id"][ 0] # we assume batch size for query is 1 if _scan_id in te_parent.potential_support_sid: # skip the support scan, don't include that to query continue if sample_batched["is_start"]: ii = 0 curr_scan_count += 1 _scan_id = sample_batched["scan_id"][0] outsize = te_dataset.dataset.info_by_scan[_scan_id][ "array_size"] outsize = ( 256, 256, outsize[0] ) # original image read by itk: Z, H, W, in prediction we use H, W, Z _pred = np.zeros(outsize) _pred.fill(np.nan) q_part = sample_batched[ "part_assign"] # the chunck of query, for assignment with support query_images = [sample_batched['image'].cuda()] query_labels = torch.cat([sample_batched['label'].cuda()], dim=0) # [way, [part, [shot x C x H x W]]] -> sup_img_part = [[ shot_tensor.unsqueeze(0) for shot_tensor in support_images[0][q_part] ]] # way(1) x shot x [B(1) x C x H x W] sup_fgm_part = [[ shot_tensor.unsqueeze(0) for shot_tensor in support_fg_mask[0][q_part] ]] sup_bgm_part = [[ shot_tensor.unsqueeze(0) for shot_tensor in support_bg_mask[0][q_part] ]] query_pred, _, _, assign_mats = model( sup_img_part, sup_fgm_part, sup_bgm_part, query_images, isval=True, val_wsize=_config["val_wsize"]) query_pred = np.array(query_pred.argmax(dim=1)[0].cpu()) _pred[..., ii] = query_pred.copy() if (sample_batched["z_id"] - sample_batched["z_max"] <= _config['z_margin']) and ( sample_batched["z_id"] - sample_batched["z_min"] >= -1 * _config['z_margin']): mar_val_metric_node.record(query_pred, np.array(query_labels[0].cpu()), labels=[curr_lb], n_scan=curr_scan_count) else: pass ii += 1 # now check data format if sample_batched["is_end"]: if _config['dataset'] != 'C0': _lb_buffer[_scan_id] = _pred.transpose( 2, 0, 1) # H, W, Z -> to Z H W else: lb_buffer[_scan_id] = _pred save_pred_buffer[str(curr_lb)] = _lb_buffer ### save results for curr_lb, _preds in save_pred_buffer.items(): for _scan_id, _pred in _preds.items(): _pred *= float(curr_lb) itk_pred = convert_to_sitk( _pred, te_dataset.dataset.info_by_scan[_scan_id]) fid = os.path.join(f'{_run.observers[0].dir}/interm_preds', f'scan_{_scan_id}_label_{curr_lb}.nii.gz') sitk.WriteImage(itk_pred, fid, True) _log.info(f'###### {fid} has been saved ######') del save_pred_buffer del sample_batched, support_images, support_bg_mask, query_images, query_labels, query_pred # compute dice scores by scan m_classDice, _, m_meanDice, _, m_rawDice = mar_val_metric_node.get_mDice( labels=sorted(test_labels), n_scan=None, give_raw=True) m_classPrec, _, m_meanPrec, _, m_classRec, _, m_meanRec, _, m_rawPrec, m_rawRec = mar_val_metric_node.get_mPrecRecall( labels=sorted(test_labels), n_scan=None, give_raw=True) mar_val_metric_node.reset() # reset this calculation node # write validation result to log file _run.log_scalar('mar_val_batches_classDice', m_classDice.tolist()) _run.log_scalar('mar_val_batches_meanDice', m_meanDice.tolist()) _run.log_scalar('mar_val_batches_rawDice', m_rawDice.tolist()) _run.log_scalar('mar_val_batches_classPrec', m_classPrec.tolist()) _run.log_scalar('mar_val_batches_meanPrec', m_meanPrec.tolist()) _run.log_scalar('mar_val_batches_rawPrec', m_rawPrec.tolist()) _run.log_scalar('mar_val_batches_classRec', m_classRec.tolist()) _run.log_scalar('mar_val_al_batches_meanRec', m_meanRec.tolist()) _run.log_scalar('mar_val_al_batches_rawRec', m_rawRec.tolist()) _log.info(f'mar_val batches classDice: {m_classDice}') _log.info(f'mar_val batches meanDice: {m_meanDice}') _log.info(f'mar_val batches classPrec: {m_classPrec}') _log.info(f'mar_val batches meanPrec: {m_meanPrec}') _log.info(f'mar_val batches classRec: {m_classRec}') _log.info(f'mar_val batches meanRec: {m_meanRec}') print("============ ============") _log.info(f'End of validation') return 1
def main(_run, _config, _log): for source_file, _ in _run.experiment_info['sources']: os.makedirs( os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) _log.info('###### Create model ######') model = FewShotSeg(pretrained_path=_config['path']['init_path'], cfg=_config['model']) model = nn.DataParallel(model.cuda(), device_ids=[ _config['gpu_id'], ]) if not _config['notrain']: model.load_state_dict( torch.load(_config['snapshot'], map_location='cpu')) print("Snapshotttt") print(_config['snapshot']) model.eval() # u2_model_dir = '/content/gdrive/My Drive/Research/U-2-Net/saved_models/'+ 'u2net' + '/' + 'u2net_bce_itr_3168_train_1.523160_tar_0.203136.pth' # u2_net = U2NET(3,1) # u2_net.load_state_dict(torch.load(u2_model_dir)) # if torch.cuda.is_available(): # u2_net.cuda() _log.info('###### Prepare data ######') data_name = _config['dataset'] if data_name == 'VOC': make_data = voc_fewshot max_label = 20 elif data_name == 'COCO': make_data = coco_fewshot max_label = 80 else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][ _config['label_sets']] transforms = [Resize(size=_config['input_size'])] if _config['scribble_dilation'] > 0: transforms.append(DilateScribble(size=_config['scribble_dilation'])) transforms = Compose(transforms) _log.info('###### Testing begins ######') metric = Metric(max_label=max_label, n_runs=_config['n_runs']) with torch.no_grad(): for run in range(_config['n_runs']): _log.info(f'### Run {run + 1} ###') set_seed(_config['seed'] + run) _log.info(f'### Load data ###') dataset = make_data( base_dir=_config['path'][data_name]['data_dir'], split=_config['path'][data_name]['data_split'], transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=_config['n_steps'] * _config['batch_size'], n_ways=_config['task']['n_ways'], n_shots=_config['task']['n_shots'], n_queries=_config['task']['n_queries']) if _config['dataset'] == 'COCO': coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds() testloader = DataLoader(dataset, batch_size=_config['batch_size'], shuffle=False, num_workers=1, pin_memory=True, drop_last=False) _log.info(f"Total # of Data: {len(dataset)}") for sample_batched in tqdm.tqdm(testloader): if _config['dataset'] == 'COCO': label_ids = [ coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids'] ] else: label_ids = list(sample_batched['class_ids']) support_images = [[shot.cuda() for shot in way] for way in sample_batched['support_images']] suffix = 'scribble' if _config['scribble'] else 'mask' if _config['bbox']: support_fg_mask = [] support_bg_mask = [] for i, way in enumerate(sample_batched['support_mask']): fg_masks = [] bg_masks = [] for j, shot in enumerate(way): fg_mask, bg_mask = get_bbox( shot['fg_mask'], sample_batched['support_inst'][i][j]) fg_masks.append(fg_mask.float().cuda()) bg_masks.append(bg_mask.float().cuda()) support_fg_mask.append(fg_masks) support_bg_mask.append(bg_masks) else: support_fg_mask = [[ shot[f'fg_{suffix}'].float().cuda() for shot in way ] for way in sample_batched['support_mask']] support_bg_mask = [[ shot[f'bg_{suffix}'].float().cuda() for shot in way ] for way in sample_batched['support_mask']] query_images = [ query_image.cuda() for query_image in sample_batched['query_images'] ] query_labels = torch.cat([ query_label.cuda() for query_label in sample_batched['query_labels'] ], dim=0) # u2net inputs = query_images[0].type(torch.FloatTensor) labels_v = query_labels.type(torch.FloatTensor) if torch.cuda.is_available(): inputs_v, labels_v = Variable( inputs.cuda(), requires_grad=False), Variable(labels_v.cuda(), requires_grad=False) else: inputs_v, labels_v = Variable( inputs, requires_grad=False), Variable(labels_v, requires_grad=False) #d1,d2,d3,d4,d5,d6,d7= u2_net(inputs_v) # normalization # pred = d1[:,0,:,:] # pred = normPRED(pred) pred = [] query_pred, _, _ = model(support_images, support_fg_mask, support_bg_mask, query_images, pred) metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()), np.array(query_labels[0].cpu()), labels=label_ids, n_run=run) classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run) _run.log_scalar('classIoU', classIoU.tolist()) _run.log_scalar('meanIoU', meanIoU.tolist()) _run.log_scalar('classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist()) _log.info(f'classIoU: {classIoU}') _log.info(f'meanIoU: {meanIoU}') _log.info(f'classIoU_binary: {classIoU_binary}') _log.info(f'meanIoU_binary: {meanIoU_binary}') classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU( labels=sorted(labels)) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary( ) _log.info('----- Final Result -----') _run.log_scalar('final_classIoU', classIoU.tolist()) _run.log_scalar('final_classIoU_std', classIoU_std.tolist()) _run.log_scalar('final_meanIoU', meanIoU.tolist()) _run.log_scalar('final_meanIoU_std', meanIoU_std.tolist()) _run.log_scalar('final_classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('final_classIoU_std_binary', classIoU_std_binary.tolist()) _run.log_scalar('final_meanIoU_binary', meanIoU_binary.tolist()) _run.log_scalar('final_meanIoU_std_binary', meanIoU_std_binary.tolist()) _log.info(f'classIoU mean: {classIoU}') _log.info(f'classIoU std: {classIoU_std}') _log.info(f'meanIoU mean: {meanIoU}') _log.info(f'meanIoU std: {meanIoU_std}') _log.info(f'classIoU_binary mean: {classIoU_binary}') _log.info(f'classIoU_binary std: {classIoU_std_binary}') _log.info(f'meanIoU_binary mean: {meanIoU_binary}') _log.info(f'meanIoU_binary std: {meanIoU_std_binary}')
def main(cfg, gpus): # Network Builders torch.cuda.set_device(gpus[0]) print('###### Create model ######') net_objectness = ModelBuilder.build_objectness( arch=cfg.MODEL.arch_objectness, weights=cfg.MODEL.weights_enc_query, fix_encoder=cfg.TRAIN.fix_encoder) net_decoder = ModelBuilder.build_decoder( arch=cfg.MODEL.arch_decoder.lower(), input_dim=cfg.MODEL.decoder_dim, fc_dim=cfg.MODEL.fc_dim, ppm_dim=cfg.MODEL.ppm_dim, num_class=2, weights=cfg.MODEL.weights_decoder, dropout_rate=cfg.MODEL.dropout_rate, use_dropout=cfg.MODEL.use_dropout) crit = nn.NLLLoss(ignore_index=255) print('###### Load data ######') data_name = cfg.DATASET.name if data_name == 'VOC': max_label = 20 elif data_name == 'COCO': max_label = 80 else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx] labels_val = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][ cfg.TASK.fold_idx] exclude_labels = labels_val value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] train_transform = [ transform.RandScale([0.9, 1.1]), transform.RandRotate([-10, 10], padding=mean, ignore_label=255), transform.RandomGaussianBlur(), transform.RandomHorizontalFlip(), transform.Crop([cfg.DATASET.input_size[0], cfg.DATASET.input_size[1]], crop_type='rand', padding=mean, ignore_label=255), transform.ToTensor(), transform.Normalize(mean=mean, std=std) ] train_transform = transform.Compose(train_transform) train_data = dataset.SemData(split=cfg.TASK.fold_idx, shot=cfg.TASK.n_shots, data_root=cfg.DATASET.data_dir, data_list=cfg.DATASET.train_list, transform=train_transform, mode='train', \ use_coco=False, use_split_coco=False) train_sampler = None train_loader = torch.utils.data.DataLoader(train_data, batch_size=cfg.TRAIN.n_batch, shuffle=(train_sampler is None), num_workers=cfg.TRAIN.workers, pin_memory=True, sampler=train_sampler, drop_last=True) val_transform = transform.Compose([ transform.Resize(size=cfg.DATASET.input_size[0]), transform.ToTensor(), transform.Normalize(mean=mean, std=std) ]) val_data = dataset.SemData(split=cfg.TASK.fold_idx, shot=cfg.TASK.n_shots, data_root=cfg.DATASET.data_dir, data_list=cfg.DATASET.val_list, transform=val_transform, mode='val', use_coco=False, use_split_coco=False) val_sampler = None val_loader = torch.utils.data.DataLoader(val_data, batch_size=cfg.VAL.n_batch, shuffle=False, num_workers=cfg.TRAIN.workers, pin_memory=True, sampler=val_sampler) #segmentation_module = nn.DataParallel(segmentation_module, device_ids=gpus) net_objectness.cuda() net_decoder.cuda() # Set up optimizers nets = (net_objectness, net_decoder, crit) optimizers = create_optimizers(nets, cfg) batch_time = AverageMeter() data_time = AverageMeter() ave_total_loss = AverageMeter() ave_acc = AverageMeter() history = {'train': {'iter': [], 'loss': [], 'acc': []}} net_objectness.train(not cfg.TRAIN.fix_bn) net_decoder.train(not cfg.TRAIN.fix_bn) best_iou = 0 # main loop tic = time.time() i_iter = -1 print('###### Training ######') for epoch in range(0, 200): for _, (input, target) in enumerate(train_loader): # Prepare input i_iter += 1 input = input.cuda(non_blocking=True) target = target.cuda(non_blocking=True) data_time.update(time.time() - tic) net_objectness.zero_grad() net_decoder.zero_grad() # adjust learning rate adjust_learning_rate(optimizers, i_iter, cfg) # forward pass feat = net_objectness(input, return_feature_maps=True) pred = net_decoder(feat, segSize=cfg.DATASET.input_size) loss = crit(pred, target) acc = pixel_acc(pred, target) loss = loss.mean() acc = acc.mean() # Backward loss.backward() for optimizer in optimizers: if optimizer: optimizer.step() # measure elapsed time batch_time.update(time.time() - tic) tic = time.time() # update average loss and acc ave_total_loss.update(loss.data.item()) ave_acc.update(acc.data.item() * 100) # calculate accuracy, and display if i_iter % cfg.TRAIN.disp_iter == 0: print( 'Iter: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, ' 'lr_encoder: {:.6f}, lr_decoder: {:.6f}, ' 'Ave_Accuracy: {:4.2f}, Accuracy:{:4.2f}, Ave_Loss: {:.6f}, Loss: {:.6f}' .format(i_iter, i_iter, cfg.TRAIN.n_iters, batch_time.average(), data_time.average(), cfg.TRAIN.running_lr_encoder, cfg.TRAIN.running_lr_decoder, ave_acc.average(), acc.data.item() * 100, ave_total_loss.average(), loss.data.item())) history['train']['iter'].append(i_iter) history['train']['loss'].append(loss.data.item()) history['train']['acc'].append(acc.data.item()) if (i_iter + 1) % cfg.TRAIN.save_freq == 0: checkpoint(nets, history, cfg, i_iter + 1) if (i_iter + 1) % cfg.TRAIN.eval_freq == 0: metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs) with torch.no_grad(): print('----Evaluation----') net_objectness.eval() net_decoder.eval() net_decoder.use_softmax = True #for run in range(cfg.VAL.n_runs): for run in range(3): print(f'### Run {run + 1} ###') set_seed(cfg.VAL.seed + run) print(f'### Load validation data ###') #for sample_batched in tqdm.tqdm(testloader): for (input, target, _) in val_loader: input = input.cuda(non_blocking=True) target = target.cuda(non_blocking=True) feat = net_objectness(input, return_feature_maps=True) query_pred = net_decoder( feat, segSize=cfg.DATASET.input_size) metric.record(np.array( query_pred.argmax(dim=1)[0].cpu()), np.array(target[0].cpu()), labels=None, n_run=run) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary( ) print('----- Evaluation Result -----') print(f'best meanIoU_binary: {best_iou}') print(f'meanIoU_binary mean: {meanIoU_binary}') print(f'meanIoU_binary std: {meanIoU_std_binary}') if meanIoU_binary > best_iou: best_iou = meanIoU_binary checkpoint(nets, history, cfg, 'best') net_objectness.train(not cfg.TRAIN.fix_bn) net_decoder.train(not cfg.TRAIN.fix_bn) net_decoder.use_softmax = False print('Training Done!')