def evaluate(self, loader, use_tqdm=False, single_batch=False): meter = Meter() for batch in tqdm(loader, disable=not use_tqdm): report = self._forward_pass(batch) self._update_meter(report, meter, eval_mode=True) if single_batch is True: break self.model.train() return report, meter
def evaluate(self, loader, use_tqdm=False, single_batch=False): meter = Meter() with torch.no_grad(): self.model.eval() for batch in tqdm(loader, disable=not use_tqdm): # print(f"batch shape: {batch.shape!r}, batch type: {type(batch)!r}") report = self._forward_pass(batch) self._update_meter(report, meter, eval_mode=True) if single_batch is True: break self.model.train() return report, meter
def evaluate(self, loader, use_tqdm=False, single_batch=False): meter = Meter() with torch.no_grad(): self.model.eval() # disable_tqdm = not use_tqdm or not is_main_process() disable_tqdm = False for batch in tqdm(loader, disable=disable_tqdm): report = self._forward_pass(batch) self._update_meter(report, meter, eval_mode=True) if single_batch is True: break self.model.train() return report, meter
def load_extras(self): self.checkpoint = Checkpoint(self) self.meter = Meter() self.training_parameters = self.config.training_parameters monitored_metric = self.training_parameters.monitored_metric metric_minimize = self.training_parameters.metric_minimize should_early_stop = self.training_parameters.should_early_stop patience = self.training_parameters.patience self.log_interval = self.training_parameters.log_interval self.snapshot_interval = self.training_parameters.snapshot_interval self.test_interval = self.training_parameters.test_interval self.max_iterations = self.training_parameters.max_iterations self.should_clip_gradients = self.training_parameters.clip_gradients self.max_epochs = self.training_parameters.max_epochs self.early_stopping = EarlyStopping( self.model, self.checkpoint, monitored_metric, patience=patience, minimize=metric_minimize, should_stop=should_early_stop, ) self.current_epoch = 0 self.current_iteration = 0 self.checkpoint.load_state_dict() self.not_debug = self.training_parameters.logger_level != "debug" self.lr_scheduler = None # TODO: Allow custom scheduler if self.training_parameters.lr_scheduler is True: scheduler_class = optim.lr_scheduler.LambdaLR scheduler_func = lambda x: lr_lambda_update(x, self.config) self.lr_scheduler = scheduler_class( self.optimizer, lr_lambda=scheduler_func )
def evaluate_full(self, loader, use_tqdm=False): meter = Meter() # metrics = ['vqamb_map', 'vqamb_f1'] # hardcode metrics for now metrics = ['accuracy'] # metrics = ['vqamb_f1pt'] print(len(loader)) with torch.no_grad(): self.model.eval() tot_preds = [] tot_targets = [] tot_ids = [] tot_att_pt = [] tot_att_img = [] tot_bbox_gt = [] tot_bbox_pt = [] tot_bbox_img = [] tot_part = [] # tot_qa_ids = [] for batch in tqdm(loader, disable=not use_tqdm): report = self._forward_pass(batch) tot_preds.append(report.scores) tot_targets.append(report.targets) # tot_ids.extend(report.qa_id) # tot_att_pt.append(report.att) # tot_att_img.append(report.att_img) # tot_bbox_gt.append(report.gt_bbox) # tot_bbox_img.append(report.img_bbox) # tot_bbox_pt.append(report.pt_bbox) # tot_part.append(report.part) # tot_bbox_gt.append(report.gt_bbox) # tot_ptpath.append(report.ptpath) # tot_bbox_pt.append(report.bboxes) # tot_bbox_gt.append(report.gt_bbox) # tot_qa_ids.extend(report.qa_id) tot_preds = torch.cat(tot_preds, dim=0) tot_targets = torch.cat(tot_targets, dim=0) # tot_att_pt = torch.cat(tot_att_pt, dim=0) # tot_att_img = torch.cat(tot_att_img, dim=0) # tot_att_pt = torch.cat(tot_att_pt, dim=0) # tot_bbox_pt = torch.cat(tot_bbox_pt, dim=0) # tot_bbox_gt = torch.cat(tot_bbox_gt, dim=0) # tot_bbox_img = torch.cat(tot_bbox_img, dim=0) # Find bounding box with max attention # max_att_pt = tot_att_pt.argmax(dim=1) # max_bbox_pt = tot_bbox_pt[torch.arange(tot_bbox_pt.size(0)), max_att_pt] ''' torch.save(tot_att_pt, 'tot_pt_att_objpartdev.pt') torch.save(tot_bbox_pt, 'tot_ptbboxes_objpartdev.pt') tot_part = sum(tot_part, []) torch.save(torch.Tensor(tot_part), 'tot_part_objpartdev.pt') ''' # torch.save(tot_att_pt, 'tot_att_pt_localqafinal.pt') # torch.save(tot_att_img, 'tot_att_img_pythiaptfinal.pt') # torch.save(tot_bbox_pt, 'tot_bbox_pt_localqafinal.pt') # torch.save(tot_bbox_img, 'tot_bbox_img_pythia_ptfinal.pt') # torch.save(tot_bbox_gt, 'tot_bboxgt_localqafinal.pt') # torch.save(tot_preds, 'tot_preds_localqafinal.pt') # torch.save(tot_targets, 'tot_targets_localqafinal.pt') # torch.save(max_bbox_pt, 'max_pt_bbox_pythiaptfinal.pt') # torch.save(tot_bbox_gt, 'gt_bbox_pythiaptfinal.pt') # torch.save(tot_preds, 'tot_preds_localqa.pt') # torch.save(tot_targets, 'tot_targets_localqa.pt') # torch.save(tot_ptpath, 'tot_ptpath_vqambnew.pt') # torch.save(tot_att, 'tot_att_vqambnew.pt') # tot_qa_ids = torch.Tensor(tot_qa_ids) # torch.save(tot_qa_ids, 'tot_qa_ids.pt') model_output = {"scores": tot_preds} sample = Sample({"targets": tot_targets}) # "qa_index": tot_qa_index}) # "dataset_type": report.dataset_type, "dataset_name": report.dataset_name}) sample_list = SampleList([sample]) sample_list.add_field('dataset_type', report.dataset_type) sample_list.add_field('dataset_name', report.dataset_name) metric_fn = Metrics(metrics) full_met = metric_fn(sample_list, model_output) self.writer.write(full_met) if report.dataset_type == 'test': return meter.update(full_met) stop = self.early_stopping(self.current_iteration, meter) should_break = False if stop is True: self.writer.write("Early stopping activated") should_break = True self.model.train() return should_break