def main(): """Training for softmax classifier only. """ # Retreve experiment configurations. args = parse_args('Training for softmax classifier only.') # Retrieve GPU informations. device_ids = [int(i) for i in config.gpus.split(',')] gpu_ids = [torch.device('cuda', i) for i in device_ids] num_gpus = len(gpu_ids) # Create logger and tensorboard writer. summary_writer = tensorboardX.SummaryWriter(logdir=args.snapshot_dir) color_map = vis_utils.load_color_map(config.dataset.color_map_path) model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth') optimizer_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.state.pth') # Create data loaders. train_dataset = ListTagClassifierDataset( data_dir=args.data_dir, data_list=args.data_list, img_mean=config.network.pixel_means, img_std=config.network.pixel_stds, size=config.train.crop_size, random_crop=config.train.random_crop, random_scale=config.train.random_scale, random_mirror=config.train.random_mirror, random_grayscale=True, random_blur=True, training=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.train.batch_size, shuffle=config.train.shuffle, num_workers=num_gpus * config.num_threads, drop_last=False, collate_fn=train_dataset.collate_fn) # Create models. if config.network.backbone_types == 'panoptic_pspnet_101': embedding_model = resnet_101_pspnet(config).cuda() elif config.network.backbone_types == 'panoptic_deeplab_101': embedding_model = resnet_101_deeplab(config).cuda() else: raise ValueError('Not support ' + config.network.backbone_types) if config.network.prediction_types == 'softmax_classifier': prediction_model = softmax_classifier(config).cuda() else: raise ValueError('Not support ' + config.network.prediction_types) # Use customized optimizer and pass lr=1 to support different lr for # different weights. optimizer = SGD(embedding_model.get_params_lr() + prediction_model.get_params_lr(), lr=1, momentum=config.train.momentum, weight_decay=config.train.weight_decay) optimizer.zero_grad() # Load pre-trained weights. curr_iter = config.train.begin_iteration if config.network.pretrained: print('Loading pre-trained model: {:s}'.format( config.network.pretrained)) embedding_model.load_state_dict(torch.load( config.network.pretrained)['embedding_model'], resume=True) else: raise ValueError('Pre-trained model is required.') # Distribute model weights to multi-gpus. embedding_model = DataParallel(embedding_model, device_ids=device_ids, gather_output=False) prediction_model = DataParallel(prediction_model, device_ids=device_ids, gather_output=False) embedding_model.eval() prediction_model.train() print(embedding_model) print(prediction_model) # Create memory bank. memory_banks = {} # start training train_iterator = train_loader.__iter__() iterator_index = 0 pbar = tqdm(range(curr_iter, config.train.max_iteration)) for curr_iter in pbar: # Check if the rest of datas is enough to iterate through; # otherwise, re-initiate the data iterator. if iterator_index + num_gpus >= len(train_loader): train_iterator = train_loader.__iter__() iterator_index = 0 # Feed-forward. image_batch, label_batch = other_utils.prepare_datas_and_labels_mgpu( train_iterator, gpu_ids) iterator_index += num_gpus # Generate embeddings, clustering and prototypes. with torch.no_grad(): embeddings = embedding_model(*zip(image_batch, label_batch)) # Compute loss. outputs = prediction_model(*zip(embeddings, label_batch)) outputs = scatter_gather.gather(outputs, gpu_ids[0]) losses = [] for k in ['sem_ann_loss']: loss = outputs.get(k, None) if loss is not None: outputs[k] = loss.mean() losses.append(outputs[k]) loss = sum(losses) acc = outputs['accuracy'].mean() # Backward propogation. if config.train.lr_policy == 'step': lr = train_utils.lr_step(config.train.base_lr, curr_iter, config.train.decay_iterations, config.train.warmup_iteration) else: lr = train_utils.lr_poly(config.train.base_lr, curr_iter, config.train.max_iteration, config.train.warmup_iteration) optimizer.zero_grad() loss.backward() optimizer.step(lr) # Snapshot the trained model. if ((curr_iter + 1) % config.train.snapshot_step == 0 or curr_iter == config.train.max_iteration - 1): model_state_dict = { 'embedding_model': embedding_model.module.state_dict(), 'prediction_model': prediction_model.module.state_dict() } torch.save(model_state_dict, model_path_template.format(curr_iter)) torch.save(optimizer.state_dict(), optimizer_path_template.format(curr_iter)) # Print loss in the progress bar. line = 'loss = {:.3f}, acc = {:.3f}, lr = {:.6f}'.format( loss.item(), acc.item(), lr) pbar.set_description(line)
def main(): """Inference for semantic segmentation. """ # Retreve experiment configurations. args = parse_args('Inference for semantic segmentation.') # Create directories to save results. semantic_dir = os.path.join(args.save_dir, 'semantic_gray') semantic_rgb_dir = os.path.join(args.save_dir, 'semantic_color') os.makedirs(semantic_dir, exist_ok=True) os.makedirs(semantic_rgb_dir, exist_ok=True) # Create color map. color_map = vis_utils.load_color_map(config.dataset.color_map_path) color_map = color_map.numpy() # Create data loaders. test_dataset = ListDataset(data_dir=args.data_dir, data_list=args.data_list, img_mean=config.network.pixel_means, img_std=config.network.pixel_stds, size=None, random_crop=False, random_scale=False, random_mirror=False, training=False) test_image_paths = test_dataset.image_paths # Create models. if config.network.backbone_types == 'panoptic_pspnet_101': embedding_model = resnet_101_pspnet(config).cuda() elif config.network.backbone_types == 'panoptic_deeplab_101': embedding_model = resnet_101_deeplab(config).cuda() else: raise ValueError('Not support ' + config.network.backbone_types) prediction_model = softmax_classifier(config).cuda() embedding_model.eval() prediction_model.eval() # Load trained weights. model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth') save_iter = config.train.max_iteration - 1 embedding_model.load_state_dict(torch.load( model_path_template.format(save_iter))['embedding_model'], resume=True) prediction_model.load_state_dict( torch.load(model_path_template.format(save_iter))['prediction_model']) # Start inferencing. with torch.no_grad(): for data_index in tqdm(range(len(test_dataset))): # Image path. image_path = test_image_paths[data_index] base_name = os.path.basename(image_path).replace('.jpg', '.png') # Image resolution. image_batch, label_batch, _ = test_dataset[data_index] image_h, image_w = image_batch['image'].shape[-2:] batches = other_utils.create_image_pyramid( image_batch, label_batch, scales=[0.5, 0.75, 1, 1.25, 1.5], is_flip=True) semantic_logits = [] for image_batch, label_batch, data_info in batches: resize_image_h, resize_image_w = image_batch['image'].shape[ -2:] # Crop and Pad the input image. image_batch['image'] = transforms.resize_with_pad( image_batch['image'].transpose(1, 2, 0), config.test.crop_size, image_pad_value=0).transpose(2, 0, 1) image_batch['image'] = torch.FloatTensor( image_batch['image'][np.newaxis, ...]).cuda() pad_image_h, pad_image_w = image_batch['image'].shape[-2:] # Create the ending index of each patch. stride_h, stride_w = config.test.stride crop_h, crop_w = config.test.crop_size npatches_h = math.ceil(1.0 * (pad_image_h - crop_h) / stride_h) + 1 npatches_w = math.ceil(1.0 * (pad_image_w - crop_w) / stride_w) + 1 patch_ind_h = np.linspace(crop_h, pad_image_h, npatches_h, dtype=np.int32) patch_ind_w = np.linspace(crop_w, pad_image_w, npatches_w, dtype=np.int32) # Create place holder for full-resolution embeddings. semantic_logit = torch.FloatTensor( 1, config.dataset.num_classes, pad_image_h, pad_image_w).zero_().to("cuda:0") counts = torch.FloatTensor(1, 1, pad_image_h, pad_image_w).zero_().to("cuda:0") for ind_h in patch_ind_h: for ind_w in patch_ind_w: sh, eh = ind_h - crop_h, ind_h sw, ew = ind_w - crop_w, ind_w crop_image_batch = { k: v[:, :, sh:eh, sw:ew] for k, v in image_batch.items() } # Feed-forward. crop_embeddings = embedding_model(crop_image_batch, resize_as_input=True) crop_outputs = prediction_model(crop_embeddings) semantic_logit[..., sh:eh, sw:ew] += crop_outputs[ 'semantic_logit'].to("cuda:0") counts[..., sh:eh, sw:ew] += 1 semantic_logit /= counts semantic_logit = semantic_logit[ ..., :resize_image_h, :resize_image_w] semantic_logit = F.interpolate(semantic_logit, size=(image_h, image_w), mode='bilinear') semantic_logit = F.softmax(semantic_logit, dim=1) semantic_logit = semantic_logit.data.cpu().numpy().astype( np.float32) if data_info['is_flip']: semantic_logit = semantic_logit[..., ::-1] semantic_logits.append(semantic_logit) # Save semantic predictions. semantic_logits = np.concatenate(semantic_logits, axis=0) semantic_logits = np.sum(semantic_logits, axis=0) if semantic_logits is not None: semantic_pred = np.argmax(semantic_logits, axis=0).astype(np.uint8) semantic_pred_name = os.path.join(semantic_dir, base_name) Image.fromarray(semantic_pred, mode='L').save(semantic_pred_name) semantic_pred_rgb = color_map[semantic_pred] semantic_pred_rgb_name = os.path.join(semantic_rgb_dir, base_name) Image.fromarray(semantic_pred_rgb, mode='RGB').save(semantic_pred_rgb_name) # Clean GPU memory cache to save more space. outputs = {} crop_embeddings = {} crop_outputs = {} torch.cuda.empty_cache()
def main(): """Generate pseudo labels by softmax classifier. """ # Retreve experiment configurations. args = parse_args('Generate pseudo labels by softmax classifier.') # Create directories to save results. semantic_dir = os.path.join(args.save_dir, 'semantic_gray') semantic_rgb_dir = os.path.join(args.save_dir, 'semantic_color') # Create color map. color_map = vis_utils.load_color_map(config.dataset.color_map_path) color_map = color_map.numpy() # Create data loaders. test_dataset = ListDataset(data_dir=args.data_dir, data_list=args.data_list, img_mean=config.network.pixel_means, img_std=config.network.pixel_stds, size=None, random_crop=False, random_scale=False, random_mirror=False, training=False) test_image_paths = test_dataset.image_paths # Define CRF. postprocessor = DenseCRF( iter_max=args.crf_iter_max, pos_xy_std=args.crf_pos_xy_std, pos_w=args.crf_pos_w, bi_xy_std=args.crf_bi_xy_std, bi_rgb_std=args.crf_bi_rgb_std, bi_w=args.crf_bi_w, ) # Create models. if config.network.backbone_types == 'panoptic_pspnet_101': embedding_model = resnet_101_pspnet(config).cuda() elif config.network.backbone_types == 'panoptic_deeplab_101': embedding_model = resnet_101_deeplab(config).cuda() else: raise ValueError('Not support ' + config.network.backbone_types) prediction_model = softmax_classifier(config).cuda() embedding_model.eval() prediction_model.eval() # Load trained weights. model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth') save_iter = config.train.max_iteration - 1 embedding_model.load_state_dict(torch.load( model_path_template.format(save_iter))['embedding_model'], resume=True) prediction_model.load_state_dict( torch.load(model_path_template.format(save_iter))['prediction_model']) # Start inferencing. with torch.no_grad(): for data_index in tqdm(range(len(test_dataset))): # Image path. image_path = test_image_paths[data_index] base_name = os.path.basename(image_path).replace('.jpg', '.png') # Image resolution. original_image_batch, original_label_batch, _ = test_dataset[ data_index] image_h, image_w = original_image_batch['image'].shape[-2:] lab_tags = np.unique(original_label_batch['semantic_label']) lab_tags = lab_tags[lab_tags < config.dataset.num_classes] label_tags = np.zeros((config.dataset.num_classes, ), dtype=np.bool) label_tags[lab_tags] = True label_tags = torch.from_numpy(label_tags).cuda() # Image resolution. batches = other_utils.create_image_pyramid(original_image_batch, original_label_batch, scales=[0.75, 1], is_flip=True) affs = [] semantic_probs = [] for image_batch, label_batch, data_info in batches: resize_image_h, resize_image_w = image_batch['image'].shape[ -2:] # Crop and Pad the input image. image_batch['image'] = transforms.resize_with_pad( image_batch['image'].transpose(1, 2, 0), config.test.crop_size, image_pad_value=0).transpose(2, 0, 1) image_batch['image'] = torch.FloatTensor( image_batch['image'][np.newaxis, ...]).cuda() pad_image_h, pad_image_w = image_batch['image'].shape[-2:] embeddings = embedding_model(image_batch, resize_as_input=True) outputs = prediction_model(embeddings) embs = embeddings[ 'embedding'][:, :, :resize_image_h, :resize_image_w] semantic_logit = outputs['semantic_logit'][ ..., :resize_image_h, :resize_image_w] if data_info['is_flip']: embs = torch.flip(embs, dims=[3]) semantic_logit = torch.flip(semantic_logit, dims=[3]) embs = F.interpolate(embs, size=(image_h // 8, image_w // 8), mode='bilinear') embs = embs / torch.norm(embs, dim=1) embs_flat = embs.view(embs.shape[1], -1) aff = torch.matmul(embs_flat.t(), embs_flat).mul_(5).add_(-5).exp_() affs.append(aff) semantic_logit = F.interpolate(semantic_logit, size=(image_h // 8, image_w // 8), mode='bilinear') #semantic_prob = F.softmax(semantic_logit, dim=1) #semantic_probs.append(semantic_prob) semantic_probs.append(semantic_logit) cat_semantic_probs = torch.cat(semantic_probs, dim=0) #semantic_probs, _ = torch.max(cat_semantic_probs, dim=0) #semantic_probs[0] = torch.min(cat_semantic_probs[:, 0, :, :], dim=0)[0] semantic_probs = torch.mean(cat_semantic_probs, dim=0) semantic_probs = F.softmax(semantic_probs, dim=0) # normalize cam. max_prob = torch.max(semantic_probs.view(21, -1), dim=1)[0] cam_full_arr = semantic_probs / max_prob.view(21, 1, 1) cam_shape = cam_full_arr.shape[-2:] label_tags = (~label_tags).view(-1, 1, 1).expand(-1, cam_shape[0], cam_shape[1]) cam_full_arr = cam_full_arr.masked_fill(label_tags, 0) if TH is not None: cam_full_arr[0] = TH aff = torch.mean(torch.stack(affs, dim=0), dim=0) # Start random walk. aff_mat = aff**20 trans_mat = aff_mat / torch.sum(aff_mat, dim=0, keepdim=True) for _ in range(WALK_STEPS): trans_mat = torch.matmul(trans_mat, trans_mat) cam_vec = cam_full_arr.view(21, -1) cam_rw = torch.matmul(cam_vec, trans_mat) cam_rw = cam_rw.view(21, cam_shape[0], cam_shape[1]) cam_rw = cam_rw.data.cpu().numpy() cam_rw = cv2.resize(cam_rw.transpose(1, 2, 0), dsize=(image_w, image_h), interpolation=cv2.INTER_LINEAR) cam_rw_pred = np.argmax(cam_rw, axis=-1).astype(np.uint8) # CRF #image = image_batch['image'].data.cpu().numpy().astype(np.float32) #image = image[0, :, :image_h, :image_w].transpose(1, 2, 0) #image *= np.reshape(config.network.pixel_stds, (1, 1, 3)) #image += np.reshape(config.network.pixel_means, (1, 1, 3)) #image = image * 255 #image = image.astype(np.uint8) #cam_rw = postprocessor(image, cam_rw.transpose(2,0,1)) #cam_rw_pred = np.argmax(cam_rw, axis=0).astype(np.uint8) # Save semantic predictions. semantic_pred = cam_rw_pred semantic_pred_name = os.path.join(semantic_dir, base_name) if not os.path.isdir(os.path.dirname(semantic_pred_name)): os.makedirs(os.path.dirname(semantic_pred_name)) Image.fromarray(semantic_pred, mode='L').save(semantic_pred_name) semantic_pred_rgb = color_map[semantic_pred] semantic_pred_rgb_name = os.path.join(semantic_rgb_dir, base_name) if not os.path.isdir(os.path.dirname(semantic_pred_rgb_name)): os.makedirs(os.path.dirname(semantic_pred_rgb_name)) Image.fromarray(semantic_pred_rgb, mode='RGB').save(semantic_pred_rgb_name)
def main(): """Inference for semantic segmentation. """ # Retreve experiment configurations. args = parse_args('Inference for semantic segmentation.') # Create directories to save results. semantic_dir = os.path.join(args.save_dir, 'semantic_gray') semantic_rgb_dir = os.path.join(args.save_dir, 'semantic_color') if not os.path.isdir(semantic_dir): os.makedirs(semantic_dir) if not os.path.isdir(semantic_rgb_dir): os.makedirs(semantic_rgb_dir) # Create color map. color_map = vis_utils.load_color_map(config.dataset.color_map_path) color_map = color_map.numpy() # Create data loaders. test_dataset = ListDataset(data_dir=args.data_dir, data_list=args.data_list, img_mean=config.network.pixel_means, img_std=config.network.pixel_stds, size=None, random_crop=False, random_scale=False, random_mirror=False, training=False) test_image_paths = test_dataset.image_paths # Create models. if config.network.backbone_types == 'panoptic_pspnet_101': embedding_model = resnet_101_pspnet(config).cuda() elif config.network.backbone_types == 'panoptic_deeplab_101': embedding_model = resnet_101_deeplab(config).cuda() else: raise ValueError('Not support ' + config.network.backbone_types) prediction_model = softmax_classifier(config).cuda() embedding_model.eval() prediction_model.eval() # Load trained weights. model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth') save_iter = config.train.max_iteration - 1 embedding_model.load_state_dict(torch.load( model_path_template.format(save_iter))['embedding_model'], resume=True) prediction_model.load_state_dict( torch.load(model_path_template.format(save_iter))['prediction_model']) # Define CRF. postprocessor = DenseCRF( iter_max=args.crf_iter_max, pos_xy_std=args.crf_pos_xy_std, pos_w=args.crf_pos_w, bi_xy_std=args.crf_bi_xy_std, bi_rgb_std=args.crf_bi_rgb_std, bi_w=args.crf_bi_w, ) # Start inferencing. for data_index in range(len(test_dataset)): # Image path. image_path = test_image_paths[data_index] base_name = os.path.basename(image_path).replace('.jpg', '.png') # Image resolution. image_batch, _, _ = test_dataset[data_index] image_h, image_w = image_batch['image'].shape[-2:] # Resize the input image. if config.test.image_size > 0: image_batch['image'] = transforms.resize_with_interpolation( image_batch['image'].transpose(1, 2, 0), config.test.image_size, method='bilinear').transpose(2, 0, 1) resize_image_h, resize_image_w = image_batch['image'].shape[-2:] # Crop and Pad the input image. image_batch['image'] = transforms.resize_with_pad( image_batch['image'].transpose(1, 2, 0), config.test.crop_size, image_pad_value=0).transpose(2, 0, 1) image_batch['image'] = torch.FloatTensor( image_batch['image'][np.newaxis, ...]).cuda() pad_image_h, pad_image_w = image_batch['image'].shape[-2:] # Create the ending index of each patch. stride_h, stride_w = config.test.stride crop_h, crop_w = config.test.crop_size npatches_h = math.ceil(1.0 * (pad_image_h - crop_h) / stride_h) + 1 npatches_w = math.ceil(1.0 * (pad_image_w - crop_w) / stride_w) + 1 patch_ind_h = np.linspace(crop_h, pad_image_h, npatches_h, dtype=np.int32) patch_ind_w = np.linspace(crop_w, pad_image_w, npatches_w, dtype=np.int32) # Create place holder for full-resolution embeddings. outputs = {} with torch.no_grad(): for ind_h in patch_ind_h: for ind_w in patch_ind_w: sh, eh = ind_h - crop_h, ind_h sw, ew = ind_w - crop_w, ind_w crop_image_batch = { k: v[:, :, sh:eh, sw:ew] for k, v in image_batch.items() } # Feed-forward. crop_embeddings = embedding_model(crop_image_batch, resize_as_input=True) crop_outputs = prediction_model(crop_embeddings) for name, crop_out in crop_outputs.items(): if crop_out is not None: if name not in outputs.keys(): output_shape = list(crop_out.shape) output_shape[-2:] = pad_image_h, pad_image_w outputs[name] = torch.zeros( output_shape, dtype=crop_out.dtype).cuda() outputs[name][..., sh:eh, sw:ew] += crop_out # Save semantic predictions. semantic_logits = outputs.get('semantic_logit', None) if semantic_logits is not None: semantic_prob = F.softmax(semantic_logits, dim=1) semantic_prob = semantic_prob[ 0, :, :resize_image_h, :resize_image_w] semantic_prob = semantic_prob.data.cpu().numpy().astype(np.float32) # DenseCRF post-processing. image = image_batch['image'][0].data.cpu().numpy().astype( np.float32) image = image.transpose(1, 2, 0) image *= np.reshape(config.network.pixel_stds, (1, 1, 3)) image += np.reshape(config.network.pixel_means, (1, 1, 3)) image = image * 255 image = image.astype(np.uint8) image = image[:resize_image_h, :resize_image_w, :] semantic_prob = postprocessor(image, semantic_prob) #semantic_pred = torch.argmax(semantic_logits, 1) semantic_pred = np.argmax(semantic_prob, axis=0).astype(np.uint8) #semantic_pred = (semantic_pred.view(pad_image_h, pad_image_w) # .cpu() # .data # .numpy() # .astype(np.uint8)) #semantic_pred = semantic_pred[:resize_image_h, :resize_image_w] semantic_pred = cv2.resize(semantic_pred, (image_w, image_h), interpolation=cv2.INTER_NEAREST) semantic_pred_name = os.path.join(semantic_dir, base_name) Image.fromarray(semantic_pred, mode='L').save(semantic_pred_name) semantic_pred_rgb = color_map[semantic_pred] semantic_pred_rgb_name = os.path.join(semantic_rgb_dir, base_name) Image.fromarray(semantic_pred_rgb, mode='RGB').save(semantic_pred_rgb_name) # Clean GPU memory cache to save more space. outputs = {} crop_embeddings = {} crop_outputs = {} torch.cuda.empty_cache()
def main(): """Training for pixel-wise embeddings by pixel-segment contrastive learning loss. """ # Retreve experiment configurations. args = parse_args('Training for pixel-wise embeddings.') # Retrieve GPU informations. device_ids = [int(i) for i in config.gpus.split(',')] gpu_ids = [torch.device('cuda', i) for i in device_ids] num_gpus = len(gpu_ids) # Create logger and tensorboard writer. summary_writer = tensorboardX.SummaryWriter(logdir=args.snapshot_dir) color_map = vis_utils.load_color_map(config.dataset.color_map_path) model_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.pth') optimizer_path_template = os.path.join(args.snapshot_dir, 'model-{:d}.state.pth') # Create data loaders. train_dataset = ListTagDataset(data_dir=args.data_dir, data_list=args.data_list, img_mean=config.network.pixel_means, img_std=config.network.pixel_stds, size=config.train.crop_size, random_crop=config.train.random_crop, random_scale=config.train.random_scale, random_mirror=config.train.random_mirror, training=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.train.batch_size, shuffle=config.train.shuffle, num_workers=num_gpus * config.num_threads, drop_last=False, collate_fn=train_dataset.collate_fn) # Create models. if config.network.backbone_types == 'panoptic_pspnet_101': embedding_model = resnet_101_pspnet(config).cuda() elif config.network.backbone_types == 'panoptic_deeplab_101': embedding_model = resnet_101_deeplab(config).cuda() else: raise ValueError('Not support ' + config.network.backbone_types) if config.network.prediction_types == 'segsort': prediction_model = segsort(config).cuda() elif config.network.prediction_types == 'softmax_classifier': prediction_model = softmax_classifier(config).cuda() else: raise ValueError('Not support ' + config.network.prediction_types) # Use synchronize batchnorm. if config.network.use_syncbn: embedding_model = convert_model(embedding_model).cuda() prediction_model = convert_model(prediction_model).cuda() # Use customized optimizer and pass lr=1 to support different lr for # different weights. optimizer = SGD(embedding_model.get_params_lr() + prediction_model.get_params_lr(), lr=1, momentum=config.train.momentum, weight_decay=config.train.weight_decay) optimizer.zero_grad() # Load pre-trained weights. curr_iter = config.train.begin_iteration if config.train.resume: model_path = model_path_template.fromat(curr_iter) print('Resume training from {:s}'.format(model_path)) embedding_model.load_state_dict( torch.load(model_path)['embedding_model'], resume=True) prediction_model.load_state_dict( torch.load(model_path)['prediction_model'], resume=True) optimizer.load_state_dict( torch.load(optimizer_path_template.format(curr_iter))) elif config.network.pretrained: print('Loading pre-trained model: {:s}'.format( config.network.pretrained)) embedding_model.load_state_dict(torch.load(config.network.pretrained)) else: print('Training from scratch') # Distribute model weights to multi-gpus. embedding_model = DataParallel(embedding_model, device_ids=device_ids, gather_output=False) prediction_model = DataParallel(prediction_model, device_ids=device_ids, gather_output=False) if config.network.use_syncbn: patch_replication_callback(embedding_model) patch_replication_callback(prediction_model) for module in embedding_model.modules(): if isinstance(module, _BatchNorm) or isinstance(module, _ConvNd): print(module.training, module) print(embedding_model) print(prediction_model) # Create memory bank. memory_banks = {} # start training train_iterator = train_loader.__iter__() iterator_index = 0 pbar = tqdm(range(curr_iter, config.train.max_iteration)) for curr_iter in pbar: # Check if the rest of datas is enough to iterate through; # otherwise, re-initiate the data iterator. if iterator_index + num_gpus >= len(train_loader): train_iterator = train_loader.__iter__() iterator_index = 0 # Feed-forward. image_batch, label_batch = other_utils.prepare_datas_and_labels_mgpu( train_iterator, gpu_ids) iterator_index += num_gpus # Generate embeddings, clustering and prototypes. embeddings = embedding_model(*zip(image_batch, label_batch)) # Synchronize cluster indices and computer prototypes. c_inds = [emb['cluster_index'] for emb in embeddings] cb_inds = [emb['cluster_batch_index'] for emb in embeddings] cs_labs = [emb['cluster_semantic_label'] for emb in embeddings] ci_labs = [emb['cluster_instance_label'] for emb in embeddings] c_embs = [emb['cluster_embedding'] for emb in embeddings] c_embs_with_loc = [ emb['cluster_embedding_with_loc'] for emb in embeddings ] (prototypes, prototypes_with_loc, prototype_semantic_labels, prototype_instance_labels, prototype_batch_indices, cluster_indices) = ( model_utils.gather_clustering_and_update_prototypes( c_embs, c_embs_with_loc, c_inds, cb_inds, cs_labs, ci_labs, 'cuda:{:d}'.format(num_gpus - 1))) for i in range(len(label_batch)): label_batch[i]['prototype'] = prototypes[i] label_batch[i]['prototype_with_loc'] = prototypes_with_loc[i] label_batch[i][ 'prototype_semantic_label'] = prototype_semantic_labels[i] label_batch[i][ 'prototype_instance_label'] = prototype_instance_labels[i] label_batch[i]['prototype_batch_index'] = prototype_batch_indices[ i] embeddings[i]['cluster_index'] = cluster_indices[i] semantic_tags = model_utils.gather_and_update_datas( [lab['semantic_tag'] for lab in label_batch], 'cuda:{:d}'.format(num_gpus - 1)) for i in range(len(label_batch)): label_batch[i]['semantic_tag'] = semantic_tags[i] label_batch[i]['prototype_semantic_tag'] = torch.index_select( semantic_tags[i], 0, label_batch[i]['prototype_batch_index']) # Add memory bank to label batch. for k in memory_banks.keys(): for i in range(len(label_batch)): assert (label_batch[i].get(k, None) is None) label_batch[i][k] = [m.to(gpu_ids[i]) for m in memory_banks[k]] # Compute loss. outputs = prediction_model(*zip(embeddings, label_batch)) outputs = scatter_gather.gather(outputs, gpu_ids[0]) losses = [] for k in [ 'sem_ann_loss', 'sem_occ_loss', 'img_sim_loss', 'feat_aff_loss' ]: loss = outputs.get(k, None) if loss is not None: outputs[k] = loss.mean() losses.append(outputs[k]) loss = sum(losses) acc = outputs['accuracy'].mean() # Write to tensorboard summary. writer = (summary_writer if curr_iter % config.train.tensorboard_step == 0 else None) if writer is not None: summary_vis = [] summary_val = {} # Gather labels to cpu. cpu_label_batch = scatter_gather.gather(label_batch, -1) summary_vis.append( vis_utils.convert_label_to_color( cpu_label_batch['semantic_label'], color_map)) summary_vis.append( vis_utils.convert_label_to_color( cpu_label_batch['instance_label'], color_map)) # Gather outputs to cpu. vis_names = ['embedding'] cpu_embeddings = scatter_gather.gather( [{k: emb.get(k, None) for k in vis_names} for emb in embeddings], -1) for vis_name in vis_names: if cpu_embeddings.get(vis_name, None) is not None: summary_vis.append( vis_utils.embedding_to_rgb(cpu_embeddings[vis_name], 'pca')) val_names = [ 'sem_ann_loss', 'sem_occ_loss', 'img_sim_loss', 'feat_aff_loss', 'accuracy' ] for val_name in val_names: if outputs.get(val_name, None) is not None: summary_val[val_name] = outputs[val_name].mean().to('cpu') vis_utils.write_image_to_tensorboard(summary_writer, summary_vis, summary_vis[-1].shape[-2:], curr_iter) vis_utils.write_scalars_to_tensorboard(summary_writer, summary_val, curr_iter) # Backward propogation. if config.train.lr_policy == 'step': lr = train_utils.lr_step(config.train.base_lr, curr_iter, config.train.decay_iterations, config.train.warmup_iteration) else: lr = train_utils.lr_poly(config.train.base_lr, curr_iter, config.train.max_iteration, config.train.warmup_iteration) optimizer.zero_grad() loss.backward() optimizer.step(lr) # Update memory banks. with torch.no_grad(): for k in label_batch[0].keys(): if 'prototype' in k and 'memory' not in k: memory = label_batch[0][k].clone().detach() memory_key = 'memory_' + k if memory_key not in memory_banks.keys(): memory_banks[memory_key] = [] memory_banks[memory_key].append(memory) if len(memory_banks[memory_key] ) > config.train.memory_bank_size: memory_banks[memory_key] = memory_banks[memory_key][1:] # Update batch labels. for k in ['memory_prototype_batch_index']: memory_labels = memory_banks.get(k, None) if memory_labels is not None: for i, memory_label in enumerate(memory_labels): memory_labels[i] += config.train.batch_size * num_gpus # Snapshot the trained model. if ((curr_iter + 1) % config.train.snapshot_step == 0 or curr_iter == config.train.max_iteration - 1): model_state_dict = { 'embedding_model': embedding_model.module.state_dict(), 'prediction_model': prediction_model.module.state_dict() } torch.save(model_state_dict, model_path_template.format(curr_iter)) torch.save(optimizer.state_dict(), optimizer_path_template.format(curr_iter)) # Print loss in the progress bar. line = 'loss = {:.3f}, acc = {:.3f}, lr = {:.6f}'.format( loss.item(), acc.item(), lr) pbar.set_description(line)