예제 #1
0
    def evaluate(self, loader, use_tqdm=False, single_batch=False):
        meter = Meter()

        for batch in tqdm(loader, disable=not use_tqdm):
            report = self._forward_pass(batch)
            self._update_meter(report, meter, eval_mode=True)

            if single_batch is True:
                break
        self.model.train()

        return report, meter
예제 #2
0
    def evaluate(self, loader, use_tqdm=False, single_batch=False):
        meter = Meter()

        with torch.no_grad():
            self.model.eval()
            for batch in tqdm(loader, disable=not use_tqdm):
                # print(f"batch shape: {batch.shape!r}, batch type: {type(batch)!r}")
                report = self._forward_pass(batch)
                self._update_meter(report, meter, eval_mode=True)

                if single_batch is True:
                    break
            self.model.train()

        return report, meter
예제 #3
0
    def evaluate(self, loader, use_tqdm=False, single_batch=False):
        meter = Meter()

        with torch.no_grad():
            self.model.eval()
            # disable_tqdm = not use_tqdm or not is_main_process()
            disable_tqdm = False
            for batch in tqdm(loader, disable=disable_tqdm):
                report = self._forward_pass(batch)
                self._update_meter(report, meter, eval_mode=True)

                if single_batch is True:
                    break
            self.model.train()

        return report, meter
예제 #4
0
    def load_extras(self):
        self.checkpoint = Checkpoint(self)
        self.meter = Meter()

        self.training_parameters = self.config.training_parameters

        monitored_metric = self.training_parameters.monitored_metric
        metric_minimize = self.training_parameters.metric_minimize
        should_early_stop = self.training_parameters.should_early_stop
        patience = self.training_parameters.patience

        self.log_interval = self.training_parameters.log_interval
        self.snapshot_interval = self.training_parameters.snapshot_interval
        self.test_interval = self.training_parameters.test_interval
        self.max_iterations = self.training_parameters.max_iterations
        self.should_clip_gradients = self.training_parameters.clip_gradients
        self.max_epochs = self.training_parameters.max_epochs

        self.early_stopping = EarlyStopping(
            self.model,
            self.checkpoint,
            monitored_metric,
            patience=patience,
            minimize=metric_minimize,
            should_stop=should_early_stop,
        )
        self.current_epoch = 0
        self.current_iteration = 0

        self.checkpoint.load_state_dict()

        self.not_debug = self.training_parameters.logger_level != "debug"

        self.lr_scheduler = None

        # TODO: Allow custom scheduler
        if self.training_parameters.lr_scheduler is True:
            scheduler_class = optim.lr_scheduler.LambdaLR
            scheduler_func = lambda x: lr_lambda_update(x, self.config)
            self.lr_scheduler = scheduler_class(
                self.optimizer, lr_lambda=scheduler_func
            )
예제 #5
0
    def evaluate_full(self, loader, use_tqdm=False):
        
        meter = Meter()

        # metrics = ['vqamb_map', 'vqamb_f1'] # hardcode metrics for now
        metrics = ['accuracy']
        # metrics = ['vqamb_f1pt']

        print(len(loader))
        
        with torch.no_grad():
            self.model.eval()
            tot_preds = []
            tot_targets = []
            tot_ids = []
            tot_att_pt = []
            tot_att_img = []
            tot_bbox_gt = []
            tot_bbox_pt = []
            tot_bbox_img = []
            tot_part = []
            # tot_qa_ids = []
            for batch in tqdm(loader, disable=not use_tqdm):
                report = self._forward_pass(batch)
                tot_preds.append(report.scores)
                tot_targets.append(report.targets)
                # tot_ids.extend(report.qa_id)
                # tot_att_pt.append(report.att)
                # tot_att_img.append(report.att_img)
                # tot_bbox_gt.append(report.gt_bbox)
                # tot_bbox_img.append(report.img_bbox)
                # tot_bbox_pt.append(report.pt_bbox)
                # tot_part.append(report.part)
                # tot_bbox_gt.append(report.gt_bbox)
                # tot_ptpath.append(report.ptpath)
                # tot_bbox_pt.append(report.bboxes)
                # tot_bbox_gt.append(report.gt_bbox)
                # tot_qa_ids.extend(report.qa_id)
                
            tot_preds = torch.cat(tot_preds, dim=0)
            tot_targets = torch.cat(tot_targets, dim=0)
            # tot_att_pt = torch.cat(tot_att_pt, dim=0)
            # tot_att_img = torch.cat(tot_att_img, dim=0)
            # tot_att_pt = torch.cat(tot_att_pt, dim=0)
            # tot_bbox_pt = torch.cat(tot_bbox_pt, dim=0)
            # tot_bbox_gt = torch.cat(tot_bbox_gt, dim=0)
            # tot_bbox_img = torch.cat(tot_bbox_img, dim=0)
            # Find bounding box with max attention
            
            # max_att_pt = tot_att_pt.argmax(dim=1)
            # max_bbox_pt = tot_bbox_pt[torch.arange(tot_bbox_pt.size(0)), max_att_pt]
            '''
            torch.save(tot_att_pt, 'tot_pt_att_objpartdev.pt')
            torch.save(tot_bbox_pt, 'tot_ptbboxes_objpartdev.pt')
            tot_part = sum(tot_part, [])
            torch.save(torch.Tensor(tot_part), 'tot_part_objpartdev.pt')
            '''
            # torch.save(tot_att_pt, 'tot_att_pt_localqafinal.pt')
            # torch.save(tot_att_img, 'tot_att_img_pythiaptfinal.pt')
            # torch.save(tot_bbox_pt, 'tot_bbox_pt_localqafinal.pt')
            # torch.save(tot_bbox_img, 'tot_bbox_img_pythia_ptfinal.pt')
            # torch.save(tot_bbox_gt, 'tot_bboxgt_localqafinal.pt')
            # torch.save(tot_preds, 'tot_preds_localqafinal.pt')
            # torch.save(tot_targets, 'tot_targets_localqafinal.pt')
            
            # torch.save(max_bbox_pt, 'max_pt_bbox_pythiaptfinal.pt')
            # torch.save(tot_bbox_gt, 'gt_bbox_pythiaptfinal.pt')
            
            # torch.save(tot_preds, 'tot_preds_localqa.pt')
            # torch.save(tot_targets, 'tot_targets_localqa.pt')
            # torch.save(tot_ptpath, 'tot_ptpath_vqambnew.pt')
            # torch.save(tot_att, 'tot_att_vqambnew.pt')
            # tot_qa_ids = torch.Tensor(tot_qa_ids)
            # torch.save(tot_qa_ids, 'tot_qa_ids.pt')

            model_output = {"scores": tot_preds}
            sample = Sample({"targets": tot_targets}) # "qa_index": tot_qa_index}) # "dataset_type": report.dataset_type, "dataset_name": report.dataset_name})
            sample_list = SampleList([sample])
            sample_list.add_field('dataset_type', report.dataset_type)
            sample_list.add_field('dataset_name', report.dataset_name)

            metric_fn = Metrics(metrics)
            full_met = metric_fn(sample_list, model_output)
            self.writer.write(full_met)

            if report.dataset_type == 'test':
                return
            
            meter.update(full_met)
            stop = self.early_stopping(self.current_iteration, meter)

            should_break = False
            if stop is True:
                self.writer.write("Early stopping activated")
                should_break = True
            
            self.model.train()

        return should_break