def main(cfg, gpus): # Network Builders torch.cuda.set_device(gpus[0]) print('###### Create model ######') net_objectness = ModelBuilder.build_objectness( arch=cfg.MODEL.arch_objectness, weights=cfg.MODEL.weights_enc_query, fix_encoder=cfg.TRAIN.fix_encoder) net_decoder = ModelBuilder.build_decoder( arch=cfg.MODEL.arch_decoder.lower(), input_dim=cfg.MODEL.decoder_dim, fc_dim=cfg.MODEL.fc_dim, ppm_dim=cfg.MODEL.ppm_dim, num_class=2, weights=cfg.MODEL.weights_decoder, dropout_rate=cfg.MODEL.dropout_rate, use_dropout=cfg.MODEL.use_dropout) crit = nn.NLLLoss(ignore_index=255) print('###### Load data ######') data_name = cfg.DATASET.name if data_name == 'VOC': from dataloaders.customized_objectness import voc_fewshot make_data = voc_fewshot max_label = 20 elif data_name == 'COCO': from dataloaders.customized_objectness import coco_fewshot make_data = coco_fewshot max_label = 80 else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx] labels_val = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][ cfg.TASK.fold_idx] exclude_labels = labels_val value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] train_transform = [ transforms.ToNumpy(), transforms.RandScale([0.9, 1.1]), transforms.RandRotate([-10, 10], padding=mean, ignore_label=0), transforms.RandomGaussianBlur(), transforms.RandomHorizontalFlip(), transforms.Crop([cfg.DATASET.input_size[0], cfg.DATASET.input_size[1]], crop_type='rand', padding=mean, ignore_label=0) ] train_transform = Compose(train_transform) val_transform = Compose([ transforms.ToNumpy(), transforms.Resize_pad(size=cfg.DATASET.input_size[0]) ]) dataset = make_data(base_dir=cfg.DATASET.data_dir, split=cfg.DATASET.data_split, transforms=train_transform, to_tensor=transforms.ToTensorNormalize_noresize(), labels=labels, max_iters=cfg.TRAIN.n_iters * cfg.TRAIN.n_batch, n_ways=cfg.TASK.n_ways, n_shots=cfg.TASK.n_shots, n_queries=cfg.TASK.n_queries, permute=cfg.TRAIN.permute_labels, exclude_labels=exclude_labels) trainloader = DataLoader(dataset, batch_size=cfg.TRAIN.n_batch, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) #segmentation_module = nn.DataParallel(segmentation_module, device_ids=gpus) net_objectness.cuda() net_decoder.cuda() # Set up optimizers nets = (net_objectness, net_decoder, crit) optimizers = create_optimizers(nets, cfg) batch_time = AverageMeter() data_time = AverageMeter() ave_total_loss = AverageMeter() ave_acc = AverageMeter() history = {'train': {'iter': [], 'loss': [], 'acc': []}} net_objectness.train(not cfg.TRAIN.fix_bn) net_decoder.train(not cfg.TRAIN.fix_bn) best_iou = 0 # main loop tic = time.time() print('###### Training ######') for i_iter, sample_batched in enumerate(trainloader): # Prepare input feed_dict = data_preprocess(sample_batched, cfg) data_time.update(time.time() - tic) net_objectness.zero_grad() net_decoder.zero_grad() # adjust learning rate adjust_learning_rate(optimizers, i_iter, cfg) # forward pass feat = net_objectness(feed_dict['img_data'], return_feature_maps=True) pred = net_decoder(feat) loss = crit(pred, feed_dict['seg_label']) acc = pixel_acc(pred, feed_dict['seg_label']) loss = loss.mean() acc = acc.mean() # Backward loss.backward() for optimizer in optimizers: if optimizer: optimizer.step() # measure elapsed time batch_time.update(time.time() - tic) tic = time.time() # update average loss and acc ave_total_loss.update(loss.data.item()) ave_acc.update(acc.data.item() * 100) # calculate accuracy, and display if i_iter % cfg.TRAIN.disp_iter == 0: print('Iter: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, ' 'lr_encoder: {:.6f}, lr_decoder: {:.6f}, ' 'Accuracy: {:4.2f}, Loss: {:.6f}'.format( i_iter, i_iter, cfg.TRAIN.n_iters, batch_time.average(), data_time.average(), cfg.TRAIN.running_lr_encoder, cfg.TRAIN.running_lr_decoder, ave_acc.average(), ave_total_loss.average())) history['train']['iter'].append(i_iter) history['train']['loss'].append(loss.data.item()) history['train']['acc'].append(acc.data.item()) if (i_iter + 1) % cfg.TRAIN.save_freq == 0: checkpoint(nets, history, cfg, i_iter + 1) if (i_iter + 1) % cfg.TRAIN.eval_freq == 0: metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs) with torch.no_grad(): print('----Evaluation----') net_objectness.eval() net_decoder.eval() net_decoder.use_softmax = True for run in range(cfg.VAL.n_runs): print(f'### Run {run + 1} ###') set_seed(cfg.VAL.seed + run) print(f'### Load validation data ###') dataset_val = make_data( base_dir=cfg.DATASET.data_dir, split='val', transforms=val_transform, to_tensor=transforms.ToTensorNormalize_noresize(), labels=labels_val, max_iters=cfg.VAL.n_iters * cfg.VAL.n_batch, n_ways=cfg.TASK.n_ways, n_shots=cfg.TASK.n_shots, n_queries=cfg.TASK.n_queries, permute=cfg.VAL.permute_labels, exclude_labels=[]) if data_name == 'COCO': coco_cls_ids = dataset_val.datasets[ 0].dataset.coco.getCatIds() testloader = DataLoader(dataset_val, batch_size=cfg.VAL.n_batch, shuffle=False, num_workers=1, pin_memory=True, drop_last=False) print(f"Total # of validation Data: {len(dataset)}") #for sample_batched in tqdm.tqdm(testloader): for sample_batched in testloader: feed_dict = data_preprocess(sample_batched, cfg, is_val=True) if data_name == 'COCO': label_ids = [ coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids'] ] else: label_ids = list(sample_batched['class_ids']) feat = net_objectness(feed_dict['img_data'], return_feature_maps=True) query_pred = net_decoder( feat, segSize=cfg.DATASET.input_size) metric.record( np.array(query_pred.argmax(dim=1)[0].cpu()), np.array(feed_dict['seg_label'][0].cpu()), labels=label_ids, n_run=run) classIoU, meanIoU = metric.get_mIoU( labels=sorted(labels_val), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary( n_run=run) classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU( labels=sorted(labels_val)) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary( ) print('----- Evaluation Result -----') print(f'best meanIoU_binary: {best_iou}') print(f'meanIoU mean: {meanIoU}') print(f'meanIoU std: {meanIoU_std}') print(f'meanIoU_binary mean: {meanIoU_binary}') print(f'meanIoU_binary std: {meanIoU_std_binary}') if meanIoU_binary > best_iou: best_iou = meanIoU_binary checkpoint(nets, history, cfg, 'best') net_objectness.train(not cfg.TRAIN.fix_bn) net_decoder.train(not cfg.TRAIN.fix_bn) net_decoder.use_softmax = False print('Training Done!')
def main(cfg, gpus): torch.cuda.set_device(gpus[0]) # Network Builders net_enc_query = ModelBuilder.build_encoder( arch=cfg.MODEL.arch_encoder.lower(), fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_enc_query, fix_encoder=cfg.TRAIN.fix_encoder) net_enc_memory = ModelBuilder.build_encoder_memory_separate( arch=cfg.MODEL.arch_memory_encoder.lower(), fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_enc_memory, num_class=cfg.TASK.n_ways+1, RGB_mask_combine_val=cfg.DATASET.RGB_mask_combine_val, segm_downsampling_rate=cfg.DATASET.segm_downsampling_rate) net_att_query = ModelBuilder.build_attention( arch=cfg.MODEL.arch_attention, input_dim=cfg.MODEL.encoder_dim, fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_att_query) net_att_memory = ModelBuilder.build_attention( arch=cfg.MODEL.arch_attention, input_dim=cfg.MODEL.fc_dim, fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_att_memory) net_projection = ModelBuilder.build_projection( arch=cfg.MODEL.arch_projection, input_dim=cfg.MODEL.encoder_dim, fc_dim=cfg.MODEL.projection_dim, weights=cfg.MODEL.weights_projection) net_decoder = ModelBuilder.build_decoder( arch=cfg.MODEL.arch_decoder.lower(), input_dim=cfg.MODEL.decoder_dim, fc_dim=cfg.MODEL.decoder_fc_dim, ppm_dim=cfg.MODEL.ppm_dim, num_class=cfg.TASK.n_ways+1, weights=cfg.MODEL.weights_decoder, dropout_rate=cfg.MODEL.dropout_rate, use_dropout=cfg.MODEL.use_dropout, use_softmax=True) if cfg.MODEL.weights_objectness and cfg.MODEL.weights_objectness_decoder: '''net_objectness = ModelBuilder.build_objectness( arch='resnet50_deeplab', weights=cfg.MODEL.weights_objectness, fix_encoder=True) net_objectness_decoder = ModelBuilder.build_decoder( arch='aspp_few_shot', input_dim=2048, fc_dim=256, ppm_dim=256, num_class=2, weights=cfg.MODEL.weights_objectness_decoder, dropout_rate=0.5, use_dropout=True)''' net_objectness = ModelBuilder.build_objectness( arch='hrnetv2', weights=cfg.MODEL.weights_objectness, fix_encoder=True) net_objectness_decoder = ModelBuilder.build_decoder( arch='c1_nodropout', input_dim=720, fc_dim=720, ppm_dim=256, num_class=2, weights=cfg.MODEL.weights_objectness_decoder, use_dropout=False) for param in net_objectness.parameters(): param.requires_grad = False for param in net_objectness_decoder.parameters(): param.requires_grad = False else: net_objectness = None net_objectness_decoder = None crit = nn.NLLLoss(ignore_index=255) segmentation_module = SegmentationAttentionSeparateModule(net_enc_query, net_enc_memory, net_att_query, net_att_memory, net_decoder, net_projection, net_objectness, net_objectness_decoder, crit, zero_memory=cfg.MODEL.zero_memory, zero_qval=cfg.MODEL.zero_qval, normalize_key=cfg.MODEL.normalize_key, p_scalar=cfg.MODEL.p_scalar, memory_feature_aggregation=cfg.MODEL.memory_feature_aggregation, memory_noLabel=cfg.MODEL.memory_noLabel, debug=cfg.is_debug or cfg.eval_att_voting, mask_feat_downsample_rate=cfg.MODEL.mask_feat_downsample_rate, att_mat_downsample_rate=cfg.MODEL.att_mat_downsample_rate, objectness_feat_downsample_rate=cfg.MODEL.objectness_feat_downsample_rate, segm_downsampling_rate=cfg.DATASET.segm_downsampling_rate, mask_foreground=cfg.MODEL.mask_foreground, global_pool_read=cfg.MODEL.global_pool_read, average_memory_voting=cfg.MODEL.average_memory_voting, average_memory_voting_nonorm=cfg.MODEL.average_memory_voting_nonorm, mask_memory_RGB=cfg.MODEL.mask_memory_RGB, linear_classifier_support=cfg.MODEL.linear_classifier_support, decay_lamb=cfg.MODEL.decay_lamb, linear_classifier_support_only=cfg.MODEL.linear_classifier_support_only, qread_only=cfg.MODEL.qread_only, feature_as_key=cfg.MODEL.feature_as_key, objectness_multiply=cfg.MODEL.objectness_multiply) segmentation_module = nn.DataParallel(segmentation_module, device_ids=gpus) segmentation_module.cuda() segmentation_module.eval() print('###### Prepare data ######') data_name = cfg.DATASET.name if data_name == 'VOC': from dataloaders.customized import voc_fewshot make_data = voc_fewshot max_label = 20 elif data_name == 'COCO': from dataloaders.customized import coco_fewshot make_data = coco_fewshot max_label = 80 split = cfg.DATASET.data_split + '2014' annFile = f'{cfg.DATASET.data_dir}/annotations/instances_{split}.json' cocoapi = COCO(annFile) else: raise ValueError('Wrong config for dataset!') #labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][cfg.TASK.fold_idx] labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx] transforms = [Resize_test(size=cfg.DATASET.input_size)] transforms = Compose(transforms) print('###### Testing begins ######') metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs) with torch.no_grad(): for run in range(cfg.VAL.n_runs): print(f'### Run {run + 1} ###') set_seed(cfg.VAL.seed + run) print(f'### Load data ###') dataset = make_data( base_dir=cfg.DATASET.data_dir, split=cfg.DATASET.data_split, transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=cfg.VAL.n_iters * cfg.VAL.n_batch, n_ways=cfg.TASK.n_ways, n_shots=cfg.TASK.n_shots, n_queries=cfg.TASK.n_queries, permute=cfg.VAL.permute_labels, exclude_labels=[] ) if data_name == 'COCO': coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds() testloader = DataLoader(dataset, batch_size=cfg.VAL.n_batch, shuffle=False, num_workers=1, pin_memory=True, drop_last=False) print(f"Total # of Data: {len(dataset)}") count = 0 if cfg.multi_scale_test: scales = [224, 328, 424] else: scales = [328] for sample_batched in tqdm.tqdm(testloader): feed_dict = data_preprocess(sample_batched, cfg) if data_name == 'COCO': label_ids = [coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids']] else: label_ids = list(sample_batched['class_ids']) for q, scale in enumerate(scales): if len(scales) > 1: feed_dict['img_data'] = nn.functional.interpolate(feed_dict['img_data'].cuda(), size=(scale, scale), mode='bilinear') if cfg.eval_att_voting or cfg.is_debug: query_pred, qread, qval, qk_b, mk_b, mv_b, p, feature_enc, feature_memory = segmentation_module(feed_dict, segSize=(feed_dict['seg_label_noresize'].shape[1], feed_dict['seg_label_noresize'].shape[2])) if cfg.eval_att_voting: height, width = qread.shape[-2], qread.shape[-1] assert p.shape[0] == height*width img_refs_mask_resize = nn.functional.interpolate(feed_dict['img_refs_mask'][0].cuda(), size=(height, width), mode='nearest') img_refs_mask_resize_flat = img_refs_mask_resize[:,0,:,:].view(img_refs_mask_resize.shape[0], -1) mask_voting_flat = torch.mm(img_refs_mask_resize_flat, p) mask_voting = mask_voting_flat.view(mask_voting_flat.shape[0], height, width) mask_voting = torch.unsqueeze(mask_voting, 0) query_pred = nn.functional.interpolate(mask_voting[:,0:-1], size=cfg.DATASET.input_size, mode='bilinear', align_corners=False) if cfg.is_debug: np.save('debug/img_refs_mask-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), img_refs_mask_resize.detach().cpu().float().numpy()) np.save('debug/query_pred-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), query_pred.detach().cpu().float().numpy()) if cfg.is_debug: np.save('debug/qread-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), qread.detach().cpu().float().numpy()) np.save('debug/qval-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), qval.detach().cpu().float().numpy()) #np.save('debug/qk_b-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), qk_b.detach().cpu().float().numpy()) #np.save('debug/mk_b-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), mk_b.detach().cpu().float().numpy()) #np.save('debug/mv_b-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), mv_b.detach().cpu().float().numpy()) #np.save('debug/p-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), p.detach().cpu().float().numpy()) #np.save('debug/feature_enc-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), feature_enc[-1].detach().cpu().float().numpy()) #np.save('debug/feature_memory-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), feature_memory[-1].detach().cpu().float().numpy()) else: #query_pred = segmentation_module(feed_dict, segSize=cfg.DATASET.input_size) query_pred = segmentation_module(feed_dict, segSize=(feed_dict['seg_label_noresize'].shape[1], feed_dict['seg_label_noresize'].shape[2])) if q == 0: query_pred_final = query_pred/len(scales) else: query_pred_final += query_pred/len(scales) query_pred = query_pred_final metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()), np.array(feed_dict['seg_label_noresize'][0].cpu()), labels=label_ids, n_run=run) if cfg.VAL.visualize: #print(as_numpy(feed_dict['seg_label'][0].cpu()).shape) #print(as_numpy(np.array(query_pred.argmax(dim=1)[0].cpu())).shape) #print(feed_dict['img_data'].cpu().shape) query_name = sample_batched['query_ids'][0][0] support_name = sample_batched['support_ids'][0][0][0] if data_name == 'VOC': img = imread(os.path.join(cfg.DATASET.data_dir, 'JPEGImages', query_name+'.jpg')) else: query_name = int(query_name) img_meta = cocoapi.loadImgs(query_name)[0] img = imread(os.path.join(cfg.DATASET.data_dir, split, img_meta['file_name'])) #img = imresize(img, cfg.DATASET.input_size) visualize_result( (img, as_numpy(feed_dict['seg_label_noresize'][0].cpu()), '%05d'%(count)), as_numpy(np.array(query_pred.argmax(dim=1)[0].cpu())), os.path.join(cfg.DIR, 'result') ) count += 1 classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run) '''_run.log_scalar('classIoU', classIoU.tolist()) _run.log_scalar('meanIoU', meanIoU.tolist()) _run.log_scalar('classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist()) _log.info(f'classIoU: {classIoU}') _log.info(f'meanIoU: {meanIoU}') _log.info(f'classIoU_binary: {classIoU_binary}') _log.info(f'meanIoU_binary: {meanIoU_binary}')''' classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(labels=sorted(labels)) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary() print('----- Final Result -----') print('final_classIoU', classIoU.tolist()) print('final_classIoU_std', classIoU_std.tolist()) print('final_meanIoU', meanIoU.tolist()) print('final_meanIoU_std', meanIoU_std.tolist()) print('final_classIoU_binary', classIoU_binary.tolist()) print('final_classIoU_std_binary', classIoU_std_binary.tolist()) print('final_meanIoU_binary', meanIoU_binary.tolist()) print('final_meanIoU_std_binary', meanIoU_std_binary.tolist()) print(f'classIoU mean: {classIoU}') print(f'classIoU std: {classIoU_std}') print(f'meanIoU mean: {meanIoU}') print(f'meanIoU std: {meanIoU_std}') print(f'classIoU_binary mean: {classIoU_binary}') print(f'classIoU_binary std: {classIoU_std_binary}') print(f'meanIoU_binary mean: {meanIoU_binary}') print(f'meanIoU_binary std: {meanIoU_std_binary}')
def main(cfg, gpus): # Network Builders torch.cuda.set_device(gpus[0]) print('###### Create model ######') net_enc_query = ModelBuilder.build_encoder( arch=cfg.MODEL.arch_encoder.lower(), fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_enc_query, fix_encoder=cfg.TRAIN.fix_encoder) net_enc_memory = ModelBuilder.build_encoder_memory_separate( arch=cfg.MODEL.arch_memory_encoder.lower(), fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_enc_memory, num_class=cfg.TASK.n_ways + 1, RGB_mask_combine_val=cfg.DATASET.RGB_mask_combine_val, segm_downsampling_rate=cfg.DATASET.segm_downsampling_rate) net_att_query = ModelBuilder.build_attention( arch=cfg.MODEL.arch_attention, input_dim=cfg.MODEL.encoder_dim, fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_att_query) net_att_memory = ModelBuilder.build_attention( arch=cfg.MODEL.arch_attention, input_dim=cfg.MODEL.fc_dim, fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_att_memory) net_projection = ModelBuilder.build_projection( arch=cfg.MODEL.arch_projection, input_dim=cfg.MODEL.encoder_dim, fc_dim=cfg.MODEL.projection_dim, weights=cfg.MODEL.weights_projection) net_decoder = ModelBuilder.build_decoder( arch=cfg.MODEL.arch_decoder.lower(), input_dim=cfg.MODEL.decoder_dim, fc_dim=cfg.MODEL.decoder_fc_dim, ppm_dim=cfg.MODEL.ppm_dim, num_class=cfg.TASK.n_ways + 1, weights=cfg.MODEL.weights_decoder, dropout_rate=cfg.MODEL.dropout_rate, use_dropout=cfg.MODEL.use_dropout) if cfg.MODEL.weights_objectness and cfg.MODEL.weights_objectness_decoder: '''net_objectness = ModelBuilder.build_objectness( arch='resnet50_deeplab', weights=cfg.MODEL.weights_objectness, fix_encoder=True) net_objectness_decoder = ModelBuilder.build_decoder( arch='aspp_few_shot', input_dim=2048, fc_dim=256, ppm_dim=256, num_class=2, weights=cfg.MODEL.weights_objectness_decoder, dropout_rate=0.5, use_dropout=True)''' net_objectness = ModelBuilder.build_objectness( arch='hrnetv2', weights=cfg.MODEL.weights_objectness, fix_encoder=True) net_objectness_decoder = ModelBuilder.build_decoder( arch='c1_nodropout', input_dim=720, fc_dim=720, ppm_dim=256, num_class=2, weights=cfg.MODEL.weights_objectness_decoder, use_dropout=False) for param in net_objectness.parameters(): param.requires_grad = False for param in net_objectness_decoder.parameters(): param.requires_grad = False else: net_objectness = None net_objectness_decoder = None crit = nn.NLLLoss(ignore_index=255) segmentation_module = SegmentationAttentionSeparateModule( net_enc_query, net_enc_memory, net_att_query, net_att_memory, net_decoder, net_projection, net_objectness, net_objectness_decoder, crit, zero_memory=cfg.MODEL.zero_memory, random_memory_bias=cfg.MODEL.random_memory_bias, random_memory_nobias=cfg.MODEL.random_memory_nobias, random_scale=cfg.MODEL.random_scale, zero_qval=cfg.MODEL.zero_qval, normalize_key=cfg.MODEL.normalize_key, p_scalar=cfg.MODEL.p_scalar, memory_feature_aggregation=cfg.MODEL.memory_feature_aggregation, memory_noLabel=cfg.MODEL.memory_noLabel, mask_feat_downsample_rate=cfg.MODEL.mask_feat_downsample_rate, att_mat_downsample_rate=cfg.MODEL.att_mat_downsample_rate, objectness_feat_downsample_rate=cfg.MODEL. objectness_feat_downsample_rate, segm_downsampling_rate=cfg.DATASET.segm_downsampling_rate, mask_foreground=cfg.MODEL.mask_foreground, global_pool_read=cfg.MODEL.global_pool_read, average_memory_voting=cfg.MODEL.average_memory_voting, average_memory_voting_nonorm=cfg.MODEL.average_memory_voting_nonorm, mask_memory_RGB=cfg.MODEL.mask_memory_RGB, linear_classifier_support=cfg.MODEL.linear_classifier_support, decay_lamb=cfg.MODEL.decay_lamb, linear_classifier_support_only=cfg.MODEL. linear_classifier_support_only, qread_only=cfg.MODEL.qread_only, feature_as_key=cfg.MODEL.feature_as_key, objectness_multiply=cfg.MODEL.objectness_multiply) print('###### Load data ######') data_name = cfg.DATASET.name if data_name == 'VOC': from dataloaders.customized_objectness_debug import voc_fewshot make_data = voc_fewshot max_label = 20 elif data_name == 'COCO': from dataloaders.customized_objectness_debug import coco_fewshot make_data = coco_fewshot max_label = 80 else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx] labels_val = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][ cfg.TASK.fold_idx] if cfg.DATASET.exclude_labels: exclude_labels = labels_val else: exclude_labels = [] transforms = Compose([Resize(size=cfg.DATASET.input_size), RandomMirror()]) dataset = make_data(base_dir=cfg.DATASET.data_dir, split=cfg.DATASET.data_split, transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=cfg.TRAIN.n_iters * cfg.TRAIN.n_batch, n_ways=cfg.TASK.n_ways, n_shots=cfg.TASK.n_shots, n_queries=cfg.TASK.n_queries, permute=cfg.TRAIN.permute_labels, exclude_labels=exclude_labels, use_ignore=cfg.use_ignore) trainloader = DataLoader(dataset, batch_size=cfg.TRAIN.n_batch, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) #segmentation_module = nn.DataParallel(segmentation_module, device_ids=gpus) segmentation_module.cuda() # Set up optimizers nets = (net_enc_query, net_enc_memory, net_att_query, net_att_memory, net_decoder, net_projection, crit) optimizers = create_optimizers(nets, cfg) batch_time = AverageMeter() data_time = AverageMeter() ave_total_loss = AverageMeter() ave_acc = AverageMeter() history = {'train': {'iter': [], 'loss': [], 'acc': []}} segmentation_module.train(not cfg.TRAIN.fix_bn) if net_objectness and net_objectness_decoder: net_objectness.eval() net_objectness_decoder.eval() best_iou = 0 # main loop tic = time.time() print('###### Training ######') for i_iter, sample_batched in enumerate(trainloader): # Prepare input feed_dict = data_preprocess(sample_batched, cfg) data_time.update(time.time() - tic) segmentation_module.zero_grad() # adjust learning rate adjust_learning_rate(optimizers, i_iter, cfg) # forward pass #print(batch_data) loss, acc = segmentation_module(feed_dict) loss = loss.mean() acc = acc.mean() # Backward loss.backward() for optimizer in optimizers: if optimizer: optimizer.step() # measure elapsed time batch_time.update(time.time() - tic) tic = time.time() # update average loss and acc ave_total_loss.update(loss.data.item()) ave_acc.update(acc.data.item() * 100) # calculate accuracy, and display if i_iter % cfg.TRAIN.disp_iter == 0: print('Iter: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, ' 'lr_encoder: {:.6f}, lr_decoder: {:.6f}, ' 'Accuracy: {:4.2f}, Loss: {:.6f}'.format( i_iter, i_iter, cfg.TRAIN.n_iters, batch_time.average(), data_time.average(), cfg.TRAIN.running_lr_encoder, cfg.TRAIN.running_lr_decoder, ave_acc.average(), ave_total_loss.average())) history['train']['iter'].append(i_iter) history['train']['loss'].append(loss.data.item()) history['train']['acc'].append(acc.data.item()) if (i_iter + 1) % cfg.TRAIN.save_freq == 0: checkpoint(nets, history, cfg, i_iter + 1) if (i_iter + 1) % cfg.TRAIN.eval_freq == 0: metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs) with torch.no_grad(): print('----Evaluation----') segmentation_module.eval() net_decoder.use_softmax = True for run in range(cfg.VAL.n_runs): print(f'### Run {run + 1} ###') set_seed(cfg.VAL.seed + run) print(f'### Load validation data ###') dataset_val = make_data(base_dir=cfg.DATASET.data_dir, split=cfg.DATASET.data_split, transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels_val, max_iters=cfg.VAL.n_iters * cfg.VAL.n_batch, n_ways=cfg.TASK.n_ways, n_shots=cfg.TASK.n_shots, n_queries=cfg.TASK.n_queries, permute=cfg.VAL.permute_labels, exclude_labels=[]) if data_name == 'COCO': coco_cls_ids = dataset_val.datasets[ 0].dataset.coco.getCatIds() testloader = DataLoader(dataset_val, batch_size=cfg.VAL.n_batch, shuffle=False, num_workers=1, pin_memory=True, drop_last=False) print(f"Total # of validation Data: {len(dataset)}") #for sample_batched in tqdm.tqdm(testloader): for sample_batched in testloader: feed_dict = data_preprocess(sample_batched, cfg, is_val=True) if data_name == 'COCO': label_ids = [ coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids'] ] else: label_ids = list(sample_batched['class_ids']) query_pred = segmentation_module( feed_dict, segSize=cfg.DATASET.input_size) metric.record( np.array(query_pred.argmax(dim=1)[0].cpu()), np.array(feed_dict['seg_label'][0].cpu()), labels=label_ids, n_run=run) classIoU, meanIoU = metric.get_mIoU( labels=sorted(labels_val), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary( n_run=run) classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU( labels=sorted(labels_val)) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary( ) print('----- Evaluation Result -----') print(f'best meanIoU mean: {best_iou}') print(f'meanIoU mean: {meanIoU}') print(f'meanIoU std: {meanIoU_std}') print(f'meanIoU_binary mean: {meanIoU_binary}') print(f'meanIoU_binary std: {meanIoU_std_binary}') checkpoint(nets, history, cfg, 'latest') if meanIoU > best_iou: best_iou = meanIoU checkpoint(nets, history, cfg, 'best') segmentation_module.train(not cfg.TRAIN.fix_bn) if net_objectness and net_objectness_decoder: net_objectness.eval() net_objectness_decoder.eval() net_decoder.use_softmax = False print('Training Done!')
def main(cfg, gpus): torch.cuda.set_device(gpus[0]) # Network Builders net_objectness = ModelBuilder.build_objectness( arch=cfg.MODEL.arch_objectness, weights=cfg.MODEL.weights_enc_query, fix_encoder=cfg.TRAIN.fix_encoder) net_decoder = ModelBuilder.build_decoder( arch=cfg.MODEL.arch_decoder.lower(), input_dim=cfg.MODEL.decoder_dim, fc_dim=cfg.MODEL.fc_dim, ppm_dim=cfg.MODEL.ppm_dim, num_class=2, weights=cfg.MODEL.weights_decoder, dropout_rate=cfg.MODEL.dropout_rate, use_dropout=cfg.MODEL.use_dropout, use_softmax=True) crit = nn.NLLLoss(ignore_index=255) net_objectness.cuda() net_objectness.eval() net_decoder.cuda() net_decoder.eval() print('###### Prepare data ######') data_name = cfg.DATASET.name if data_name == 'VOC': if cfg.VAL.test_with_classes: from dataloaders.customized import voc_fewshot else: from dataloaders.customized_objectness import voc_fewshot make_data = voc_fewshot max_label = 20 elif data_name == 'COCO': if cfg.VAL.test_with_classes: from dataloaders.customized import coco_fewshot else: from dataloaders.customized_objectness import coco_fewshot make_data = coco_fewshot max_label = 80 split = cfg.DATASET.data_split + '2014' annFile = f'{cfg.DATASET.data_dir}/annotations/instances_{split}.json' cocoapi = COCO(annFile) else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][ cfg.TASK.fold_idx] #labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx] #transforms = [Resize_test(size=cfg.DATASET.input_size)] val_transforms = [ transforms.ToNumpy(), transforms.Resize_pad(size=cfg.DATASET.input_size[0]) ] value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] '''val_transforms = [ transforms.ToNumpy(), #transforms.RandScale([0.9, 1.1]), #transforms.RandRotate([-10, 10], padding=mean, ignore_label=0), #transforms.RandomGaussianBlur(), #transforms.RandomHorizontalFlip(), transforms.Crop([cfg.DATASET.input_size[0], cfg.DATASET.input_size[1]], crop_type='rand', padding=mean, ignore_label=0)]''' val_transforms = Compose(val_transforms) print('###### Testing begins ######') metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs) with torch.no_grad(): for run in range(cfg.VAL.n_runs): print(f'### Run {run + 1} ###') set_seed(cfg.VAL.seed + run) print(f'### Load data ###') dataset = make_data( base_dir=cfg.DATASET.data_dir, split=cfg.DATASET.data_split, transforms=val_transforms, to_tensor=transforms.ToTensorNormalize_noresize(), labels=labels, max_iters=cfg.VAL.n_iters * cfg.VAL.n_batch, n_ways=cfg.TASK.n_ways, n_shots=cfg.TASK.n_shots, n_queries=cfg.TASK.n_queries, permute=cfg.VAL.permute_labels, ) if data_name == 'COCO': coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds() testloader = DataLoader(dataset, batch_size=cfg.VAL.n_batch, shuffle=False, num_workers=1, pin_memory=True, drop_last=False) print(f"Total # of Data: {len(dataset)}") count = 0 for sample_batched in tqdm.tqdm(testloader): feed_dict = data_preprocess(sample_batched, cfg) if data_name == 'COCO': label_ids = [ coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids'] ] else: label_ids = list(sample_batched['class_ids']) feat = net_objectness(feed_dict['img_data'], return_feature_maps=True) query_pred = net_decoder(feat, segSize=(473, 473)) metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()), np.array(feed_dict['seg_label'][0].cpu()), labels=label_ids, n_run=run) if cfg.VAL.visualize: #print(as_numpy(feed_dict['seg_label'][0].cpu()).shape) #print(as_numpy(np.array(query_pred.argmax(dim=1)[0].cpu())).shape) #print(feed_dict['img_data'].cpu().shape) query_name = sample_batched['query_ids'][0][0] support_name = sample_batched['support_ids'][0][0][0] if data_name == 'VOC': img = imread( os.path.join(cfg.DATASET.data_dir, 'JPEGImages', query_name + '.jpg')) else: query_name = int(query_name) img_meta = cocoapi.loadImgs(query_name)[0] img = imread( os.path.join(cfg.DATASET.data_dir, split, img_meta['file_name'])) #img = imresize(img, cfg.DATASET.input_size) visualize_result( (img, as_numpy(feed_dict['seg_label'][0].cpu()), '%05d' % (count)), as_numpy(np.array(query_pred.argmax(dim=1)[0].cpu())), os.path.join(cfg.DIR, 'result')) count += 1 classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run) '''_run.log_scalar('classIoU', classIoU.tolist()) _run.log_scalar('meanIoU', meanIoU.tolist()) _run.log_scalar('classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist()) _log.info(f'classIoU: {classIoU}') _log.info(f'meanIoU: {meanIoU}') _log.info(f'classIoU_binary: {classIoU_binary}') _log.info(f'meanIoU_binary: {meanIoU_binary}')''' classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU( labels=sorted(labels)) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary( ) print('----- Final Result -----') print('final_classIoU', classIoU.tolist()) print('final_classIoU_std', classIoU_std.tolist()) print('final_meanIoU', meanIoU.tolist()) print('final_meanIoU_std', meanIoU_std.tolist()) print('final_classIoU_binary', classIoU_binary.tolist()) print('final_classIoU_std_binary', classIoU_std_binary.tolist()) print('final_meanIoU_binary', meanIoU_binary.tolist()) print('final_meanIoU_std_binary', meanIoU_std_binary.tolist()) print(f'classIoU mean: {classIoU}') print(f'classIoU std: {classIoU_std}') print(f'meanIoU mean: {meanIoU}') print(f'meanIoU std: {meanIoU_std}') print(f'classIoU_binary mean: {classIoU_binary}') print(f'classIoU_binary std: {classIoU_std_binary}') print(f'meanIoU_binary mean: {meanIoU_binary}') print(f'meanIoU_binary std: {meanIoU_std_binary}')
def main(cfg, gpus): # Network Builders torch.cuda.set_device(gpus[0]) print('###### Create model ######') net_objectness = ModelBuilder.build_objectness( arch=cfg.MODEL.arch_objectness, weights=cfg.MODEL.weights_enc_query, fix_encoder=cfg.TRAIN.fix_encoder) net_decoder = ModelBuilder.build_decoder( arch=cfg.MODEL.arch_decoder.lower(), input_dim=cfg.MODEL.decoder_dim, fc_dim=cfg.MODEL.fc_dim, ppm_dim=cfg.MODEL.ppm_dim, num_class=2, weights=cfg.MODEL.weights_decoder, dropout_rate=cfg.MODEL.dropout_rate, use_dropout=cfg.MODEL.use_dropout) crit = nn.NLLLoss(ignore_index=255) print('###### Load data ######') data_name = cfg.DATASET.name if data_name == 'VOC': max_label = 20 elif data_name == 'COCO': max_label = 80 else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx] labels_val = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][ cfg.TASK.fold_idx] exclude_labels = labels_val value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] train_transform = [ transform.RandScale([0.9, 1.1]), transform.RandRotate([-10, 10], padding=mean, ignore_label=255), transform.RandomGaussianBlur(), transform.RandomHorizontalFlip(), transform.Crop([cfg.DATASET.input_size[0], cfg.DATASET.input_size[1]], crop_type='rand', padding=mean, ignore_label=255), transform.ToTensor(), transform.Normalize(mean=mean, std=std) ] train_transform = transform.Compose(train_transform) train_data = dataset.SemData(split=cfg.TASK.fold_idx, shot=cfg.TASK.n_shots, data_root=cfg.DATASET.data_dir, data_list=cfg.DATASET.train_list, transform=train_transform, mode='train', \ use_coco=False, use_split_coco=False) train_sampler = None train_loader = torch.utils.data.DataLoader(train_data, batch_size=cfg.TRAIN.n_batch, shuffle=(train_sampler is None), num_workers=cfg.TRAIN.workers, pin_memory=True, sampler=train_sampler, drop_last=True) val_transform = transform.Compose([ transform.Resize(size=cfg.DATASET.input_size[0]), transform.ToTensor(), transform.Normalize(mean=mean, std=std) ]) val_data = dataset.SemData(split=cfg.TASK.fold_idx, shot=cfg.TASK.n_shots, data_root=cfg.DATASET.data_dir, data_list=cfg.DATASET.val_list, transform=val_transform, mode='val', use_coco=False, use_split_coco=False) val_sampler = None val_loader = torch.utils.data.DataLoader(val_data, batch_size=cfg.VAL.n_batch, shuffle=False, num_workers=cfg.TRAIN.workers, pin_memory=True, sampler=val_sampler) #segmentation_module = nn.DataParallel(segmentation_module, device_ids=gpus) net_objectness.cuda() net_decoder.cuda() # Set up optimizers nets = (net_objectness, net_decoder, crit) optimizers = create_optimizers(nets, cfg) batch_time = AverageMeter() data_time = AverageMeter() ave_total_loss = AverageMeter() ave_acc = AverageMeter() history = {'train': {'iter': [], 'loss': [], 'acc': []}} net_objectness.train(not cfg.TRAIN.fix_bn) net_decoder.train(not cfg.TRAIN.fix_bn) best_iou = 0 # main loop tic = time.time() i_iter = -1 print('###### Training ######') for epoch in range(0, 200): for _, (input, target) in enumerate(train_loader): # Prepare input i_iter += 1 input = input.cuda(non_blocking=True) target = target.cuda(non_blocking=True) data_time.update(time.time() - tic) net_objectness.zero_grad() net_decoder.zero_grad() # adjust learning rate adjust_learning_rate(optimizers, i_iter, cfg) # forward pass feat = net_objectness(input, return_feature_maps=True) pred = net_decoder(feat, segSize=cfg.DATASET.input_size) loss = crit(pred, target) acc = pixel_acc(pred, target) loss = loss.mean() acc = acc.mean() # Backward loss.backward() for optimizer in optimizers: if optimizer: optimizer.step() # measure elapsed time batch_time.update(time.time() - tic) tic = time.time() # update average loss and acc ave_total_loss.update(loss.data.item()) ave_acc.update(acc.data.item() * 100) # calculate accuracy, and display if i_iter % cfg.TRAIN.disp_iter == 0: print( 'Iter: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, ' 'lr_encoder: {:.6f}, lr_decoder: {:.6f}, ' 'Ave_Accuracy: {:4.2f}, Accuracy:{:4.2f}, Ave_Loss: {:.6f}, Loss: {:.6f}' .format(i_iter, i_iter, cfg.TRAIN.n_iters, batch_time.average(), data_time.average(), cfg.TRAIN.running_lr_encoder, cfg.TRAIN.running_lr_decoder, ave_acc.average(), acc.data.item() * 100, ave_total_loss.average(), loss.data.item())) history['train']['iter'].append(i_iter) history['train']['loss'].append(loss.data.item()) history['train']['acc'].append(acc.data.item()) if (i_iter + 1) % cfg.TRAIN.save_freq == 0: checkpoint(nets, history, cfg, i_iter + 1) if (i_iter + 1) % cfg.TRAIN.eval_freq == 0: metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs) with torch.no_grad(): print('----Evaluation----') net_objectness.eval() net_decoder.eval() net_decoder.use_softmax = True #for run in range(cfg.VAL.n_runs): for run in range(3): print(f'### Run {run + 1} ###') set_seed(cfg.VAL.seed + run) print(f'### Load validation data ###') #for sample_batched in tqdm.tqdm(testloader): for (input, target, _) in val_loader: input = input.cuda(non_blocking=True) target = target.cuda(non_blocking=True) feat = net_objectness(input, return_feature_maps=True) query_pred = net_decoder( feat, segSize=cfg.DATASET.input_size) metric.record(np.array( query_pred.argmax(dim=1)[0].cpu()), np.array(target[0].cpu()), labels=None, n_run=run) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary( ) print('----- Evaluation Result -----') print(f'best meanIoU_binary: {best_iou}') print(f'meanIoU_binary mean: {meanIoU_binary}') print(f'meanIoU_binary std: {meanIoU_std_binary}') if meanIoU_binary > best_iou: best_iou = meanIoU_binary checkpoint(nets, history, cfg, 'best') net_objectness.train(not cfg.TRAIN.fix_bn) net_decoder.train(not cfg.TRAIN.fix_bn) net_decoder.use_softmax = False print('Training Done!')