def test_region_grounding_model(config): db = vg(config, 'test') loaddb = region_loader(db) loader = DataLoader(loaddb, batch_size=3 * config.batch_size, shuffle=True, num_workers=config.num_workers, collate_fn=region_collate_fn) net = RegionGroundingModel(config) if config.pretrained is not None: pretrained_path = osp.join(config.data_dir, 'caches/region_grounding_ckpts', config.pretrained + '.pkl') states = torch.load(pretrained_path, map_location=lambda storage, loc: storage) net.load_state_dict(states['state_dict'], strict=False) net.train() for name, param in net.named_parameters(): print(name, param.size()) for cnt, batched in enumerate(loader): scene_inds = batched['scene_inds'].long() sent_inds = batched['sent_inds'].long() sent_msks = batched['sent_msks'].long() region_feats = batched['region_feats'].float() region_clses = batched['region_clses'].long() region_masks = batched['region_masks'].float() img_feats, masked_feats, txt_feats, subspace_masks, sample_logits, sample_indices = \ net(scene_inds, sent_inds, sent_msks, None, None, None, region_feats, region_clses, region_masks, config.explore_mode) if config.instance_dim > 1: print(sample_indices[0]) # print('sample_logits', sample_logits.size()) # print('sample_indices', sample_indices.size()) txt_masks = txt_feats.new_ones(txt_feats.size(0), txt_feats.size(1)) losses = net.final_loss(img_feats, masked_feats, region_masks, txt_feats, txt_masks, sample_logits, sample_indices) print('losses', losses.size(), torch.mean(losses)) if config.subspace_alignment_mode > 0: metrics, cache_results = net.evaluate(masked_feats, region_masks, txt_feats) else: metrics, cache_results = net.evaluate(img_feats, region_masks, txt_feats) print('metrics', metrics) print('txt_feats', txt_feats.size()) print('img_feats', img_feats.size()) break
def test_region_model(config): db = vg(config, 'test') loaddb = region_loader(db) loader = DataLoader(loaddb, batch_size=3 * config.batch_size, shuffle=True, num_workers=config.num_workers, collate_fn=region_collate_fn) net = RegionModel(config) net.train() for name, param in net.named_parameters(): print(name, param.size()) for cnt, batched in enumerate(loader): start = time() scene_inds = batched['scene_inds'].long()[:config.batch_size] sent_inds = batched['sent_inds'].long()[:config.batch_size] sent_msks = batched['sent_msks'].long()[:config.batch_size] region_feats = batched['region_feats'].float()[:config.batch_size] region_clses = batched['region_clses'].long()[:config.batch_size] region_masks = batched['region_masks'].float()[:config.batch_size] src_region_feats = batched['region_feats'].float( )[config.batch_size:2 * config.batch_size] src_region_clses = batched['region_clses'].long()[config.batch_size:2 * config.batch_size] src_region_masks = batched['region_masks'].float( )[config.batch_size:2 * config.batch_size] img_feats, masked_feats, txt_feats, subspace_masks, sample_logits, sample_indices = \ net(scene_inds, sent_inds, sent_msks, src_region_feats, src_region_clses, src_region_masks, region_feats, region_clses, region_masks, config.explore_mode) print('img_feats', img_feats.size()) print('txt_feats', txt_feats.size()) if config.subspace_alignment_mode > 0: print('masked_feats', masked_feats.size()) print('subspace_masks', subspace_masks.size()) if config.instance_dim > 1: print('sample_logits', sample_logits.size()) print('sample_indices', sample_indices.size()) print('time:', time() - start) break
def check_region_clses(config): db = vg(config, 'train') loaddb = region_loader(db) loader = DataLoader(loaddb, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, collate_fn=region_collate_fn) min_index = 1000000 max_index = -1 for cnt, batched in enumerate(loader): region_clses = batched['region_clses'].long() min_index = min(min_index, torch.min(region_clses).item()) max_index = max(max_index, torch.max(region_clses).item()) if cnt % 1000: print('iter:', cnt) print('min_index', min_index) print('max_index', max_index)
def test_grounding_loss(config): db = vg(config, 'test') loaddb = region_loader(db) loader = DataLoader(loaddb, batch_size=3 * config.batch_size, shuffle=True, num_workers=config.num_workers, collate_fn=region_collate_fn) net = RegionModel(config) criterion = GroundingLoss(config) for cnt, batched in enumerate(loader): scene_inds = batched['scene_inds'].long()[:config.batch_size] sent_inds = batched['sent_inds'].long()[:config.batch_size] sent_msks = batched['sent_msks'].long()[:config.batch_size] region_feats = batched['region_feats'].float()[:config.batch_size] region_clses = batched['region_clses'].long()[:config.batch_size] region_masks = batched['region_masks'].float()[:config.batch_size] src_region_feats = batched['region_feats'].float( )[config.batch_size:2 * config.batch_size] src_region_clses = batched['region_clses'].long()[config.batch_size:2 * config.batch_size] src_region_masks = batched['region_masks'].float( )[config.batch_size:2 * config.batch_size] img_feats, masked_feats, txt_feats, subspace_masks, sample_logits, sample_indices = \ net(scene_inds, sent_inds, sent_msks, src_region_feats, src_region_clses, src_region_masks, region_feats, region_clses, region_masks, config.explore_mode) masked_feats = img_feats sim1 = criterion.compute_batch_mutual_similarity( masked_feats, region_masks, txt_feats) sim2 = criterion.debug_compute_batch_mutual_similarity( masked_feats, region_masks, txt_feats) print('sim1', sim1.size()) print('sim2', sim2.size()) print('diff', torch.sum(torch.abs(sim1 - sim2))) txt_masks = txt_feats.new_ones(txt_feats.size(0), txt_feats.size(1)) losses = criterion.forward_loss(masked_feats, region_masks, txt_feats, txt_masks, config.loss_reduction_mode) print('losses', losses.size()) break
def test_text_encoder(config): db = vg(config, 'test') loaddb = region_loader(db) loader = DataLoader(loaddb, batch_size=3 * config.batch_size, shuffle=True, num_workers=config.num_workers, collate_fn=region_collate_fn) net = TextEncoder(config) for cnt, batched in enumerate(loader): sent_inds = batched['sent_inds'].long() sent_msks = batched['sent_msks'].float() bsize, slen, fsize = sent_inds.size() print('sent_inds', sent_inds.size()) print('sent_msks', sent_msks.size()) f1, f2, h = net(sent_inds.view(bsize * slen, fsize), sent_msks.view(bsize * slen, fsize)) print(f1.size(), f2.size(), h.size()) break
def test_region_encoder(config): db = vg(config, 'test') loaddb = region_loader(db) loader = DataLoader(loaddb, batch_size=3 * config.batch_size, shuffle=True, num_workers=config.num_workers, collate_fn=region_collate_fn) net = RegionEncoder(config) for cnt, batched in enumerate(loader): region_feats = batched['region_feats'].float() region_clses = batched['region_clses'].long() print('region_feats', region_feats.size()) print('region_clses', region_clses.size()) img_feats, masked_feats, mm = net(region_feats, region_clses) print('img_feats', img_feats.size()) if config.subspace_alignment_mode > 0: print('masked_feats', masked_feats.size()) print('mm', mm.size()) break
def test_region_loader(config): db = vg(config, 'train') # db = coco(config, 'train') loaddb = region_loader(db) loader = DataLoader(loaddb, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, collate_fn=region_collate_fn) output_dir = osp.join(config.model_dir, 'test_region_loader') maybe_create(output_dir) start = time() plt.switch_backend('agg') for cnt, batched in enumerate(loader): print('scene_inds', batched['scene_inds']) sent_inds = batched['sent_inds'].long() sent_msks = batched['sent_msks'].long() widths = batched['widths'] heights = batched['heights'] captions = batched['captions'] region_boxes = batched['region_boxes'].float() region_feats = batched['region_feats'].float() region_clses = batched['region_clses'].long() region_masks = batched['region_masks'].long() print('sent_inds', sent_inds.size()) print('sent_msks', sent_msks.size()) print('region_boxes', region_boxes.size()) print('region_feats', region_feats.size()) print('region_clses', region_clses.size()) print('region_masks', region_masks.size()) print('clses', torch.min(region_clses), torch.max(region_clses)) print('widths', widths) print('heights', heights) for i in range(len(sent_inds)): # print('####') # print(len(captions), len(captions[0])) entry = {} image_index = batched['image_inds'][i] entry['width'] = widths[i] entry['height'] = heights[i] nr = torch.sum(region_masks[i]) entry['region_boxes'] = xyxys_to_xywhs( region_boxes[i, :nr].cpu().data.numpy()) color = cv2.imread(db.color_path_from_index(image_index), cv2.IMREAD_COLOR) color, _, _ = create_squared_image(color) out_path = osp.join(output_dir, '%d.png' % image_index) layouts = db.render_regions_as_output( entry, bg=cv2.resize( color, (config.visu_size[0], config.visu_size[0]))[:, :, ::-1]) fig = plt.figure(figsize=(32, 16)) for j in range(min(14, len(layouts))): plt.subplot(3, 5, j + 1) if j < config.max_turns: plt.title( captions[i][j] + '\n' + ' '.join([str(x.data.item()) for x in sent_inds[i, j]]) + '\n' + ' '.join([str(x.data.item()) for x in sent_msks[i, j]])) plt.imshow(layouts[j].astype(np.uint8)) plt.axis('off') plt.subplot(3, 5, 15) plt.imshow(color[:, :, ::-1]) plt.axis('off') fig.savefig(out_path, bbox_inches='tight') plt.close(fig) print('------------------') if cnt == 2: break print("Time", time() - start)
def test(self, test_db): ################################################################## ## LOG ################################################################## logz.configure_output_dir(self.cfg.model_dir) logz.save_config(self.cfg) start = time() test_loaddb = region_loader(test_db) test_loader = DataLoader(test_loaddb, batch_size=self.cfg.batch_size, shuffle=False, num_workers=self.cfg.num_workers, collate_fn=region_collate_fn) sample_mode = 0 if self.cfg.rl_finetune > 0 else self.cfg.explore_mode all_txt_feats, all_img_feats, all_img_masks, losses = [], [], [], [] self.net.eval() for cnt, batched in enumerate(test_loader): ################################################################## ## Batched data ################################################################## scene_inds, sent_inds, sent_msks, region_feats, region_masks, region_clses = self.batch_data( batched) ################################################################## ## Inference one step ################################################################## with torch.no_grad(): img_feats, masked_feats, txt_feats, subspace_masks, sample_logits, sample_indices = \ self.net(scene_inds, sent_inds, sent_msks, None, None, None, region_feats, region_clses, region_masks, sample_mode=sample_mode) txt_masks = txt_feats.new_ones(txt_feats.size(0), txt_feats.size(1)) batch_losses = self.net.final_loss(img_feats, masked_feats, region_masks, txt_feats, txt_masks, sample_logits, sample_indices) loss = torch.sum(torch.mean(batch_losses, -1)) losses.append(loss.cpu().data.item()) all_txt_feats.append(txt_feats) all_img_masks.append(region_masks) if self.cfg.subspace_alignment_mode > 0: all_img_feats.append(masked_feats) else: all_img_feats.append(img_feats) ################################################################## ## Print info ################################################################## if cnt % self.cfg.log_per_steps == 0: print('Iter %07d:' % (cnt)) tmp_losses = np.stack(losses, 0) print('mean loss: ', np.mean(tmp_losses)) print('-------------------------') torch.cuda.empty_cache() losses = np.array(losses) all_img_feats = torch.cat(all_img_feats, 0) all_img_masks = torch.cat(all_img_masks, 0) all_txt_feats = torch.cat(all_txt_feats, 0) all_txt_masks = all_txt_feats.new_ones(all_txt_feats.size(0), all_txt_feats.size(1)) # print('all_img_feats', all_img_feats.size()) all_img_feats_np = all_img_feats.cpu().data.numpy() all_img_masks_np = all_img_masks.cpu().data.numpy() with open( osp.join(self.cfg.model_dir, 'img_features_%d.pkl' % self.cfg.n_feature_dim), 'wb') as fid: pickle.dump({ 'feats': all_img_feats_np, 'masks': all_img_masks_np }, fid, pickle.HIGHEST_PROTOCOL) ################################################################## ## Evaluation ################################################################## print('Evaluating the per-turn performance, may take a while.') metrics, caches_results = self.net.evaluate(all_img_feats, all_img_masks, all_txt_feats) with open(osp.join(self.cfg.model_dir, 'test_metrics.json'), 'w') as fp: json.dump(metrics, fp, indent=4, sort_keys=True) with open(osp.join(self.cfg.model_dir, 'test_caches.pkl'), 'wb') as fid: pickle.dump(caches_results, fid, pickle.HIGHEST_PROTOCOL) visualize(self.cfg.exp_name, metrics, osp.join(self.cfg.model_dir, 'evaluation.png')) return losses, metrics, caches_results
def train(self, train_db, val_db, test_db): ################################################################## ## LOG ################################################################## logz.configure_output_dir(self.cfg.model_dir) logz.save_config(self.cfg) ################################################################## ## Main loop ################################################################## start = time() min_val_loss = 1000.0 max_val_recall = -1.0 train_loaddb = region_loader(train_db) val_loaddb = region_loader(val_db) #TODO train_loader = DataLoader(train_loaddb, batch_size=self.cfg.batch_size, shuffle=True, num_workers=self.cfg.num_workers, collate_fn=region_collate_fn) val_loader = DataLoader(val_loaddb, batch_size=self.cfg.batch_size, shuffle=False, num_workers=self.cfg.num_workers, collate_fn=region_collate_fn) for epoch in range(self.epoch, self.cfg.n_epochs): ################################################################## ## Training ################################################################## if self.cfg.coco_mode >= 0: self.cfg.coco_mode = np.random.randint(0, self.cfg.max_turns) torch.cuda.empty_cache() train_losses = self.train_epoch(train_loaddb, train_loader, epoch) ################################################################## ## Validation ################################################################## if self.cfg.coco_mode >= 0: self.cfg.coco_mode = 0 torch.cuda.empty_cache() val_losses, val_metrics, caches_results = self.validate_epoch( val_loaddb, val_loader, epoch) ################################################################# # Logging ################################################################# # update optim scheduler current_val_loss = np.mean(val_losses) self.optimizer.update(current_val_loss, epoch) logz.log_tabular("Time", time() - start) logz.log_tabular("Iteration", epoch) logz.log_tabular("TrainAverageLoss", np.mean(train_losses)) logz.log_tabular("ValAverageLoss", current_val_loss) mmm = np.zeros((5, ), dtype=np.float64) for k, v in val_metrics.items(): mmm = mmm + np.array(v) mmm /= len(val_metrics) logz.log_tabular("t2i_R1", mmm[0]) logz.log_tabular("t2i_R5", mmm[1]) logz.log_tabular("t2i_R10", mmm[2]) logz.log_tabular("t2i_medr", mmm[3]) logz.log_tabular("t2i_meanr", mmm[4]) logz.dump_tabular() current_val_recall = np.mean(mmm[:3]) ################################################################## ## Checkpoint ################################################################## if self.cfg.rl_finetune == 0 and self.cfg.coco_mode < 0: if min_val_loss > current_val_loss: min_val_loss = current_val_loss self.save_checkpoint(epoch) with open( osp.join(self.cfg.model_dir, 'val_metrics_%d.json' % epoch), 'w') as fp: json.dump(val_metrics, fp, indent=4, sort_keys=True) with open( osp.join(self.cfg.model_dir, 'val_top5_inds_%d.pkl' % epoch), 'wb') as fid: pickle.dump(caches_results, fid, pickle.HIGHEST_PROTOCOL) else: if max_val_recall < current_val_recall: max_val_recall = current_val_recall self.save_checkpoint(epoch) with open( osp.join(self.cfg.model_dir, 'val_metrics_%d.json' % epoch), 'w') as fp: json.dump(val_metrics, fp, indent=4, sort_keys=True) with open( osp.join(self.cfg.model_dir, 'val_top5_inds_%d.pkl' % epoch), 'wb') as fid: pickle.dump(caches_results, fid, pickle.HIGHEST_PROTOCOL)